Compare commits

...

1 Commits

26 changed files with 992 additions and 140 deletions

17
bot.go
View File

@@ -184,21 +184,34 @@ func NewBot[T any](opts *BotOpts) *Bot[T] {
// //
// Returns the first error encountered, if any. // Returns the first error encountered, if any.
func (bot *Bot[T]) Close() error { func (bot *Bot[T]) Close() error {
var firstErr error
if err := bot.uploader.Close(); err != nil { if err := bot.uploader.Close(); err != nil {
bot.logger.Errorln(err) bot.logger.Errorln(err)
if firstErr == nil {
firstErr = err
}
} }
if err := bot.api.CloseApi(); err != nil { if err := bot.api.CloseApi(); err != nil {
bot.logger.Errorln(err) bot.logger.Errorln(err)
if firstErr == nil {
firstErr = err
}
} }
if bot.RequestLogger != nil { if bot.RequestLogger != nil {
if err := bot.RequestLogger.Close(); err != nil { if err := bot.RequestLogger.Close(); err != nil {
bot.logger.Errorln(err) bot.logger.Errorln(err)
if firstErr == nil {
firstErr = err
}
} }
} }
if err := bot.logger.Close(); err != nil { if err := bot.logger.Close(); err != nil {
return err if firstErr == nil {
firstErr = err
}
} }
return nil return firstErr
} }
// initLoggers configures the main and optional request loggers. // initLoggers configures the main and optional request loggers.

View File

@@ -85,10 +85,10 @@ func LoadOptsFromEnv() *BotOpts {
} }
} }
stringUpdateTypes := strings.Split(os.Getenv("UPDATE_TYPES"), ";") stringUpdateTypes := splitEnvList(os.Getenv("UPDATE_TYPES"))
updateTypes := make([]tgapi.UpdateType, len(stringUpdateTypes)) updateTypes := make([]tgapi.UpdateType, 0, len(stringUpdateTypes))
for i, updateType := range stringUpdateTypes { for _, updateType := range stringUpdateTypes {
updateTypes[i] = tgapi.UpdateType(updateType) updateTypes = append(updateTypes, tgapi.UpdateType(updateType))
} }
return &BotOpts{ return &BotOpts{
@@ -222,5 +222,25 @@ func LoadPrefixesFromEnv() []string {
if !exists { if !exists {
return []string{"/"} return []string{"/"}
} }
return strings.Split(prefixesS, ";") prefixes := splitEnvList(prefixesS)
if len(prefixes) == 0 {
return []string{"/"}
}
return prefixes
}
func splitEnvList(value string) []string {
if value == "" {
return nil
}
parts := strings.Split(value, ";")
out := make([]string, 0, len(parts))
for _, part := range parts {
part = strings.TrimSpace(part)
if part == "" {
continue
}
out = append(out, part)
}
return out
} }

47
bot_opts_test.go Normal file
View File

@@ -0,0 +1,47 @@
package laniakea
import (
"reflect"
"testing"
"git.nix13.pw/scuroneko/laniakea/tgapi"
)
func TestLoadOptsFromEnvIgnoresEmptyUpdateTypes(t *testing.T) {
t.Setenv("UPDATE_TYPES", "")
opts := LoadOptsFromEnv()
if len(opts.UpdateTypes) != 0 {
t.Fatalf("expected no update types, got %v", opts.UpdateTypes)
}
}
func TestLoadOptsFromEnvSplitsAndTrimsUpdateTypes(t *testing.T) {
t.Setenv("UPDATE_TYPES", "message; ; callback_query ")
opts := LoadOptsFromEnv()
want := []tgapi.UpdateType{tgapi.UpdateTypeMessage, tgapi.UpdateTypeCallbackQuery}
if !reflect.DeepEqual(opts.UpdateTypes, want) {
t.Fatalf("unexpected update types: got %v want %v", opts.UpdateTypes, want)
}
}
func TestLoadPrefixesFromEnvDefaultsOnEmptyValue(t *testing.T) {
t.Setenv("PREFIXES", "")
got := LoadPrefixesFromEnv()
want := []string{"/"}
if !reflect.DeepEqual(got, want) {
t.Fatalf("unexpected prefixes: got %v want %v", got, want)
}
}
func TestLoadPrefixesFromEnvDropsEmptyValues(t *testing.T) {
t.Setenv("PREFIXES", "/; ; ! ")
got := LoadPrefixesFromEnv()
want := []string{"/", "!"}
if !reflect.DeepEqual(got, want) {
t.Fatalf("unexpected prefixes: got %v want %v", got, want)
}
}

View File

@@ -100,17 +100,17 @@ func gatherCommands[T any](bot *Bot[T]) []tgapi.BotCommand {
// log.Fatal(err) // log.Fatal(err)
// } // }
func (bot *Bot[T]) AutoGenerateCommands() error { func (bot *Bot[T]) AutoGenerateCommands() error {
commands := gatherCommands(bot)
if len(commands) > 100 {
return ErrTooManyCommands
}
// Clear existing commands to avoid duplication or stale entries // Clear existing commands to avoid duplication or stale entries
_, err := bot.api.DeleteMyCommands(tgapi.DeleteMyCommandsP{}) _, err := bot.api.DeleteMyCommands(tgapi.DeleteMyCommandsP{})
if err != nil { if err != nil {
return fmt.Errorf("failed to delete existing commands: %w", err) return fmt.Errorf("failed to delete existing commands: %w", err)
} }
commands := gatherCommands(bot)
if len(commands) > 100 {
return ErrTooManyCommands
}
// Register commands for each scope // Register commands for each scope
scopes := []*tgapi.BotCommandScope{ scopes := []*tgapi.BotCommandScope{
{Type: tgapi.BotCommandScopePrivateType}, {Type: tgapi.BotCommandScopePrivateType},
@@ -148,15 +148,16 @@ func (bot *Bot[T]) AutoGenerateCommands() error {
// log.Fatal(err) // log.Fatal(err)
// } // }
func (bot *Bot[T]) AutoGenerateCommandsForScope(scope *tgapi.BotCommandScope) error { func (bot *Bot[T]) AutoGenerateCommandsForScope(scope *tgapi.BotCommandScope) error {
_, err := bot.api.DeleteMyCommands(tgapi.DeleteMyCommandsP{Scope: scope})
if err != nil {
return fmt.Errorf("failed to delete existing commands: %w", err)
}
commands := gatherCommands(bot) commands := gatherCommands(bot)
if len(commands) > 100 { if len(commands) > 100 {
return ErrTooManyCommands return ErrTooManyCommands
} }
_, err := bot.api.DeleteMyCommands(tgapi.DeleteMyCommandsP{Scope: scope})
if err != nil {
return fmt.Errorf("failed to delete existing commands: %w", err)
}
_, err = bot.api.SetMyCommands(tgapi.SetMyCommandsP{ _, err = bot.api.SetMyCommands(tgapi.SetMyCommandsP{
Commands: commands, Commands: commands,
Scope: scope, Scope: scope,

64
cmd_generator_test.go Normal file
View File

@@ -0,0 +1,64 @@
package laniakea
import (
"errors"
"io"
"net/http"
"strconv"
"strings"
"sync/atomic"
"testing"
"git.nix13.pw/scuroneko/laniakea/tgapi"
"git.nix13.pw/scuroneko/slog"
)
type roundTripFunc func(*http.Request) (*http.Response, error)
func (fn roundTripFunc) RoundTrip(req *http.Request) (*http.Response, error) {
return fn(req)
}
func TestAutoGenerateCommandsChecksLimitBeforeDelete(t *testing.T) {
var calls atomic.Int64
client := &http.Client{
Transport: roundTripFunc(func(req *http.Request) (*http.Response, error) {
calls.Add(1)
return &http.Response{
StatusCode: http.StatusOK,
Header: http.Header{"Content-Type": []string{"application/json"}},
Body: io.NopCloser(strings.NewReader(`{"ok":true,"result":true}`)),
}, nil
}),
}
api := tgapi.NewAPI(
tgapi.NewAPIOpts("token").
SetAPIUrl("https://example.test").
SetHTTPClient(client),
)
defer func() {
if err := api.CloseApi(); err != nil {
t.Fatalf("CloseApi returned error: %v", err)
}
}()
plugin := NewPlugin[NoDB]("overflow")
exec := func(ctx *MsgContext, db *NoDB) {}
for i := 0; i < 101; i++ {
plugin.AddCommand(NewCommand(exec, "cmd"+strconv.Itoa(i)))
}
bot := &Bot[NoDB]{
api: api,
logger: slog.CreateLogger(),
plugins: []Plugin[NoDB]{*plugin},
}
err := bot.AutoGenerateCommands()
if !errors.Is(err, ErrTooManyCommands) {
t.Fatalf("expected ErrTooManyCommands, got %v", err)
}
if calls.Load() != 0 {
t.Fatalf("expected no HTTP calls before limit validation, got %d", calls.Load())
}
}

4
go.mod
View File

@@ -3,8 +3,8 @@ module git.nix13.pw/scuroneko/laniakea
go 1.26 go 1.26
require ( require (
git.nix13.pw/scuroneko/extypes v1.2.1 git.nix13.pw/scuroneko/extypes v1.2.2
git.nix13.pw/scuroneko/slog v1.0.2 git.nix13.pw/scuroneko/slog v1.1.2
github.com/alitto/pond/v2 v2.7.0 github.com/alitto/pond/v2 v2.7.0
golang.org/x/time v0.15.0 golang.org/x/time v0.15.0
) )

8
go.sum
View File

@@ -1,7 +1,7 @@
git.nix13.pw/scuroneko/extypes v1.2.1 h1:IYrOjnWKL2EAuJYtYNa+luB1vBe6paE8VY/YD+5/RpQ= git.nix13.pw/scuroneko/extypes v1.2.2 h1:N54c1ejrPs1yfIkvYuwqI7B1+8S9mDv2GqQA6sct4dk=
git.nix13.pw/scuroneko/extypes v1.2.1/go.mod h1:uZVs8Yo3RrYAG9dMad6qR6lsYY67t+459D9c65QAYAw= git.nix13.pw/scuroneko/extypes v1.2.2/go.mod h1:b4XYk1OW1dVSiE2MT/OMuX/K/UItf1swytX6eroVYnk=
git.nix13.pw/scuroneko/slog v1.0.2 h1:vZyUROygxC2d5FJHUQM/30xFEHY1JT/aweDZXA4rm2g= git.nix13.pw/scuroneko/slog v1.1.2 h1:pl7tV5FN25Yso7sLYoOgBXi9+jLo5BDJHWmHlNPjpY0=
git.nix13.pw/scuroneko/slog v1.0.2/go.mod h1:3Qm2wzkR5KjwOponMfG7TcGSDjmYaFqRAmLvSPTuWJI= git.nix13.pw/scuroneko/slog v1.1.2/go.mod h1:UcfRIHDqpVQHahBGM93awLDK8//AsAvOqBwwbWqMkjM=
github.com/alitto/pond/v2 v2.7.0 h1:c76L+yN916m/DRXjGCeUBHHu92uWnh/g1bwVk4zyyXg= github.com/alitto/pond/v2 v2.7.0 h1:c76L+yN916m/DRXjGCeUBHHu92uWnh/g1bwVk4zyyXg=
github.com/alitto/pond/v2 v2.7.0/go.mod h1:xkjYEgQ05RSpWdfSd1nM3OVv7TBhLdy7rMp3+2Nq+yE= github.com/alitto/pond/v2 v2.7.0/go.mod h1:xkjYEgQ05RSpWdfSd1nM3OVv7TBhLdy7rMp3+2Nq+yE=
github.com/fatih/color v1.18.0 h1:S8gINlzdQ840/4pfAwic/ZE0djQEH3wM94VfqLTZcOM= github.com/fatih/color v1.18.0 h1:S8gINlzdQ840/4pfAwic/ZE0djQEH3wM94VfqLTZcOM=

View File

@@ -138,6 +138,9 @@ func (bot *Bot[T]) handleCallback(update *tgapi.Update, ctx *MsgContext) {
func (bot *Bot[T]) checkPrefixes(text string) (string, bool) { func (bot *Bot[T]) checkPrefixes(text string) (string, bool) {
for _, prefix := range bot.prefixes { for _, prefix := range bot.prefixes {
if prefix == "" {
continue
}
if strings.HasPrefix(text, prefix) { if strings.HasPrefix(text, prefix) {
return prefix, true return prefix, true
} }

14
handler_test.go Normal file
View File

@@ -0,0 +1,14 @@
package laniakea
import "testing"
func TestCheckPrefixesSkipsEmptyPrefixes(t *testing.T) {
bot := &Bot[NoDB]{prefixes: []string{"", "/"}}
if prefix, ok := bot.checkPrefixes("hello"); ok {
t.Fatalf("unexpected prefix match for plain text: %q", prefix)
}
if prefix, ok := bot.checkPrefixes("/start"); !ok || prefix != "/" {
t.Fatalf("unexpected prefix result: prefix=%q ok=%v", prefix, ok)
}
}

View File

@@ -270,6 +270,9 @@ func (ctx *MsgContext) answerPhoto(photoId, text string, kb *InlineKeyboard, par
if ctx.Msg.MessageThreadID > 0 { if ctx.Msg.MessageThreadID > 0 {
params.MessageThreadID = ctx.Msg.MessageThreadID params.MessageThreadID = ctx.Msg.MessageThreadID
} }
if ctx.Msg.DirectMessageTopic != nil {
params.DirectMessagesTopicID = int(ctx.Msg.DirectMessageTopic.TopicID)
}
msg, err := ctx.Api.SendPhoto(params) msg, err := ctx.Api.SendPhoto(params)
if err != nil { if err != nil {

64
msg_context_test.go Normal file
View File

@@ -0,0 +1,64 @@
package laniakea
import (
"encoding/json"
"io"
"net/http"
"strings"
"testing"
"git.nix13.pw/scuroneko/laniakea/tgapi"
"git.nix13.pw/scuroneko/slog"
)
func TestAnswerPhotoIncludesDirectMessagesTopicID(t *testing.T) {
var gotBody map[string]any
client := &http.Client{
Transport: roundTripFunc(func(req *http.Request) (*http.Response, error) {
body, err := io.ReadAll(req.Body)
if err != nil {
t.Fatalf("failed to read request body: %v", err)
}
if err := json.Unmarshal(body, &gotBody); err != nil {
t.Fatalf("failed to decode request body: %v", err)
}
return &http.Response{
StatusCode: http.StatusOK,
Header: http.Header{"Content-Type": []string{"application/json"}},
Body: io.NopCloser(strings.NewReader(`{"ok":true,"result":{"message_id":9,"date":1}}`)),
}, nil
}),
}
api := tgapi.NewAPI(
tgapi.NewAPIOpts("token").
SetAPIUrl("https://example.test").
SetHTTPClient(client),
)
defer func() {
if err := api.CloseApi(); err != nil {
t.Fatalf("CloseApi returned error: %v", err)
}
}()
ctx := &MsgContext{
Api: api,
Msg: &tgapi.Message{
Chat: &tgapi.Chat{ID: 42, Type: string(tgapi.ChatTypePrivate)},
DirectMessageTopic: &tgapi.DirectMessageTopic{TopicID: 77},
},
botLogger: slog.CreateLogger(),
}
answer := ctx.AnswerPhoto("photo-id", "caption")
if answer == nil {
t.Fatal("expected answer message")
}
if answer.MessageID != 9 {
t.Fatalf("unexpected message id: %d", answer.MessageID)
}
if got := gotBody["direct_messages_topic_id"]; got != float64(77) {
t.Fatalf("unexpected direct_messages_topic_id: %v", got)
}
}

View File

@@ -23,11 +23,11 @@ const (
var ( var (
// CommandRegexInt matches one or more digits. // CommandRegexInt matches one or more digits.
CommandRegexInt = regexp.MustCompile(`\d+`) CommandRegexInt = regexp.MustCompile(`^\d+$`)
// CommandRegexString matches any non-empty string. // CommandRegexString matches any non-empty string.
CommandRegexString = regexp.MustCompile(`.+`) CommandRegexString = regexp.MustCompile(`^.+$`)
// CommandRegexBool matches true or false. // CommandRegexBool matches true or false.
CommandRegexBool = regexp.MustCompile(`true|false`) CommandRegexBool = regexp.MustCompile(`^(true|false)$`)
) )
// ErrCmdArgCountMismatch is returned when the number of provided arguments // ErrCmdArgCountMismatch is returned when the number of provided arguments
@@ -64,6 +64,7 @@ func (c *CommandArg) SetValueType(t CommandValueType) *CommandArg {
case CommandValueAnyType: case CommandValueAnyType:
regex = nil // Skip validation regex = nil // Skip validation
} }
c.valueType = t
c.regex = regex c.regex = regex
return c return c
} }

24
plugins_test.go Normal file
View File

@@ -0,0 +1,24 @@
package laniakea
import (
"errors"
"testing"
)
func TestValidateArgsRequiresFullMatch(t *testing.T) {
intCmd := NewCommand[NoDB](func(ctx *MsgContext, db *NoDB) {}, "int", *NewCommandArg("n").SetValueType(CommandValueIntType).SetRequired())
if err := intCmd.validateArgs([]string{"123"}); err != nil {
t.Fatalf("expected valid integer argument, got %v", err)
}
if err := intCmd.validateArgs([]string{"123abc"}); !errors.Is(err, ErrCmdArgRegexpMismatch) {
t.Fatalf("expected ErrCmdArgRegexpMismatch for partial int match, got %v", err)
}
boolCmd := NewCommand[NoDB](func(ctx *MsgContext, db *NoDB) {}, "bool", *NewCommandArg("flag").SetValueType(CommandValueBoolType).SetRequired())
if err := boolCmd.validateArgs([]string{"false"}); err != nil {
t.Fatalf("expected valid bool argument, got %v", err)
}
if err := boolCmd.validateArgs([]string{"falsey"}); !errors.Is(err, ErrCmdArgRegexpMismatch) {
t.Fatalf("expected ErrCmdArgRegexpMismatch for partial bool match, got %v", err)
}
}

View File

@@ -202,7 +202,6 @@ func (r TelegramRequest[R, P]) doRequest(ctx context.Context, api *API) (R, erro
req.Header.Set("Content-Type", "application/json") req.Header.Set("Content-Type", "application/json")
req.Header.Set("Accept", "application/json") req.Header.Set("Accept", "application/json")
req.Header.Set("User-Agent", fmt.Sprintf("Laniakea/%s", utils.VersionString)) req.Header.Set("User-Agent", fmt.Sprintf("Laniakea/%s", utils.VersionString))
req.Header.Set("Accept-Encoding", "gzip")
for { for {
// Apply rate limiting before making the request // Apply rate limiting before making the request

56
tgapi/api_test.go Normal file
View File

@@ -0,0 +1,56 @@
package tgapi
import (
"io"
"net/http"
"strings"
"testing"
)
type roundTripFunc func(*http.Request) (*http.Response, error)
func (fn roundTripFunc) RoundTrip(req *http.Request) (*http.Response, error) {
return fn(req)
}
func TestAPILeavesAcceptEncodingToHTTPTransport(t *testing.T) {
var gotPath string
var gotAcceptEncoding string
client := &http.Client{
Transport: roundTripFunc(func(req *http.Request) (*http.Response, error) {
gotPath = req.URL.Path
gotAcceptEncoding = req.Header.Get("Accept-Encoding")
return &http.Response{
StatusCode: http.StatusOK,
Header: http.Header{"Content-Type": []string{"application/json"}},
Body: io.NopCloser(strings.NewReader(`{"ok":true,"result":{"id":1,"is_bot":true,"first_name":"Test"}}`)),
}, nil
}),
}
api := NewAPI(
NewAPIOpts("token").
SetAPIUrl("https://example.test").
SetHTTPClient(client),
)
defer func() {
if err := api.CloseApi(); err != nil {
t.Fatalf("CloseApi returned error: %v", err)
}
}()
user, err := api.GetMe()
if err != nil {
t.Fatalf("GetMe returned error: %v", err)
}
if user.FirstName != "Test" {
t.Fatalf("unexpected first name: %q", user.FirstName)
}
if gotPath != "/bottoken/getMe" {
t.Fatalf("unexpected request path: %s", gotPath)
}
if gotAcceptEncoding != "" {
t.Fatalf("expected empty Accept-Encoding header, got %q", gotAcceptEncoding)
}
}

View File

@@ -60,6 +60,14 @@ type BusinessConnection struct {
IsEnabled bool `json:"is_enabled"` IsEnabled bool `json:"is_enabled"`
} }
// BusinessMessagesDeleted is received when messages are deleted from a connected business account.
// See https://core.telegram.org/bots/api#businessmessagesdeleted
type BusinessMessagesDeleted struct {
BusinessConnectionID string `json:"business_connection_id"`
Chat Chat `json:"chat"`
MessageIDs []int `json:"message_ids"`
}
// InputStoryContentType indicates the type of input story content. // InputStoryContentType indicates the type of input story content.
type InputStoryContentType string type InputStoryContentType string

View File

@@ -144,12 +144,12 @@ type LinkPreviewOptions struct {
type ReplyMarkup struct { type ReplyMarkup struct {
InlineKeyboard [][]InlineKeyboardButton `json:"inline_keyboard,omitempty"` InlineKeyboard [][]InlineKeyboardButton `json:"inline_keyboard,omitempty"`
Keyboard [][]int `json:"keyboard,omitempty"` Keyboard [][]KeyboardButton `json:"keyboard,omitempty"`
IsPersistent bool `json:"is_persistent,omitempty"` IsPersistent bool `json:"is_persistent,omitempty"`
ResizeKeyboard bool `json:"resize_keyboard,omitempty"` ResizeKeyboard bool `json:"resize_keyboard,omitempty"`
OneTimeKeyboard bool `json:"one_time_keyboard,omitempty"` OneTimeKeyboard bool `json:"one_time_keyboard,omitempty"`
InputFieldPlaceholder string `json:"input_field_placeholder,omitempty"` InputFieldPlaceholder string `json:"input_field_placeholder,omitempty"`
Selective bool `json:"selective,omitempty"` Selective bool `json:"selective,omitempty"`
RemoveKeyboard bool `json:"remove_keyboard,omitempty"` RemoveKeyboard bool `json:"remove_keyboard,omitempty"`
@@ -165,6 +165,60 @@ type InlineKeyboardMarkup struct {
// KeyboardButtonStyle represents the style of a keyboard button. // KeyboardButtonStyle represents the style of a keyboard button.
type KeyboardButtonStyle string type KeyboardButtonStyle string
const (
KeyboardButtonStyleDanger KeyboardButtonStyle = "danger"
KeyboardButtonStyleSuccess KeyboardButtonStyle = "success"
KeyboardButtonStylePrimary KeyboardButtonStyle = "primary"
)
// KeyboardButton represents one button of the reply keyboard.
// See https://core.telegram.org/bots/api#keyboardbutton
type KeyboardButton struct {
Text string `json:"text"`
IconCustomEmojiID string `json:"icon_custom_emoji_id,omitempty"`
Style KeyboardButtonStyle `json:"style,omitempty"`
RequestUsers *KeyboardButtonRequestUsers `json:"request_users,omitempty"`
RequestChat *KeyboardButtonRequestChat `json:"request_chat,omitempty"`
RequestContact bool `json:"request_contact,omitempty"`
RequestLocation bool `json:"request_location,omitempty"`
RequestPoll *KeyboardButtonPollType `json:"request_poll,omitempty"`
WebApp *WebAppInfo `json:"web_app,omitempty"`
}
// KeyboardButtonRequestUsers defines criteria used to request suitable users.
// See https://core.telegram.org/bots/api#keyboardbuttonrequestusers
type KeyboardButtonRequestUsers struct {
RequestID int `json:"request_id"`
UserIsBot *bool `json:"user_is_bot,omitempty"`
UserIsPremium *bool `json:"user_is_premium,omitempty"`
MaxQuantity int `json:"max_quantity,omitempty"`
RequestName bool `json:"request_name,omitempty"`
RequestUsername bool `json:"request_username,omitempty"`
RequestPhoto bool `json:"request_photo,omitempty"`
}
// KeyboardButtonRequestChat defines criteria used to request a suitable chat.
// See https://core.telegram.org/bots/api#keyboardbuttonrequestchat
type KeyboardButtonRequestChat struct {
RequestID int `json:"request_id"`
ChatIsChannel bool `json:"chat_is_channel"`
ChatIsForum *bool `json:"chat_is_forum,omitempty"`
ChatHasUsername *bool `json:"chat_has_username,omitempty"`
ChatIsCreated *bool `json:"chat_is_created,omitempty"`
UserAdministratorRights *ChatAdministratorRights `json:"user_administrator_rights,omitempty"`
BotAdministratorRights *ChatAdministratorRights `json:"bot_administrator_rights,omitempty"`
BotIsMember bool `json:"bot_is_member,omitempty"`
RequestTitle bool `json:"request_title,omitempty"`
RequestUsername bool `json:"request_username,omitempty"`
RequestPhoto bool `json:"request_photo,omitempty"`
}
// KeyboardButtonPollType represents the type of a poll that may be created from a keyboard button.
// See https://core.telegram.org/bots/api#keyboardbuttonpolltype
type KeyboardButtonPollType struct {
Type PollType `json:"type,omitempty"`
}
// InlineKeyboardButton represents one button of an inline keyboard. // InlineKeyboardButton represents one button of an inline keyboard.
// See https://core.telegram.org/bots/api#inlinekeyboardbutton // See https://core.telegram.org/bots/api#inlinekeyboardbutton
type InlineKeyboardButton struct { type InlineKeyboardButton struct {
@@ -178,7 +232,12 @@ type InlineKeyboardButton struct {
// ReplyKeyboardMarkup represents a custom keyboard with reply options. // ReplyKeyboardMarkup represents a custom keyboard with reply options.
// See https://core.telegram.org/bots/api#replykeyboardmarkup // See https://core.telegram.org/bots/api#replykeyboardmarkup
type ReplyKeyboardMarkup struct { type ReplyKeyboardMarkup struct {
Keyboard [][]int `json:"keyboard"` Keyboard [][]KeyboardButton `json:"keyboard"`
IsPersistent bool `json:"is_persistent,omitempty"`
ResizeKeyboard bool `json:"resize_keyboard,omitempty"`
OneTimeKeyboard bool `json:"one_time_keyboard,omitempty"`
InputFieldPlaceholder string `json:"input_field_placeholder,omitempty"`
Selective bool `json:"selective,omitempty"`
} }
// CallbackQuery represents an incoming callback query from a callback button in an inline keyboard. // CallbackQuery represents an incoming callback query from a callback button in an inline keyboard.
@@ -238,7 +297,8 @@ const (
ChatActionUploadDocument ChatActionType = "upload_document" ChatActionUploadDocument ChatActionType = "upload_document"
ChatActionChooseSticker ChatActionType = "choose_sticker" ChatActionChooseSticker ChatActionType = "choose_sticker"
ChatActionFindLocation ChatActionType = "find_location" ChatActionFindLocation ChatActionType = "find_location"
ChatActionUploadVideoNone ChatActionType = "upload_video_none" ChatActionUploadVideoNote ChatActionType = "upload_video_note"
ChatActionUploadVideoNone ChatActionType = ChatActionUploadVideoNote
) )
// MessageReactionUpdated represents a change of a reaction on a message. // MessageReactionUpdated represents a change of a reaction on a message.

View File

@@ -0,0 +1,37 @@
package tgapi
import (
"encoding/json"
"strings"
"testing"
)
func TestReplyKeyboardMarkupMarshalsKeyboardButtons(t *testing.T) {
markup := ReplyKeyboardMarkup{
Keyboard: [][]KeyboardButton{{
{
Text: "Create poll",
RequestPoll: &KeyboardButtonPollType{Type: PollTypeQuiz},
},
}},
}
data, err := json.Marshal(markup)
if err != nil {
t.Fatalf("Marshal returned error: %v", err)
}
got := string(data)
if !strings.Contains(got, `"keyboard":[[{"text":"Create poll","request_poll":{"type":"quiz"}}]]`) {
t.Fatalf("unexpected reply keyboard JSON: %s", got)
}
}
func TestChatActionUploadVideoNoteValue(t *testing.T) {
if ChatActionUploadVideoNote != "upload_video_note" {
t.Fatalf("unexpected chat action value: %q", ChatActionUploadVideoNote)
}
if ChatActionUploadVideoNone != ChatActionUploadVideoNote {
t.Fatalf("expected deprecated alias to match upload_video_note, got %q", ChatActionUploadVideoNone)
}
}

View File

@@ -1,9 +1,12 @@
package tgapi package tgapi
import ( import (
"context"
"fmt" "fmt"
"io" "io"
"net/http" "net/http"
"git.nix13.pw/scuroneko/laniakea/utils"
) )
// UpdateParams holds parameters for the getUpdates method. // UpdateParams holds parameters for the getUpdates method.
@@ -12,7 +15,7 @@ type UpdateParams struct {
Offset *int `json:"offset,omitempty"` Offset *int `json:"offset,omitempty"`
Limit *int `json:"limit,omitempty"` Limit *int `json:"limit,omitempty"`
Timeout *int `json:"timeout,omitempty"` Timeout *int `json:"timeout,omitempty"`
AllowedUpdates []UpdateType `json:"allowed_updates"` AllowedUpdates []UpdateType `json:"allowed_updates,omitempty"`
} }
// GetMe returns basic information about the bot. // GetMe returns basic information about the bot.
@@ -103,13 +106,31 @@ func (api *API) GetFile(params GetFileP) (File, error) {
// The link is usually obtained from File.FilePath. // The link is usually obtained from File.FilePath.
// See https://core.telegram.org/bots/api#file // See https://core.telegram.org/bots/api#file
func (api *API) GetFileByLink(link string) ([]byte, error) { func (api *API) GetFileByLink(link string) ([]byte, error) {
u := fmt.Sprintf("https://api.telegram.org/file/bot%s/%s", api.token, link) methodPrefix := ""
res, err := http.Get(u) if api.useTestServer {
methodPrefix = "/test"
}
u := fmt.Sprintf("%s/file/bot%s%s/%s", api.apiUrl, api.token, methodPrefix, link)
req, err := http.NewRequestWithContext(context.Background(), http.MethodGet, u, nil)
if err != nil {
return nil, err
}
req.Header.Set("User-Agent", fmt.Sprintf("Laniakea/%s", utils.VersionString))
res, err := api.client.Do(req)
if err != nil { if err != nil {
return nil, err return nil, err
} }
defer func() { defer func() {
_ = res.Body.Close() _ = res.Body.Close()
}() }()
if res.StatusCode < http.StatusOK || res.StatusCode >= http.StatusMultipleChoices {
body, readErr := io.ReadAll(io.LimitReader(res.Body, 4<<10))
if readErr != nil {
return nil, fmt.Errorf("unexpected status %d", res.StatusCode)
}
return nil, fmt.Errorf("unexpected status %d: %s", res.StatusCode, string(body))
}
return io.ReadAll(res.Body) return io.ReadAll(res.Body)
} }

115
tgapi/methods_test.go Normal file
View File

@@ -0,0 +1,115 @@
package tgapi
import (
"encoding/json"
"io"
"net/http"
"strings"
"testing"
)
func TestGetFileByLinkUsesConfiguredAPIURL(t *testing.T) {
var gotPath string
client := &http.Client{
Transport: roundTripFunc(func(req *http.Request) (*http.Response, error) {
gotPath = req.URL.Path
return &http.Response{
StatusCode: http.StatusOK,
Body: io.NopCloser(strings.NewReader("payload")),
}, nil
}),
}
api := NewAPI(
NewAPIOpts("token").
SetAPIUrl("https://example.test").
SetHTTPClient(client),
)
defer func() {
if err := api.CloseApi(); err != nil {
t.Fatalf("CloseApi returned error: %v", err)
}
}()
data, err := api.GetFileByLink("files/report.txt")
if err != nil {
t.Fatalf("GetFileByLink returned error: %v", err)
}
if string(data) != "payload" {
t.Fatalf("unexpected payload: %q", string(data))
}
if gotPath != "/file/bottoken/files/report.txt" {
t.Fatalf("unexpected request path: %s", gotPath)
}
}
func TestGetFileByLinkReturnsHTTPStatusError(t *testing.T) {
client := &http.Client{
Transport: roundTripFunc(func(req *http.Request) (*http.Response, error) {
return &http.Response{
StatusCode: http.StatusNotFound,
Body: io.NopCloser(strings.NewReader("missing\n")),
}, nil
}),
}
api := NewAPI(
NewAPIOpts("token").
SetAPIUrl("https://example.test").
SetHTTPClient(client),
)
defer func() {
if err := api.CloseApi(); err != nil {
t.Fatalf("CloseApi returned error: %v", err)
}
}()
_, err := api.GetFileByLink("files/report.txt")
if err == nil {
t.Fatal("expected error for non-2xx response")
}
}
func TestGetUpdatesOmitsAllowedUpdatesWhenEmpty(t *testing.T) {
var gotBody map[string]any
client := &http.Client{
Transport: roundTripFunc(func(req *http.Request) (*http.Response, error) {
body, err := io.ReadAll(req.Body)
if err != nil {
t.Fatalf("failed to read request body: %v", err)
}
if err := json.Unmarshal(body, &gotBody); err != nil {
t.Fatalf("failed to decode request body: %v", err)
}
return &http.Response{
StatusCode: http.StatusOK,
Header: http.Header{"Content-Type": []string{"application/json"}},
Body: io.NopCloser(strings.NewReader(`{"ok":true,"result":[]}`)),
}, nil
}),
}
api := NewAPI(
NewAPIOpts("token").
SetAPIUrl("https://example.test").
SetHTTPClient(client),
)
defer func() {
if err := api.CloseApi(); err != nil {
t.Fatalf("CloseApi returned error: %v", err)
}
}()
updates, err := api.GetUpdates(UpdateParams{})
if err != nil {
t.Fatalf("GetUpdates returned error: %v", err)
}
if len(updates) != 0 {
t.Fatalf("expected no updates, got %d", len(updates))
}
if _, exists := gotBody["allowed_updates"]; exists {
t.Fatalf("expected allowed_updates to be omitted, got %v", gotBody["allowed_updates"])
}
}

View File

@@ -1,5 +1,7 @@
package tgapi package tgapi
import "encoding/json"
// UpdateType represents the type of incoming update. // UpdateType represents the type of incoming update.
type UpdateType string type UpdateType string
@@ -23,8 +25,10 @@ const (
UpdateTypeBusinessMessage UpdateType = "business_message" UpdateTypeBusinessMessage UpdateType = "business_message"
// UpdateTypeEditedBusinessMessage is an edited business message update. // UpdateTypeEditedBusinessMessage is an edited business message update.
UpdateTypeEditedBusinessMessage UpdateType = "edited_business_message" UpdateTypeEditedBusinessMessage UpdateType = "edited_business_message"
// UpdateTypeDeletedBusinessMessage is a deleted business message update. // UpdateTypeDeletedBusinessMessages is a deleted business messages update.
UpdateTypeDeletedBusinessMessage UpdateType = "deleted_business_message" UpdateTypeDeletedBusinessMessages UpdateType = "deleted_business_messages"
// UpdateTypeDeletedBusinessMessage is kept as a backward-compatible alias.
UpdateTypeDeletedBusinessMessage UpdateType = UpdateTypeDeletedBusinessMessages
// UpdateTypeInlineQuery is an inline query update. // UpdateTypeInlineQuery is an inline query update.
UpdateTypeInlineQuery UpdateType = "inline_query" UpdateTypeInlineQuery UpdateType = "inline_query"
@@ -63,17 +67,18 @@ type Update struct {
ChannelPost *Message `json:"channel_post,omitempty"` ChannelPost *Message `json:"channel_post,omitempty"`
EditedChannelPost *Message `json:"edited_channel_post,omitempty"` EditedChannelPost *Message `json:"edited_channel_post,omitempty"`
BusinessConnection *BusinessConnection `json:"business_connection,omitempty"` BusinessConnection *BusinessConnection `json:"business_connection,omitempty"`
BusinessMessage *Message `json:"business_message,omitempty"` BusinessMessage *Message `json:"business_message,omitempty"`
EditedBusinessMessage *Message `json:"edited_business_message,omitempty"` EditedBusinessMessage *Message `json:"edited_business_message,omitempty"`
DeletedBusinessMessage *Message `json:"deleted_business_messages,omitempty"` DeletedBusinessMessages *BusinessMessagesDeleted `json:"deleted_business_messages,omitempty"`
MessageReaction *MessageReactionUpdated `json:"message_reaction,omitempty"` DeletedBusinessMessage *BusinessMessagesDeleted `json:"-"`
MessageReactionCount *MessageReactionCountUpdated `json:"message_reaction_count,omitempty"` MessageReaction *MessageReactionUpdated `json:"message_reaction,omitempty"`
MessageReactionCount *MessageReactionCountUpdated `json:"message_reaction_count,omitempty"`
InlineQuery *InlineQuery `json:"inline_query,omitempty"` InlineQuery *InlineQuery `json:"inline_query,omitempty"`
ChosenInlineResult *ChosenInlineResult `json:"chosen_inline_result,omitempty"` ChosenInlineResult *ChosenInlineResult `json:"chosen_inline_result,omitempty"`
CallbackQuery *CallbackQuery `json:"callback_query,omitempty"` CallbackQuery *CallbackQuery `json:"callback_query,omitempty"`
ShippingQuery ShippingQuery `json:"shipping_query,omitempty"` ShippingQuery *ShippingQuery `json:"shipping_query,omitempty"`
PreCheckoutQuery *PreCheckoutQuery `json:"pre_checkout_query,omitempty"` PreCheckoutQuery *PreCheckoutQuery `json:"pre_checkout_query,omitempty"`
PurchasedPaidMedia *PaidMediaPurchased `json:"purchased_paid_media,omitempty"` PurchasedPaidMedia *PaidMediaPurchased `json:"purchased_paid_media,omitempty"`
@@ -86,6 +91,35 @@ type Update struct {
RemovedChatBoost *ChatBoostRemoved `json:"removed_chat_boost,omitempty"` RemovedChatBoost *ChatBoostRemoved `json:"removed_chat_boost,omitempty"`
} }
func (u *Update) syncDeletedBusinessMessages() {
if u.DeletedBusinessMessages != nil {
u.DeletedBusinessMessage = u.DeletedBusinessMessages
return
}
if u.DeletedBusinessMessage != nil {
u.DeletedBusinessMessages = u.DeletedBusinessMessage
}
}
// UnmarshalJSON keeps the deprecated DeletedBusinessMessage alias in sync.
func (u *Update) UnmarshalJSON(data []byte) error {
type alias Update
var aux alias
if err := json.Unmarshal(data, &aux); err != nil {
return err
}
*u = Update(aux)
u.syncDeletedBusinessMessages()
return nil
}
// MarshalJSON emits the canonical deleted_business_messages field.
func (u Update) MarshalJSON() ([]byte, error) {
u.syncDeletedBusinessMessages()
type alias Update
return json.Marshal(alias(u))
}
// InlineQuery represents an incoming inline query. // InlineQuery represents an incoming inline query.
// See https://core.telegram.org/bots/api#inlinequery // See https://core.telegram.org/bots/api#inlinequery
type InlineQuery struct { type InlineQuery struct {

69
tgapi/types_test.go Normal file
View File

@@ -0,0 +1,69 @@
package tgapi
import (
"encoding/json"
"strings"
"testing"
)
func TestUpdateDeletedBusinessMessagesUnmarshalSetsAlias(t *testing.T) {
var update Update
err := json.Unmarshal([]byte(`{
"update_id": 1,
"deleted_business_messages": {
"business_connection_id": "conn",
"chat": {"id": 42, "type": "private"},
"message_ids": [3, 5]
}
}`), &update)
if err != nil {
t.Fatalf("Unmarshal returned error: %v", err)
}
if update.DeletedBusinessMessages == nil {
t.Fatal("expected DeletedBusinessMessages to be populated")
}
if update.DeletedBusinessMessage == nil {
t.Fatal("expected deprecated DeletedBusinessMessage alias to be populated")
}
if update.DeletedBusinessMessages != update.DeletedBusinessMessage {
t.Fatal("expected deleted business message fields to share the same payload")
}
if got := update.DeletedBusinessMessages.MessageIDs; len(got) != 2 || got[0] != 3 || got[1] != 5 {
t.Fatalf("unexpected message ids: %v", got)
}
}
func TestUpdateMarshalUsesCanonicalDeletedBusinessMessagesField(t *testing.T) {
update := Update{
UpdateID: 1,
DeletedBusinessMessage: &BusinessMessagesDeleted{
BusinessConnectionID: "conn",
Chat: Chat{ID: 42, Type: string(ChatTypePrivate)},
MessageIDs: []int{7},
},
}
data, err := json.Marshal(update)
if err != nil {
t.Fatalf("Marshal returned error: %v", err)
}
got := string(data)
if !strings.Contains(got, `"deleted_business_messages"`) {
t.Fatalf("expected canonical deleted_business_messages field, got %s", got)
}
if strings.Contains(got, `"deleted_business_message"`) {
t.Fatalf("unexpected singular deleted_business_message field, got %s", got)
}
}
func TestUpdateShippingQueryIsNilWhenAbsent(t *testing.T) {
var update Update
if err := json.Unmarshal([]byte(`{"update_id":1}`), &update); err != nil {
t.Fatalf("Unmarshal returned error: %v", err)
}
if update.ShippingQuery != nil {
t.Fatalf("expected ShippingQuery to be nil, got %+v", update.ShippingQuery)
}
}

View File

@@ -126,7 +126,6 @@ func (r UploaderRequest[R, P]) doRequest(ctx context.Context, up *Uploader) (R,
req.Header.Set("Content-Type", contentType) req.Header.Set("Content-Type", contentType)
req.Header.Set("Accept", "application/json") req.Header.Set("Accept", "application/json")
req.Header.Set("User-Agent", fmt.Sprintf("Laniakea/%s", utils.VersionString)) req.Header.Set("User-Agent", fmt.Sprintf("Laniakea/%s", utils.VersionString))
req.Header.Set("Accept-Encoding", "gzip")
req.ContentLength = int64(buf.Len()) req.ContentLength = int64(buf.Len())
up.logger.Debugln("UPLOADER REQ", r.method) up.logger.Debugln("UPLOADER REQ", r.method)

138
tgapi/uploader_api_test.go Normal file
View File

@@ -0,0 +1,138 @@
package tgapi
import (
"fmt"
"io"
"mime"
"mime/multipart"
"net/http"
"strings"
"testing"
)
func TestUploaderEncodesJSONFieldsAndLeavesAcceptEncodingToHTTPTransport(t *testing.T) {
var (
gotPath string
gotAcceptEncoding string
gotFields map[string]string
gotFileName string
gotFileData []byte
roundTripErr error
)
client := &http.Client{
Transport: roundTripFunc(func(req *http.Request) (*http.Response, error) {
gotPath = req.URL.Path
gotAcceptEncoding = req.Header.Get("Accept-Encoding")
gotFields, gotFileName, gotFileData, roundTripErr = readMultipartRequest(req)
if roundTripErr != nil {
roundTripErr = fmt.Errorf("readMultipartRequest: %w", roundTripErr)
}
return &http.Response{
StatusCode: http.StatusOK,
Header: http.Header{"Content-Type": []string{"application/json"}},
Body: io.NopCloser(strings.NewReader(`{"ok":true,"result":{"message_id":5,"date":1}}`)),
}, nil
}),
}
api := NewAPI(
NewAPIOpts("token").
SetAPIUrl("https://example.test").
SetHTTPClient(client),
)
defer func() {
if err := api.CloseApi(); err != nil {
t.Fatalf("CloseApi returned error: %v", err)
}
}()
uploader := NewUploader(api)
defer func() {
if err := uploader.Close(); err != nil {
t.Fatalf("Close returned error: %v", err)
}
}()
msg, err := uploader.SendPhoto(
UploadPhotoP{
ChatID: 42,
CaptionEntities: []MessageEntity{{
Type: MessageEntityBold,
Offset: 0,
Length: 4,
}},
ReplyMarkup: &ReplyMarkup{
InlineKeyboard: [][]InlineKeyboardButton{{
{Text: "A", CallbackData: "b"},
}},
},
},
NewUploaderFile("photo.jpg", []byte("img")),
)
if err != nil {
t.Fatalf("SendPhoto returned error: %v", err)
}
if msg.MessageID != 5 {
t.Fatalf("unexpected message id: %d", msg.MessageID)
}
if roundTripErr != nil {
t.Fatalf("multipart parse failed: %v", roundTripErr)
}
if gotPath != "/bottoken/sendPhoto" {
t.Fatalf("unexpected request path: %s", gotPath)
}
if gotAcceptEncoding != "" {
t.Fatalf("expected empty Accept-Encoding header, got %q", gotAcceptEncoding)
}
if got := gotFields["chat_id"]; got != "42" {
t.Fatalf("chat_id mismatch: %q", got)
}
if got := gotFields["caption_entities"]; got != `[{"type":"bold","offset":0,"length":4}]` {
t.Fatalf("caption_entities mismatch: %q", got)
}
if got := gotFields["reply_markup"]; got != `{"inline_keyboard":[[{"text":"A","callback_data":"b"}]]}` {
t.Fatalf("reply_markup mismatch: %q", got)
}
if gotFileName != "photo.jpg" {
t.Fatalf("unexpected file name: %q", gotFileName)
}
if string(gotFileData) != "img" {
t.Fatalf("unexpected file content: %q", string(gotFileData))
}
}
func readMultipartRequest(req *http.Request) (map[string]string, string, []byte, error) {
_, params, err := mime.ParseMediaType(req.Header.Get("Content-Type"))
if err != nil {
return nil, "", nil, err
}
reader := multipart.NewReader(req.Body, params["boundary"])
fields := make(map[string]string)
var fileName string
var fileData []byte
for {
part, err := reader.NextPart()
if err == io.EOF {
return fields, fileName, fileData, nil
}
if err != nil {
return nil, "", nil, err
}
data, err := io.ReadAll(part)
if err != nil {
return nil, "", nil, err
}
if part.FileName() != "" {
fileName = part.FileName()
fileData = data
continue
}
fields[part.FormName()] = string(data)
}
}

View File

@@ -1,6 +1,7 @@
package utils package utils
import ( import (
"encoding/json"
"fmt" "fmt"
"io" "io"
"mime/multipart" "mime/multipart"
@@ -12,12 +13,8 @@ import (
// Encode writes struct fields into multipart form-data using json tags as field names. // Encode writes struct fields into multipart form-data using json tags as field names.
func Encode[T any](w *multipart.Writer, req T) error { func Encode[T any](w *multipart.Writer, req T) error {
v := reflect.ValueOf(req) v := unwrapMultipartValue(reflect.ValueOf(req))
if v.Kind() == reflect.Ptr { if !v.IsValid() || v.Kind() != reflect.Struct {
v = v.Elem()
}
if v.Kind() != reflect.Struct {
return fmt.Errorf("req must be a struct") return fmt.Errorf("req must be a struct")
} }
@@ -33,6 +30,9 @@ func Encode[T any](w *multipart.Writer, req T) error {
parts := strings.Split(jsonTag, ",") parts := strings.Split(jsonTag, ",")
fieldName := parts[0] fieldName := parts[0]
if fieldName == "" {
fieldName = fieldType.Name
}
if fieldName == "-" { if fieldName == "-" {
continue continue
} }
@@ -43,96 +43,73 @@ func Encode[T any](w *multipart.Writer, req T) error {
continue continue
} }
var ( if err := writeMultipartField(w, fieldName, fieldType.Tag.Get("filename"), field); err != nil {
fw io.Writer
err error
)
switch field.Kind() {
case reflect.String:
fw, err = w.CreateFormField(fieldName)
if err == nil {
_, err = fw.Write([]byte(field.String()))
}
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
fw, err = w.CreateFormField(fieldName)
if err == nil {
_, err = fw.Write([]byte(strconv.FormatInt(field.Int(), 10)))
}
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
fw, err = w.CreateFormField(fieldName)
if err == nil {
_, err = fw.Write([]byte(strconv.FormatUint(field.Uint(), 10)))
}
case reflect.Float32:
fw, err = w.CreateFormField(fieldName)
if err == nil {
_, err = fw.Write([]byte(strconv.FormatFloat(field.Float(), 'f', -1, 32)))
}
case reflect.Float64:
fw, err = w.CreateFormField(fieldName)
if err == nil {
_, err = fw.Write([]byte(strconv.FormatFloat(field.Float(), 'f', -1, 64)))
}
case reflect.Bool:
fw, err = w.CreateFormField(fieldName)
if err == nil {
_, err = fw.Write([]byte(strconv.FormatBool(field.Bool())))
}
case reflect.Slice:
if field.Type().Elem().Kind() == reflect.Uint8 && !field.IsNil() {
// Handle []byte as file upload (e.g., thumbnail)
filename := fieldType.Tag.Get("filename")
if filename == "" {
filename = fieldName
}
fw, err = w.CreateFormFile(fieldName, filename)
if err == nil {
_, err = fw.Write(field.Bytes())
}
} else if !field.IsNil() {
// Handle []string, []int, etc. — send as multiple fields with same name
for j := 0; j < field.Len(); j++ {
elem := field.Index(j)
fw, err = w.CreateFormField(fieldName)
if err != nil {
break
}
switch elem.Kind() {
case reflect.String:
_, err = fw.Write([]byte(elem.String()))
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
_, err = fw.Write([]byte(strconv.FormatInt(elem.Int(), 10)))
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
_, err = fw.Write([]byte(strconv.FormatUint(elem.Uint(), 10)))
case reflect.Bool:
_, err = fw.Write([]byte(strconv.FormatBool(elem.Bool())))
case reflect.Float32:
_, err = fw.Write([]byte(strconv.FormatFloat(elem.Float(), 'f', -1, 32)))
case reflect.Float64:
_, err = fw.Write([]byte(strconv.FormatFloat(elem.Float(), 'f', -1, 64)))
default:
continue
}
if err != nil {
break
}
}
}
case reflect.Struct:
// Don't serialize structs as JSON — flatten them!
// Telegram doesn't support nested JSON in form-data.
// If you need nested data, use separate fields (e.g., ParseMode, CaptionEntities)
// This is a design choice — you should avoid nested structs in params.
return fmt.Errorf("nested structs are not supported in params — use flat fields")
}
if err != nil {
return err return err
} }
} }
return nil return nil
} }
func unwrapMultipartValue(v reflect.Value) reflect.Value {
for v.IsValid() && (v.Kind() == reflect.Pointer || v.Kind() == reflect.Interface) {
if v.IsNil() {
return reflect.Value{}
}
v = v.Elem()
}
return v
}
func writeMultipartField(w *multipart.Writer, fieldName, filename string, field reflect.Value) error {
value := unwrapMultipartValue(field)
if !value.IsValid() {
return nil
}
switch value.Kind() {
case reflect.String:
return writeMultipartValue(w, fieldName, []byte(value.String()))
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
return writeMultipartValue(w, fieldName, []byte(strconv.FormatInt(value.Int(), 10)))
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
return writeMultipartValue(w, fieldName, []byte(strconv.FormatUint(value.Uint(), 10)))
case reflect.Float32:
return writeMultipartValue(w, fieldName, []byte(strconv.FormatFloat(value.Float(), 'f', -1, 32)))
case reflect.Float64:
return writeMultipartValue(w, fieldName, []byte(strconv.FormatFloat(value.Float(), 'f', -1, 64)))
case reflect.Bool:
return writeMultipartValue(w, fieldName, []byte(strconv.FormatBool(value.Bool())))
case reflect.Slice:
if value.Type().Elem().Kind() == reflect.Uint8 {
if filename == "" {
filename = fieldName
}
fw, err := w.CreateFormFile(fieldName, filename)
if err != nil {
return err
}
_, err = fw.Write(value.Bytes())
return err
}
}
// Telegram expects nested objects and arrays in multipart requests as JSON strings.
data, err := json.Marshal(value.Interface())
if err != nil {
return err
}
if string(data) == "null" {
return nil
}
return writeMultipartValue(w, fieldName, data)
}
func writeMultipartValue(w *multipart.Writer, fieldName string, value []byte) error {
fw, err := w.CreateFormField(fieldName)
if err != nil {
return err
}
_, err = io.Copy(fw, strings.NewReader(string(value)))
return err
}

85
utils/multipart_test.go Normal file
View File

@@ -0,0 +1,85 @@
package utils_test
import (
"bytes"
"io"
"mime/multipart"
"testing"
"git.nix13.pw/scuroneko/laniakea/tgapi"
"git.nix13.pw/scuroneko/laniakea/utils"
)
type multipartEncodeParams struct {
ChatID int64 `json:"chat_id"`
MessageThreadID *int `json:"message_thread_id,omitempty"`
ReplyMarkup *tgapi.ReplyMarkup `json:"reply_markup,omitempty"`
CaptionEntities []tgapi.MessageEntity `json:"caption_entities,omitempty"`
ReplyParameters *tgapi.ReplyParameters `json:"reply_parameters,omitempty"`
}
func TestEncodeMultipartJSONFields(t *testing.T) {
threadID := 7
params := multipartEncodeParams{
ChatID: 42,
MessageThreadID: &threadID,
ReplyMarkup: &tgapi.ReplyMarkup{
InlineKeyboard: [][]tgapi.InlineKeyboardButton{{
{Text: "A", CallbackData: "b"},
}},
},
CaptionEntities: []tgapi.MessageEntity{{
Type: tgapi.MessageEntityBold,
Offset: 0,
Length: 4,
}},
}
body := bytes.NewBuffer(nil)
writer := multipart.NewWriter(body)
if err := utils.Encode(writer, params); err != nil {
t.Fatalf("Encode returned error: %v", err)
}
if err := writer.Close(); err != nil {
t.Fatalf("writer.Close returned error: %v", err)
}
got := readMultipartFields(t, body.Bytes(), writer.Boundary())
if got["chat_id"] != "42" {
t.Fatalf("chat_id mismatch: %q", got["chat_id"])
}
if got["message_thread_id"] != "7" {
t.Fatalf("message_thread_id mismatch: %q", got["message_thread_id"])
}
if got["reply_markup"] != `{"inline_keyboard":[[{"text":"A","callback_data":"b"}]]}` {
t.Fatalf("reply_markup mismatch: %q", got["reply_markup"])
}
if got["caption_entities"] != `[{"type":"bold","offset":0,"length":4}]` {
t.Fatalf("caption_entities mismatch: %q", got["caption_entities"])
}
if _, ok := got["reply_parameters"]; ok {
t.Fatalf("reply_parameters should be omitted when nil")
}
}
func readMultipartFields(t *testing.T, body []byte, boundary string) map[string]string {
t.Helper()
reader := multipart.NewReader(bytes.NewReader(body), boundary)
fields := make(map[string]string)
for {
part, err := reader.NextPart()
if err == io.EOF {
return fields
}
if err != nil {
t.Fatalf("NextPart returned error: %v", err)
}
data, err := io.ReadAll(part)
if err != nil {
t.Fatalf("ReadAll returned error: %v", err)
}
fields[part.FormName()] = string(data)
}
}