From 6ba8520bb7ca36130b9f3f5f25e5b994043cbf3d Mon Sep 17 00:00:00 2001 From: ScuroNeko Date: Fri, 13 Mar 2026 11:24:13 +0300 Subject: [PATCH] v1.0.0 beta 18 --- bot.go | 16 ++++++++++++---- cmd_generator.go | 9 +++++---- drafts.go | 17 ++++++++++++++++- handler.go | 11 +++++++++-- msg_context.go | 19 ++++++++++++++----- plugins.go | 31 ++++++++++++++++++++++--------- runners.go | 27 ++++++++++++++++++--------- utils/limiter.go | 4 +++- utils/version.go | 4 ++-- 9 files changed, 101 insertions(+), 37 deletions(-) diff --git a/bot.go b/bot.go index 2a0daa5..69f60a1 100644 --- a/bot.go +++ b/bot.go @@ -163,7 +163,7 @@ func LoadPrefixesFromEnv() []string { // bot := NewBot[MyDB](opts).DatabaseContext(&myDB) // // Use NoDB if no database is needed. -type DbContext interface{} +type DbContext any // NoDB is a placeholder type for bots that do not use a database. // Use Bot[NoDB] to indicate no dependency injection is required. @@ -599,12 +599,18 @@ func (bot *Bot[T]) RunWithContext(ctx context.Context) { return } - bot.ExecRunners() + bot.ExecRunners(ctx) bot.logger.Infoln("Bot running. Press CTRL+C to exit.") // Start update polling in a goroutine go func() { + defer func() { + if r := recover(); r != nil { + bot.logger.Errorln(fmt.Sprintf("panic in update polling: %v", r)) + } + close(bot.updateQueue) + }() for { select { case <-ctx.Done(): @@ -618,6 +624,7 @@ func (bot *Bot[T]) RunWithContext(ctx context.Context) { } for _, u := range updates { + u := u // copy loop variable to avoid race condition select { case bot.updateQueue <- &u: case <-ctx.Done(): @@ -631,11 +638,12 @@ func (bot *Bot[T]) RunWithContext(ctx context.Context) { // Start worker pool for concurrent update handling pool := pond.NewPool(16) for update := range bot.updateQueue { - update := update // capture loop variable + u := update // capture loop variable pool.Submit(func() { - bot.handle(update) + bot.handle(u) }) } + pool.Stop() // Wait for all tasks to complete and stop the pool } // Run starts the bot using a background context. diff --git a/cmd_generator.go b/cmd_generator.go index 3f797ad..3f2e0a8 100644 --- a/cmd_generator.go +++ b/cmd_generator.go @@ -41,6 +41,8 @@ var ErrTooManyCommands = errors.New("too many commands. max 100") // // Command{command: "start", description: "Start the bot", args: []Arg{{text: "name", required: false}}} // → Description: "Start the bot. Usage: /start [name]" +// Command{command: "echo", description: "Echo user input", args: []Arg{{text: "name", required: true}}} +// → Description: "Echo user input. Usage: /echo " func generateBotCommand[T any](cmd *Command[T]) tgapi.BotCommand { desc := "" if len(cmd.description) > 0 { @@ -50,16 +52,15 @@ func generateBotCommand[T any](cmd *Command[T]) tgapi.BotCommand { var descArgs []string for _, a := range cmd.args { if a.required { - descArgs = append(descArgs, a.text) + descArgs = append(descArgs, fmt.Sprintf("<%s>", a.text)) } else { descArgs = append(descArgs, fmt.Sprintf("[%s]", a.text)) } } + usage := fmt.Sprintf("Usage: /%s %s", cmd.command, strings.Join(descArgs, " ")) if desc != "" { - desc = fmt.Sprintf("%s. Usage: /%s %s", desc, cmd.command, strings.Join(descArgs, " ")) - } else { - desc = fmt.Sprintf("Usage: /%s %s", cmd.command, strings.Join(descArgs, " ")) + desc = fmt.Sprintf("%s. %s", desc, usage) } return tgapi.BotCommand{Command: cmd.command, Description: desc} } diff --git a/drafts.go b/drafts.go index e9d93f2..5db19fe 100644 --- a/drafts.go +++ b/drafts.go @@ -30,6 +30,7 @@ package laniakea import ( "math/rand/v2" + "sync" "sync/atomic" "git.nix13.pw/scuroneko/laniakea/tgapi" @@ -68,6 +69,7 @@ func (g *LinearDraftIdGenerator) Next() uint64 { // DraftProvider is NOT thread-safe. Concurrent access from multiple goroutines // requires external synchronization. type DraftProvider struct { + mu sync.RWMutex api *tgapi.API drafts map[uint64]*Draft generator draftIdGenerator @@ -139,6 +141,8 @@ func (p *DraftProvider) SetEntities(entities []tgapi.MessageEntity) *DraftProvid // // Returns the draft and true if found, or nil and false if not found. func (p *DraftProvider) GetDraft(id uint64) (*Draft, bool) { + p.mu.RLock() + defer p.mu.RUnlock() draft, ok := p.drafts[id] return draft, ok } @@ -150,8 +154,15 @@ func (p *DraftProvider) GetDraft(id uint64) (*Draft, bool) { // // After successful flush, each draft is removed from the provider and cleared. func (p *DraftProvider) FlushAll() error { - var lastErr error + p.mu.RLock() + drafts := make([]*Draft, 0, len(p.drafts)) for _, draft := range p.drafts { + drafts = append(drafts, draft) + } + p.mu.RUnlock() + + var lastErr error + for _, draft := range drafts { if err := draft.Flush(); err != nil { lastErr = err break // Stop on first error to avoid partial state @@ -201,7 +212,9 @@ func (p *DraftProvider) NewDraft(parseMode tgapi.ParseMode) *Draft { ID: id, Message: "", } + p.mu.Lock() p.drafts[id] = draft + p.mu.Unlock() return draft } @@ -253,7 +266,9 @@ func (d *Draft) Clear() { // want to cancel a draft without sending it. func (d *Draft) Delete() { if d.provider != nil { + d.provider.mu.Lock() delete(d.provider.drafts, d.ID) + d.provider.mu.Unlock() } d.Clear() } diff --git a/handler.go b/handler.go index a1ea281..23ad408 100644 --- a/handler.go +++ b/handler.go @@ -4,6 +4,7 @@ import ( "encoding/base64" "encoding/json" "errors" + "fmt" "strings" "git.nix13.pw/scuroneko/laniakea/tgapi" @@ -12,6 +13,12 @@ import ( var ErrInvalidPayloadType = errors.New("invalid payload type") func (bot *Bot[T]) handle(u *tgapi.Update) { + defer func() { + if r := recover(); r != nil { + bot.logger.Errorln(fmt.Sprintf("panic in handle: %v", r)) + } + }() + ctx := &MsgContext{ Update: *u, Api: bot.api, botLogger: bot.logger, @@ -84,7 +91,7 @@ func (bot *Bot[T]) handleMessage(update *tgapi.Update, ctx *MsgContext) { if !plugin.executeMiddlewares(ctx, bot.dbContext) { return } - go plugin.executeCmd(cmd, ctx, bot.dbContext) + plugin.executeCmd(cmd, ctx, bot.dbContext) return } } @@ -113,7 +120,7 @@ func (bot *Bot[T]) handleCallback(update *tgapi.Update, ctx *MsgContext) { if !plugin.executeMiddlewares(ctx, bot.dbContext) { return } - go plugin.executePayload(data.Command, ctx, bot.dbContext) + plugin.executePayload(data.Command, ctx, bot.dbContext) return } } diff --git a/msg_context.go b/msg_context.go index 276df29..5c2785e 100644 --- a/msg_context.go +++ b/msg_context.go @@ -22,6 +22,7 @@ package laniakea import ( "context" "fmt" + "time" "git.nix13.pw/scuroneko/laniakea/tgapi" "git.nix13.pw/scuroneko/slog" @@ -31,10 +32,12 @@ import ( // It provides methods to respond, edit, delete, and translate messages, as well as // manage inline keyboards and message drafts. type MsgContext struct { - Api *tgapi.API - Msg *tgapi.Message - Update tgapi.Update - From *tgapi.User + Api *tgapi.API + Update tgapi.Update + + Msg *tgapi.Message + From *tgapi.User + CallbackMsgId int CallbackQueryId string FromID int @@ -385,7 +388,13 @@ func (ctx *MsgContext) error(err error) { func (ctx *MsgContext) Error(err error) { ctx.error(err) } func (ctx *MsgContext) newDraft(parseMode tgapi.ParseMode) *Draft { - c := context.Background() + if ctx.Msg == nil { + ctx.botLogger.Errorln("can't create draft: ctx.Msg is nil") + return nil + } + + c, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() if err := ctx.Api.Limiter.Wait(c, ctx.Msg.Chat.ID); err != nil { ctx.botLogger.Errorln(err) return nil diff --git a/plugins.go b/plugins.go index 6255fe2..6f783f7 100644 --- a/plugins.go +++ b/plugins.go @@ -33,11 +33,14 @@ const ( CommandValueAnyType CommandValueType = "any" ) -// CommandRegexInt matches one or more digits. -var CommandRegexInt = regexp.MustCompile(`\d+`) - -// CommandRegexString matches any non-empty string. -var CommandRegexString = regexp.MustCompile(".+") +var ( + // CommandRegexInt matches one or more digits. + CommandRegexInt = regexp.MustCompile(`\d+`) + // CommandRegexString matches any non-empty string. + CommandRegexString = regexp.MustCompile(`.+`) + // CommandRegexBool matches true or false + CommandRegexBool = regexp.MustCompile(`true|false`) +) // ErrCmdArgCountMismatch is returned when the number of provided arguments // is less than the number of required arguments. @@ -58,15 +61,22 @@ type CommandArg struct { // NewCommandArg creates a new CommandArg with the given text and type. // Uses a default regex based on the type (string or int). // For CommandValueAnyType, no validation is performed. -func NewCommandArg(text string, valueType CommandValueType) *CommandArg { +func NewCommandArg(text string) *CommandArg { + return &CommandArg{CommandValueAnyType, text, CommandRegexString, false} +} + +func (c *CommandArg) SetValueType(t CommandValueType) *CommandArg { regex := CommandRegexString - switch valueType { + switch t { case CommandValueIntType: regex = CommandRegexInt + case CommandValueBoolType: + regex = CommandRegexBool case CommandValueAnyType: regex = nil // Skip validation } - return &CommandArg{valueType, text, regex, false} + c.regex = regex + return c } // SetRequired marks this argument as required. @@ -320,7 +330,10 @@ func (m *Middleware[T]) SetAsync(async bool) *Middleware[T] { // Otherwise, returns the result of the executor. func (m *Middleware[T]) Execute(ctx *MsgContext, db *T) bool { if m.async { - go m.executor(ctx, db) + ctx := *ctx // copy context to avoid race condition + go func(ctx MsgContext) { + m.executor(&ctx, db) + }(ctx) return true } return m.executor(ctx, db) diff --git a/runners.go b/runners.go index 0774a37..98bc359 100644 --- a/runners.go +++ b/runners.go @@ -11,6 +11,7 @@ package laniakea import ( + "context" "time" ) @@ -83,7 +84,7 @@ func (r *Runner[T]) Timeout(timeout time.Duration) *Runner[T] { return r } -// ExecRunners executes all runners registered on the Bot. +// ExecRunners executes all runners registered on the Bot with context-based lifecycle management. // // It logs warnings for misconfigured runners: // - Sync, non-onetime runners are skipped (invalid configuration). @@ -92,11 +93,13 @@ func (r *Runner[T]) Timeout(timeout time.Duration) *Runner[T] { // Execution logic: // - onetime + async: Runs once in a goroutine. // - onetime + sync: Runs once synchronously; warns if slower than 2 seconds. -// - !onetime + async: Runs in an infinite loop with timeout between iterations. +// - !onetime + async: Runs in a loop with timeout between iterations until ctx.Done(). // - !onetime + sync: Skipped with warning. // -// This method is typically called once during bot startup. -func (bot *Bot[T]) ExecRunners() { +// Background runners listen for ctx.Done() and gracefully shut down when the context is canceled. +// +// This method is typically called once during bot startup in RunWithContext. +func (bot *Bot[T]) ExecRunners(ctx context.Context) { bot.logger.Infoln("Executing runners...") for _, runner := range bot.runners { // Validate configuration @@ -128,14 +131,20 @@ func (bot *Bot[T]) ExecRunners() { bot.logger.Warnf("Runner %s too slow. Elapsed time %v >= 2s\n", runner.name, elapsed) } } else if !runner.onetime && runner.async { - // Background loop: periodic execution + // Background loop: periodic execution with graceful shutdown go func(r Runner[T]) { + ticker := time.NewTicker(r.timeout) + defer ticker.Stop() for { - err := r.fn(bot) - if err != nil { - bot.logger.Warnf("Runner %s failed: %s\n", r.name, err) + select { + case <-ctx.Done(): + return + case <-ticker.C: + err := r.fn(bot) + if err != nil { + bot.logger.Warnf("Runner %s failed: %s\n", r.name, err) + } } - time.Sleep(r.timeout) } }(runner) } diff --git a/utils/limiter.go b/utils/limiter.go index 65ce2c6..02c7330 100644 --- a/utils/limiter.go +++ b/utils/limiter.go @@ -193,8 +193,10 @@ func (rl *RateLimiter) waitForChatUnlock(ctx context.Context, chatID int64) erro // getChatLimiter returns the rate limiter for the given chat, creating it if needed. // Uses 1 request per second with burst of 1 — conservative for per-user limits. -// Must be called with rl.chatMu held. func (rl *RateLimiter) getChatLimiter(chatID int64) *rate.Limiter { + rl.chatMu.Lock() + defer rl.chatMu.Unlock() + if lim, ok := rl.chatLimiters[chatID]; ok { return lim } diff --git a/utils/version.go b/utils/version.go index 477125d..3bf3187 100644 --- a/utils/version.go +++ b/utils/version.go @@ -1,9 +1,9 @@ package utils const ( - VersionString = "1.0.0-beta.17" + VersionString = "1.0.0-beta.18" VersionMajor = 1 VersionMinor = 0 VersionPatch = 0 - VersionBeta = 17 + VersionBeta = 18 )