From 7f248fff6207443aa5d968748352c364d629ceb7 Mon Sep 17 00:00:00 2001 From: ScuroNeko Date: Wed, 5 Nov 2025 11:38:09 +0300 Subject: [PATCH] fix --- bot.go | 133 +++++++++++++++++++++++++++++++++++------------------ logger.go | 84 +++++++++++++++++++++++++++++---- plugins.go | 59 ++++++++++++++++++++++-- queue.go | 26 +++++++++-- utils.go | 10 ++-- version.go | 12 +++-- 6 files changed, 253 insertions(+), 71 deletions(-) diff --git a/bot.go b/bot.go index 3b9c700..57d96e2 100644 --- a/bot.go +++ b/bot.go @@ -7,7 +7,12 @@ import ( "io" "net/http" "os" + "sort" "strings" + + "github.com/redis/go-redis/v9" + "go.mongodb.org/mongo-driver/v2/mongo" + "gorm.io/gorm" ) type ParseMode string @@ -26,8 +31,11 @@ type Bot struct { logger *Logger requestLogger *Logger - plugins []*Plugin - prefixes []string + plugins []*Plugin + middlewares []*Middleware + prefixes []string + + dbContext *DatabaseContext updateOffset int updateTypes []string @@ -45,7 +53,14 @@ type BotSettings struct { } func LoadSettingsFromEnv() *BotSettings { - + return &BotSettings{ + Token: os.Getenv("TG_TOKEN"), + Debug: os.Getenv("DEBUG") == "true", + ErrorTemplate: os.Getenv("ERROR_TEMPLATE"), + Prefixes: LoadPrefixesFromEnv(), + UpdateTypes: strings.Split(os.Getenv("UPDATE_TYPES"), ";"), + UseRequestLogger: os.Getenv("USE_REQ_LOG") == "true", + } } type MsgContext struct { @@ -58,6 +73,12 @@ type MsgContext struct { Args []string } +type DatabaseContext struct { + PostgresSQL *gorm.DB + MongoDB *mongo.Client + Redis *redis.Client +} + func NewBot(settings *BotSettings) *Bot { updateQueue := CreateQueue[*Update](256) bot := &Bot{ @@ -79,6 +100,7 @@ func NewBot(settings *BotSettings) *Bot { level = DEBUG } bot.logger = CreateLogger().Level(level).OpenFile(fmt.Sprintf("%s/main.log", strings.TrimRight(settings.LoggerBasePath, "/"))) + bot.logger = bot.logger.PrintTraceback(true) if settings.UseRequestLogger { bot.requestLogger = CreateLogger().Level(level).Prefix("REQUESTS").OpenFile(fmt.Sprintf("%s/requests.log", strings.TrimRight(settings.LoggerBasePath, "/"))) } @@ -95,6 +117,19 @@ func (b *Bot) Close() { } } +func (b *Bot) InitDatabaseContext(ctx *DatabaseContext) *Bot { + b.dbContext = ctx + return b +} +func (b *Bot) AddDatabaseLogger(writer func(db *DatabaseContext) LoggerWriter) *Bot { + w := []LoggerWriter{writer(b.dbContext)} + b.logger.AddWriters(w) + if b.requestLogger != nil { + b.requestLogger.AddWriters(w) + } + return b +} + func (b *Bot) UpdateTypes(t ...string) *Bot { b.updateTypes = make([]string, 0) b.updateTypes = append(b.updateTypes, t...) @@ -111,11 +146,11 @@ func (b *Bot) AddPrefixes(prefixes ...string) *Bot { } func LoadPrefixesFromEnv() []string { - prefixesS, exists := os.LookupEnv("PREFIXES") + prefixes, exists := os.LookupEnv("PREFIXES") if !exists { return []string{"!"} } - return strings.Split(prefixesS, ";") + return strings.Split(prefixes, ";") } func (b *Bot) ErrorTemplate(s string) *Bot { @@ -136,6 +171,23 @@ func (b *Bot) AddPlugins(plugin ...*Plugin) *Bot { return b } +func (b *Bot) AddMiddleware(middleware ...*Middleware) *Bot { + sort.Slice(middleware, func(a, b int) bool { + first := middleware[a] + second := middleware[b] + if first.Order == second.Order { + return first.Name < second.Name + } + return middleware[a].Order < middleware[b].Order + }) + + b.middlewares = append(b.middlewares, middleware...) + for _, m := range middleware { + b.logger.Debug(fmt.Sprintf("middleware with name \"%s\" was registered", m.Name)) + } + return b +} + func (b *Bot) Run() { if len(b.prefixes) == 0 { b.logger.Fatal("no prefixes defined") @@ -165,28 +217,31 @@ func (b *Bot) Run() { } u := queue.Dequeue() + ctx := &MsgContext{ + Bot: b, + Update: u, + } + for _, middleware := range b.middlewares { + middleware.Execute(ctx, b.dbContext) + } + + for _, plugin := range b.plugins { + if plugin.UpdateListener != nil { + (*plugin.UpdateListener)(ctx, b.dbContext) + } + } + if u.CallbackQuery != nil { - b.handleCallback(u) + b.handleCallback(u, ctx) } else { - b.handleMessage(u) + b.handleMessage(u, ctx) } } } // {"callback_query":{"chat_instance":"6202057960757700762","data":"aboba","from":{"first_name":"scuroneko","id":314834933,"is_bot":false,"language_code":"ru","username":"scuroneko"},"id":"1352205741990111553","message":{"chat":{"first_name":"scuroneko","id":314834933,"type":"private","username":"scuroneko"},"date":1734338107,"from":{"first_name":"Kurumi","id":7718900880,"is_bot":true,"username":"kurumi_game_bot"},"message_id":19,"reply_markup":{"inline_keyboard":[[{"callback_data":"aboba","text":"Test"},{"callback_data":"another","text":"Another"}]]},"text":"Aboba"}},"update_id":350979488} -func (b *Bot) handleMessage(update *Update) { - ctx := &MsgContext{ - Bot: b, - Update: update, - } - - for _, plugin := range b.plugins { - if plugin.UpdateListener != nil { - (*plugin.UpdateListener)(ctx) - } - } - +func (b *Bot) handleMessage(update *Update, ctx *MsgContext) { var text string if update.Message == nil { return @@ -209,7 +264,6 @@ func (b *Bot) handleMessage(update *Update) { text = strings.TrimSpace(text[len(prefix):]) for _, plugin := range b.plugins { - // Check every command for cmd := range plugin.Commands { if !strings.HasPrefix(text, cmd) { @@ -219,29 +273,18 @@ func (b *Bot) handleMessage(update *Update) { ctx.Text = strings.TrimSpace(text[len(cmd):]) ctx.Args = strings.Split(ctx.Text, " ") - go plugin.Execute(cmd, ctx) + go plugin.Execute(cmd, ctx, b.dbContext) } } } -func (b *Bot) handleCallback(update *Update) { - ctx := &MsgContext{ - Bot: b, - Update: update, - } - - for _, plugin := range b.plugins { - if plugin.UpdateListener != nil { - (*plugin.UpdateListener)(ctx) - } - } - +func (b *Bot) handleCallback(update *Update, ctx *MsgContext) { for _, plugin := range b.plugins { for payload := range plugin.Payloads { if !strings.HasPrefix(update.CallbackQuery.Data, payload) { continue } - go plugin.ExecutePayload(payload, ctx) + go plugin.ExecutePayload(payload, ctx, b.dbContext) } } } @@ -259,7 +302,7 @@ func (ctx *MsgContext) Answer(text string) { _, err := ctx.Bot.SendMessage(&SendMessageP{ ChatID: ctx.Msg.Chat.ID, Text: text, - ParseMode: "markdown", + ParseMode: ParseMD, }) if err != nil { ctx.Bot.logger.Error(err) @@ -294,21 +337,21 @@ func (b *Bot) Logger() *Logger { } type ApiResponse struct { - Ok bool `json:"ok"` - Result map[string]interface{} `json:"result,omitempty"` - Description string `json:"description,omitempty"` - ErrorCode int `json:"error_code,omitempty"` + Ok bool `json:"ok"` + Result map[string]any `json:"result,omitempty"` + Description string `json:"description,omitempty"` + ErrorCode int `json:"error_code,omitempty"` } type ApiResponseA struct { - Ok bool `json:"ok"` - Result []interface{} `json:"result,omitempty"` - Description string `json:"description,omitempty"` - ErrorCode int `json:"error_code,omitempty"` + Ok bool `json:"ok"` + Result []any `json:"result,omitempty"` + Description string `json:"description,omitempty"` + ErrorCode int `json:"error_code,omitempty"` } // request is a low-level call to api. -func (b *Bot) request(methodName string, params map[string]interface{}) (map[string]interface{}, error) { +func (b *Bot) request(methodName string, params map[string]any) (map[string]any, error) { var buf bytes.Buffer err := json.NewEncoder(&buf).Encode(params) if err != nil { @@ -333,7 +376,7 @@ func (b *Bot) request(methodName string, params map[string]interface{}) (map[str } response := new(ApiResponse) - var result map[string]interface{} + var result map[string]any err = json.Unmarshal(data, &response) if err != nil { diff --git a/logger.go b/logger.go index 2960fde..feb00bd 100644 --- a/logger.go +++ b/logger.go @@ -5,17 +5,21 @@ import ( "os" "path/filepath" "runtime" + "sort" "strings" "time" "github.com/fatih/color" ) +type LoggerWriter func(level LogLevel, prefix, traceback string, message []any) + type Logger struct { prefix string level LogLevel printTraceback bool printTime bool + writers []LoggerWriter f *os.File } @@ -26,6 +30,10 @@ type LogLevel struct { c color.Attribute } +func (l *LogLevel) GetName() string { + return l.t +} + type MethodTraceback struct { Package string Method string @@ -36,11 +44,11 @@ type MethodTraceback struct { } var ( - INFO LogLevel = LogLevel{n: 0, t: "info", c: color.FgWhite} - WARN LogLevel = LogLevel{n: 1, t: "warn", c: color.FgHiYellow} - ERROR LogLevel = LogLevel{n: 2, t: "error", c: color.FgHiRed} - FATAL LogLevel = LogLevel{n: 3, t: "fatal", c: color.FgRed} - DEBUG LogLevel = LogLevel{n: 4, t: "debug", c: color.FgGreen} + INFO = LogLevel{n: 0, t: "info", c: color.FgWhite} + WARN = LogLevel{n: 1, t: "warn", c: color.FgHiYellow} + ERROR = LogLevel{n: 2, t: "error", c: color.FgHiRed} + FATAL = LogLevel{n: 3, t: "fatal", c: color.FgRed} + DEBUG = LogLevel{n: 4, t: "debug", c: color.FgGreen} ) func CreateLogger() *Logger { @@ -76,6 +84,14 @@ func (l *Logger) PrintTraceback(b bool) *Logger { l.printTraceback = b return l } +func (l *Logger) PrintTime(b bool) *Logger { + l.printTime = b + return l +} +func (l *Logger) AddWriters(writers []LoggerWriter) *Logger { + l.writers = append(l.writers, writers...) + return l +} func (l *Logger) Info(m ...any) { l.print(INFO, m) @@ -127,6 +143,47 @@ func (l *Logger) formatTraceback(mt *MethodTraceback) string { return fmt.Sprintf("%s:%s:%d", mt.filename, mt.Method, mt.line) } +func (l *Logger) getFullTraceback(skip int) []*MethodTraceback { + pc := make([]uintptr, 15) + runtime.Callers(skip, pc) + list := make([]*MethodTraceback, 0) + frames := runtime.CallersFrames(pc) + for { + frame, more := frames.Next() + if !more { + break + } + details := runtime.FuncForPC(frame.PC) + signature := details.Name() + path, line := details.FileLine(frame.PC) + splitPath := strings.Split(path, "/") + + splitSignature := strings.Split(signature, ".") + pkg, method := splitSignature[0], splitSignature[len(splitSignature)-1] + + tb := &MethodTraceback{ + filename: splitPath[len(splitPath)-1], + fullPath: path, + line: line, + signature: signature, + Package: pkg, + Method: method, + } + list = append(list, tb) + } + sort.Slice(list, func(i, j int) bool { + return j < i + }) + return list +} +func (l *Logger) formatFullTraceback(tracebacks []*MethodTraceback) string { + formatted := make([]string, 0) + for _, tb := range tracebacks { + formatted = append(formatted, l.formatTraceback(tb)) + } + return strings.Join(formatted, "->") +} + func (l *Logger) buildString(level LogLevel, m []any) string { args := []string{ fmt.Sprintf("[%s]", l.prefix), @@ -152,11 +209,22 @@ func (l *Logger) print(level LogLevel, m []any) { if l.level.n < level.n { return } - color.New(level.c).Println(l.buildString(level, m)) + _, err := color.New(level.c).Println(l.buildString(level, m)) + if err != nil { + l.Fatal(err) + return + } + + for _, writer := range l.writers { + writer(level, l.prefix, l.formatFullTraceback(l.getFullTraceback(4)), m) + } if l.f != nil { - if _, err := l.f.Write([]byte(l.buildString(level, m) + "\n")); err != nil { - l.Fatal(err) + writeToFiles := os.Getenv("WRITE_TO_FILE") + if writeToFiles != "false" { + if _, err := l.f.Write([]byte(l.buildString(level, m) + "\n")); err != nil { + l.Fatal(err) + } } } } diff --git a/plugins.go b/plugins.go index 0d3dc00..29f5c27 100644 --- a/plugins.go +++ b/plugins.go @@ -1,6 +1,6 @@ package laniakea -type CommandExecutor func(ctx *MsgContext) +type CommandExecutor func(ctx *MsgContext, dbContext *DatabaseContext) type PluginBuilder struct { name string @@ -56,10 +56,59 @@ func (p *PluginBuilder) Build() *Plugin { return plugin } -func (p *Plugin) Execute(cmd string, ctx *MsgContext) { - (*p.Commands[cmd])(ctx) +func (p *Plugin) Execute(cmd string, ctx *MsgContext, dbContext *DatabaseContext) { + (*p.Commands[cmd])(ctx, dbContext) } -func (p *Plugin) ExecutePayload(payload string, ctx *MsgContext) { - (*p.Payloads[payload])(ctx) +func (p *Plugin) ExecutePayload(payload string, ctx *MsgContext, dbContext *DatabaseContext) { + (*p.Payloads[payload])(ctx, dbContext) +} + +type Middleware struct { + Name string + Executor *CommandExecutor + Order int + Async bool +} +type MiddlewareBuilder struct { + name string + executor *CommandExecutor + order int + async bool +} + +func NewMiddleware(name string) *MiddlewareBuilder { + return &MiddlewareBuilder{name: name, async: false} +} +func (m *MiddlewareBuilder) SetName(name string) *MiddlewareBuilder { + m.name = name + return m +} +func (m *MiddlewareBuilder) SetExecutor(executor CommandExecutor) *MiddlewareBuilder { + m.executor = &executor + return m +} +func (m *MiddlewareBuilder) SetOrder(order int) *MiddlewareBuilder { + m.order = order + return m +} +func (m *MiddlewareBuilder) SetAsync(async bool) *MiddlewareBuilder { + m.async = async + return m +} +func (m *MiddlewareBuilder) Build() *Middleware { + return &Middleware{ + Name: m.name, + Executor: m.executor, + Order: m.order, + Async: m.async, + } +} +func (m *Middleware) Execute(ctx *MsgContext, db *DatabaseContext) { + exec := *m.Executor + if m.Async { + go exec(ctx, db) + } else { + exec(ctx, db) + } } diff --git a/queue.go b/queue.go index 79ebeca..69f6020 100644 --- a/queue.go +++ b/queue.go @@ -1,12 +1,18 @@ package laniakea -import "fmt" +import ( + "errors" + "sync" +) type Queue[T any] struct { - queue []T size uint64 + mu sync.RWMutex + queue []T } +var QueueFullError = errors.New("queue full") + func CreateQueue[T any](size uint64) *Queue[T] { return &Queue[T]{ queue: make([]T, 0), @@ -16,18 +22,20 @@ func CreateQueue[T any](size uint64) *Queue[T] { func (q *Queue[T]) Enqueue(el T) error { if q.IsFull() { - return fmt.Errorf("queue full") + return QueueFullError } q.queue = append(q.queue, el) return nil } func (q *Queue[T]) Peak() T { + q.mu.RLock() + defer q.mu.RUnlock() return q.queue[0] } func (q *Queue[T]) IsEmpty() bool { - return len(q.queue) == 0 + return q.Length() == 0 } func (q *Queue[T]) IsFull() bool { @@ -35,16 +43,26 @@ func (q *Queue[T]) IsFull() bool { } func (q *Queue[T]) Length() uint64 { + q.mu.RLock() + defer q.mu.RUnlock() return uint64(len(q.queue)) } func (q *Queue[T]) Dequeue() T { + q.mu.RLock() el := q.queue[0] + q.mu.RUnlock() + if q.Length() == 1 { + q.mu.Lock() q.queue = make([]T, 0) + q.mu.Unlock() return el } + + q.mu.Lock() q.queue = q.queue[1:] + q.mu.Unlock() return el } diff --git a/utils.go b/utils.go index e8c6385..c16334f 100644 --- a/utils.go +++ b/utils.go @@ -2,26 +2,26 @@ package laniakea import "encoding/json" -func MapToStruct(m map[string]interface{}, s interface{}) error { +func MapToStruct(m map[string]any, dst interface{}) error { data, err := json.Marshal(m) if err != nil { return err } - err = json.Unmarshal(data, s) + err = json.Unmarshal(data, dst) return err } -func MapToJson(m map[string]interface{}) (string, error) { +func MapToJson(m map[string]any) (string, error) { data, err := json.Marshal(m) return string(data), err } -func StructToMap(s interface{}) (map[string]interface{}, error) { +func StructToMap(s interface{}) (map[string]any, error) { data, err := json.Marshal(s) if err != nil { return nil, err } - m := make(map[string]interface{}) + m := make(map[string]any) err = json.Unmarshal(data, &m) return m, err } diff --git a/version.go b/version.go index fe6e598..d47af22 100644 --- a/version.go +++ b/version.go @@ -1,8 +1,12 @@ package laniakea +import "os" + const ( - VERSION_STRING = "0.1.4" - VERSION_MAJOR = 0 - VERSION_MINOR = 1 - VERSION_PATCH = 4 + VersionString = "0.1.4" + VersionMajor = 0 + VersionMinor = 1 + VersionPatch = 4 ) + +var GoVersion = os.Getenv("GoV")