diff --git a/bot.go b/bot.go index 900145d..6abec8f 100644 --- a/bot.go +++ b/bot.go @@ -11,9 +11,9 @@ import ( "git.nix13.pw/scuroneko/extypes" "git.nix13.pw/scuroneko/laniakea/tgapi" + "git.nix13.pw/scuroneko/laniakea/utils" "git.nix13.pw/scuroneko/slog" "github.com/alitto/pond/v2" - "golang.org/x/time/rate" ) type BotOpts struct { @@ -99,9 +99,9 @@ type Bot[T DbContext] struct { func NewBot[T any](opts *BotOpts) *Bot[T] { updateQueue := make(chan *tgapi.Update, 512) - var limiter *rate.Limiter + var limiter *utils.RateLimiter if opts.RateLimit > 0 { - limiter = rate.NewLimiter(rate.Limit(opts.RateLimit), opts.RateLimit) + limiter = utils.NewRateLimiter() } apiOpts := tgapi.NewAPIOpts(opts.Token).SetAPIUrl(opts.APIUrl).UseTestServer(opts.UseTestServer).SetLimiter(limiter) diff --git a/drafts.go b/drafts.go index 9b95411..fb37088 100644 --- a/drafts.go +++ b/drafts.go @@ -32,7 +32,7 @@ func (g *LinearDraftIdGenerator) Next() uint64 { type DraftProvider struct { api *tgapi.API - chatID int + chatID int64 messageThreadID int parseMode tgapi.ParseMode entities []tgapi.MessageEntity @@ -43,7 +43,7 @@ type DraftProvider struct { type Draft struct { api *tgapi.API - chatID int + chatID int64 messageThreadID int parseMode tgapi.ParseMode entities []tgapi.MessageEntity diff --git a/msg_context.go b/msg_context.go index 1d2eb19..817a630 100644 --- a/msg_context.go +++ b/msg_context.go @@ -1,6 +1,7 @@ package laniakea import ( + "context" "fmt" "git.nix13.pw/scuroneko/laniakea/tgapi" @@ -114,6 +115,11 @@ func (ctx *MsgContext) answer(text string, keyboard *InlineKeyboard) *AnswerMess params.DirectMessagesTopicID = ctx.Msg.DirectMessageTopic.TopicID } + cont := context.Background() + if err := ctx.Api.Limiter.Wait(cont, ctx.Msg.Chat.ID); err != nil { + ctx.botLogger.Errorln(err) + return nil + } msg, err := ctx.Api.SendMessage(params) if err != nil { ctx.botLogger.Errorln(err) @@ -220,6 +226,12 @@ func (ctx *MsgContext) error(err error) { func (ctx *MsgContext) Error(err error) { ctx.error(err) } func (ctx *MsgContext) NewDraft() *Draft { + c := context.Background() + if err := ctx.Api.Limiter.Wait(c, ctx.Msg.Chat.ID); err != nil { + ctx.botLogger.Errorln(err) + return nil + } + draft := ctx.draftProvider.NewDraft() draft.chatID = ctx.Msg.Chat.ID draft.messageThreadID = ctx.Msg.MessageThreadID diff --git a/tgapi/api.go b/tgapi/api.go index 8422b20..0004e6c 100644 --- a/tgapi/api.go +++ b/tgapi/api.go @@ -12,7 +12,6 @@ import ( "git.nix13.pw/scuroneko/laniakea/utils" "git.nix13.pw/scuroneko/slog" - "golang.org/x/time/rate" ) type APIOpts struct { @@ -21,7 +20,7 @@ type APIOpts struct { useTestServer bool apiUrl string - limiter *rate.Limiter + limiter *utils.RateLimiter dropOverflowLimit bool } @@ -46,7 +45,7 @@ func (opts *APIOpts) SetAPIUrl(apiUrl string) *APIOpts { } return opts } -func (opts *APIOpts) SetLimiter(limiter *rate.Limiter) *APIOpts { +func (opts *APIOpts) SetLimiter(limiter *utils.RateLimiter) *APIOpts { opts.limiter = limiter return opts } @@ -63,7 +62,7 @@ type API struct { apiUrl string pool *WorkerPool - limiter *rate.Limiter + Limiter *utils.RateLimiter dropOverflowLimit bool } @@ -88,11 +87,17 @@ func (api *API) CloseApi() error { } func (api *API) GetLogger() *slog.Logger { return api.logger } +type ResponseParameters struct { + MigrateToChatID *int64 `json:"migrate_to_chat_id,omitempty"` + RetryAfter *int `json:"retry_after,omitempty"` +} type ApiResponse[R any] struct { Ok bool `json:"ok"` Description string `json:"description,omitempty"` Result R `json:"result,omitempty"` ErrorCode int `json:"error_code,omitempty"` + + Parameters *ResponseParameters `json:"parameters,omitempty"` } type TelegramRequest[R, P any] struct { method string @@ -104,13 +109,13 @@ func NewRequest[R, P any](method string, params P) TelegramRequest[R, P] { } func (r TelegramRequest[R, P]) doRequest(ctx context.Context, api *API) (R, error) { var zero R - if api.limiter != nil { + if api.Limiter != nil { if api.dropOverflowLimit { - if !api.limiter.Allow() { + if !api.Limiter.GlobalAllow() { return zero, errors.New("rate limited") } } else { - if err := api.limiter.Wait(ctx); err != nil { + if err := api.Limiter.GlobalWait(ctx); err != nil { return zero, err } } @@ -149,10 +154,23 @@ func (r TelegramRequest[R, P]) doRequest(ctx context.Context, api *API) (R, erro return zero, err } api.logger.Debugln("RES", r.method, string(data)) - if res.StatusCode != http.StatusOK { + if res.StatusCode != http.StatusOK && res.StatusCode != http.StatusTooManyRequests { return zero, fmt.Errorf("unexpected status code: %d, %s", res.StatusCode, string(data)) } - return parseBody[R](data) + + responseData, err := parseBody[R](data) + if errors.Is(err, ErrRateLimit) { + if responseData.Parameters != nil { + after := 0 + if responseData.Parameters.RetryAfter != nil { + after = *responseData.Parameters.RetryAfter + } + api.Limiter.SetGlobalLock(after) + return r.doRequest(ctx, api) + } + return zero, ErrRateLimit + } + return responseData.Result, err } func (r TelegramRequest[R, P]) DoWithContext(ctx context.Context, api *API) (R, error) { var zero R @@ -184,15 +202,17 @@ func readBody(body io.ReadCloser) ([]byte, error) { reader := io.LimitReader(body, 10<<20) return io.ReadAll(reader) } -func parseBody[R any](data []byte) (R, error) { - var zero R +func parseBody[R any](data []byte) (ApiResponse[R], error) { var resp ApiResponse[R] err := json.Unmarshal(data, &resp) if err != nil { - return zero, err + return resp, err } if !resp.Ok { - return zero, fmt.Errorf("[%d] %s", resp.ErrorCode, resp.Description) + if resp.ErrorCode == 429 { + return resp, ErrRateLimit + } + return resp, fmt.Errorf("[%d] %s", resp.ErrorCode, resp.Description) } - return resp.Result, nil + return resp, nil } diff --git a/tgapi/attachments_methods.go b/tgapi/attachments_methods.go index 7545e33..455cd8c 100644 --- a/tgapi/attachments_methods.go +++ b/tgapi/attachments_methods.go @@ -2,7 +2,7 @@ package tgapi type SendPhotoP struct { BusinessConnectionID string `json:"business_connection_id,omitempty"` - ChatID int `json:"chat_id"` + ChatID int64 `json:"chat_id"` MessageThreadID int `json:"message_thread_id,omitempty"` DirectMessagesTopicID int `json:"direct_messages_topic_id,omitempty"` diff --git a/tgapi/chat_types.go b/tgapi/chat_types.go index 96f5ac7..42bf188 100644 --- a/tgapi/chat_types.go +++ b/tgapi/chat_types.go @@ -1,7 +1,7 @@ package tgapi type Chat struct { - ID int `json:"id"` + ID int64 `json:"id"` Type string `json:"type"` Title *string `json:"title,omitempty"` Username *string `json:"username,omitempty"` diff --git a/tgapi/errors.go b/tgapi/errors.go new file mode 100644 index 0000000..48a368b --- /dev/null +++ b/tgapi/errors.go @@ -0,0 +1,5 @@ +package tgapi + +import "errors" + +var ErrRateLimit = errors.New("rate limit exceeded") diff --git a/tgapi/messages_methods.go b/tgapi/messages_methods.go index 7ec41e5..2ef6bc2 100644 --- a/tgapi/messages_methods.go +++ b/tgapi/messages_methods.go @@ -2,7 +2,7 @@ package tgapi type SendMessageP struct { BusinessConnectionID string `json:"business_connection_id,omitempty"` - ChatID int `json:"chat_id"` + ChatID int64 `json:"chat_id"` MessageThreadID int `json:"message_thread_id,omitempty"` DirectMessagesTopicID int64 `json:"direct_messages_topic_id,omitempty"` @@ -266,7 +266,7 @@ func (api *API) SendDice(params SendDiceP) (Message, error) { } type SendMessageDraftP struct { - ChatID int `json:"chat_id"` + ChatID int64 `json:"chat_id"` MessageThreadID int `json:"message_thread_id,omitempty"` DraftID uint64 `json:"draft_id"` Text string `json:"text"` @@ -281,7 +281,7 @@ func (api *API) SendMessageDraft(params SendMessageDraftP) (bool, error) { type SendChatActionP struct { BusinessConnectionID string `json:"business_connection_id,omitempty"` - ChatID int `json:"chat_id"` + ChatID int64 `json:"chat_id"` MessageThreadID int `json:"message_thread_id,omitempty"` Action ChatActionType `json:"action"` } @@ -307,7 +307,7 @@ func (api *API) SetMessageReaction(params SetMessageReactionP) (bool, error) { type EditMessageTextP struct { BusinessConnectionID string `json:"business_connection_id,omitempty"` - ChatID int `json:"chat_id,omitempty"` + ChatID int64 `json:"chat_id,omitempty"` MessageID int `json:"message_id,omitempty"` InlineMessageID string `json:"inline_message_id,omitempty"` Text string `json:"text"` @@ -331,7 +331,7 @@ func (api *API) EditMessageText(params EditMessageTextP) (Message, bool, error) type EditMessageCaptionP struct { BusinessConnectionID string `json:"business_connection_id,omitempty"` - ChatID int `json:"chat_id,omitempty"` + ChatID int64 `json:"chat_id,omitempty"` MessageID int `json:"message_id,omitempty"` InlineMessageID string `json:"inline_message_id,omitempty"` Caption string `json:"caption"` @@ -495,8 +495,8 @@ func (api *API) DeclineSuggestedPost(params DeclineSuggestedPostP) (bool, error) } type DeleteMessageP struct { - ChatID int `json:"chat_id"` - MessageID int `json:"message_id"` + ChatID int64 `json:"chat_id"` + MessageID int `json:"message_id"` } func (api *API) DeleteMessage(params DeleteMessageP) (bool, error) { diff --git a/tgapi/uploader_api.go b/tgapi/uploader_api.go index 1641162..0b4d2a5 100644 --- a/tgapi/uploader_api.go +++ b/tgapi/uploader_api.go @@ -66,13 +66,13 @@ func NewUploaderRequest[R, P any](method string, params P, files ...UploaderFile } func (r UploaderRequest[R, P]) doRequest(ctx context.Context, up *Uploader) (R, error) { var zero R - if up.api.limiter != nil { + if up.api.Limiter != nil { if up.api.dropOverflowLimit { - if !up.api.limiter.Allow() { + if !up.api.Limiter.GlobalAllow() { return zero, errors.New("rate limited") } } else { - if err := up.api.limiter.Wait(ctx); err != nil { + if err := up.api.Limiter.GlobalWait(ctx); err != nil { return zero, err } } @@ -109,7 +109,11 @@ func (r UploaderRequest[R, P]) doRequest(ctx context.Context, up *Uploader) (R, return zero, fmt.Errorf("unexpected status code: %d, %s", res.StatusCode, string(body)) } - return parseBody[R](body) + respBody, err := parseBody[R](body) + if err != nil { + return zero, err + } + return respBody.Result, nil } func (r UploaderRequest[R, P]) DoWithContext(ctx context.Context, up *Uploader) (R, error) { var zero R diff --git a/tgapi/uploader_methods.go b/tgapi/uploader_methods.go index 835faf4..8a9520e 100644 --- a/tgapi/uploader_methods.go +++ b/tgapi/uploader_methods.go @@ -2,7 +2,7 @@ package tgapi type UploadPhotoP struct { BusinessConnectionID string `json:"business_connection_id,omitempty"` - ChatID int `json:"chat_id"` + ChatID int64 `json:"chat_id"` MessageThreadID int `json:"message_thread_id,omitempty"` DirectMessagesTopicID int `json:"direct_messages_topic_id,omitempty"` diff --git a/utils/limiter.go b/utils/limiter.go new file mode 100644 index 0000000..23ed999 --- /dev/null +++ b/utils/limiter.go @@ -0,0 +1,129 @@ +package utils + +import ( + "context" + "sync" + "time" + + "golang.org/x/time/rate" +) + +type RateLimiter struct { + globalLockUntil time.Time + globalLimiter *rate.Limiter + globalMu sync.RWMutex + + chatLocks map[int64]time.Time + chatLimiters map[int64]*rate.Limiter + chatMu sync.Mutex +} + +func NewRateLimiter() *RateLimiter { + return &RateLimiter{ + // 30 запросов в секунду (burst=30) + globalLimiter: rate.NewLimiter(rate.Limit(30), 30), + chatLimiters: make(map[int64]*rate.Limiter), + } +} + +func (rl *RateLimiter) SetGlobalLock(retryAfter int) { + if retryAfter <= 0 { + return + } + rl.globalMu.Lock() + defer rl.globalMu.Unlock() + rl.globalLockUntil = time.Now().Add(time.Duration(retryAfter) * time.Second) +} +func (rl *RateLimiter) SetChatLock(chatID int64, retryAfter int) { + rl.chatMu.Lock() + defer rl.chatMu.Unlock() + rl.chatLocks[chatID] = time.Now().Add(time.Duration(retryAfter) * time.Second) +} + +func (rl *RateLimiter) GlobalWait(ctx context.Context) error { + rl.globalMu.RLock() + until := rl.globalLockUntil + rl.globalMu.RUnlock() + + if !until.IsZero() { + if time.Now().Before(until) { + // Ждём до окончания блокировки или отмены контекста + select { + case <-time.After(time.Until(until)): + // блокировка снята + case <-ctx.Done(): + return ctx.Err() + } + } + } + // Теперь ждём разрешения rate limiter'а + return rl.globalLimiter.Wait(ctx) +} +func (rl *RateLimiter) Wait(ctx context.Context, chatID int64) error { + rl.chatMu.Lock() + until, ok := rl.chatLocks[chatID] + rl.chatMu.Unlock() + if ok && !until.IsZero() { + if time.Now().Before(until) { + select { + case <-time.After(time.Until(until)): + // блокировка снята + case <-ctx.Done(): + return ctx.Err() + } + } + } + + if err := rl.GlobalWait(ctx); err != nil { + return err + } + rl.chatMu.Lock() + chatLimiter, ok := rl.chatLimiters[chatID] + if !ok { + chatLimiter = rate.NewLimiter(rate.Limit(1), 1) + rl.chatLimiters[chatID] = chatLimiter + } + rl.chatMu.Unlock() + return chatLimiter.Wait(ctx) +} + +func (rl *RateLimiter) GlobalAllow() bool { + rl.globalMu.RLock() + until := rl.globalLockUntil + rl.globalMu.RUnlock() + + if !until.IsZero() { + if time.Now().Before(until) { + // Ждём до окончания блокировки или отмены контекста + select { + case <-time.After(time.Until(until)): + rl.globalLimiter.Allow() + } + } + } + return rl.globalLimiter.Allow() +} +func (rl *RateLimiter) Allow(chatID int64) bool { + rl.chatMu.Lock() + until, ok := rl.chatLocks[chatID] + rl.chatMu.Unlock() + if ok && !until.IsZero() { + if time.Now().Before(until) { + select { + case <-time.After(time.Until(until)): + } + } + } + + if !rl.globalLimiter.Allow() { + return false + } + rl.chatMu.Lock() + chatLimiter, ok := rl.chatLimiters[chatID] + if !ok { + chatLimiter = rate.NewLimiter(rate.Limit(1), 1) + rl.chatLimiters[chatID] = chatLimiter + } + rl.chatMu.Unlock() + return chatLimiter.Allow() +} diff --git a/utils/version.go b/utils/version.go index 080dc3f..bf1a8b0 100644 --- a/utils/version.go +++ b/utils/version.go @@ -1,9 +1,9 @@ package utils const ( - VersionString = "1.0.0-beta.6" + VersionString = "1.0.0-beta.7" VersionMajor = 1 VersionMinor = 0 VersionPatch = 0 - Beta = 6 + Beta = 7 )