diff --git a/bot.go b/bot.go index bf077ba..a08ee4d 100644 --- a/bot.go +++ b/bot.go @@ -184,21 +184,34 @@ func NewBot[T any](opts *BotOpts) *Bot[T] { // // Returns the first error encountered, if any. func (bot *Bot[T]) Close() error { + var firstErr error + if err := bot.uploader.Close(); err != nil { bot.logger.Errorln(err) + if firstErr == nil { + firstErr = err + } } if err := bot.api.CloseApi(); err != nil { bot.logger.Errorln(err) + if firstErr == nil { + firstErr = err + } } if bot.RequestLogger != nil { if err := bot.RequestLogger.Close(); err != nil { bot.logger.Errorln(err) + if firstErr == nil { + firstErr = err + } } } if err := bot.logger.Close(); err != nil { - return err + if firstErr == nil { + firstErr = err + } } - return nil + return firstErr } // initLoggers configures the main and optional request loggers. diff --git a/bot_opts.go b/bot_opts.go index ef49264..ae04a19 100644 --- a/bot_opts.go +++ b/bot_opts.go @@ -85,10 +85,10 @@ func LoadOptsFromEnv() *BotOpts { } } - stringUpdateTypes := strings.Split(os.Getenv("UPDATE_TYPES"), ";") - updateTypes := make([]tgapi.UpdateType, len(stringUpdateTypes)) - for i, updateType := range stringUpdateTypes { - updateTypes[i] = tgapi.UpdateType(updateType) + stringUpdateTypes := splitEnvList(os.Getenv("UPDATE_TYPES")) + updateTypes := make([]tgapi.UpdateType, 0, len(stringUpdateTypes)) + for _, updateType := range stringUpdateTypes { + updateTypes = append(updateTypes, tgapi.UpdateType(updateType)) } return &BotOpts{ @@ -222,5 +222,25 @@ func LoadPrefixesFromEnv() []string { if !exists { return []string{"/"} } - return strings.Split(prefixesS, ";") + prefixes := splitEnvList(prefixesS) + if len(prefixes) == 0 { + return []string{"/"} + } + return prefixes +} + +func splitEnvList(value string) []string { + if value == "" { + return nil + } + parts := strings.Split(value, ";") + out := make([]string, 0, len(parts)) + for _, part := range parts { + part = strings.TrimSpace(part) + if part == "" { + continue + } + out = append(out, part) + } + return out } diff --git a/bot_opts_test.go b/bot_opts_test.go new file mode 100644 index 0000000..9f3c1be --- /dev/null +++ b/bot_opts_test.go @@ -0,0 +1,47 @@ +package laniakea + +import ( + "reflect" + "testing" + + "git.nix13.pw/scuroneko/laniakea/tgapi" +) + +func TestLoadOptsFromEnvIgnoresEmptyUpdateTypes(t *testing.T) { + t.Setenv("UPDATE_TYPES", "") + + opts := LoadOptsFromEnv() + if len(opts.UpdateTypes) != 0 { + t.Fatalf("expected no update types, got %v", opts.UpdateTypes) + } +} + +func TestLoadOptsFromEnvSplitsAndTrimsUpdateTypes(t *testing.T) { + t.Setenv("UPDATE_TYPES", "message; ; callback_query ") + + opts := LoadOptsFromEnv() + want := []tgapi.UpdateType{tgapi.UpdateTypeMessage, tgapi.UpdateTypeCallbackQuery} + if !reflect.DeepEqual(opts.UpdateTypes, want) { + t.Fatalf("unexpected update types: got %v want %v", opts.UpdateTypes, want) + } +} + +func TestLoadPrefixesFromEnvDefaultsOnEmptyValue(t *testing.T) { + t.Setenv("PREFIXES", "") + + got := LoadPrefixesFromEnv() + want := []string{"/"} + if !reflect.DeepEqual(got, want) { + t.Fatalf("unexpected prefixes: got %v want %v", got, want) + } +} + +func TestLoadPrefixesFromEnvDropsEmptyValues(t *testing.T) { + t.Setenv("PREFIXES", "/; ; ! ") + + got := LoadPrefixesFromEnv() + want := []string{"/", "!"} + if !reflect.DeepEqual(got, want) { + t.Fatalf("unexpected prefixes: got %v want %v", got, want) + } +} diff --git a/cmd_generator.go b/cmd_generator.go index 388dfb0..a5b0d40 100644 --- a/cmd_generator.go +++ b/cmd_generator.go @@ -100,17 +100,17 @@ func gatherCommands[T any](bot *Bot[T]) []tgapi.BotCommand { // log.Fatal(err) // } func (bot *Bot[T]) AutoGenerateCommands() error { + commands := gatherCommands(bot) + if len(commands) > 100 { + return ErrTooManyCommands + } + // Clear existing commands to avoid duplication or stale entries _, err := bot.api.DeleteMyCommands(tgapi.DeleteMyCommandsP{}) if err != nil { return fmt.Errorf("failed to delete existing commands: %w", err) } - commands := gatherCommands(bot) - if len(commands) > 100 { - return ErrTooManyCommands - } - // Register commands for each scope scopes := []*tgapi.BotCommandScope{ {Type: tgapi.BotCommandScopePrivateType}, @@ -148,15 +148,16 @@ func (bot *Bot[T]) AutoGenerateCommands() error { // log.Fatal(err) // } func (bot *Bot[T]) AutoGenerateCommandsForScope(scope *tgapi.BotCommandScope) error { - _, err := bot.api.DeleteMyCommands(tgapi.DeleteMyCommandsP{Scope: scope}) - if err != nil { - return fmt.Errorf("failed to delete existing commands: %w", err) - } commands := gatherCommands(bot) if len(commands) > 100 { return ErrTooManyCommands } + _, err := bot.api.DeleteMyCommands(tgapi.DeleteMyCommandsP{Scope: scope}) + if err != nil { + return fmt.Errorf("failed to delete existing commands: %w", err) + } + _, err = bot.api.SetMyCommands(tgapi.SetMyCommandsP{ Commands: commands, Scope: scope, diff --git a/cmd_generator_test.go b/cmd_generator_test.go new file mode 100644 index 0000000..2931170 --- /dev/null +++ b/cmd_generator_test.go @@ -0,0 +1,64 @@ +package laniakea + +import ( + "errors" + "io" + "net/http" + "strconv" + "strings" + "sync/atomic" + "testing" + + "git.nix13.pw/scuroneko/laniakea/tgapi" + "git.nix13.pw/scuroneko/slog" +) + +type roundTripFunc func(*http.Request) (*http.Response, error) + +func (fn roundTripFunc) RoundTrip(req *http.Request) (*http.Response, error) { + return fn(req) +} + +func TestAutoGenerateCommandsChecksLimitBeforeDelete(t *testing.T) { + var calls atomic.Int64 + client := &http.Client{ + Transport: roundTripFunc(func(req *http.Request) (*http.Response, error) { + calls.Add(1) + return &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{"Content-Type": []string{"application/json"}}, + Body: io.NopCloser(strings.NewReader(`{"ok":true,"result":true}`)), + }, nil + }), + } + api := tgapi.NewAPI( + tgapi.NewAPIOpts("token"). + SetAPIUrl("https://example.test"). + SetHTTPClient(client), + ) + defer func() { + if err := api.CloseApi(); err != nil { + t.Fatalf("CloseApi returned error: %v", err) + } + }() + + plugin := NewPlugin[NoDB]("overflow") + exec := func(ctx *MsgContext, db *NoDB) {} + for i := 0; i < 101; i++ { + plugin.AddCommand(NewCommand(exec, "cmd"+strconv.Itoa(i))) + } + + bot := &Bot[NoDB]{ + api: api, + logger: slog.CreateLogger(), + plugins: []Plugin[NoDB]{*plugin}, + } + + err := bot.AutoGenerateCommands() + if !errors.Is(err, ErrTooManyCommands) { + t.Fatalf("expected ErrTooManyCommands, got %v", err) + } + if calls.Load() != 0 { + t.Fatalf("expected no HTTP calls before limit validation, got %d", calls.Load()) + } +} diff --git a/go.mod b/go.mod index db3117a..b6c1042 100644 --- a/go.mod +++ b/go.mod @@ -3,8 +3,8 @@ module git.nix13.pw/scuroneko/laniakea go 1.26 require ( - git.nix13.pw/scuroneko/extypes v1.2.1 - git.nix13.pw/scuroneko/slog v1.0.2 + git.nix13.pw/scuroneko/extypes v1.2.2 + git.nix13.pw/scuroneko/slog v1.1.2 github.com/alitto/pond/v2 v2.7.0 golang.org/x/time v0.15.0 ) diff --git a/go.sum b/go.sum index a1e3c21..72b92c1 100644 --- a/go.sum +++ b/go.sum @@ -1,7 +1,7 @@ -git.nix13.pw/scuroneko/extypes v1.2.1 h1:IYrOjnWKL2EAuJYtYNa+luB1vBe6paE8VY/YD+5/RpQ= -git.nix13.pw/scuroneko/extypes v1.2.1/go.mod h1:uZVs8Yo3RrYAG9dMad6qR6lsYY67t+459D9c65QAYAw= -git.nix13.pw/scuroneko/slog v1.0.2 h1:vZyUROygxC2d5FJHUQM/30xFEHY1JT/aweDZXA4rm2g= -git.nix13.pw/scuroneko/slog v1.0.2/go.mod h1:3Qm2wzkR5KjwOponMfG7TcGSDjmYaFqRAmLvSPTuWJI= +git.nix13.pw/scuroneko/extypes v1.2.2 h1:N54c1ejrPs1yfIkvYuwqI7B1+8S9mDv2GqQA6sct4dk= +git.nix13.pw/scuroneko/extypes v1.2.2/go.mod h1:b4XYk1OW1dVSiE2MT/OMuX/K/UItf1swytX6eroVYnk= +git.nix13.pw/scuroneko/slog v1.1.2 h1:pl7tV5FN25Yso7sLYoOgBXi9+jLo5BDJHWmHlNPjpY0= +git.nix13.pw/scuroneko/slog v1.1.2/go.mod h1:UcfRIHDqpVQHahBGM93awLDK8//AsAvOqBwwbWqMkjM= github.com/alitto/pond/v2 v2.7.0 h1:c76L+yN916m/DRXjGCeUBHHu92uWnh/g1bwVk4zyyXg= github.com/alitto/pond/v2 v2.7.0/go.mod h1:xkjYEgQ05RSpWdfSd1nM3OVv7TBhLdy7rMp3+2Nq+yE= github.com/fatih/color v1.18.0 h1:S8gINlzdQ840/4pfAwic/ZE0djQEH3wM94VfqLTZcOM= diff --git a/handler.go b/handler.go index 1afa576..b07228d 100644 --- a/handler.go +++ b/handler.go @@ -138,6 +138,9 @@ func (bot *Bot[T]) handleCallback(update *tgapi.Update, ctx *MsgContext) { func (bot *Bot[T]) checkPrefixes(text string) (string, bool) { for _, prefix := range bot.prefixes { + if prefix == "" { + continue + } if strings.HasPrefix(text, prefix) { return prefix, true } diff --git a/handler_test.go b/handler_test.go new file mode 100644 index 0000000..789dae5 --- /dev/null +++ b/handler_test.go @@ -0,0 +1,14 @@ +package laniakea + +import "testing" + +func TestCheckPrefixesSkipsEmptyPrefixes(t *testing.T) { + bot := &Bot[NoDB]{prefixes: []string{"", "/"}} + + if prefix, ok := bot.checkPrefixes("hello"); ok { + t.Fatalf("unexpected prefix match for plain text: %q", prefix) + } + if prefix, ok := bot.checkPrefixes("/start"); !ok || prefix != "/" { + t.Fatalf("unexpected prefix result: prefix=%q ok=%v", prefix, ok) + } +} diff --git a/msg_context.go b/msg_context.go index b287d95..725b59e 100644 --- a/msg_context.go +++ b/msg_context.go @@ -270,6 +270,9 @@ func (ctx *MsgContext) answerPhoto(photoId, text string, kb *InlineKeyboard, par if ctx.Msg.MessageThreadID > 0 { params.MessageThreadID = ctx.Msg.MessageThreadID } + if ctx.Msg.DirectMessageTopic != nil { + params.DirectMessagesTopicID = int(ctx.Msg.DirectMessageTopic.TopicID) + } msg, err := ctx.Api.SendPhoto(params) if err != nil { diff --git a/msg_context_test.go b/msg_context_test.go new file mode 100644 index 0000000..f6cb124 --- /dev/null +++ b/msg_context_test.go @@ -0,0 +1,64 @@ +package laniakea + +import ( + "encoding/json" + "io" + "net/http" + "strings" + "testing" + + "git.nix13.pw/scuroneko/laniakea/tgapi" + "git.nix13.pw/scuroneko/slog" +) + +func TestAnswerPhotoIncludesDirectMessagesTopicID(t *testing.T) { + var gotBody map[string]any + + client := &http.Client{ + Transport: roundTripFunc(func(req *http.Request) (*http.Response, error) { + body, err := io.ReadAll(req.Body) + if err != nil { + t.Fatalf("failed to read request body: %v", err) + } + if err := json.Unmarshal(body, &gotBody); err != nil { + t.Fatalf("failed to decode request body: %v", err) + } + return &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{"Content-Type": []string{"application/json"}}, + Body: io.NopCloser(strings.NewReader(`{"ok":true,"result":{"message_id":9,"date":1}}`)), + }, nil + }), + } + + api := tgapi.NewAPI( + tgapi.NewAPIOpts("token"). + SetAPIUrl("https://example.test"). + SetHTTPClient(client), + ) + defer func() { + if err := api.CloseApi(); err != nil { + t.Fatalf("CloseApi returned error: %v", err) + } + }() + + ctx := &MsgContext{ + Api: api, + Msg: &tgapi.Message{ + Chat: &tgapi.Chat{ID: 42, Type: string(tgapi.ChatTypePrivate)}, + DirectMessageTopic: &tgapi.DirectMessageTopic{TopicID: 77}, + }, + botLogger: slog.CreateLogger(), + } + + answer := ctx.AnswerPhoto("photo-id", "caption") + if answer == nil { + t.Fatal("expected answer message") + } + if answer.MessageID != 9 { + t.Fatalf("unexpected message id: %d", answer.MessageID) + } + if got := gotBody["direct_messages_topic_id"]; got != float64(77) { + t.Fatalf("unexpected direct_messages_topic_id: %v", got) + } +} diff --git a/plugins.go b/plugins.go index 091e449..134c900 100644 --- a/plugins.go +++ b/plugins.go @@ -23,11 +23,11 @@ const ( var ( // CommandRegexInt matches one or more digits. - CommandRegexInt = regexp.MustCompile(`\d+`) + CommandRegexInt = regexp.MustCompile(`^\d+$`) // CommandRegexString matches any non-empty string. - CommandRegexString = regexp.MustCompile(`.+`) + CommandRegexString = regexp.MustCompile(`^.+$`) // CommandRegexBool matches true or false. - CommandRegexBool = regexp.MustCompile(`true|false`) + CommandRegexBool = regexp.MustCompile(`^(true|false)$`) ) // ErrCmdArgCountMismatch is returned when the number of provided arguments @@ -64,6 +64,7 @@ func (c *CommandArg) SetValueType(t CommandValueType) *CommandArg { case CommandValueAnyType: regex = nil // Skip validation } + c.valueType = t c.regex = regex return c } diff --git a/plugins_test.go b/plugins_test.go new file mode 100644 index 0000000..a8fd856 --- /dev/null +++ b/plugins_test.go @@ -0,0 +1,24 @@ +package laniakea + +import ( + "errors" + "testing" +) + +func TestValidateArgsRequiresFullMatch(t *testing.T) { + intCmd := NewCommand[NoDB](func(ctx *MsgContext, db *NoDB) {}, "int", *NewCommandArg("n").SetValueType(CommandValueIntType).SetRequired()) + if err := intCmd.validateArgs([]string{"123"}); err != nil { + t.Fatalf("expected valid integer argument, got %v", err) + } + if err := intCmd.validateArgs([]string{"123abc"}); !errors.Is(err, ErrCmdArgRegexpMismatch) { + t.Fatalf("expected ErrCmdArgRegexpMismatch for partial int match, got %v", err) + } + + boolCmd := NewCommand[NoDB](func(ctx *MsgContext, db *NoDB) {}, "bool", *NewCommandArg("flag").SetValueType(CommandValueBoolType).SetRequired()) + if err := boolCmd.validateArgs([]string{"false"}); err != nil { + t.Fatalf("expected valid bool argument, got %v", err) + } + if err := boolCmd.validateArgs([]string{"falsey"}); !errors.Is(err, ErrCmdArgRegexpMismatch) { + t.Fatalf("expected ErrCmdArgRegexpMismatch for partial bool match, got %v", err) + } +} diff --git a/tgapi/api.go b/tgapi/api.go index 98e2c50..30a16c6 100644 --- a/tgapi/api.go +++ b/tgapi/api.go @@ -202,7 +202,6 @@ func (r TelegramRequest[R, P]) doRequest(ctx context.Context, api *API) (R, erro req.Header.Set("Content-Type", "application/json") req.Header.Set("Accept", "application/json") req.Header.Set("User-Agent", fmt.Sprintf("Laniakea/%s", utils.VersionString)) - req.Header.Set("Accept-Encoding", "gzip") for { // Apply rate limiting before making the request diff --git a/tgapi/api_test.go b/tgapi/api_test.go new file mode 100644 index 0000000..2f3db3d --- /dev/null +++ b/tgapi/api_test.go @@ -0,0 +1,56 @@ +package tgapi + +import ( + "io" + "net/http" + "strings" + "testing" +) + +type roundTripFunc func(*http.Request) (*http.Response, error) + +func (fn roundTripFunc) RoundTrip(req *http.Request) (*http.Response, error) { + return fn(req) +} + +func TestAPILeavesAcceptEncodingToHTTPTransport(t *testing.T) { + var gotPath string + var gotAcceptEncoding string + + client := &http.Client{ + Transport: roundTripFunc(func(req *http.Request) (*http.Response, error) { + gotPath = req.URL.Path + gotAcceptEncoding = req.Header.Get("Accept-Encoding") + return &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{"Content-Type": []string{"application/json"}}, + Body: io.NopCloser(strings.NewReader(`{"ok":true,"result":{"id":1,"is_bot":true,"first_name":"Test"}}`)), + }, nil + }), + } + + api := NewAPI( + NewAPIOpts("token"). + SetAPIUrl("https://example.test"). + SetHTTPClient(client), + ) + defer func() { + if err := api.CloseApi(); err != nil { + t.Fatalf("CloseApi returned error: %v", err) + } + }() + + user, err := api.GetMe() + if err != nil { + t.Fatalf("GetMe returned error: %v", err) + } + if user.FirstName != "Test" { + t.Fatalf("unexpected first name: %q", user.FirstName) + } + if gotPath != "/bottoken/getMe" { + t.Fatalf("unexpected request path: %s", gotPath) + } + if gotAcceptEncoding != "" { + t.Fatalf("expected empty Accept-Encoding header, got %q", gotAcceptEncoding) + } +} diff --git a/tgapi/business_types.go b/tgapi/business_types.go index ca3c37a..56a5f9f 100644 --- a/tgapi/business_types.go +++ b/tgapi/business_types.go @@ -60,6 +60,14 @@ type BusinessConnection struct { IsEnabled bool `json:"is_enabled"` } +// BusinessMessagesDeleted is received when messages are deleted from a connected business account. +// See https://core.telegram.org/bots/api#businessmessagesdeleted +type BusinessMessagesDeleted struct { + BusinessConnectionID string `json:"business_connection_id"` + Chat Chat `json:"chat"` + MessageIDs []int `json:"message_ids"` +} + // InputStoryContentType indicates the type of input story content. type InputStoryContentType string diff --git a/tgapi/messages_types.go b/tgapi/messages_types.go index 642c66b..6ef5fe9 100644 --- a/tgapi/messages_types.go +++ b/tgapi/messages_types.go @@ -144,12 +144,12 @@ type LinkPreviewOptions struct { type ReplyMarkup struct { InlineKeyboard [][]InlineKeyboardButton `json:"inline_keyboard,omitempty"` - Keyboard [][]int `json:"keyboard,omitempty"` - IsPersistent bool `json:"is_persistent,omitempty"` - ResizeKeyboard bool `json:"resize_keyboard,omitempty"` - OneTimeKeyboard bool `json:"one_time_keyboard,omitempty"` - InputFieldPlaceholder string `json:"input_field_placeholder,omitempty"` - Selective bool `json:"selective,omitempty"` + Keyboard [][]KeyboardButton `json:"keyboard,omitempty"` + IsPersistent bool `json:"is_persistent,omitempty"` + ResizeKeyboard bool `json:"resize_keyboard,omitempty"` + OneTimeKeyboard bool `json:"one_time_keyboard,omitempty"` + InputFieldPlaceholder string `json:"input_field_placeholder,omitempty"` + Selective bool `json:"selective,omitempty"` RemoveKeyboard bool `json:"remove_keyboard,omitempty"` @@ -165,6 +165,60 @@ type InlineKeyboardMarkup struct { // KeyboardButtonStyle represents the style of a keyboard button. type KeyboardButtonStyle string +const ( + KeyboardButtonStyleDanger KeyboardButtonStyle = "danger" + KeyboardButtonStyleSuccess KeyboardButtonStyle = "success" + KeyboardButtonStylePrimary KeyboardButtonStyle = "primary" +) + +// KeyboardButton represents one button of the reply keyboard. +// See https://core.telegram.org/bots/api#keyboardbutton +type KeyboardButton struct { + Text string `json:"text"` + IconCustomEmojiID string `json:"icon_custom_emoji_id,omitempty"` + Style KeyboardButtonStyle `json:"style,omitempty"` + RequestUsers *KeyboardButtonRequestUsers `json:"request_users,omitempty"` + RequestChat *KeyboardButtonRequestChat `json:"request_chat,omitempty"` + RequestContact bool `json:"request_contact,omitempty"` + RequestLocation bool `json:"request_location,omitempty"` + RequestPoll *KeyboardButtonPollType `json:"request_poll,omitempty"` + WebApp *WebAppInfo `json:"web_app,omitempty"` +} + +// KeyboardButtonRequestUsers defines criteria used to request suitable users. +// See https://core.telegram.org/bots/api#keyboardbuttonrequestusers +type KeyboardButtonRequestUsers struct { + RequestID int `json:"request_id"` + UserIsBot *bool `json:"user_is_bot,omitempty"` + UserIsPremium *bool `json:"user_is_premium,omitempty"` + MaxQuantity int `json:"max_quantity,omitempty"` + RequestName bool `json:"request_name,omitempty"` + RequestUsername bool `json:"request_username,omitempty"` + RequestPhoto bool `json:"request_photo,omitempty"` +} + +// KeyboardButtonRequestChat defines criteria used to request a suitable chat. +// See https://core.telegram.org/bots/api#keyboardbuttonrequestchat +type KeyboardButtonRequestChat struct { + RequestID int `json:"request_id"` + ChatIsChannel bool `json:"chat_is_channel"` + ChatIsForum *bool `json:"chat_is_forum,omitempty"` + ChatHasUsername *bool `json:"chat_has_username,omitempty"` + ChatIsCreated *bool `json:"chat_is_created,omitempty"` + UserAdministratorRights *ChatAdministratorRights `json:"user_administrator_rights,omitempty"` + BotAdministratorRights *ChatAdministratorRights `json:"bot_administrator_rights,omitempty"` + BotIsMember bool `json:"bot_is_member,omitempty"` + RequestTitle bool `json:"request_title,omitempty"` + RequestUsername bool `json:"request_username,omitempty"` + RequestPhoto bool `json:"request_photo,omitempty"` +} + +// KeyboardButtonPollType represents the type of a poll that may be created from a keyboard button. +// See https://core.telegram.org/bots/api#keyboardbuttonpolltype +type KeyboardButtonPollType struct { + Type PollType `json:"type,omitempty"` +} + // InlineKeyboardButton represents one button of an inline keyboard. // See https://core.telegram.org/bots/api#inlinekeyboardbutton type InlineKeyboardButton struct { @@ -178,7 +232,12 @@ type InlineKeyboardButton struct { // ReplyKeyboardMarkup represents a custom keyboard with reply options. // See https://core.telegram.org/bots/api#replykeyboardmarkup type ReplyKeyboardMarkup struct { - Keyboard [][]int `json:"keyboard"` + Keyboard [][]KeyboardButton `json:"keyboard"` + IsPersistent bool `json:"is_persistent,omitempty"` + ResizeKeyboard bool `json:"resize_keyboard,omitempty"` + OneTimeKeyboard bool `json:"one_time_keyboard,omitempty"` + InputFieldPlaceholder string `json:"input_field_placeholder,omitempty"` + Selective bool `json:"selective,omitempty"` } // CallbackQuery represents an incoming callback query from a callback button in an inline keyboard. @@ -238,7 +297,8 @@ const ( ChatActionUploadDocument ChatActionType = "upload_document" ChatActionChooseSticker ChatActionType = "choose_sticker" ChatActionFindLocation ChatActionType = "find_location" - ChatActionUploadVideoNone ChatActionType = "upload_video_none" + ChatActionUploadVideoNote ChatActionType = "upload_video_note" + ChatActionUploadVideoNone ChatActionType = ChatActionUploadVideoNote ) // MessageReactionUpdated represents a change of a reaction on a message. diff --git a/tgapi/messages_types_test.go b/tgapi/messages_types_test.go new file mode 100644 index 0000000..25c304a --- /dev/null +++ b/tgapi/messages_types_test.go @@ -0,0 +1,37 @@ +package tgapi + +import ( + "encoding/json" + "strings" + "testing" +) + +func TestReplyKeyboardMarkupMarshalsKeyboardButtons(t *testing.T) { + markup := ReplyKeyboardMarkup{ + Keyboard: [][]KeyboardButton{{ + { + Text: "Create poll", + RequestPoll: &KeyboardButtonPollType{Type: PollTypeQuiz}, + }, + }}, + } + + data, err := json.Marshal(markup) + if err != nil { + t.Fatalf("Marshal returned error: %v", err) + } + + got := string(data) + if !strings.Contains(got, `"keyboard":[[{"text":"Create poll","request_poll":{"type":"quiz"}}]]`) { + t.Fatalf("unexpected reply keyboard JSON: %s", got) + } +} + +func TestChatActionUploadVideoNoteValue(t *testing.T) { + if ChatActionUploadVideoNote != "upload_video_note" { + t.Fatalf("unexpected chat action value: %q", ChatActionUploadVideoNote) + } + if ChatActionUploadVideoNone != ChatActionUploadVideoNote { + t.Fatalf("expected deprecated alias to match upload_video_note, got %q", ChatActionUploadVideoNone) + } +} diff --git a/tgapi/methods.go b/tgapi/methods.go index 37bc843..fb3edcb 100644 --- a/tgapi/methods.go +++ b/tgapi/methods.go @@ -1,9 +1,12 @@ package tgapi import ( + "context" "fmt" "io" "net/http" + + "git.nix13.pw/scuroneko/laniakea/utils" ) // UpdateParams holds parameters for the getUpdates method. @@ -12,7 +15,7 @@ type UpdateParams struct { Offset *int `json:"offset,omitempty"` Limit *int `json:"limit,omitempty"` Timeout *int `json:"timeout,omitempty"` - AllowedUpdates []UpdateType `json:"allowed_updates"` + AllowedUpdates []UpdateType `json:"allowed_updates,omitempty"` } // GetMe returns basic information about the bot. @@ -103,13 +106,31 @@ func (api *API) GetFile(params GetFileP) (File, error) { // The link is usually obtained from File.FilePath. // See https://core.telegram.org/bots/api#file func (api *API) GetFileByLink(link string) ([]byte, error) { - u := fmt.Sprintf("https://api.telegram.org/file/bot%s/%s", api.token, link) - res, err := http.Get(u) + methodPrefix := "" + if api.useTestServer { + methodPrefix = "/test" + } + u := fmt.Sprintf("%s/file/bot%s%s/%s", api.apiUrl, api.token, methodPrefix, link) + + req, err := http.NewRequestWithContext(context.Background(), http.MethodGet, u, nil) + if err != nil { + return nil, err + } + req.Header.Set("User-Agent", fmt.Sprintf("Laniakea/%s", utils.VersionString)) + + res, err := api.client.Do(req) if err != nil { return nil, err } defer func() { _ = res.Body.Close() }() + if res.StatusCode < http.StatusOK || res.StatusCode >= http.StatusMultipleChoices { + body, readErr := io.ReadAll(io.LimitReader(res.Body, 4<<10)) + if readErr != nil { + return nil, fmt.Errorf("unexpected status %d", res.StatusCode) + } + return nil, fmt.Errorf("unexpected status %d: %s", res.StatusCode, string(body)) + } return io.ReadAll(res.Body) } diff --git a/tgapi/methods_test.go b/tgapi/methods_test.go new file mode 100644 index 0000000..649ff50 --- /dev/null +++ b/tgapi/methods_test.go @@ -0,0 +1,115 @@ +package tgapi + +import ( + "encoding/json" + "io" + "net/http" + "strings" + "testing" +) + +func TestGetFileByLinkUsesConfiguredAPIURL(t *testing.T) { + var gotPath string + + client := &http.Client{ + Transport: roundTripFunc(func(req *http.Request) (*http.Response, error) { + gotPath = req.URL.Path + return &http.Response{ + StatusCode: http.StatusOK, + Body: io.NopCloser(strings.NewReader("payload")), + }, nil + }), + } + + api := NewAPI( + NewAPIOpts("token"). + SetAPIUrl("https://example.test"). + SetHTTPClient(client), + ) + defer func() { + if err := api.CloseApi(); err != nil { + t.Fatalf("CloseApi returned error: %v", err) + } + }() + + data, err := api.GetFileByLink("files/report.txt") + if err != nil { + t.Fatalf("GetFileByLink returned error: %v", err) + } + if string(data) != "payload" { + t.Fatalf("unexpected payload: %q", string(data)) + } + if gotPath != "/file/bottoken/files/report.txt" { + t.Fatalf("unexpected request path: %s", gotPath) + } +} + +func TestGetFileByLinkReturnsHTTPStatusError(t *testing.T) { + client := &http.Client{ + Transport: roundTripFunc(func(req *http.Request) (*http.Response, error) { + return &http.Response{ + StatusCode: http.StatusNotFound, + Body: io.NopCloser(strings.NewReader("missing\n")), + }, nil + }), + } + + api := NewAPI( + NewAPIOpts("token"). + SetAPIUrl("https://example.test"). + SetHTTPClient(client), + ) + defer func() { + if err := api.CloseApi(); err != nil { + t.Fatalf("CloseApi returned error: %v", err) + } + }() + + _, err := api.GetFileByLink("files/report.txt") + if err == nil { + t.Fatal("expected error for non-2xx response") + } +} + +func TestGetUpdatesOmitsAllowedUpdatesWhenEmpty(t *testing.T) { + var gotBody map[string]any + + client := &http.Client{ + Transport: roundTripFunc(func(req *http.Request) (*http.Response, error) { + body, err := io.ReadAll(req.Body) + if err != nil { + t.Fatalf("failed to read request body: %v", err) + } + if err := json.Unmarshal(body, &gotBody); err != nil { + t.Fatalf("failed to decode request body: %v", err) + } + return &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{"Content-Type": []string{"application/json"}}, + Body: io.NopCloser(strings.NewReader(`{"ok":true,"result":[]}`)), + }, nil + }), + } + + api := NewAPI( + NewAPIOpts("token"). + SetAPIUrl("https://example.test"). + SetHTTPClient(client), + ) + defer func() { + if err := api.CloseApi(); err != nil { + t.Fatalf("CloseApi returned error: %v", err) + } + }() + + updates, err := api.GetUpdates(UpdateParams{}) + if err != nil { + t.Fatalf("GetUpdates returned error: %v", err) + } + if len(updates) != 0 { + t.Fatalf("expected no updates, got %d", len(updates)) + } + if _, exists := gotBody["allowed_updates"]; exists { + t.Fatalf("expected allowed_updates to be omitted, got %v", gotBody["allowed_updates"]) + } +} diff --git a/tgapi/types.go b/tgapi/types.go index ed3c3d7..d96632b 100644 --- a/tgapi/types.go +++ b/tgapi/types.go @@ -1,5 +1,7 @@ package tgapi +import "encoding/json" + // UpdateType represents the type of incoming update. type UpdateType string @@ -23,8 +25,10 @@ const ( UpdateTypeBusinessMessage UpdateType = "business_message" // UpdateTypeEditedBusinessMessage is an edited business message update. UpdateTypeEditedBusinessMessage UpdateType = "edited_business_message" - // UpdateTypeDeletedBusinessMessage is a deleted business message update. - UpdateTypeDeletedBusinessMessage UpdateType = "deleted_business_message" + // UpdateTypeDeletedBusinessMessages is a deleted business messages update. + UpdateTypeDeletedBusinessMessages UpdateType = "deleted_business_messages" + // UpdateTypeDeletedBusinessMessage is kept as a backward-compatible alias. + UpdateTypeDeletedBusinessMessage UpdateType = UpdateTypeDeletedBusinessMessages // UpdateTypeInlineQuery is an inline query update. UpdateTypeInlineQuery UpdateType = "inline_query" @@ -63,17 +67,18 @@ type Update struct { ChannelPost *Message `json:"channel_post,omitempty"` EditedChannelPost *Message `json:"edited_channel_post,omitempty"` - BusinessConnection *BusinessConnection `json:"business_connection,omitempty"` - BusinessMessage *Message `json:"business_message,omitempty"` - EditedBusinessMessage *Message `json:"edited_business_message,omitempty"` - DeletedBusinessMessage *Message `json:"deleted_business_messages,omitempty"` - MessageReaction *MessageReactionUpdated `json:"message_reaction,omitempty"` - MessageReactionCount *MessageReactionCountUpdated `json:"message_reaction_count,omitempty"` + BusinessConnection *BusinessConnection `json:"business_connection,omitempty"` + BusinessMessage *Message `json:"business_message,omitempty"` + EditedBusinessMessage *Message `json:"edited_business_message,omitempty"` + DeletedBusinessMessages *BusinessMessagesDeleted `json:"deleted_business_messages,omitempty"` + DeletedBusinessMessage *BusinessMessagesDeleted `json:"-"` + MessageReaction *MessageReactionUpdated `json:"message_reaction,omitempty"` + MessageReactionCount *MessageReactionCountUpdated `json:"message_reaction_count,omitempty"` InlineQuery *InlineQuery `json:"inline_query,omitempty"` ChosenInlineResult *ChosenInlineResult `json:"chosen_inline_result,omitempty"` CallbackQuery *CallbackQuery `json:"callback_query,omitempty"` - ShippingQuery ShippingQuery `json:"shipping_query,omitempty"` + ShippingQuery *ShippingQuery `json:"shipping_query,omitempty"` PreCheckoutQuery *PreCheckoutQuery `json:"pre_checkout_query,omitempty"` PurchasedPaidMedia *PaidMediaPurchased `json:"purchased_paid_media,omitempty"` @@ -86,6 +91,35 @@ type Update struct { RemovedChatBoost *ChatBoostRemoved `json:"removed_chat_boost,omitempty"` } +func (u *Update) syncDeletedBusinessMessages() { + if u.DeletedBusinessMessages != nil { + u.DeletedBusinessMessage = u.DeletedBusinessMessages + return + } + if u.DeletedBusinessMessage != nil { + u.DeletedBusinessMessages = u.DeletedBusinessMessage + } +} + +// UnmarshalJSON keeps the deprecated DeletedBusinessMessage alias in sync. +func (u *Update) UnmarshalJSON(data []byte) error { + type alias Update + var aux alias + if err := json.Unmarshal(data, &aux); err != nil { + return err + } + *u = Update(aux) + u.syncDeletedBusinessMessages() + return nil +} + +// MarshalJSON emits the canonical deleted_business_messages field. +func (u Update) MarshalJSON() ([]byte, error) { + u.syncDeletedBusinessMessages() + type alias Update + return json.Marshal(alias(u)) +} + // InlineQuery represents an incoming inline query. // See https://core.telegram.org/bots/api#inlinequery type InlineQuery struct { diff --git a/tgapi/types_test.go b/tgapi/types_test.go new file mode 100644 index 0000000..5b18efc --- /dev/null +++ b/tgapi/types_test.go @@ -0,0 +1,69 @@ +package tgapi + +import ( + "encoding/json" + "strings" + "testing" +) + +func TestUpdateDeletedBusinessMessagesUnmarshalSetsAlias(t *testing.T) { + var update Update + err := json.Unmarshal([]byte(`{ + "update_id": 1, + "deleted_business_messages": { + "business_connection_id": "conn", + "chat": {"id": 42, "type": "private"}, + "message_ids": [3, 5] + } + }`), &update) + if err != nil { + t.Fatalf("Unmarshal returned error: %v", err) + } + + if update.DeletedBusinessMessages == nil { + t.Fatal("expected DeletedBusinessMessages to be populated") + } + if update.DeletedBusinessMessage == nil { + t.Fatal("expected deprecated DeletedBusinessMessage alias to be populated") + } + if update.DeletedBusinessMessages != update.DeletedBusinessMessage { + t.Fatal("expected deleted business message fields to share the same payload") + } + if got := update.DeletedBusinessMessages.MessageIDs; len(got) != 2 || got[0] != 3 || got[1] != 5 { + t.Fatalf("unexpected message ids: %v", got) + } +} + +func TestUpdateMarshalUsesCanonicalDeletedBusinessMessagesField(t *testing.T) { + update := Update{ + UpdateID: 1, + DeletedBusinessMessage: &BusinessMessagesDeleted{ + BusinessConnectionID: "conn", + Chat: Chat{ID: 42, Type: string(ChatTypePrivate)}, + MessageIDs: []int{7}, + }, + } + + data, err := json.Marshal(update) + if err != nil { + t.Fatalf("Marshal returned error: %v", err) + } + + got := string(data) + if !strings.Contains(got, `"deleted_business_messages"`) { + t.Fatalf("expected canonical deleted_business_messages field, got %s", got) + } + if strings.Contains(got, `"deleted_business_message"`) { + t.Fatalf("unexpected singular deleted_business_message field, got %s", got) + } +} + +func TestUpdateShippingQueryIsNilWhenAbsent(t *testing.T) { + var update Update + if err := json.Unmarshal([]byte(`{"update_id":1}`), &update); err != nil { + t.Fatalf("Unmarshal returned error: %v", err) + } + if update.ShippingQuery != nil { + t.Fatalf("expected ShippingQuery to be nil, got %+v", update.ShippingQuery) + } +} diff --git a/tgapi/uploader_api.go b/tgapi/uploader_api.go index 24cf8c5..a950030 100644 --- a/tgapi/uploader_api.go +++ b/tgapi/uploader_api.go @@ -126,7 +126,6 @@ func (r UploaderRequest[R, P]) doRequest(ctx context.Context, up *Uploader) (R, req.Header.Set("Content-Type", contentType) req.Header.Set("Accept", "application/json") req.Header.Set("User-Agent", fmt.Sprintf("Laniakea/%s", utils.VersionString)) - req.Header.Set("Accept-Encoding", "gzip") req.ContentLength = int64(buf.Len()) up.logger.Debugln("UPLOADER REQ", r.method) diff --git a/tgapi/uploader_api_test.go b/tgapi/uploader_api_test.go new file mode 100644 index 0000000..7cf55dd --- /dev/null +++ b/tgapi/uploader_api_test.go @@ -0,0 +1,138 @@ +package tgapi + +import ( + "fmt" + "io" + "mime" + "mime/multipart" + "net/http" + "strings" + "testing" +) + +func TestUploaderEncodesJSONFieldsAndLeavesAcceptEncodingToHTTPTransport(t *testing.T) { + var ( + gotPath string + gotAcceptEncoding string + gotFields map[string]string + gotFileName string + gotFileData []byte + roundTripErr error + ) + + client := &http.Client{ + Transport: roundTripFunc(func(req *http.Request) (*http.Response, error) { + gotPath = req.URL.Path + gotAcceptEncoding = req.Header.Get("Accept-Encoding") + + gotFields, gotFileName, gotFileData, roundTripErr = readMultipartRequest(req) + if roundTripErr != nil { + roundTripErr = fmt.Errorf("readMultipartRequest: %w", roundTripErr) + } + + return &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{"Content-Type": []string{"application/json"}}, + Body: io.NopCloser(strings.NewReader(`{"ok":true,"result":{"message_id":5,"date":1}}`)), + }, nil + }), + } + + api := NewAPI( + NewAPIOpts("token"). + SetAPIUrl("https://example.test"). + SetHTTPClient(client), + ) + defer func() { + if err := api.CloseApi(); err != nil { + t.Fatalf("CloseApi returned error: %v", err) + } + }() + + uploader := NewUploader(api) + defer func() { + if err := uploader.Close(); err != nil { + t.Fatalf("Close returned error: %v", err) + } + }() + + msg, err := uploader.SendPhoto( + UploadPhotoP{ + ChatID: 42, + CaptionEntities: []MessageEntity{{ + Type: MessageEntityBold, + Offset: 0, + Length: 4, + }}, + ReplyMarkup: &ReplyMarkup{ + InlineKeyboard: [][]InlineKeyboardButton{{ + {Text: "A", CallbackData: "b"}, + }}, + }, + }, + NewUploaderFile("photo.jpg", []byte("img")), + ) + if err != nil { + t.Fatalf("SendPhoto returned error: %v", err) + } + if msg.MessageID != 5 { + t.Fatalf("unexpected message id: %d", msg.MessageID) + } + if roundTripErr != nil { + t.Fatalf("multipart parse failed: %v", roundTripErr) + } + if gotPath != "/bottoken/sendPhoto" { + t.Fatalf("unexpected request path: %s", gotPath) + } + if gotAcceptEncoding != "" { + t.Fatalf("expected empty Accept-Encoding header, got %q", gotAcceptEncoding) + } + if got := gotFields["chat_id"]; got != "42" { + t.Fatalf("chat_id mismatch: %q", got) + } + if got := gotFields["caption_entities"]; got != `[{"type":"bold","offset":0,"length":4}]` { + t.Fatalf("caption_entities mismatch: %q", got) + } + if got := gotFields["reply_markup"]; got != `{"inline_keyboard":[[{"text":"A","callback_data":"b"}]]}` { + t.Fatalf("reply_markup mismatch: %q", got) + } + if gotFileName != "photo.jpg" { + t.Fatalf("unexpected file name: %q", gotFileName) + } + if string(gotFileData) != "img" { + t.Fatalf("unexpected file content: %q", string(gotFileData)) + } +} + +func readMultipartRequest(req *http.Request) (map[string]string, string, []byte, error) { + _, params, err := mime.ParseMediaType(req.Header.Get("Content-Type")) + if err != nil { + return nil, "", nil, err + } + reader := multipart.NewReader(req.Body, params["boundary"]) + + fields := make(map[string]string) + var fileName string + var fileData []byte + for { + part, err := reader.NextPart() + if err == io.EOF { + return fields, fileName, fileData, nil + } + if err != nil { + return nil, "", nil, err + } + + data, err := io.ReadAll(part) + if err != nil { + return nil, "", nil, err + } + + if part.FileName() != "" { + fileName = part.FileName() + fileData = data + continue + } + fields[part.FormName()] = string(data) + } +} diff --git a/utils/multipart.go b/utils/multipart.go index 8b2a72c..4ee0bfe 100644 --- a/utils/multipart.go +++ b/utils/multipart.go @@ -1,6 +1,7 @@ package utils import ( + "encoding/json" "fmt" "io" "mime/multipart" @@ -12,12 +13,8 @@ import ( // Encode writes struct fields into multipart form-data using json tags as field names. func Encode[T any](w *multipart.Writer, req T) error { - v := reflect.ValueOf(req) - if v.Kind() == reflect.Ptr { - v = v.Elem() - } - - if v.Kind() != reflect.Struct { + v := unwrapMultipartValue(reflect.ValueOf(req)) + if !v.IsValid() || v.Kind() != reflect.Struct { return fmt.Errorf("req must be a struct") } @@ -33,6 +30,9 @@ func Encode[T any](w *multipart.Writer, req T) error { parts := strings.Split(jsonTag, ",") fieldName := parts[0] + if fieldName == "" { + fieldName = fieldType.Name + } if fieldName == "-" { continue } @@ -43,96 +43,73 @@ func Encode[T any](w *multipart.Writer, req T) error { continue } - var ( - fw io.Writer - err error - ) - - switch field.Kind() { - case reflect.String: - fw, err = w.CreateFormField(fieldName) - if err == nil { - _, err = fw.Write([]byte(field.String())) - } - case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: - fw, err = w.CreateFormField(fieldName) - if err == nil { - _, err = fw.Write([]byte(strconv.FormatInt(field.Int(), 10))) - } - case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: - fw, err = w.CreateFormField(fieldName) - if err == nil { - _, err = fw.Write([]byte(strconv.FormatUint(field.Uint(), 10))) - } - case reflect.Float32: - fw, err = w.CreateFormField(fieldName) - if err == nil { - _, err = fw.Write([]byte(strconv.FormatFloat(field.Float(), 'f', -1, 32))) - } - case reflect.Float64: - fw, err = w.CreateFormField(fieldName) - if err == nil { - _, err = fw.Write([]byte(strconv.FormatFloat(field.Float(), 'f', -1, 64))) - } - - case reflect.Bool: - fw, err = w.CreateFormField(fieldName) - if err == nil { - _, err = fw.Write([]byte(strconv.FormatBool(field.Bool()))) - } - case reflect.Slice: - if field.Type().Elem().Kind() == reflect.Uint8 && !field.IsNil() { - // Handle []byte as file upload (e.g., thumbnail) - filename := fieldType.Tag.Get("filename") - if filename == "" { - filename = fieldName - } - fw, err = w.CreateFormFile(fieldName, filename) - if err == nil { - _, err = fw.Write(field.Bytes()) - } - } else if !field.IsNil() { - // Handle []string, []int, etc. — send as multiple fields with same name - for j := 0; j < field.Len(); j++ { - elem := field.Index(j) - fw, err = w.CreateFormField(fieldName) - if err != nil { - break - } - switch elem.Kind() { - case reflect.String: - _, err = fw.Write([]byte(elem.String())) - case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: - _, err = fw.Write([]byte(strconv.FormatInt(elem.Int(), 10))) - case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: - _, err = fw.Write([]byte(strconv.FormatUint(elem.Uint(), 10))) - case reflect.Bool: - _, err = fw.Write([]byte(strconv.FormatBool(elem.Bool()))) - case reflect.Float32: - _, err = fw.Write([]byte(strconv.FormatFloat(elem.Float(), 'f', -1, 32))) - case reflect.Float64: - _, err = fw.Write([]byte(strconv.FormatFloat(elem.Float(), 'f', -1, 64))) - default: - continue - } - if err != nil { - break - } - } - } - - case reflect.Struct: - // Don't serialize structs as JSON — flatten them! - // Telegram doesn't support nested JSON in form-data. - // If you need nested data, use separate fields (e.g., ParseMode, CaptionEntities) - // This is a design choice — you should avoid nested structs in params. - return fmt.Errorf("nested structs are not supported in params — use flat fields") - } - - if err != nil { + if err := writeMultipartField(w, fieldName, fieldType.Tag.Get("filename"), field); err != nil { return err } } return nil } + +func unwrapMultipartValue(v reflect.Value) reflect.Value { + for v.IsValid() && (v.Kind() == reflect.Pointer || v.Kind() == reflect.Interface) { + if v.IsNil() { + return reflect.Value{} + } + v = v.Elem() + } + return v +} + +func writeMultipartField(w *multipart.Writer, fieldName, filename string, field reflect.Value) error { + value := unwrapMultipartValue(field) + if !value.IsValid() { + return nil + } + + switch value.Kind() { + case reflect.String: + return writeMultipartValue(w, fieldName, []byte(value.String())) + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + return writeMultipartValue(w, fieldName, []byte(strconv.FormatInt(value.Int(), 10))) + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + return writeMultipartValue(w, fieldName, []byte(strconv.FormatUint(value.Uint(), 10))) + case reflect.Float32: + return writeMultipartValue(w, fieldName, []byte(strconv.FormatFloat(value.Float(), 'f', -1, 32))) + case reflect.Float64: + return writeMultipartValue(w, fieldName, []byte(strconv.FormatFloat(value.Float(), 'f', -1, 64))) + case reflect.Bool: + return writeMultipartValue(w, fieldName, []byte(strconv.FormatBool(value.Bool()))) + case reflect.Slice: + if value.Type().Elem().Kind() == reflect.Uint8 { + if filename == "" { + filename = fieldName + } + fw, err := w.CreateFormFile(fieldName, filename) + if err != nil { + return err + } + _, err = fw.Write(value.Bytes()) + return err + } + } + + // Telegram expects nested objects and arrays in multipart requests as JSON strings. + data, err := json.Marshal(value.Interface()) + if err != nil { + return err + } + if string(data) == "null" { + return nil + } + return writeMultipartValue(w, fieldName, data) +} + +func writeMultipartValue(w *multipart.Writer, fieldName string, value []byte) error { + fw, err := w.CreateFormField(fieldName) + if err != nil { + return err + } + _, err = io.Copy(fw, strings.NewReader(string(value))) + return err +} diff --git a/utils/multipart_test.go b/utils/multipart_test.go new file mode 100644 index 0000000..479ed05 --- /dev/null +++ b/utils/multipart_test.go @@ -0,0 +1,85 @@ +package utils_test + +import ( + "bytes" + "io" + "mime/multipart" + "testing" + + "git.nix13.pw/scuroneko/laniakea/tgapi" + "git.nix13.pw/scuroneko/laniakea/utils" +) + +type multipartEncodeParams struct { + ChatID int64 `json:"chat_id"` + MessageThreadID *int `json:"message_thread_id,omitempty"` + ReplyMarkup *tgapi.ReplyMarkup `json:"reply_markup,omitempty"` + CaptionEntities []tgapi.MessageEntity `json:"caption_entities,omitempty"` + ReplyParameters *tgapi.ReplyParameters `json:"reply_parameters,omitempty"` +} + +func TestEncodeMultipartJSONFields(t *testing.T) { + threadID := 7 + params := multipartEncodeParams{ + ChatID: 42, + MessageThreadID: &threadID, + ReplyMarkup: &tgapi.ReplyMarkup{ + InlineKeyboard: [][]tgapi.InlineKeyboardButton{{ + {Text: "A", CallbackData: "b"}, + }}, + }, + CaptionEntities: []tgapi.MessageEntity{{ + Type: tgapi.MessageEntityBold, + Offset: 0, + Length: 4, + }}, + } + + body := bytes.NewBuffer(nil) + writer := multipart.NewWriter(body) + if err := utils.Encode(writer, params); err != nil { + t.Fatalf("Encode returned error: %v", err) + } + if err := writer.Close(); err != nil { + t.Fatalf("writer.Close returned error: %v", err) + } + + got := readMultipartFields(t, body.Bytes(), writer.Boundary()) + if got["chat_id"] != "42" { + t.Fatalf("chat_id mismatch: %q", got["chat_id"]) + } + if got["message_thread_id"] != "7" { + t.Fatalf("message_thread_id mismatch: %q", got["message_thread_id"]) + } + if got["reply_markup"] != `{"inline_keyboard":[[{"text":"A","callback_data":"b"}]]}` { + t.Fatalf("reply_markup mismatch: %q", got["reply_markup"]) + } + if got["caption_entities"] != `[{"type":"bold","offset":0,"length":4}]` { + t.Fatalf("caption_entities mismatch: %q", got["caption_entities"]) + } + if _, ok := got["reply_parameters"]; ok { + t.Fatalf("reply_parameters should be omitted when nil") + } +} + +func readMultipartFields(t *testing.T, body []byte, boundary string) map[string]string { + t.Helper() + + reader := multipart.NewReader(bytes.NewReader(body), boundary) + fields := make(map[string]string) + for { + part, err := reader.NextPart() + if err == io.EOF { + return fields + } + if err != nil { + t.Fatalf("NextPart returned error: %v", err) + } + + data, err := io.ReadAll(part) + if err != nil { + t.Fatalf("ReadAll returned error: %v", err) + } + fields[part.FormName()] = string(data) + } +}