From c6b47d18f6d0b00bd061893b180e675d7b0f7824 Mon Sep 17 00:00:00 2001 From: ScuroNeko Date: Mon, 29 Sep 2025 09:18:48 +0300 Subject: [PATCH] middlewares --- bot.go | 65 ++++++++++++++++++++++++++++++++++++++++++++++++------ plugins.go | 59 ++++++++++++++++++++++++++++++++++++++++++++----- queue.go | 22 +++++++++++++++--- version.go | 12 ++++++---- 4 files changed, 139 insertions(+), 19 deletions(-) diff --git a/bot.go b/bot.go index 3b9c700..79087da 100644 --- a/bot.go +++ b/bot.go @@ -7,7 +7,12 @@ import ( "io" "net/http" "os" + "sort" "strings" + + "github.com/redis/go-redis/v9" + "go.mongodb.org/mongo-driver/v2/mongo" + "gorm.io/gorm" ) type ParseMode string @@ -26,8 +31,11 @@ type Bot struct { logger *Logger requestLogger *Logger - plugins []*Plugin - prefixes []string + plugins []*Plugin + middlewares []*Middleware + prefixes []string + + dbContext *DatabaseContext updateOffset int updateTypes []string @@ -45,7 +53,14 @@ type BotSettings struct { } func LoadSettingsFromEnv() *BotSettings { - + return &BotSettings{ + Token: os.Getenv("TG_TOKEN"), + Debug: os.Getenv("DEBUG") == "true", + ErrorTemplate: os.Getenv("ERROR_TEMPLATE"), + Prefixes: LoadPrefixesFromEnv(), + UpdateTypes: strings.Split(os.Getenv("UPDATE_TYPES"), ";"), + UseRequestLogger: os.Getenv("USE_REQ_LOG") == "true", + } } type MsgContext struct { @@ -58,6 +73,12 @@ type MsgContext struct { Args []string } +type DatabaseContext struct { + PostgresSQL *gorm.DB + MongoDB *mongo.Client + Redis *redis.Client +} + func NewBot(settings *BotSettings) *Bot { updateQueue := CreateQueue[*Update](256) bot := &Bot{ @@ -95,6 +116,11 @@ func (b *Bot) Close() { } } +func (b *Bot) InitDatabaseContext(ctx *DatabaseContext) *Bot { + b.dbContext = ctx + return b +} + func (b *Bot) UpdateTypes(t ...string) *Bot { b.updateTypes = make([]string, 0) b.updateTypes = append(b.updateTypes, t...) @@ -136,6 +162,23 @@ func (b *Bot) AddPlugins(plugin ...*Plugin) *Bot { return b } +func (b *Bot) AddMiddleware(middleware ...*Middleware) *Bot { + sort.Slice(middleware, func(a, b int) bool { + first := middleware[a] + second := middleware[b] + if first.Order == second.Order { + return first.Name < second.Name + } + return middleware[a].Order < middleware[b].Order + }) + + b.middlewares = append(b.middlewares, middleware...) + for _, m := range middleware { + b.logger.Debug(fmt.Sprintf("middleware with name \"%s\" was registered", m.Name)) + } + return b +} + func (b *Bot) Run() { if len(b.prefixes) == 0 { b.logger.Fatal("no prefixes defined") @@ -181,9 +224,13 @@ func (b *Bot) handleMessage(update *Update) { Update: update, } + for _, middleware := range b.middlewares { + middleware.Execute(ctx, b.dbContext) + } + for _, plugin := range b.plugins { if plugin.UpdateListener != nil { - (*plugin.UpdateListener)(ctx) + (*plugin.UpdateListener)(ctx, b.dbContext) } } @@ -219,7 +266,7 @@ func (b *Bot) handleMessage(update *Update) { ctx.Text = strings.TrimSpace(text[len(cmd):]) ctx.Args = strings.Split(ctx.Text, " ") - go plugin.Execute(cmd, ctx) + go plugin.Execute(cmd, ctx, b.dbContext) } } } @@ -230,9 +277,13 @@ func (b *Bot) handleCallback(update *Update) { Update: update, } + for _, m := range b.middlewares { + m.Execute(ctx, b.dbContext) + } + for _, plugin := range b.plugins { if plugin.UpdateListener != nil { - (*plugin.UpdateListener)(ctx) + (*plugin.UpdateListener)(ctx, b.dbContext) } } @@ -241,7 +292,7 @@ func (b *Bot) handleCallback(update *Update) { if !strings.HasPrefix(update.CallbackQuery.Data, payload) { continue } - go plugin.ExecutePayload(payload, ctx) + go plugin.ExecutePayload(payload, ctx, b.dbContext) } } } diff --git a/plugins.go b/plugins.go index 0d3dc00..29f5c27 100644 --- a/plugins.go +++ b/plugins.go @@ -1,6 +1,6 @@ package laniakea -type CommandExecutor func(ctx *MsgContext) +type CommandExecutor func(ctx *MsgContext, dbContext *DatabaseContext) type PluginBuilder struct { name string @@ -56,10 +56,59 @@ func (p *PluginBuilder) Build() *Plugin { return plugin } -func (p *Plugin) Execute(cmd string, ctx *MsgContext) { - (*p.Commands[cmd])(ctx) +func (p *Plugin) Execute(cmd string, ctx *MsgContext, dbContext *DatabaseContext) { + (*p.Commands[cmd])(ctx, dbContext) } -func (p *Plugin) ExecutePayload(payload string, ctx *MsgContext) { - (*p.Payloads[payload])(ctx) +func (p *Plugin) ExecutePayload(payload string, ctx *MsgContext, dbContext *DatabaseContext) { + (*p.Payloads[payload])(ctx, dbContext) +} + +type Middleware struct { + Name string + Executor *CommandExecutor + Order int + Async bool +} +type MiddlewareBuilder struct { + name string + executor *CommandExecutor + order int + async bool +} + +func NewMiddleware(name string) *MiddlewareBuilder { + return &MiddlewareBuilder{name: name, async: false} +} +func (m *MiddlewareBuilder) SetName(name string) *MiddlewareBuilder { + m.name = name + return m +} +func (m *MiddlewareBuilder) SetExecutor(executor CommandExecutor) *MiddlewareBuilder { + m.executor = &executor + return m +} +func (m *MiddlewareBuilder) SetOrder(order int) *MiddlewareBuilder { + m.order = order + return m +} +func (m *MiddlewareBuilder) SetAsync(async bool) *MiddlewareBuilder { + m.async = async + return m +} +func (m *MiddlewareBuilder) Build() *Middleware { + return &Middleware{ + Name: m.name, + Executor: m.executor, + Order: m.order, + Async: m.async, + } +} +func (m *Middleware) Execute(ctx *MsgContext, db *DatabaseContext) { + exec := *m.Executor + if m.Async { + go exec(ctx, db) + } else { + exec(ctx, db) + } } diff --git a/queue.go b/queue.go index 79ebeca..e35028d 100644 --- a/queue.go +++ b/queue.go @@ -1,10 +1,14 @@ package laniakea -import "fmt" +import ( + "fmt" + "sync" +) type Queue[T any] struct { - queue []T size uint64 + mu sync.RWMutex + queue []T } func CreateQueue[T any](size uint64) *Queue[T] { @@ -23,11 +27,13 @@ func (q *Queue[T]) Enqueue(el T) error { } func (q *Queue[T]) Peak() T { + q.mu.RLock() + defer q.mu.RUnlock() return q.queue[0] } func (q *Queue[T]) IsEmpty() bool { - return len(q.queue) == 0 + return q.Length() == 0 } func (q *Queue[T]) IsFull() bool { @@ -35,16 +41,26 @@ func (q *Queue[T]) IsFull() bool { } func (q *Queue[T]) Length() uint64 { + q.mu.RLock() + defer q.mu.RUnlock() return uint64(len(q.queue)) } func (q *Queue[T]) Dequeue() T { + q.mu.RLock() el := q.queue[0] + q.mu.RUnlock() + if q.Length() == 1 { + q.mu.Lock() q.queue = make([]T, 0) + q.mu.Unlock() return el } + + q.mu.Lock() q.queue = q.queue[1:] + q.mu.Unlock() return el } diff --git a/version.go b/version.go index fe6e598..d47af22 100644 --- a/version.go +++ b/version.go @@ -1,8 +1,12 @@ package laniakea +import "os" + const ( - VERSION_STRING = "0.1.4" - VERSION_MAJOR = 0 - VERSION_MINOR = 1 - VERSION_PATCH = 4 + VersionString = "0.1.4" + VersionMajor = 0 + VersionMinor = 1 + VersionPatch = 4 ) + +var GoVersion = os.Getenv("GoV")