diff --git a/bot.go b/bot.go index 8a1a8e6..8864452 100644 --- a/bot.go +++ b/bot.go @@ -1,16 +1,19 @@ package laniakea import ( + "context" "fmt" + "log" "os" "sort" "strconv" "strings" - "time" + "sync" "git.nix13.pw/scuroneko/extypes" "git.nix13.pw/scuroneko/laniakea/tgapi" "git.nix13.pw/scuroneko/slog" + "github.com/alitto/pond/v2" "golang.org/x/time/rate" ) @@ -87,13 +90,14 @@ type Bot[T DbContext] struct { dbContext *T l10n *L10n - updateOffset int - updateTypes []tgapi.UpdateType - updateQueue *extypes.Queue[*tgapi.Update] + updateOffsetMu sync.Mutex + updateOffset int + updateTypes []tgapi.UpdateType + updateQueue chan *tgapi.Update } func NewBot[T any](opts *BotOpts) *Bot[T] { - updateQueue := extypes.CreateQueue[*tgapi.Update](512) + updateQueue := make(chan *tgapi.Update, 512) var limiter *rate.Limiter if opts.RateLimit > 0 { @@ -185,13 +189,20 @@ func (bot *Bot[T]) initLoggers(opts *BotOpts) { } } -func (bot *Bot[T]) GetUpdateOffset() int { return bot.updateOffset } -func (bot *Bot[T]) SetUpdateOffset(offset int) { bot.updateOffset = offset } -func (bot *Bot[T]) GetUpdateTypes() []tgapi.UpdateType { return bot.updateTypes } -func (bot *Bot[T]) GetQueue() *extypes.Queue[*tgapi.Update] { return bot.updateQueue } -func (bot *Bot[T]) GetLogger() *slog.Logger { return bot.logger } -func (bot *Bot[T]) GetDBContext() *T { return bot.dbContext } -func (bot *Bot[T]) L10n(lang, key string) string { return bot.l10n.Translate(lang, key) } +func (bot *Bot[T]) GetUpdateOffset() int { + bot.updateOffsetMu.Lock() + defer bot.updateOffsetMu.Unlock() + return bot.updateOffset +} +func (bot *Bot[T]) SetUpdateOffset(offset int) { + bot.updateOffsetMu.Lock() + defer bot.updateOffsetMu.Unlock() + bot.updateOffset = offset +} +func (bot *Bot[T]) GetUpdateTypes() []tgapi.UpdateType { return bot.updateTypes } +func (bot *Bot[T]) GetLogger() *slog.Logger { return bot.logger } +func (bot *Bot[T]) GetDBContext() *T { return bot.dbContext } +func (bot *Bot[T]) L10n(lang, key string) string { return bot.l10n.Translate(lang, key) } type DbLogger[T DbContext] func(db *T) slog.LoggerWriter @@ -235,7 +246,7 @@ func (bot *Bot[T]) Debug(debug bool) *Bot[T] { func (bot *Bot[T]) AddPlugins(plugin ...*Plugin[T]) *Bot[T] { for _, p := range plugin { bot.plugins = append(bot.plugins, *p) - bot.logger.Debugln(fmt.Sprintf("plugins with name \"%s\" registered", p.Name)) + bot.logger.Debugln(fmt.Sprintf("plugins with name \"%s\" registered", p.name)) } return bot } @@ -266,7 +277,15 @@ func (bot *Bot[T]) AddL10n(l *L10n) *Bot[T] { return bot } -func (bot *Bot[T]) Run() { +func (bot *Bot[T]) enqueueUpdate(u *tgapi.Update) error { + select { + case bot.updateQueue <- u: + return nil + default: + return extypes.QueueFullErr + } +} +func (bot *Bot[T]) RunWithContext(ctx context.Context) { if len(bot.prefixes) == 0 { bot.logger.Fatalln("no prefixes defined") return @@ -282,26 +301,36 @@ func (bot *Bot[T]) Run() { bot.logger.Infoln("Bot running. Press CTRL+C to exit.") go func() { for { - _, err := bot.Updates() - if err != nil { - bot.logger.Errorln(err) + select { + case <-ctx.Done(): + return + default: + updates, err := bot.Updates() + if err != nil { + bot.logger.Errorln(err) + continue + } + + for _, u := range updates { + select { + case bot.updateQueue <- new(u): + case <-ctx.Done(): + return + } + } } } }() - for { - queue := bot.updateQueue - if queue.IsEmpty() { - time.Sleep(time.Millisecond * 25) - continue - } - - u := queue.Dequeue() - if u == nil { - bot.logger.Errorln("update is nil") - continue - } - - bot.handle(u) + pool := pond.NewPool(16) + for update := range bot.updateQueue { + update := update + log.Println(update) + pool.Submit(func() { + bot.handle(update) + }) } } +func (bot *Bot[T]) Run() { + bot.RunWithContext(context.Background()) +} diff --git a/cmd_generator.go b/cmd_generator.go index 4e6153b..45da59d 100644 --- a/cmd_generator.go +++ b/cmd_generator.go @@ -27,7 +27,7 @@ func generateBotCommand[T any](cmd Command[T]) tgapi.BotCommand { func generateBotCommandForPlugin[T any](pl Plugin[T]) []tgapi.BotCommand { commands := make([]tgapi.BotCommand, 0) - for _, cmd := range pl.Commands { + for _, cmd := range pl.commands { if cmd.skipAutoCmd { continue } @@ -46,6 +46,10 @@ func (bot *Bot[T]) AutoGenerateCommands() error { commands := make([]tgapi.BotCommand, 0) for _, pl := range bot.plugins { + if pl.skipAutoCmd { + continue + } + commands = append(commands, generateBotCommandForPlugin(pl)...) } if len(commands) > 100 { diff --git a/go.mod b/go.mod index f6a0218..72a0376 100644 --- a/go.mod +++ b/go.mod @@ -3,13 +3,13 @@ module git.nix13.pw/scuroneko/laniakea go 1.26 require ( - git.nix13.pw/scuroneko/extypes v1.2.0 + git.nix13.pw/scuroneko/extypes v1.2.1 git.nix13.pw/scuroneko/slog v1.0.2 golang.org/x/time v0.14.0 + github.com/alitto/pond/v2 v2.6.2 ) require ( - github.com/alitto/pond/v2 v2.6.2 // indirect github.com/fatih/color v1.18.0 // indirect github.com/mattn/go-colorable v0.1.14 // indirect github.com/mattn/go-isatty v0.0.20 // indirect diff --git a/go.sum b/go.sum index ac69590..66fa4ef 100644 --- a/go.sum +++ b/go.sum @@ -1,5 +1,6 @@ git.nix13.pw/scuroneko/extypes v1.2.0 h1:2n2hD6KsMAted+6MGhAyeWyli2Qzc9G2y+pQNB7C1dM= git.nix13.pw/scuroneko/extypes v1.2.0/go.mod h1:uZVs8Yo3RrYAG9dMad6qR6lsYY67t+459D9c65QAYAw= +git.nix13.pw/scuroneko/extypes v1.2.1/go.mod h1:uZVs8Yo3RrYAG9dMad6qR6lsYY67t+459D9c65QAYAw= git.nix13.pw/scuroneko/slog v1.0.2 h1:vZyUROygxC2d5FJHUQM/30xFEHY1JT/aweDZXA4rm2g= git.nix13.pw/scuroneko/slog v1.0.2/go.mod h1:3Qm2wzkR5KjwOponMfG7TcGSDjmYaFqRAmLvSPTuWJI= github.com/alitto/pond/v2 v2.6.2 h1:Sphe40g0ILeM1pA2c2K+Th0DGU+pt0A/Kprr+WB24Pw= diff --git a/handler.go b/handler.go index 7fadb3e..2d975c4 100644 --- a/handler.go +++ b/handler.go @@ -46,7 +46,7 @@ func (bot *Bot[T]) handleMessage(update *tgapi.Update, ctx *MsgContext) { text = strings.TrimSpace(text[len(prefix):]) for _, plugin := range bot.plugins { - for cmd := range plugin.Commands { + for cmd := range plugin.commands { if !strings.HasPrefix(text, cmd) { continue } @@ -96,7 +96,7 @@ func (bot *Bot[T]) handleCallback(update *tgapi.Update, ctx *MsgContext) { ctx.Args = data.Args for _, plugin := range bot.plugins { - _, ok := plugin.Payloads[data.Command] + _, ok := plugin.payloads[data.Command] if !ok { continue } diff --git a/methods.go b/methods.go index 80280b9..21f9e46 100644 --- a/methods.go +++ b/methods.go @@ -19,14 +19,8 @@ func (bot *Bot[T]) Updates() ([]tgapi.Update, error) { return nil, err } - for _, u := range updates { - bot.SetUpdateOffset(u.UpdateID + 1) - err = bot.GetQueue().Enqueue(&u) - if err != nil { - return nil, err - } - - if bot.RequestLogger != nil { + if bot.RequestLogger != nil { + for _, u := range updates { j, err := json.Marshal(u) if err != nil { bot.GetLogger().Error(err) @@ -34,5 +28,8 @@ func (bot *Bot[T]) Updates() ([]tgapi.Update, error) { bot.RequestLogger.Debugf("UPDATE %s\n", j) } } + if len(updates) > 0 { + bot.SetUpdateOffset(updates[len(updates)-1].UpdateID + 1) + } return updates, err } diff --git a/plugins.go b/plugins.go index 312c386..a30514b 100644 --- a/plugins.go +++ b/plugins.go @@ -93,37 +93,42 @@ func (c *Command[T]) validateArgs(args []string) error { } type Plugin[T DbContext] struct { - Name string - Commands map[string]Command[T] - Payloads map[string]Command[T] - Middlewares extypes.Slice[Middleware[T]] + name string + commands map[string]Command[T] + payloads map[string]Command[T] + middlewares extypes.Slice[Middleware[T]] + skipAutoCmd bool } func NewPlugin[T DbContext](name string) *Plugin[T] { return &Plugin[T]{ name, map[string]Command[T]{}, - map[string]Command[T]{}, extypes.Slice[Middleware[T]]{}, + map[string]Command[T]{}, extypes.Slice[Middleware[T]]{}, false, } } func (p *Plugin[T]) AddCommand(command *Command[T]) *Plugin[T] { - p.Commands[command.command] = *command + p.commands[command.command] = *command return p } func (p *Plugin[T]) NewCommand(exec CommandExecutor[T], command string, args ...CommandArg) *Command[T] { return NewCommand(exec, command, args...) } func (p *Plugin[T]) AddPayload(command *Command[T]) *Plugin[T] { - p.Payloads[command.command] = *command + p.payloads[command.command] = *command return p } func (p *Plugin[T]) AddMiddleware(middleware Middleware[T]) *Plugin[T] { - p.Middlewares = p.Middlewares.Push(middleware) + p.middlewares = p.middlewares.Push(middleware) + return p +} +func (p *Plugin[T]) SkipCommandAutoGen() *Plugin[T] { + p.skipAutoCmd = true return p } func (p *Plugin[T]) executeCmd(cmd string, ctx *MsgContext, dbContext *T) { - command := p.Commands[cmd] + command := p.commands[cmd] if err := command.validateArgs(ctx.Args); err != nil { ctx.error(err) return @@ -131,7 +136,7 @@ func (p *Plugin[T]) executeCmd(cmd string, ctx *MsgContext, dbContext *T) { command.exec(ctx, dbContext) } func (p *Plugin[T]) executePayload(payload string, ctx *MsgContext, dbContext *T) { - pl := p.Payloads[payload] + pl := p.payloads[payload] if err := pl.validateArgs(ctx.Args); err != nil { ctx.error(err) return @@ -139,7 +144,7 @@ func (p *Plugin[T]) executePayload(payload string, ctx *MsgContext, dbContext *T pl.exec(ctx, dbContext) } func (p *Plugin[T]) executeMiddlewares(ctx *MsgContext, db *T) bool { - for _, m := range p.Middlewares { + for _, m := range p.middlewares { if !m.Execute(ctx, db) { return false } diff --git a/utils/version.go b/utils/version.go index fb99be2..b825859 100644 --- a/utils/version.go +++ b/utils/version.go @@ -1,9 +1,9 @@ package utils const ( - VersionString = "0.8.0-beta.4" - VersionMajor = 0 - VersionMinor = 8 + VersionString = "1.0.0-beta.1" + VersionMajor = 1 + VersionMinor = 0 VersionPatch = 0 - Beta = 4 + Beta = 1 )