diff --git a/handler.go b/handler.go index 23ad408..7157e1a 100644 --- a/handler.go +++ b/handler.go @@ -151,8 +151,8 @@ func encodeBase64Payload(d CallbackData) (string, error) { if err != nil { return "", err } - dst := make([]byte, base64.StdEncoding.EncodedLen(len([]byte(data)))) - base64.StdEncoding.Encode(dst, []byte(data)) + dst := make([]byte, base64.RawURLEncoding.EncodedLen(len([]byte(data)))) + base64.RawURLEncoding.Encode(dst, []byte(data)) return string(dst), nil } @@ -166,7 +166,7 @@ func encodeBase64Payload(d CallbackData) (string, error) { // return "", ErrInvalidPayloadType // } func decodeBase64Payload(s string) (CallbackData, error) { - b, err := base64.StdEncoding.DecodeString(s) + b, err := base64.RawURLEncoding.DecodeString(s) if err != nil { return CallbackData{}, err } diff --git a/keyboard.go b/keyboard.go index 5a0e218..b1daafd 100644 --- a/keyboard.go +++ b/keyboard.go @@ -109,16 +109,32 @@ type InlineKeyboard struct { payloadType BotPayloadType // Serialization format for callback data (JSON or Base64) } -// NewInlineKeyboard creates a new keyboard builder with the specified maximum +// NewInlineKeyboardJson creates a new keyboard builder with the specified maximum // number of buttons per row. // -// Example: NewInlineKeyboard(3) creates a keyboard with at most 3 buttons per line. -func NewInlineKeyboard(maxRow int) *InlineKeyboard { +// Example: NewInlineKeyboardJson(3) creates a keyboard with at most 3 buttons per line. +func NewInlineKeyboardJson(maxRow int) *InlineKeyboard { + return NewInlineKeyboard(BotPayloadJson, maxRow) +} + +// NewInlineKeyboardBase64 creates a new keyboard builder with the specified maximum +// number of buttons per row, using Base64 encoding for button payloads. +// +// Example: NewInlineKeyboardBase64(3) creates a keyboard with at most 3 buttons per line. +func NewInlineKeyboardBase64(maxRow int) *InlineKeyboard { + return NewInlineKeyboard(BotPayloadBase64, maxRow) +} + +// NewInlineKeyboard creates a new keyboard builder with the specified payload encoding +// type and maximum number of buttons per row. +// +// Use NewInlineKeyboardJson or NewInlineKeyboardBase64 for the common cases. +func NewInlineKeyboard(payloadType BotPayloadType, maxRow int) *InlineKeyboard { return &InlineKeyboard{ CurrentLine: make(extypes.Slice[tgapi.InlineKeyboardButton], 0), Lines: make([][]tgapi.InlineKeyboardButton, 0), maxRow: maxRow, - payloadType: BotPayloadBase64, + payloadType: payloadType, } } diff --git a/msg_context.go b/msg_context.go index 519ef6a..222923f 100644 --- a/msg_context.go +++ b/msg_context.go @@ -391,6 +391,9 @@ func (ctx *MsgContext) NewDraft() *Draft { return ctx.newDraft(tgapi.ParseNone) } +// NewDraftMarkdown creates a new message draft associated with the current chat, +// with Markdown V2 parse mode enabled. +// Uses the API limiter to avoid rate limiting. func (ctx *MsgContext) NewDraftMarkdown() *Draft { return ctx.newDraft(tgapi.ParseMDV2) } @@ -404,3 +407,9 @@ func (ctx *MsgContext) Translate(key string) string { lang := Val(ctx.From.LanguageCode, ctx.l10n.GetFallbackLanguage()) return ctx.l10n.Translate(lang, key) } + +// NewInlineKeyboard creates a new keyboard builder with the context's payload +// encoding type and the specified maximum number of buttons per row. +func (ctx *MsgContext) NewInlineKeyboard(maxRow int) *InlineKeyboard { + return NewInlineKeyboard(ctx.payloadType, maxRow) +} diff --git a/tgapi/api.go b/tgapi/api.go index 5c624fd..228e417 100644 --- a/tgapi/api.go +++ b/tgapi/api.go @@ -161,27 +161,13 @@ type TelegramRequest[R, P any] struct { chatId int64 } -// NewRequest and NewRequestWithChatID are DEPRECATED. -// They encourage unsafe, untyped usage and bypass Go's type safety. -// Instead, define explicit, type-safe methods for each Telegram API endpoint. -// -// Example: -// -// func (api *API) SendMessage(ctx context.Context, chatID int64, text string) (Message, error) { ... } -// -// This provides: -// -// ✅ Compile-time validation -// ✅ IDE autocompletion -// ✅ Clear API surface -// ✅ Better error messages -// -// DO NOT use these constructors in production code. -// This can be used ONLY for testing or if you NEED method, that wasn't added as function. +// NewRequest creates an untyped TelegramRequest for the given method and params with no chat ID. func NewRequest[R, P any](method string, params P) TelegramRequest[R, P] { return TelegramRequest[R, P]{method, params, 0} } +// NewRequestWithChatID creates an untyped TelegramRequest with an associated chat ID. +// The chat ID is used for per-chat rate limiting. func NewRequestWithChatID[R, P any](method string, params P, chatId int64) TelegramRequest[R, P] { return TelegramRequest[R, P]{method, params, chatId} } @@ -191,12 +177,10 @@ func NewRequestWithChatID[R, P any](method string, params P, chatId int64) Teleg // Must be called within a worker pool context if using DoWithContext. func (r TelegramRequest[R, P]) doRequest(ctx context.Context, api *API) (R, error) { var zero R - - data, err := json.Marshal(r.params) + reqData, err := json.Marshal(r.params) if err != nil { return zero, fmt.Errorf("failed to marshal request: %w", err) } - buf := bytes.NewBuffer(data) methodPrefix := "" if api.useTestServer { @@ -204,7 +188,7 @@ func (r TelegramRequest[R, P]) doRequest(ctx context.Context, api *API) (R, erro } url := fmt.Sprintf("%s/bot%s%s/%s", api.apiUrl, api.token, methodPrefix, r.method) - req, err := http.NewRequestWithContext(ctx, "POST", url, buf) + req, err := http.NewRequestWithContext(ctx, "POST", url, nil) if err != nil { return zero, fmt.Errorf("failed to create request: %w", err) } @@ -213,7 +197,6 @@ func (r TelegramRequest[R, P]) doRequest(ctx context.Context, api *API) (R, erro req.Header.Set("Accept", "application/json") req.Header.Set("User-Agent", fmt.Sprintf("Laniakea/%s", utils.VersionString)) req.Header.Set("Accept-Encoding", "gzip") - req.ContentLength = int64(len(data)) for { // Apply rate limiting before making the request @@ -222,22 +205,25 @@ func (r TelegramRequest[R, P]) doRequest(ctx context.Context, api *API) (R, erro return zero, err } } + buf := bytes.NewBuffer(reqData) + req.Body = io.NopCloser(buf) + req.ContentLength = int64(len(reqData)) - api.logger.Debugln("REQ", url, string(data)) + api.logger.Debugln("REQ", url, string(reqData)) resp, err := api.client.Do(req) if err != nil { return zero, fmt.Errorf("HTTP request failed: %w", err) } - data, err = readBody(resp.Body) + respData, err := readBody(resp.Body) _ = resp.Body.Close() // ensure body is closed if err != nil { return zero, fmt.Errorf("failed to read response body: %w", err) } - api.logger.Debugln("RES", r.method, string(data)) + api.logger.Debugln("RES", r.method, string(respData)) - response, err := parseBody[R](data) + response, err := parseBody[R](respData) if err != nil { return zero, fmt.Errorf("failed to parse response: %w", err) } @@ -249,10 +235,12 @@ func (r TelegramRequest[R, P]) doRequest(ctx context.Context, api *API) (R, erro api.logger.Warnf("Rate limited by Telegram, retry after %d seconds (chat: %d)", after, r.chatId) // Apply cooldown to global or chat-specific limiter - if r.chatId > 0 { - api.Limiter.SetChatLock(r.chatId, after) - } else { - api.Limiter.SetGlobalLock(after) + if api.Limiter != nil { + if r.chatId > 0 { + api.Limiter.SetChatLock(r.chatId, after) + } else { + api.Limiter.SetGlobalLock(after) + } } // Wait and retry @@ -311,21 +299,13 @@ func readBody(body io.ReadCloser) ([]byte, error) { return io.ReadAll(reader) } -// parseBody unmarshals Telegram API response and returns structured result. -// Returns ErrRateLimit internally if error_code == 429 — caller must handle via response.Ok check. +// parseBody unmarshals a Telegram API response into a typed ApiResponse. +// Only returns an error on malformed JSON; non-OK responses are left for the caller to handle. func parseBody[R any](data []byte) (ApiResponse[R], error) { var resp ApiResponse[R] err := json.Unmarshal(data, &resp) if err != nil { return resp, fmt.Errorf("failed to unmarshal JSON: %w", err) } - - if !resp.Ok { - if resp.ErrorCode == 429 { - return resp, ErrRateLimit // internal use only - } - return resp, fmt.Errorf("[%d] %s", resp.ErrorCode, resp.Description) - } - return resp, nil } diff --git a/tgapi/uploader_api.go b/tgapi/uploader_api.go index 3d0a60f..efa0451 100644 --- a/tgapi/uploader_api.go +++ b/tgapi/uploader_api.go @@ -3,7 +3,6 @@ package tgapi import ( "bytes" "context" - "errors" "fmt" "mime/multipart" "net/http" @@ -24,13 +23,18 @@ const ( UploaderThumbnailType UploaderFileType = "thumbnail" ) +// UploaderFileType represents the Telegram form field name for a file upload. type UploaderFileType string + +// UploaderFile holds the data and metadata for a single file to be uploaded. type UploaderFile struct { filename string data []byte field UploaderFileType } +// NewUploaderFile creates a new UploaderFile, auto-detecting the field type from the file extension. +// If detection is incorrect, use SetType to override. func NewUploaderFile(name string, data []byte) UploaderFile { t := uploaderTypeByExt(name) return UploaderFile{filename: name, data: data, field: t} @@ -56,6 +60,8 @@ func NewUploader(api *API) *Uploader { func (u *Uploader) Close() error { return u.logger.Close() } func (u *Uploader) GetLogger() *slog.Logger { return u.logger } +// UploaderRequest is a multipart file upload request to the Telegram API. +// Use NewUploaderRequest or NewUploaderRequestWithChatID to construct one. type UploaderRequest[R, P any] struct { method string files []UploaderFile @@ -63,40 +69,30 @@ type UploaderRequest[R, P any] struct { chatId int64 } +// NewUploaderRequest creates a new multipart upload request with no associated chat ID. func NewUploaderRequest[R, P any](method string, params P, files ...UploaderFile) UploaderRequest[R, P] { return UploaderRequest[R, P]{method: method, files: files, params: params, chatId: 0} } + +// NewUploaderRequestWithChatID creates a new multipart upload request with an associated chat ID. +// The chat ID is used for per-chat rate limiting. func NewUploaderRequestWithChatID[R, P any](method string, params P, chatId int64, files ...UploaderFile) UploaderRequest[R, P] { return UploaderRequest[R, P]{method: method, files: files, params: params, chatId: chatId} } func (r UploaderRequest[R, P]) doRequest(ctx context.Context, up *Uploader) (R, error) { var zero R - buf, contentType, err := prepareMultipart(r.files, r.params) - if err != nil { - return zero, err - } - methodPrefix := "" if up.api.useTestServer { methodPrefix = "/test" } url := fmt.Sprintf("%s/bot%s%s/%s", up.api.apiUrl, up.api.token, methodPrefix, r.method) - req, err := http.NewRequestWithContext(ctx, "POST", url, buf) - if err != nil { - return zero, err - } - req.Header.Set("Content-Type", contentType) - req.Header.Set("Accept", "application/json") - req.Header.Set("User-Agent", fmt.Sprintf("Laniakea/%s", utils.VersionString)) - req.Header.Set("Accept-Encoding", "gzip") - req.ContentLength = int64(buf.Len()) for { if up.api.Limiter != nil { if up.api.dropOverflowLimit { if !up.api.Limiter.GlobalAllow() { - return zero, errors.New("rate limited") + return zero, utils.ErrDropOverflow } } else { if err := up.api.Limiter.GlobalWait(ctx); err != nil { @@ -105,6 +101,20 @@ func (r UploaderRequest[R, P]) doRequest(ctx context.Context, up *Uploader) (R, } } + buf, contentType, err := prepareMultipart(r.files, r.params) + if err != nil { + return zero, err + } + req, err := http.NewRequestWithContext(ctx, "POST", url, buf) + if err != nil { + return zero, err + } + req.Header.Set("Content-Type", contentType) + req.Header.Set("Accept", "application/json") + req.Header.Set("User-Agent", fmt.Sprintf("Laniakea/%s", utils.VersionString)) + req.Header.Set("Accept-Encoding", "gzip") + req.ContentLength = int64(buf.Len()) + up.logger.Debugln("UPLOADER REQ", r.method) resp, err := up.api.client.Do(req) if err != nil { @@ -127,10 +137,12 @@ func (r UploaderRequest[R, P]) doRequest(ctx context.Context, up *Uploader) (R, if response.ErrorCode == 429 && response.Parameters != nil && response.Parameters.RetryAfter != nil { after := *response.Parameters.RetryAfter up.logger.Warnf("Rate limited, retry after %d seconds (chat: %d)", after, r.chatId) - if r.chatId > 0 { - up.api.Limiter.SetChatLock(r.chatId, after) - } else { - up.api.Limiter.SetGlobalLock(after) + if up.api.Limiter != nil { + if r.chatId > 0 { + up.api.Limiter.SetChatLock(r.chatId, after) + } else { + up.api.Limiter.SetGlobalLock(after) + } } select { @@ -145,6 +157,9 @@ func (r UploaderRequest[R, P]) doRequest(ctx context.Context, up *Uploader) (R, return response.Result, nil } } + +// DoWithContext executes the upload request asynchronously via the worker pool. +// Returns the result or error. Respects context cancellation. func (r UploaderRequest[R, P]) DoWithContext(ctx context.Context, up *Uploader) (R, error) { var zero R @@ -168,10 +183,15 @@ func (r UploaderRequest[R, P]) DoWithContext(ctx context.Context, up *Uploader) return zero, ErrPoolUnexpected } } + +// Do executes the upload request synchronously with a background context. +// Use only for simple, non-critical uploads. func (r UploaderRequest[R, P]) Do(up *Uploader) (R, error) { return r.DoWithContext(context.Background(), up) } +// prepareMultipart builds a multipart form body from the given files and params. +// Params are encoded via utils.Encode. The writer boundary is finalized before returning. func prepareMultipart[P any](files []UploaderFile, params P) (*bytes.Buffer, string, error) { buf := bytes.NewBuffer(nil) w := multipart.NewWriter(buf) @@ -204,6 +224,8 @@ func prepareMultipart[P any](files []UploaderFile, params P) (*bytes.Buffer, str return buf, w.FormDataContentType(), nil } +// uploaderTypeByExt infers the Telegram upload field name from a file extension. +// Falls back to UploaderDocumentType for unrecognized extensions. func uploaderTypeByExt(filename string) UploaderFileType { ext := filepath.Ext(filename) switch ext { diff --git a/utils/limiter.go b/utils/limiter.go index 02c7330..f15811c 100644 --- a/utils/limiter.go +++ b/utils/limiter.go @@ -22,7 +22,7 @@ type RateLimiter struct { chatLocks map[int64]time.Time // per-chat cooldown timestamps chatLimiters map[int64]*rate.Limiter // per-chat token buckets (1 req/sec) - chatMu sync.Mutex // protects chatLocks and chatLimiters + chatMu sync.RWMutex // protects chatLocks and chatLimiters } // NewRateLimiter creates a new RateLimiter with default limits. @@ -107,9 +107,9 @@ func (rl *RateLimiter) Allow(chatID int64) bool { } // Check chat cooldown - rl.chatMu.Lock() + rl.chatMu.RLock() chatUntil, ok := rl.chatLocks[chatID] - rl.chatMu.Unlock() + rl.chatMu.RUnlock() if ok && !chatUntil.IsZero() && time.Now().Before(chatUntil) { return false } @@ -135,11 +135,15 @@ func (rl *RateLimiter) Allow(chatID int64) bool { // chatID == 0 means no specific chat context (e.g., inline query, webhook without chat). func (rl *RateLimiter) Check(ctx context.Context, dropOverflow bool, chatID int64) error { if dropOverflow { - if chatID != 0 && !rl.Allow(chatID) { - return ErrDropOverflow - } - if !rl.GlobalAllow() { - return ErrDropOverflow + if chatID != 0 { + if !rl.Allow(chatID) { + + return ErrDropOverflow + } + } else { + if !rl.GlobalAllow() { + return ErrDropOverflow + } } } else if chatID != 0 { if err := rl.Wait(ctx, chatID); err != nil { @@ -175,9 +179,9 @@ func (rl *RateLimiter) waitForGlobalUnlock(ctx context.Context) error { // waitForChatUnlock blocks until the specified chat's cooldown expires or context is done. // Does not check token bucket — only cooldown. func (rl *RateLimiter) waitForChatUnlock(ctx context.Context, chatID int64) error { - rl.chatMu.Lock() + rl.chatMu.RLock() until, ok := rl.chatLocks[chatID] - rl.chatMu.Unlock() + rl.chatMu.RUnlock() if !ok || until.IsZero() || time.Now().After(until) { return nil diff --git a/utils/multipart.go b/utils/multipart.go index 0dfdac0..df6131f 100644 --- a/utils/multipart.go +++ b/utils/multipart.go @@ -49,11 +49,9 @@ func Encode[T any](w *multipart.Writer, req T) error { switch field.Kind() { case reflect.String: - if !isEmpty { - fw, err = w.CreateFormField(fieldName) - if err == nil { - _, err = fw.Write([]byte(field.String())) - } + fw, err = w.CreateFormField(fieldName) + if err == nil { + _, err = fw.Write([]byte(field.String())) } case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: fw, err = w.CreateFormField(fieldName) @@ -65,11 +63,17 @@ func Encode[T any](w *multipart.Writer, req T) error { if err == nil { _, err = fw.Write([]byte(strconv.FormatUint(field.Uint(), 10))) } - case reflect.Float32, reflect.Float64: + case reflect.Float32: + fw, err = w.CreateFormField(fieldName) + if err == nil { + _, err = fw.Write([]byte(strconv.FormatFloat(field.Float(), 'f', -1, 32))) + } + case reflect.Float64: fw, err = w.CreateFormField(fieldName) if err == nil { _, err = fw.Write([]byte(strconv.FormatFloat(field.Float(), 'f', -1, 64))) } + case reflect.Bool: fw, err = w.CreateFormField(fieldName) if err == nil { @@ -103,8 +107,12 @@ func Encode[T any](w *multipart.Writer, req T) error { _, err = fw.Write([]byte(strconv.FormatUint(elem.Uint(), 10))) case reflect.Bool: _, err = fw.Write([]byte(strconv.FormatBool(elem.Bool()))) - case reflect.Float32, reflect.Float64: + case reflect.Float32: + _, err = fw.Write([]byte(strconv.FormatFloat(elem.Float(), 'f', -1, 32))) + case reflect.Float64: _, err = fw.Write([]byte(strconv.FormatFloat(elem.Float(), 'f', -1, 64))) + default: + continue } if err != nil { break diff --git a/utils/version.go b/utils/version.go index 5e759e1..cc789d9 100644 --- a/utils/version.go +++ b/utils/version.go @@ -1,9 +1,9 @@ package utils const ( - VersionString = "1.0.0-beta.20" + VersionString = "1.0.0-beta.21" VersionMajor = 1 VersionMinor = 0 VersionPatch = 0 - VersionBeta = 20 + VersionBeta = 21 )