Compare commits
1 Commits
v1.0.0-bet
...
dev
| Author | SHA1 | Date | |
|---|---|---|---|
|
4ebe76dd4a
|
17
bot.go
17
bot.go
@@ -184,21 +184,34 @@ func NewBot[T any](opts *BotOpts) *Bot[T] {
|
|||||||
//
|
//
|
||||||
// Returns the first error encountered, if any.
|
// Returns the first error encountered, if any.
|
||||||
func (bot *Bot[T]) Close() error {
|
func (bot *Bot[T]) Close() error {
|
||||||
|
var firstErr error
|
||||||
|
|
||||||
if err := bot.uploader.Close(); err != nil {
|
if err := bot.uploader.Close(); err != nil {
|
||||||
bot.logger.Errorln(err)
|
bot.logger.Errorln(err)
|
||||||
|
if firstErr == nil {
|
||||||
|
firstErr = err
|
||||||
|
}
|
||||||
}
|
}
|
||||||
if err := bot.api.CloseApi(); err != nil {
|
if err := bot.api.CloseApi(); err != nil {
|
||||||
bot.logger.Errorln(err)
|
bot.logger.Errorln(err)
|
||||||
|
if firstErr == nil {
|
||||||
|
firstErr = err
|
||||||
|
}
|
||||||
}
|
}
|
||||||
if bot.RequestLogger != nil {
|
if bot.RequestLogger != nil {
|
||||||
if err := bot.RequestLogger.Close(); err != nil {
|
if err := bot.RequestLogger.Close(); err != nil {
|
||||||
bot.logger.Errorln(err)
|
bot.logger.Errorln(err)
|
||||||
|
if firstErr == nil {
|
||||||
|
firstErr = err
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if err := bot.logger.Close(); err != nil {
|
if err := bot.logger.Close(); err != nil {
|
||||||
return err
|
if firstErr == nil {
|
||||||
|
firstErr = err
|
||||||
}
|
}
|
||||||
return nil
|
}
|
||||||
|
return firstErr
|
||||||
}
|
}
|
||||||
|
|
||||||
// initLoggers configures the main and optional request loggers.
|
// initLoggers configures the main and optional request loggers.
|
||||||
|
|||||||
30
bot_opts.go
30
bot_opts.go
@@ -85,10 +85,10 @@ func LoadOptsFromEnv() *BotOpts {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
stringUpdateTypes := strings.Split(os.Getenv("UPDATE_TYPES"), ";")
|
stringUpdateTypes := splitEnvList(os.Getenv("UPDATE_TYPES"))
|
||||||
updateTypes := make([]tgapi.UpdateType, len(stringUpdateTypes))
|
updateTypes := make([]tgapi.UpdateType, 0, len(stringUpdateTypes))
|
||||||
for i, updateType := range stringUpdateTypes {
|
for _, updateType := range stringUpdateTypes {
|
||||||
updateTypes[i] = tgapi.UpdateType(updateType)
|
updateTypes = append(updateTypes, tgapi.UpdateType(updateType))
|
||||||
}
|
}
|
||||||
|
|
||||||
return &BotOpts{
|
return &BotOpts{
|
||||||
@@ -222,5 +222,25 @@ func LoadPrefixesFromEnv() []string {
|
|||||||
if !exists {
|
if !exists {
|
||||||
return []string{"/"}
|
return []string{"/"}
|
||||||
}
|
}
|
||||||
return strings.Split(prefixesS, ";")
|
prefixes := splitEnvList(prefixesS)
|
||||||
|
if len(prefixes) == 0 {
|
||||||
|
return []string{"/"}
|
||||||
|
}
|
||||||
|
return prefixes
|
||||||
|
}
|
||||||
|
|
||||||
|
func splitEnvList(value string) []string {
|
||||||
|
if value == "" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
parts := strings.Split(value, ";")
|
||||||
|
out := make([]string, 0, len(parts))
|
||||||
|
for _, part := range parts {
|
||||||
|
part = strings.TrimSpace(part)
|
||||||
|
if part == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
out = append(out, part)
|
||||||
|
}
|
||||||
|
return out
|
||||||
}
|
}
|
||||||
|
|||||||
47
bot_opts_test.go
Normal file
47
bot_opts_test.go
Normal file
@@ -0,0 +1,47 @@
|
|||||||
|
package laniakea
|
||||||
|
|
||||||
|
import (
|
||||||
|
"reflect"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"git.nix13.pw/scuroneko/laniakea/tgapi"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestLoadOptsFromEnvIgnoresEmptyUpdateTypes(t *testing.T) {
|
||||||
|
t.Setenv("UPDATE_TYPES", "")
|
||||||
|
|
||||||
|
opts := LoadOptsFromEnv()
|
||||||
|
if len(opts.UpdateTypes) != 0 {
|
||||||
|
t.Fatalf("expected no update types, got %v", opts.UpdateTypes)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLoadOptsFromEnvSplitsAndTrimsUpdateTypes(t *testing.T) {
|
||||||
|
t.Setenv("UPDATE_TYPES", "message; ; callback_query ")
|
||||||
|
|
||||||
|
opts := LoadOptsFromEnv()
|
||||||
|
want := []tgapi.UpdateType{tgapi.UpdateTypeMessage, tgapi.UpdateTypeCallbackQuery}
|
||||||
|
if !reflect.DeepEqual(opts.UpdateTypes, want) {
|
||||||
|
t.Fatalf("unexpected update types: got %v want %v", opts.UpdateTypes, want)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLoadPrefixesFromEnvDefaultsOnEmptyValue(t *testing.T) {
|
||||||
|
t.Setenv("PREFIXES", "")
|
||||||
|
|
||||||
|
got := LoadPrefixesFromEnv()
|
||||||
|
want := []string{"/"}
|
||||||
|
if !reflect.DeepEqual(got, want) {
|
||||||
|
t.Fatalf("unexpected prefixes: got %v want %v", got, want)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLoadPrefixesFromEnvDropsEmptyValues(t *testing.T) {
|
||||||
|
t.Setenv("PREFIXES", "/; ; ! ")
|
||||||
|
|
||||||
|
got := LoadPrefixesFromEnv()
|
||||||
|
want := []string{"/", "!"}
|
||||||
|
if !reflect.DeepEqual(got, want) {
|
||||||
|
t.Fatalf("unexpected prefixes: got %v want %v", got, want)
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -100,17 +100,17 @@ func gatherCommands[T any](bot *Bot[T]) []tgapi.BotCommand {
|
|||||||
// log.Fatal(err)
|
// log.Fatal(err)
|
||||||
// }
|
// }
|
||||||
func (bot *Bot[T]) AutoGenerateCommands() error {
|
func (bot *Bot[T]) AutoGenerateCommands() error {
|
||||||
|
commands := gatherCommands(bot)
|
||||||
|
if len(commands) > 100 {
|
||||||
|
return ErrTooManyCommands
|
||||||
|
}
|
||||||
|
|
||||||
// Clear existing commands to avoid duplication or stale entries
|
// Clear existing commands to avoid duplication or stale entries
|
||||||
_, err := bot.api.DeleteMyCommands(tgapi.DeleteMyCommandsP{})
|
_, err := bot.api.DeleteMyCommands(tgapi.DeleteMyCommandsP{})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to delete existing commands: %w", err)
|
return fmt.Errorf("failed to delete existing commands: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
commands := gatherCommands(bot)
|
|
||||||
if len(commands) > 100 {
|
|
||||||
return ErrTooManyCommands
|
|
||||||
}
|
|
||||||
|
|
||||||
// Register commands for each scope
|
// Register commands for each scope
|
||||||
scopes := []*tgapi.BotCommandScope{
|
scopes := []*tgapi.BotCommandScope{
|
||||||
{Type: tgapi.BotCommandScopePrivateType},
|
{Type: tgapi.BotCommandScopePrivateType},
|
||||||
@@ -148,15 +148,16 @@ func (bot *Bot[T]) AutoGenerateCommands() error {
|
|||||||
// log.Fatal(err)
|
// log.Fatal(err)
|
||||||
// }
|
// }
|
||||||
func (bot *Bot[T]) AutoGenerateCommandsForScope(scope *tgapi.BotCommandScope) error {
|
func (bot *Bot[T]) AutoGenerateCommandsForScope(scope *tgapi.BotCommandScope) error {
|
||||||
_, err := bot.api.DeleteMyCommands(tgapi.DeleteMyCommandsP{Scope: scope})
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to delete existing commands: %w", err)
|
|
||||||
}
|
|
||||||
commands := gatherCommands(bot)
|
commands := gatherCommands(bot)
|
||||||
if len(commands) > 100 {
|
if len(commands) > 100 {
|
||||||
return ErrTooManyCommands
|
return ErrTooManyCommands
|
||||||
}
|
}
|
||||||
|
|
||||||
|
_, err := bot.api.DeleteMyCommands(tgapi.DeleteMyCommandsP{Scope: scope})
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to delete existing commands: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
_, err = bot.api.SetMyCommands(tgapi.SetMyCommandsP{
|
_, err = bot.api.SetMyCommands(tgapi.SetMyCommandsP{
|
||||||
Commands: commands,
|
Commands: commands,
|
||||||
Scope: scope,
|
Scope: scope,
|
||||||
|
|||||||
64
cmd_generator_test.go
Normal file
64
cmd_generator_test.go
Normal file
@@ -0,0 +1,64 @@
|
|||||||
|
package laniakea
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
"sync/atomic"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"git.nix13.pw/scuroneko/laniakea/tgapi"
|
||||||
|
"git.nix13.pw/scuroneko/slog"
|
||||||
|
)
|
||||||
|
|
||||||
|
type roundTripFunc func(*http.Request) (*http.Response, error)
|
||||||
|
|
||||||
|
func (fn roundTripFunc) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||||
|
return fn(req)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAutoGenerateCommandsChecksLimitBeforeDelete(t *testing.T) {
|
||||||
|
var calls atomic.Int64
|
||||||
|
client := &http.Client{
|
||||||
|
Transport: roundTripFunc(func(req *http.Request) (*http.Response, error) {
|
||||||
|
calls.Add(1)
|
||||||
|
return &http.Response{
|
||||||
|
StatusCode: http.StatusOK,
|
||||||
|
Header: http.Header{"Content-Type": []string{"application/json"}},
|
||||||
|
Body: io.NopCloser(strings.NewReader(`{"ok":true,"result":true}`)),
|
||||||
|
}, nil
|
||||||
|
}),
|
||||||
|
}
|
||||||
|
api := tgapi.NewAPI(
|
||||||
|
tgapi.NewAPIOpts("token").
|
||||||
|
SetAPIUrl("https://example.test").
|
||||||
|
SetHTTPClient(client),
|
||||||
|
)
|
||||||
|
defer func() {
|
||||||
|
if err := api.CloseApi(); err != nil {
|
||||||
|
t.Fatalf("CloseApi returned error: %v", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
plugin := NewPlugin[NoDB]("overflow")
|
||||||
|
exec := func(ctx *MsgContext, db *NoDB) {}
|
||||||
|
for i := 0; i < 101; i++ {
|
||||||
|
plugin.AddCommand(NewCommand(exec, "cmd"+strconv.Itoa(i)))
|
||||||
|
}
|
||||||
|
|
||||||
|
bot := &Bot[NoDB]{
|
||||||
|
api: api,
|
||||||
|
logger: slog.CreateLogger(),
|
||||||
|
plugins: []Plugin[NoDB]{*plugin},
|
||||||
|
}
|
||||||
|
|
||||||
|
err := bot.AutoGenerateCommands()
|
||||||
|
if !errors.Is(err, ErrTooManyCommands) {
|
||||||
|
t.Fatalf("expected ErrTooManyCommands, got %v", err)
|
||||||
|
}
|
||||||
|
if calls.Load() != 0 {
|
||||||
|
t.Fatalf("expected no HTTP calls before limit validation, got %d", calls.Load())
|
||||||
|
}
|
||||||
|
}
|
||||||
4
go.mod
4
go.mod
@@ -3,8 +3,8 @@ module git.nix13.pw/scuroneko/laniakea
|
|||||||
go 1.26
|
go 1.26
|
||||||
|
|
||||||
require (
|
require (
|
||||||
git.nix13.pw/scuroneko/extypes v1.2.1
|
git.nix13.pw/scuroneko/extypes v1.2.2
|
||||||
git.nix13.pw/scuroneko/slog v1.0.2
|
git.nix13.pw/scuroneko/slog v1.1.2
|
||||||
github.com/alitto/pond/v2 v2.7.0
|
github.com/alitto/pond/v2 v2.7.0
|
||||||
golang.org/x/time v0.15.0
|
golang.org/x/time v0.15.0
|
||||||
)
|
)
|
||||||
|
|||||||
8
go.sum
8
go.sum
@@ -1,7 +1,7 @@
|
|||||||
git.nix13.pw/scuroneko/extypes v1.2.1 h1:IYrOjnWKL2EAuJYtYNa+luB1vBe6paE8VY/YD+5/RpQ=
|
git.nix13.pw/scuroneko/extypes v1.2.2 h1:N54c1ejrPs1yfIkvYuwqI7B1+8S9mDv2GqQA6sct4dk=
|
||||||
git.nix13.pw/scuroneko/extypes v1.2.1/go.mod h1:uZVs8Yo3RrYAG9dMad6qR6lsYY67t+459D9c65QAYAw=
|
git.nix13.pw/scuroneko/extypes v1.2.2/go.mod h1:b4XYk1OW1dVSiE2MT/OMuX/K/UItf1swytX6eroVYnk=
|
||||||
git.nix13.pw/scuroneko/slog v1.0.2 h1:vZyUROygxC2d5FJHUQM/30xFEHY1JT/aweDZXA4rm2g=
|
git.nix13.pw/scuroneko/slog v1.1.2 h1:pl7tV5FN25Yso7sLYoOgBXi9+jLo5BDJHWmHlNPjpY0=
|
||||||
git.nix13.pw/scuroneko/slog v1.0.2/go.mod h1:3Qm2wzkR5KjwOponMfG7TcGSDjmYaFqRAmLvSPTuWJI=
|
git.nix13.pw/scuroneko/slog v1.1.2/go.mod h1:UcfRIHDqpVQHahBGM93awLDK8//AsAvOqBwwbWqMkjM=
|
||||||
github.com/alitto/pond/v2 v2.7.0 h1:c76L+yN916m/DRXjGCeUBHHu92uWnh/g1bwVk4zyyXg=
|
github.com/alitto/pond/v2 v2.7.0 h1:c76L+yN916m/DRXjGCeUBHHu92uWnh/g1bwVk4zyyXg=
|
||||||
github.com/alitto/pond/v2 v2.7.0/go.mod h1:xkjYEgQ05RSpWdfSd1nM3OVv7TBhLdy7rMp3+2Nq+yE=
|
github.com/alitto/pond/v2 v2.7.0/go.mod h1:xkjYEgQ05RSpWdfSd1nM3OVv7TBhLdy7rMp3+2Nq+yE=
|
||||||
github.com/fatih/color v1.18.0 h1:S8gINlzdQ840/4pfAwic/ZE0djQEH3wM94VfqLTZcOM=
|
github.com/fatih/color v1.18.0 h1:S8gINlzdQ840/4pfAwic/ZE0djQEH3wM94VfqLTZcOM=
|
||||||
|
|||||||
@@ -138,6 +138,9 @@ func (bot *Bot[T]) handleCallback(update *tgapi.Update, ctx *MsgContext) {
|
|||||||
|
|
||||||
func (bot *Bot[T]) checkPrefixes(text string) (string, bool) {
|
func (bot *Bot[T]) checkPrefixes(text string) (string, bool) {
|
||||||
for _, prefix := range bot.prefixes {
|
for _, prefix := range bot.prefixes {
|
||||||
|
if prefix == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
if strings.HasPrefix(text, prefix) {
|
if strings.HasPrefix(text, prefix) {
|
||||||
return prefix, true
|
return prefix, true
|
||||||
}
|
}
|
||||||
|
|||||||
14
handler_test.go
Normal file
14
handler_test.go
Normal file
@@ -0,0 +1,14 @@
|
|||||||
|
package laniakea
|
||||||
|
|
||||||
|
import "testing"
|
||||||
|
|
||||||
|
func TestCheckPrefixesSkipsEmptyPrefixes(t *testing.T) {
|
||||||
|
bot := &Bot[NoDB]{prefixes: []string{"", "/"}}
|
||||||
|
|
||||||
|
if prefix, ok := bot.checkPrefixes("hello"); ok {
|
||||||
|
t.Fatalf("unexpected prefix match for plain text: %q", prefix)
|
||||||
|
}
|
||||||
|
if prefix, ok := bot.checkPrefixes("/start"); !ok || prefix != "/" {
|
||||||
|
t.Fatalf("unexpected prefix result: prefix=%q ok=%v", prefix, ok)
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -270,6 +270,9 @@ func (ctx *MsgContext) answerPhoto(photoId, text string, kb *InlineKeyboard, par
|
|||||||
if ctx.Msg.MessageThreadID > 0 {
|
if ctx.Msg.MessageThreadID > 0 {
|
||||||
params.MessageThreadID = ctx.Msg.MessageThreadID
|
params.MessageThreadID = ctx.Msg.MessageThreadID
|
||||||
}
|
}
|
||||||
|
if ctx.Msg.DirectMessageTopic != nil {
|
||||||
|
params.DirectMessagesTopicID = int(ctx.Msg.DirectMessageTopic.TopicID)
|
||||||
|
}
|
||||||
|
|
||||||
msg, err := ctx.Api.SendPhoto(params)
|
msg, err := ctx.Api.SendPhoto(params)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
64
msg_context_test.go
Normal file
64
msg_context_test.go
Normal file
@@ -0,0 +1,64 @@
|
|||||||
|
package laniakea
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"git.nix13.pw/scuroneko/laniakea/tgapi"
|
||||||
|
"git.nix13.pw/scuroneko/slog"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestAnswerPhotoIncludesDirectMessagesTopicID(t *testing.T) {
|
||||||
|
var gotBody map[string]any
|
||||||
|
|
||||||
|
client := &http.Client{
|
||||||
|
Transport: roundTripFunc(func(req *http.Request) (*http.Response, error) {
|
||||||
|
body, err := io.ReadAll(req.Body)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to read request body: %v", err)
|
||||||
|
}
|
||||||
|
if err := json.Unmarshal(body, &gotBody); err != nil {
|
||||||
|
t.Fatalf("failed to decode request body: %v", err)
|
||||||
|
}
|
||||||
|
return &http.Response{
|
||||||
|
StatusCode: http.StatusOK,
|
||||||
|
Header: http.Header{"Content-Type": []string{"application/json"}},
|
||||||
|
Body: io.NopCloser(strings.NewReader(`{"ok":true,"result":{"message_id":9,"date":1}}`)),
|
||||||
|
}, nil
|
||||||
|
}),
|
||||||
|
}
|
||||||
|
|
||||||
|
api := tgapi.NewAPI(
|
||||||
|
tgapi.NewAPIOpts("token").
|
||||||
|
SetAPIUrl("https://example.test").
|
||||||
|
SetHTTPClient(client),
|
||||||
|
)
|
||||||
|
defer func() {
|
||||||
|
if err := api.CloseApi(); err != nil {
|
||||||
|
t.Fatalf("CloseApi returned error: %v", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
ctx := &MsgContext{
|
||||||
|
Api: api,
|
||||||
|
Msg: &tgapi.Message{
|
||||||
|
Chat: &tgapi.Chat{ID: 42, Type: string(tgapi.ChatTypePrivate)},
|
||||||
|
DirectMessageTopic: &tgapi.DirectMessageTopic{TopicID: 77},
|
||||||
|
},
|
||||||
|
botLogger: slog.CreateLogger(),
|
||||||
|
}
|
||||||
|
|
||||||
|
answer := ctx.AnswerPhoto("photo-id", "caption")
|
||||||
|
if answer == nil {
|
||||||
|
t.Fatal("expected answer message")
|
||||||
|
}
|
||||||
|
if answer.MessageID != 9 {
|
||||||
|
t.Fatalf("unexpected message id: %d", answer.MessageID)
|
||||||
|
}
|
||||||
|
if got := gotBody["direct_messages_topic_id"]; got != float64(77) {
|
||||||
|
t.Fatalf("unexpected direct_messages_topic_id: %v", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -23,11 +23,11 @@ const (
|
|||||||
|
|
||||||
var (
|
var (
|
||||||
// CommandRegexInt matches one or more digits.
|
// CommandRegexInt matches one or more digits.
|
||||||
CommandRegexInt = regexp.MustCompile(`\d+`)
|
CommandRegexInt = regexp.MustCompile(`^\d+$`)
|
||||||
// CommandRegexString matches any non-empty string.
|
// CommandRegexString matches any non-empty string.
|
||||||
CommandRegexString = regexp.MustCompile(`.+`)
|
CommandRegexString = regexp.MustCompile(`^.+$`)
|
||||||
// CommandRegexBool matches true or false.
|
// CommandRegexBool matches true or false.
|
||||||
CommandRegexBool = regexp.MustCompile(`true|false`)
|
CommandRegexBool = regexp.MustCompile(`^(true|false)$`)
|
||||||
)
|
)
|
||||||
|
|
||||||
// ErrCmdArgCountMismatch is returned when the number of provided arguments
|
// ErrCmdArgCountMismatch is returned when the number of provided arguments
|
||||||
@@ -64,6 +64,7 @@ func (c *CommandArg) SetValueType(t CommandValueType) *CommandArg {
|
|||||||
case CommandValueAnyType:
|
case CommandValueAnyType:
|
||||||
regex = nil // Skip validation
|
regex = nil // Skip validation
|
||||||
}
|
}
|
||||||
|
c.valueType = t
|
||||||
c.regex = regex
|
c.regex = regex
|
||||||
return c
|
return c
|
||||||
}
|
}
|
||||||
|
|||||||
24
plugins_test.go
Normal file
24
plugins_test.go
Normal file
@@ -0,0 +1,24 @@
|
|||||||
|
package laniakea
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestValidateArgsRequiresFullMatch(t *testing.T) {
|
||||||
|
intCmd := NewCommand[NoDB](func(ctx *MsgContext, db *NoDB) {}, "int", *NewCommandArg("n").SetValueType(CommandValueIntType).SetRequired())
|
||||||
|
if err := intCmd.validateArgs([]string{"123"}); err != nil {
|
||||||
|
t.Fatalf("expected valid integer argument, got %v", err)
|
||||||
|
}
|
||||||
|
if err := intCmd.validateArgs([]string{"123abc"}); !errors.Is(err, ErrCmdArgRegexpMismatch) {
|
||||||
|
t.Fatalf("expected ErrCmdArgRegexpMismatch for partial int match, got %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
boolCmd := NewCommand[NoDB](func(ctx *MsgContext, db *NoDB) {}, "bool", *NewCommandArg("flag").SetValueType(CommandValueBoolType).SetRequired())
|
||||||
|
if err := boolCmd.validateArgs([]string{"false"}); err != nil {
|
||||||
|
t.Fatalf("expected valid bool argument, got %v", err)
|
||||||
|
}
|
||||||
|
if err := boolCmd.validateArgs([]string{"falsey"}); !errors.Is(err, ErrCmdArgRegexpMismatch) {
|
||||||
|
t.Fatalf("expected ErrCmdArgRegexpMismatch for partial bool match, got %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -202,7 +202,6 @@ func (r TelegramRequest[R, P]) doRequest(ctx context.Context, api *API) (R, erro
|
|||||||
req.Header.Set("Content-Type", "application/json")
|
req.Header.Set("Content-Type", "application/json")
|
||||||
req.Header.Set("Accept", "application/json")
|
req.Header.Set("Accept", "application/json")
|
||||||
req.Header.Set("User-Agent", fmt.Sprintf("Laniakea/%s", utils.VersionString))
|
req.Header.Set("User-Agent", fmt.Sprintf("Laniakea/%s", utils.VersionString))
|
||||||
req.Header.Set("Accept-Encoding", "gzip")
|
|
||||||
|
|
||||||
for {
|
for {
|
||||||
// Apply rate limiting before making the request
|
// Apply rate limiting before making the request
|
||||||
|
|||||||
56
tgapi/api_test.go
Normal file
56
tgapi/api_test.go
Normal file
@@ -0,0 +1,56 @@
|
|||||||
|
package tgapi
|
||||||
|
|
||||||
|
import (
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
type roundTripFunc func(*http.Request) (*http.Response, error)
|
||||||
|
|
||||||
|
func (fn roundTripFunc) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||||
|
return fn(req)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAPILeavesAcceptEncodingToHTTPTransport(t *testing.T) {
|
||||||
|
var gotPath string
|
||||||
|
var gotAcceptEncoding string
|
||||||
|
|
||||||
|
client := &http.Client{
|
||||||
|
Transport: roundTripFunc(func(req *http.Request) (*http.Response, error) {
|
||||||
|
gotPath = req.URL.Path
|
||||||
|
gotAcceptEncoding = req.Header.Get("Accept-Encoding")
|
||||||
|
return &http.Response{
|
||||||
|
StatusCode: http.StatusOK,
|
||||||
|
Header: http.Header{"Content-Type": []string{"application/json"}},
|
||||||
|
Body: io.NopCloser(strings.NewReader(`{"ok":true,"result":{"id":1,"is_bot":true,"first_name":"Test"}}`)),
|
||||||
|
}, nil
|
||||||
|
}),
|
||||||
|
}
|
||||||
|
|
||||||
|
api := NewAPI(
|
||||||
|
NewAPIOpts("token").
|
||||||
|
SetAPIUrl("https://example.test").
|
||||||
|
SetHTTPClient(client),
|
||||||
|
)
|
||||||
|
defer func() {
|
||||||
|
if err := api.CloseApi(); err != nil {
|
||||||
|
t.Fatalf("CloseApi returned error: %v", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
user, err := api.GetMe()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("GetMe returned error: %v", err)
|
||||||
|
}
|
||||||
|
if user.FirstName != "Test" {
|
||||||
|
t.Fatalf("unexpected first name: %q", user.FirstName)
|
||||||
|
}
|
||||||
|
if gotPath != "/bottoken/getMe" {
|
||||||
|
t.Fatalf("unexpected request path: %s", gotPath)
|
||||||
|
}
|
||||||
|
if gotAcceptEncoding != "" {
|
||||||
|
t.Fatalf("expected empty Accept-Encoding header, got %q", gotAcceptEncoding)
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -60,6 +60,14 @@ type BusinessConnection struct {
|
|||||||
IsEnabled bool `json:"is_enabled"`
|
IsEnabled bool `json:"is_enabled"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// BusinessMessagesDeleted is received when messages are deleted from a connected business account.
|
||||||
|
// See https://core.telegram.org/bots/api#businessmessagesdeleted
|
||||||
|
type BusinessMessagesDeleted struct {
|
||||||
|
BusinessConnectionID string `json:"business_connection_id"`
|
||||||
|
Chat Chat `json:"chat"`
|
||||||
|
MessageIDs []int `json:"message_ids"`
|
||||||
|
}
|
||||||
|
|
||||||
// InputStoryContentType indicates the type of input story content.
|
// InputStoryContentType indicates the type of input story content.
|
||||||
type InputStoryContentType string
|
type InputStoryContentType string
|
||||||
|
|
||||||
|
|||||||
@@ -144,7 +144,7 @@ type LinkPreviewOptions struct {
|
|||||||
type ReplyMarkup struct {
|
type ReplyMarkup struct {
|
||||||
InlineKeyboard [][]InlineKeyboardButton `json:"inline_keyboard,omitempty"`
|
InlineKeyboard [][]InlineKeyboardButton `json:"inline_keyboard,omitempty"`
|
||||||
|
|
||||||
Keyboard [][]int `json:"keyboard,omitempty"`
|
Keyboard [][]KeyboardButton `json:"keyboard,omitempty"`
|
||||||
IsPersistent bool `json:"is_persistent,omitempty"`
|
IsPersistent bool `json:"is_persistent,omitempty"`
|
||||||
ResizeKeyboard bool `json:"resize_keyboard,omitempty"`
|
ResizeKeyboard bool `json:"resize_keyboard,omitempty"`
|
||||||
OneTimeKeyboard bool `json:"one_time_keyboard,omitempty"`
|
OneTimeKeyboard bool `json:"one_time_keyboard,omitempty"`
|
||||||
@@ -165,6 +165,60 @@ type InlineKeyboardMarkup struct {
|
|||||||
// KeyboardButtonStyle represents the style of a keyboard button.
|
// KeyboardButtonStyle represents the style of a keyboard button.
|
||||||
type KeyboardButtonStyle string
|
type KeyboardButtonStyle string
|
||||||
|
|
||||||
|
const (
|
||||||
|
KeyboardButtonStyleDanger KeyboardButtonStyle = "danger"
|
||||||
|
KeyboardButtonStyleSuccess KeyboardButtonStyle = "success"
|
||||||
|
KeyboardButtonStylePrimary KeyboardButtonStyle = "primary"
|
||||||
|
)
|
||||||
|
|
||||||
|
// KeyboardButton represents one button of the reply keyboard.
|
||||||
|
// See https://core.telegram.org/bots/api#keyboardbutton
|
||||||
|
type KeyboardButton struct {
|
||||||
|
Text string `json:"text"`
|
||||||
|
IconCustomEmojiID string `json:"icon_custom_emoji_id,omitempty"`
|
||||||
|
Style KeyboardButtonStyle `json:"style,omitempty"`
|
||||||
|
RequestUsers *KeyboardButtonRequestUsers `json:"request_users,omitempty"`
|
||||||
|
RequestChat *KeyboardButtonRequestChat `json:"request_chat,omitempty"`
|
||||||
|
RequestContact bool `json:"request_contact,omitempty"`
|
||||||
|
RequestLocation bool `json:"request_location,omitempty"`
|
||||||
|
RequestPoll *KeyboardButtonPollType `json:"request_poll,omitempty"`
|
||||||
|
WebApp *WebAppInfo `json:"web_app,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// KeyboardButtonRequestUsers defines criteria used to request suitable users.
|
||||||
|
// See https://core.telegram.org/bots/api#keyboardbuttonrequestusers
|
||||||
|
type KeyboardButtonRequestUsers struct {
|
||||||
|
RequestID int `json:"request_id"`
|
||||||
|
UserIsBot *bool `json:"user_is_bot,omitempty"`
|
||||||
|
UserIsPremium *bool `json:"user_is_premium,omitempty"`
|
||||||
|
MaxQuantity int `json:"max_quantity,omitempty"`
|
||||||
|
RequestName bool `json:"request_name,omitempty"`
|
||||||
|
RequestUsername bool `json:"request_username,omitempty"`
|
||||||
|
RequestPhoto bool `json:"request_photo,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// KeyboardButtonRequestChat defines criteria used to request a suitable chat.
|
||||||
|
// See https://core.telegram.org/bots/api#keyboardbuttonrequestchat
|
||||||
|
type KeyboardButtonRequestChat struct {
|
||||||
|
RequestID int `json:"request_id"`
|
||||||
|
ChatIsChannel bool `json:"chat_is_channel"`
|
||||||
|
ChatIsForum *bool `json:"chat_is_forum,omitempty"`
|
||||||
|
ChatHasUsername *bool `json:"chat_has_username,omitempty"`
|
||||||
|
ChatIsCreated *bool `json:"chat_is_created,omitempty"`
|
||||||
|
UserAdministratorRights *ChatAdministratorRights `json:"user_administrator_rights,omitempty"`
|
||||||
|
BotAdministratorRights *ChatAdministratorRights `json:"bot_administrator_rights,omitempty"`
|
||||||
|
BotIsMember bool `json:"bot_is_member,omitempty"`
|
||||||
|
RequestTitle bool `json:"request_title,omitempty"`
|
||||||
|
RequestUsername bool `json:"request_username,omitempty"`
|
||||||
|
RequestPhoto bool `json:"request_photo,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// KeyboardButtonPollType represents the type of a poll that may be created from a keyboard button.
|
||||||
|
// See https://core.telegram.org/bots/api#keyboardbuttonpolltype
|
||||||
|
type KeyboardButtonPollType struct {
|
||||||
|
Type PollType `json:"type,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
// InlineKeyboardButton represents one button of an inline keyboard.
|
// InlineKeyboardButton represents one button of an inline keyboard.
|
||||||
// See https://core.telegram.org/bots/api#inlinekeyboardbutton
|
// See https://core.telegram.org/bots/api#inlinekeyboardbutton
|
||||||
type InlineKeyboardButton struct {
|
type InlineKeyboardButton struct {
|
||||||
@@ -178,7 +232,12 @@ type InlineKeyboardButton struct {
|
|||||||
// ReplyKeyboardMarkup represents a custom keyboard with reply options.
|
// ReplyKeyboardMarkup represents a custom keyboard with reply options.
|
||||||
// See https://core.telegram.org/bots/api#replykeyboardmarkup
|
// See https://core.telegram.org/bots/api#replykeyboardmarkup
|
||||||
type ReplyKeyboardMarkup struct {
|
type ReplyKeyboardMarkup struct {
|
||||||
Keyboard [][]int `json:"keyboard"`
|
Keyboard [][]KeyboardButton `json:"keyboard"`
|
||||||
|
IsPersistent bool `json:"is_persistent,omitempty"`
|
||||||
|
ResizeKeyboard bool `json:"resize_keyboard,omitempty"`
|
||||||
|
OneTimeKeyboard bool `json:"one_time_keyboard,omitempty"`
|
||||||
|
InputFieldPlaceholder string `json:"input_field_placeholder,omitempty"`
|
||||||
|
Selective bool `json:"selective,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// CallbackQuery represents an incoming callback query from a callback button in an inline keyboard.
|
// CallbackQuery represents an incoming callback query from a callback button in an inline keyboard.
|
||||||
@@ -238,7 +297,8 @@ const (
|
|||||||
ChatActionUploadDocument ChatActionType = "upload_document"
|
ChatActionUploadDocument ChatActionType = "upload_document"
|
||||||
ChatActionChooseSticker ChatActionType = "choose_sticker"
|
ChatActionChooseSticker ChatActionType = "choose_sticker"
|
||||||
ChatActionFindLocation ChatActionType = "find_location"
|
ChatActionFindLocation ChatActionType = "find_location"
|
||||||
ChatActionUploadVideoNone ChatActionType = "upload_video_none"
|
ChatActionUploadVideoNote ChatActionType = "upload_video_note"
|
||||||
|
ChatActionUploadVideoNone ChatActionType = ChatActionUploadVideoNote
|
||||||
)
|
)
|
||||||
|
|
||||||
// MessageReactionUpdated represents a change of a reaction on a message.
|
// MessageReactionUpdated represents a change of a reaction on a message.
|
||||||
|
|||||||
37
tgapi/messages_types_test.go
Normal file
37
tgapi/messages_types_test.go
Normal file
@@ -0,0 +1,37 @@
|
|||||||
|
package tgapi
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestReplyKeyboardMarkupMarshalsKeyboardButtons(t *testing.T) {
|
||||||
|
markup := ReplyKeyboardMarkup{
|
||||||
|
Keyboard: [][]KeyboardButton{{
|
||||||
|
{
|
||||||
|
Text: "Create poll",
|
||||||
|
RequestPoll: &KeyboardButtonPollType{Type: PollTypeQuiz},
|
||||||
|
},
|
||||||
|
}},
|
||||||
|
}
|
||||||
|
|
||||||
|
data, err := json.Marshal(markup)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Marshal returned error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
got := string(data)
|
||||||
|
if !strings.Contains(got, `"keyboard":[[{"text":"Create poll","request_poll":{"type":"quiz"}}]]`) {
|
||||||
|
t.Fatalf("unexpected reply keyboard JSON: %s", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestChatActionUploadVideoNoteValue(t *testing.T) {
|
||||||
|
if ChatActionUploadVideoNote != "upload_video_note" {
|
||||||
|
t.Fatalf("unexpected chat action value: %q", ChatActionUploadVideoNote)
|
||||||
|
}
|
||||||
|
if ChatActionUploadVideoNone != ChatActionUploadVideoNote {
|
||||||
|
t.Fatalf("expected deprecated alias to match upload_video_note, got %q", ChatActionUploadVideoNone)
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -1,9 +1,12 @@
|
|||||||
package tgapi
|
package tgapi
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
|
||||||
|
"git.nix13.pw/scuroneko/laniakea/utils"
|
||||||
)
|
)
|
||||||
|
|
||||||
// UpdateParams holds parameters for the getUpdates method.
|
// UpdateParams holds parameters for the getUpdates method.
|
||||||
@@ -12,7 +15,7 @@ type UpdateParams struct {
|
|||||||
Offset *int `json:"offset,omitempty"`
|
Offset *int `json:"offset,omitempty"`
|
||||||
Limit *int `json:"limit,omitempty"`
|
Limit *int `json:"limit,omitempty"`
|
||||||
Timeout *int `json:"timeout,omitempty"`
|
Timeout *int `json:"timeout,omitempty"`
|
||||||
AllowedUpdates []UpdateType `json:"allowed_updates"`
|
AllowedUpdates []UpdateType `json:"allowed_updates,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetMe returns basic information about the bot.
|
// GetMe returns basic information about the bot.
|
||||||
@@ -103,13 +106,31 @@ func (api *API) GetFile(params GetFileP) (File, error) {
|
|||||||
// The link is usually obtained from File.FilePath.
|
// The link is usually obtained from File.FilePath.
|
||||||
// See https://core.telegram.org/bots/api#file
|
// See https://core.telegram.org/bots/api#file
|
||||||
func (api *API) GetFileByLink(link string) ([]byte, error) {
|
func (api *API) GetFileByLink(link string) ([]byte, error) {
|
||||||
u := fmt.Sprintf("https://api.telegram.org/file/bot%s/%s", api.token, link)
|
methodPrefix := ""
|
||||||
res, err := http.Get(u)
|
if api.useTestServer {
|
||||||
|
methodPrefix = "/test"
|
||||||
|
}
|
||||||
|
u := fmt.Sprintf("%s/file/bot%s%s/%s", api.apiUrl, api.token, methodPrefix, link)
|
||||||
|
|
||||||
|
req, err := http.NewRequestWithContext(context.Background(), http.MethodGet, u, nil)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
req.Header.Set("User-Agent", fmt.Sprintf("Laniakea/%s", utils.VersionString))
|
||||||
|
|
||||||
|
res, err := api.client.Do(req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
defer func() {
|
defer func() {
|
||||||
_ = res.Body.Close()
|
_ = res.Body.Close()
|
||||||
}()
|
}()
|
||||||
|
if res.StatusCode < http.StatusOK || res.StatusCode >= http.StatusMultipleChoices {
|
||||||
|
body, readErr := io.ReadAll(io.LimitReader(res.Body, 4<<10))
|
||||||
|
if readErr != nil {
|
||||||
|
return nil, fmt.Errorf("unexpected status %d", res.StatusCode)
|
||||||
|
}
|
||||||
|
return nil, fmt.Errorf("unexpected status %d: %s", res.StatusCode, string(body))
|
||||||
|
}
|
||||||
return io.ReadAll(res.Body)
|
return io.ReadAll(res.Body)
|
||||||
}
|
}
|
||||||
|
|||||||
115
tgapi/methods_test.go
Normal file
115
tgapi/methods_test.go
Normal file
@@ -0,0 +1,115 @@
|
|||||||
|
package tgapi
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestGetFileByLinkUsesConfiguredAPIURL(t *testing.T) {
|
||||||
|
var gotPath string
|
||||||
|
|
||||||
|
client := &http.Client{
|
||||||
|
Transport: roundTripFunc(func(req *http.Request) (*http.Response, error) {
|
||||||
|
gotPath = req.URL.Path
|
||||||
|
return &http.Response{
|
||||||
|
StatusCode: http.StatusOK,
|
||||||
|
Body: io.NopCloser(strings.NewReader("payload")),
|
||||||
|
}, nil
|
||||||
|
}),
|
||||||
|
}
|
||||||
|
|
||||||
|
api := NewAPI(
|
||||||
|
NewAPIOpts("token").
|
||||||
|
SetAPIUrl("https://example.test").
|
||||||
|
SetHTTPClient(client),
|
||||||
|
)
|
||||||
|
defer func() {
|
||||||
|
if err := api.CloseApi(); err != nil {
|
||||||
|
t.Fatalf("CloseApi returned error: %v", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
data, err := api.GetFileByLink("files/report.txt")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("GetFileByLink returned error: %v", err)
|
||||||
|
}
|
||||||
|
if string(data) != "payload" {
|
||||||
|
t.Fatalf("unexpected payload: %q", string(data))
|
||||||
|
}
|
||||||
|
if gotPath != "/file/bottoken/files/report.txt" {
|
||||||
|
t.Fatalf("unexpected request path: %s", gotPath)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetFileByLinkReturnsHTTPStatusError(t *testing.T) {
|
||||||
|
client := &http.Client{
|
||||||
|
Transport: roundTripFunc(func(req *http.Request) (*http.Response, error) {
|
||||||
|
return &http.Response{
|
||||||
|
StatusCode: http.StatusNotFound,
|
||||||
|
Body: io.NopCloser(strings.NewReader("missing\n")),
|
||||||
|
}, nil
|
||||||
|
}),
|
||||||
|
}
|
||||||
|
|
||||||
|
api := NewAPI(
|
||||||
|
NewAPIOpts("token").
|
||||||
|
SetAPIUrl("https://example.test").
|
||||||
|
SetHTTPClient(client),
|
||||||
|
)
|
||||||
|
defer func() {
|
||||||
|
if err := api.CloseApi(); err != nil {
|
||||||
|
t.Fatalf("CloseApi returned error: %v", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
_, err := api.GetFileByLink("files/report.txt")
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected error for non-2xx response")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetUpdatesOmitsAllowedUpdatesWhenEmpty(t *testing.T) {
|
||||||
|
var gotBody map[string]any
|
||||||
|
|
||||||
|
client := &http.Client{
|
||||||
|
Transport: roundTripFunc(func(req *http.Request) (*http.Response, error) {
|
||||||
|
body, err := io.ReadAll(req.Body)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to read request body: %v", err)
|
||||||
|
}
|
||||||
|
if err := json.Unmarshal(body, &gotBody); err != nil {
|
||||||
|
t.Fatalf("failed to decode request body: %v", err)
|
||||||
|
}
|
||||||
|
return &http.Response{
|
||||||
|
StatusCode: http.StatusOK,
|
||||||
|
Header: http.Header{"Content-Type": []string{"application/json"}},
|
||||||
|
Body: io.NopCloser(strings.NewReader(`{"ok":true,"result":[]}`)),
|
||||||
|
}, nil
|
||||||
|
}),
|
||||||
|
}
|
||||||
|
|
||||||
|
api := NewAPI(
|
||||||
|
NewAPIOpts("token").
|
||||||
|
SetAPIUrl("https://example.test").
|
||||||
|
SetHTTPClient(client),
|
||||||
|
)
|
||||||
|
defer func() {
|
||||||
|
if err := api.CloseApi(); err != nil {
|
||||||
|
t.Fatalf("CloseApi returned error: %v", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
updates, err := api.GetUpdates(UpdateParams{})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("GetUpdates returned error: %v", err)
|
||||||
|
}
|
||||||
|
if len(updates) != 0 {
|
||||||
|
t.Fatalf("expected no updates, got %d", len(updates))
|
||||||
|
}
|
||||||
|
if _, exists := gotBody["allowed_updates"]; exists {
|
||||||
|
t.Fatalf("expected allowed_updates to be omitted, got %v", gotBody["allowed_updates"])
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -1,5 +1,7 @@
|
|||||||
package tgapi
|
package tgapi
|
||||||
|
|
||||||
|
import "encoding/json"
|
||||||
|
|
||||||
// UpdateType represents the type of incoming update.
|
// UpdateType represents the type of incoming update.
|
||||||
type UpdateType string
|
type UpdateType string
|
||||||
|
|
||||||
@@ -23,8 +25,10 @@ const (
|
|||||||
UpdateTypeBusinessMessage UpdateType = "business_message"
|
UpdateTypeBusinessMessage UpdateType = "business_message"
|
||||||
// UpdateTypeEditedBusinessMessage is an edited business message update.
|
// UpdateTypeEditedBusinessMessage is an edited business message update.
|
||||||
UpdateTypeEditedBusinessMessage UpdateType = "edited_business_message"
|
UpdateTypeEditedBusinessMessage UpdateType = "edited_business_message"
|
||||||
// UpdateTypeDeletedBusinessMessage is a deleted business message update.
|
// UpdateTypeDeletedBusinessMessages is a deleted business messages update.
|
||||||
UpdateTypeDeletedBusinessMessage UpdateType = "deleted_business_message"
|
UpdateTypeDeletedBusinessMessages UpdateType = "deleted_business_messages"
|
||||||
|
// UpdateTypeDeletedBusinessMessage is kept as a backward-compatible alias.
|
||||||
|
UpdateTypeDeletedBusinessMessage UpdateType = UpdateTypeDeletedBusinessMessages
|
||||||
|
|
||||||
// UpdateTypeInlineQuery is an inline query update.
|
// UpdateTypeInlineQuery is an inline query update.
|
||||||
UpdateTypeInlineQuery UpdateType = "inline_query"
|
UpdateTypeInlineQuery UpdateType = "inline_query"
|
||||||
@@ -66,14 +70,15 @@ type Update struct {
|
|||||||
BusinessConnection *BusinessConnection `json:"business_connection,omitempty"`
|
BusinessConnection *BusinessConnection `json:"business_connection,omitempty"`
|
||||||
BusinessMessage *Message `json:"business_message,omitempty"`
|
BusinessMessage *Message `json:"business_message,omitempty"`
|
||||||
EditedBusinessMessage *Message `json:"edited_business_message,omitempty"`
|
EditedBusinessMessage *Message `json:"edited_business_message,omitempty"`
|
||||||
DeletedBusinessMessage *Message `json:"deleted_business_messages,omitempty"`
|
DeletedBusinessMessages *BusinessMessagesDeleted `json:"deleted_business_messages,omitempty"`
|
||||||
|
DeletedBusinessMessage *BusinessMessagesDeleted `json:"-"`
|
||||||
MessageReaction *MessageReactionUpdated `json:"message_reaction,omitempty"`
|
MessageReaction *MessageReactionUpdated `json:"message_reaction,omitempty"`
|
||||||
MessageReactionCount *MessageReactionCountUpdated `json:"message_reaction_count,omitempty"`
|
MessageReactionCount *MessageReactionCountUpdated `json:"message_reaction_count,omitempty"`
|
||||||
|
|
||||||
InlineQuery *InlineQuery `json:"inline_query,omitempty"`
|
InlineQuery *InlineQuery `json:"inline_query,omitempty"`
|
||||||
ChosenInlineResult *ChosenInlineResult `json:"chosen_inline_result,omitempty"`
|
ChosenInlineResult *ChosenInlineResult `json:"chosen_inline_result,omitempty"`
|
||||||
CallbackQuery *CallbackQuery `json:"callback_query,omitempty"`
|
CallbackQuery *CallbackQuery `json:"callback_query,omitempty"`
|
||||||
ShippingQuery ShippingQuery `json:"shipping_query,omitempty"`
|
ShippingQuery *ShippingQuery `json:"shipping_query,omitempty"`
|
||||||
PreCheckoutQuery *PreCheckoutQuery `json:"pre_checkout_query,omitempty"`
|
PreCheckoutQuery *PreCheckoutQuery `json:"pre_checkout_query,omitempty"`
|
||||||
PurchasedPaidMedia *PaidMediaPurchased `json:"purchased_paid_media,omitempty"`
|
PurchasedPaidMedia *PaidMediaPurchased `json:"purchased_paid_media,omitempty"`
|
||||||
|
|
||||||
@@ -86,6 +91,35 @@ type Update struct {
|
|||||||
RemovedChatBoost *ChatBoostRemoved `json:"removed_chat_boost,omitempty"`
|
RemovedChatBoost *ChatBoostRemoved `json:"removed_chat_boost,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (u *Update) syncDeletedBusinessMessages() {
|
||||||
|
if u.DeletedBusinessMessages != nil {
|
||||||
|
u.DeletedBusinessMessage = u.DeletedBusinessMessages
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if u.DeletedBusinessMessage != nil {
|
||||||
|
u.DeletedBusinessMessages = u.DeletedBusinessMessage
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// UnmarshalJSON keeps the deprecated DeletedBusinessMessage alias in sync.
|
||||||
|
func (u *Update) UnmarshalJSON(data []byte) error {
|
||||||
|
type alias Update
|
||||||
|
var aux alias
|
||||||
|
if err := json.Unmarshal(data, &aux); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
*u = Update(aux)
|
||||||
|
u.syncDeletedBusinessMessages()
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// MarshalJSON emits the canonical deleted_business_messages field.
|
||||||
|
func (u Update) MarshalJSON() ([]byte, error) {
|
||||||
|
u.syncDeletedBusinessMessages()
|
||||||
|
type alias Update
|
||||||
|
return json.Marshal(alias(u))
|
||||||
|
}
|
||||||
|
|
||||||
// InlineQuery represents an incoming inline query.
|
// InlineQuery represents an incoming inline query.
|
||||||
// See https://core.telegram.org/bots/api#inlinequery
|
// See https://core.telegram.org/bots/api#inlinequery
|
||||||
type InlineQuery struct {
|
type InlineQuery struct {
|
||||||
|
|||||||
69
tgapi/types_test.go
Normal file
69
tgapi/types_test.go
Normal file
@@ -0,0 +1,69 @@
|
|||||||
|
package tgapi
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestUpdateDeletedBusinessMessagesUnmarshalSetsAlias(t *testing.T) {
|
||||||
|
var update Update
|
||||||
|
err := json.Unmarshal([]byte(`{
|
||||||
|
"update_id": 1,
|
||||||
|
"deleted_business_messages": {
|
||||||
|
"business_connection_id": "conn",
|
||||||
|
"chat": {"id": 42, "type": "private"},
|
||||||
|
"message_ids": [3, 5]
|
||||||
|
}
|
||||||
|
}`), &update)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Unmarshal returned error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if update.DeletedBusinessMessages == nil {
|
||||||
|
t.Fatal("expected DeletedBusinessMessages to be populated")
|
||||||
|
}
|
||||||
|
if update.DeletedBusinessMessage == nil {
|
||||||
|
t.Fatal("expected deprecated DeletedBusinessMessage alias to be populated")
|
||||||
|
}
|
||||||
|
if update.DeletedBusinessMessages != update.DeletedBusinessMessage {
|
||||||
|
t.Fatal("expected deleted business message fields to share the same payload")
|
||||||
|
}
|
||||||
|
if got := update.DeletedBusinessMessages.MessageIDs; len(got) != 2 || got[0] != 3 || got[1] != 5 {
|
||||||
|
t.Fatalf("unexpected message ids: %v", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUpdateMarshalUsesCanonicalDeletedBusinessMessagesField(t *testing.T) {
|
||||||
|
update := Update{
|
||||||
|
UpdateID: 1,
|
||||||
|
DeletedBusinessMessage: &BusinessMessagesDeleted{
|
||||||
|
BusinessConnectionID: "conn",
|
||||||
|
Chat: Chat{ID: 42, Type: string(ChatTypePrivate)},
|
||||||
|
MessageIDs: []int{7},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
data, err := json.Marshal(update)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Marshal returned error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
got := string(data)
|
||||||
|
if !strings.Contains(got, `"deleted_business_messages"`) {
|
||||||
|
t.Fatalf("expected canonical deleted_business_messages field, got %s", got)
|
||||||
|
}
|
||||||
|
if strings.Contains(got, `"deleted_business_message"`) {
|
||||||
|
t.Fatalf("unexpected singular deleted_business_message field, got %s", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUpdateShippingQueryIsNilWhenAbsent(t *testing.T) {
|
||||||
|
var update Update
|
||||||
|
if err := json.Unmarshal([]byte(`{"update_id":1}`), &update); err != nil {
|
||||||
|
t.Fatalf("Unmarshal returned error: %v", err)
|
||||||
|
}
|
||||||
|
if update.ShippingQuery != nil {
|
||||||
|
t.Fatalf("expected ShippingQuery to be nil, got %+v", update.ShippingQuery)
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -126,7 +126,6 @@ func (r UploaderRequest[R, P]) doRequest(ctx context.Context, up *Uploader) (R,
|
|||||||
req.Header.Set("Content-Type", contentType)
|
req.Header.Set("Content-Type", contentType)
|
||||||
req.Header.Set("Accept", "application/json")
|
req.Header.Set("Accept", "application/json")
|
||||||
req.Header.Set("User-Agent", fmt.Sprintf("Laniakea/%s", utils.VersionString))
|
req.Header.Set("User-Agent", fmt.Sprintf("Laniakea/%s", utils.VersionString))
|
||||||
req.Header.Set("Accept-Encoding", "gzip")
|
|
||||||
req.ContentLength = int64(buf.Len())
|
req.ContentLength = int64(buf.Len())
|
||||||
|
|
||||||
up.logger.Debugln("UPLOADER REQ", r.method)
|
up.logger.Debugln("UPLOADER REQ", r.method)
|
||||||
|
|||||||
138
tgapi/uploader_api_test.go
Normal file
138
tgapi/uploader_api_test.go
Normal file
@@ -0,0 +1,138 @@
|
|||||||
|
package tgapi
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"mime"
|
||||||
|
"mime/multipart"
|
||||||
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestUploaderEncodesJSONFieldsAndLeavesAcceptEncodingToHTTPTransport(t *testing.T) {
|
||||||
|
var (
|
||||||
|
gotPath string
|
||||||
|
gotAcceptEncoding string
|
||||||
|
gotFields map[string]string
|
||||||
|
gotFileName string
|
||||||
|
gotFileData []byte
|
||||||
|
roundTripErr error
|
||||||
|
)
|
||||||
|
|
||||||
|
client := &http.Client{
|
||||||
|
Transport: roundTripFunc(func(req *http.Request) (*http.Response, error) {
|
||||||
|
gotPath = req.URL.Path
|
||||||
|
gotAcceptEncoding = req.Header.Get("Accept-Encoding")
|
||||||
|
|
||||||
|
gotFields, gotFileName, gotFileData, roundTripErr = readMultipartRequest(req)
|
||||||
|
if roundTripErr != nil {
|
||||||
|
roundTripErr = fmt.Errorf("readMultipartRequest: %w", roundTripErr)
|
||||||
|
}
|
||||||
|
|
||||||
|
return &http.Response{
|
||||||
|
StatusCode: http.StatusOK,
|
||||||
|
Header: http.Header{"Content-Type": []string{"application/json"}},
|
||||||
|
Body: io.NopCloser(strings.NewReader(`{"ok":true,"result":{"message_id":5,"date":1}}`)),
|
||||||
|
}, nil
|
||||||
|
}),
|
||||||
|
}
|
||||||
|
|
||||||
|
api := NewAPI(
|
||||||
|
NewAPIOpts("token").
|
||||||
|
SetAPIUrl("https://example.test").
|
||||||
|
SetHTTPClient(client),
|
||||||
|
)
|
||||||
|
defer func() {
|
||||||
|
if err := api.CloseApi(); err != nil {
|
||||||
|
t.Fatalf("CloseApi returned error: %v", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
uploader := NewUploader(api)
|
||||||
|
defer func() {
|
||||||
|
if err := uploader.Close(); err != nil {
|
||||||
|
t.Fatalf("Close returned error: %v", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
msg, err := uploader.SendPhoto(
|
||||||
|
UploadPhotoP{
|
||||||
|
ChatID: 42,
|
||||||
|
CaptionEntities: []MessageEntity{{
|
||||||
|
Type: MessageEntityBold,
|
||||||
|
Offset: 0,
|
||||||
|
Length: 4,
|
||||||
|
}},
|
||||||
|
ReplyMarkup: &ReplyMarkup{
|
||||||
|
InlineKeyboard: [][]InlineKeyboardButton{{
|
||||||
|
{Text: "A", CallbackData: "b"},
|
||||||
|
}},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
NewUploaderFile("photo.jpg", []byte("img")),
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("SendPhoto returned error: %v", err)
|
||||||
|
}
|
||||||
|
if msg.MessageID != 5 {
|
||||||
|
t.Fatalf("unexpected message id: %d", msg.MessageID)
|
||||||
|
}
|
||||||
|
if roundTripErr != nil {
|
||||||
|
t.Fatalf("multipart parse failed: %v", roundTripErr)
|
||||||
|
}
|
||||||
|
if gotPath != "/bottoken/sendPhoto" {
|
||||||
|
t.Fatalf("unexpected request path: %s", gotPath)
|
||||||
|
}
|
||||||
|
if gotAcceptEncoding != "" {
|
||||||
|
t.Fatalf("expected empty Accept-Encoding header, got %q", gotAcceptEncoding)
|
||||||
|
}
|
||||||
|
if got := gotFields["chat_id"]; got != "42" {
|
||||||
|
t.Fatalf("chat_id mismatch: %q", got)
|
||||||
|
}
|
||||||
|
if got := gotFields["caption_entities"]; got != `[{"type":"bold","offset":0,"length":4}]` {
|
||||||
|
t.Fatalf("caption_entities mismatch: %q", got)
|
||||||
|
}
|
||||||
|
if got := gotFields["reply_markup"]; got != `{"inline_keyboard":[[{"text":"A","callback_data":"b"}]]}` {
|
||||||
|
t.Fatalf("reply_markup mismatch: %q", got)
|
||||||
|
}
|
||||||
|
if gotFileName != "photo.jpg" {
|
||||||
|
t.Fatalf("unexpected file name: %q", gotFileName)
|
||||||
|
}
|
||||||
|
if string(gotFileData) != "img" {
|
||||||
|
t.Fatalf("unexpected file content: %q", string(gotFileData))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func readMultipartRequest(req *http.Request) (map[string]string, string, []byte, error) {
|
||||||
|
_, params, err := mime.ParseMediaType(req.Header.Get("Content-Type"))
|
||||||
|
if err != nil {
|
||||||
|
return nil, "", nil, err
|
||||||
|
}
|
||||||
|
reader := multipart.NewReader(req.Body, params["boundary"])
|
||||||
|
|
||||||
|
fields := make(map[string]string)
|
||||||
|
var fileName string
|
||||||
|
var fileData []byte
|
||||||
|
for {
|
||||||
|
part, err := reader.NextPart()
|
||||||
|
if err == io.EOF {
|
||||||
|
return fields, fileName, fileData, nil
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
return nil, "", nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
data, err := io.ReadAll(part)
|
||||||
|
if err != nil {
|
||||||
|
return nil, "", nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if part.FileName() != "" {
|
||||||
|
fileName = part.FileName()
|
||||||
|
fileData = data
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
fields[part.FormName()] = string(data)
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -1,6 +1,7 @@
|
|||||||
package utils
|
package utils
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"mime/multipart"
|
"mime/multipart"
|
||||||
@@ -12,12 +13,8 @@ import (
|
|||||||
|
|
||||||
// Encode writes struct fields into multipart form-data using json tags as field names.
|
// Encode writes struct fields into multipart form-data using json tags as field names.
|
||||||
func Encode[T any](w *multipart.Writer, req T) error {
|
func Encode[T any](w *multipart.Writer, req T) error {
|
||||||
v := reflect.ValueOf(req)
|
v := unwrapMultipartValue(reflect.ValueOf(req))
|
||||||
if v.Kind() == reflect.Ptr {
|
if !v.IsValid() || v.Kind() != reflect.Struct {
|
||||||
v = v.Elem()
|
|
||||||
}
|
|
||||||
|
|
||||||
if v.Kind() != reflect.Struct {
|
|
||||||
return fmt.Errorf("req must be a struct")
|
return fmt.Errorf("req must be a struct")
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -33,6 +30,9 @@ func Encode[T any](w *multipart.Writer, req T) error {
|
|||||||
|
|
||||||
parts := strings.Split(jsonTag, ",")
|
parts := strings.Split(jsonTag, ",")
|
||||||
fieldName := parts[0]
|
fieldName := parts[0]
|
||||||
|
if fieldName == "" {
|
||||||
|
fieldName = fieldType.Name
|
||||||
|
}
|
||||||
if fieldName == "-" {
|
if fieldName == "-" {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
@@ -43,96 +43,73 @@ func Encode[T any](w *multipart.Writer, req T) error {
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
var (
|
if err := writeMultipartField(w, fieldName, fieldType.Tag.Get("filename"), field); err != nil {
|
||||||
fw io.Writer
|
|
||||||
err error
|
|
||||||
)
|
|
||||||
|
|
||||||
switch field.Kind() {
|
|
||||||
case reflect.String:
|
|
||||||
fw, err = w.CreateFormField(fieldName)
|
|
||||||
if err == nil {
|
|
||||||
_, err = fw.Write([]byte(field.String()))
|
|
||||||
}
|
|
||||||
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
|
|
||||||
fw, err = w.CreateFormField(fieldName)
|
|
||||||
if err == nil {
|
|
||||||
_, err = fw.Write([]byte(strconv.FormatInt(field.Int(), 10)))
|
|
||||||
}
|
|
||||||
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
|
|
||||||
fw, err = w.CreateFormField(fieldName)
|
|
||||||
if err == nil {
|
|
||||||
_, err = fw.Write([]byte(strconv.FormatUint(field.Uint(), 10)))
|
|
||||||
}
|
|
||||||
case reflect.Float32:
|
|
||||||
fw, err = w.CreateFormField(fieldName)
|
|
||||||
if err == nil {
|
|
||||||
_, err = fw.Write([]byte(strconv.FormatFloat(field.Float(), 'f', -1, 32)))
|
|
||||||
}
|
|
||||||
case reflect.Float64:
|
|
||||||
fw, err = w.CreateFormField(fieldName)
|
|
||||||
if err == nil {
|
|
||||||
_, err = fw.Write([]byte(strconv.FormatFloat(field.Float(), 'f', -1, 64)))
|
|
||||||
}
|
|
||||||
|
|
||||||
case reflect.Bool:
|
|
||||||
fw, err = w.CreateFormField(fieldName)
|
|
||||||
if err == nil {
|
|
||||||
_, err = fw.Write([]byte(strconv.FormatBool(field.Bool())))
|
|
||||||
}
|
|
||||||
case reflect.Slice:
|
|
||||||
if field.Type().Elem().Kind() == reflect.Uint8 && !field.IsNil() {
|
|
||||||
// Handle []byte as file upload (e.g., thumbnail)
|
|
||||||
filename := fieldType.Tag.Get("filename")
|
|
||||||
if filename == "" {
|
|
||||||
filename = fieldName
|
|
||||||
}
|
|
||||||
fw, err = w.CreateFormFile(fieldName, filename)
|
|
||||||
if err == nil {
|
|
||||||
_, err = fw.Write(field.Bytes())
|
|
||||||
}
|
|
||||||
} else if !field.IsNil() {
|
|
||||||
// Handle []string, []int, etc. — send as multiple fields with same name
|
|
||||||
for j := 0; j < field.Len(); j++ {
|
|
||||||
elem := field.Index(j)
|
|
||||||
fw, err = w.CreateFormField(fieldName)
|
|
||||||
if err != nil {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
switch elem.Kind() {
|
|
||||||
case reflect.String:
|
|
||||||
_, err = fw.Write([]byte(elem.String()))
|
|
||||||
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
|
|
||||||
_, err = fw.Write([]byte(strconv.FormatInt(elem.Int(), 10)))
|
|
||||||
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
|
|
||||||
_, err = fw.Write([]byte(strconv.FormatUint(elem.Uint(), 10)))
|
|
||||||
case reflect.Bool:
|
|
||||||
_, err = fw.Write([]byte(strconv.FormatBool(elem.Bool())))
|
|
||||||
case reflect.Float32:
|
|
||||||
_, err = fw.Write([]byte(strconv.FormatFloat(elem.Float(), 'f', -1, 32)))
|
|
||||||
case reflect.Float64:
|
|
||||||
_, err = fw.Write([]byte(strconv.FormatFloat(elem.Float(), 'f', -1, 64)))
|
|
||||||
default:
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if err != nil {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
case reflect.Struct:
|
|
||||||
// Don't serialize structs as JSON — flatten them!
|
|
||||||
// Telegram doesn't support nested JSON in form-data.
|
|
||||||
// If you need nested data, use separate fields (e.g., ParseMode, CaptionEntities)
|
|
||||||
// This is a design choice — you should avoid nested structs in params.
|
|
||||||
return fmt.Errorf("nested structs are not supported in params — use flat fields")
|
|
||||||
}
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func unwrapMultipartValue(v reflect.Value) reflect.Value {
|
||||||
|
for v.IsValid() && (v.Kind() == reflect.Pointer || v.Kind() == reflect.Interface) {
|
||||||
|
if v.IsNil() {
|
||||||
|
return reflect.Value{}
|
||||||
|
}
|
||||||
|
v = v.Elem()
|
||||||
|
}
|
||||||
|
return v
|
||||||
|
}
|
||||||
|
|
||||||
|
func writeMultipartField(w *multipart.Writer, fieldName, filename string, field reflect.Value) error {
|
||||||
|
value := unwrapMultipartValue(field)
|
||||||
|
if !value.IsValid() {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
switch value.Kind() {
|
||||||
|
case reflect.String:
|
||||||
|
return writeMultipartValue(w, fieldName, []byte(value.String()))
|
||||||
|
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
|
||||||
|
return writeMultipartValue(w, fieldName, []byte(strconv.FormatInt(value.Int(), 10)))
|
||||||
|
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
|
||||||
|
return writeMultipartValue(w, fieldName, []byte(strconv.FormatUint(value.Uint(), 10)))
|
||||||
|
case reflect.Float32:
|
||||||
|
return writeMultipartValue(w, fieldName, []byte(strconv.FormatFloat(value.Float(), 'f', -1, 32)))
|
||||||
|
case reflect.Float64:
|
||||||
|
return writeMultipartValue(w, fieldName, []byte(strconv.FormatFloat(value.Float(), 'f', -1, 64)))
|
||||||
|
case reflect.Bool:
|
||||||
|
return writeMultipartValue(w, fieldName, []byte(strconv.FormatBool(value.Bool())))
|
||||||
|
case reflect.Slice:
|
||||||
|
if value.Type().Elem().Kind() == reflect.Uint8 {
|
||||||
|
if filename == "" {
|
||||||
|
filename = fieldName
|
||||||
|
}
|
||||||
|
fw, err := w.CreateFormFile(fieldName, filename)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
_, err = fw.Write(value.Bytes())
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Telegram expects nested objects and arrays in multipart requests as JSON strings.
|
||||||
|
data, err := json.Marshal(value.Interface())
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if string(data) == "null" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return writeMultipartValue(w, fieldName, data)
|
||||||
|
}
|
||||||
|
|
||||||
|
func writeMultipartValue(w *multipart.Writer, fieldName string, value []byte) error {
|
||||||
|
fw, err := w.CreateFormField(fieldName)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
_, err = io.Copy(fw, strings.NewReader(string(value)))
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|||||||
85
utils/multipart_test.go
Normal file
85
utils/multipart_test.go
Normal file
@@ -0,0 +1,85 @@
|
|||||||
|
package utils_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"io"
|
||||||
|
"mime/multipart"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"git.nix13.pw/scuroneko/laniakea/tgapi"
|
||||||
|
"git.nix13.pw/scuroneko/laniakea/utils"
|
||||||
|
)
|
||||||
|
|
||||||
|
type multipartEncodeParams struct {
|
||||||
|
ChatID int64 `json:"chat_id"`
|
||||||
|
MessageThreadID *int `json:"message_thread_id,omitempty"`
|
||||||
|
ReplyMarkup *tgapi.ReplyMarkup `json:"reply_markup,omitempty"`
|
||||||
|
CaptionEntities []tgapi.MessageEntity `json:"caption_entities,omitempty"`
|
||||||
|
ReplyParameters *tgapi.ReplyParameters `json:"reply_parameters,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestEncodeMultipartJSONFields(t *testing.T) {
|
||||||
|
threadID := 7
|
||||||
|
params := multipartEncodeParams{
|
||||||
|
ChatID: 42,
|
||||||
|
MessageThreadID: &threadID,
|
||||||
|
ReplyMarkup: &tgapi.ReplyMarkup{
|
||||||
|
InlineKeyboard: [][]tgapi.InlineKeyboardButton{{
|
||||||
|
{Text: "A", CallbackData: "b"},
|
||||||
|
}},
|
||||||
|
},
|
||||||
|
CaptionEntities: []tgapi.MessageEntity{{
|
||||||
|
Type: tgapi.MessageEntityBold,
|
||||||
|
Offset: 0,
|
||||||
|
Length: 4,
|
||||||
|
}},
|
||||||
|
}
|
||||||
|
|
||||||
|
body := bytes.NewBuffer(nil)
|
||||||
|
writer := multipart.NewWriter(body)
|
||||||
|
if err := utils.Encode(writer, params); err != nil {
|
||||||
|
t.Fatalf("Encode returned error: %v", err)
|
||||||
|
}
|
||||||
|
if err := writer.Close(); err != nil {
|
||||||
|
t.Fatalf("writer.Close returned error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
got := readMultipartFields(t, body.Bytes(), writer.Boundary())
|
||||||
|
if got["chat_id"] != "42" {
|
||||||
|
t.Fatalf("chat_id mismatch: %q", got["chat_id"])
|
||||||
|
}
|
||||||
|
if got["message_thread_id"] != "7" {
|
||||||
|
t.Fatalf("message_thread_id mismatch: %q", got["message_thread_id"])
|
||||||
|
}
|
||||||
|
if got["reply_markup"] != `{"inline_keyboard":[[{"text":"A","callback_data":"b"}]]}` {
|
||||||
|
t.Fatalf("reply_markup mismatch: %q", got["reply_markup"])
|
||||||
|
}
|
||||||
|
if got["caption_entities"] != `[{"type":"bold","offset":0,"length":4}]` {
|
||||||
|
t.Fatalf("caption_entities mismatch: %q", got["caption_entities"])
|
||||||
|
}
|
||||||
|
if _, ok := got["reply_parameters"]; ok {
|
||||||
|
t.Fatalf("reply_parameters should be omitted when nil")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func readMultipartFields(t *testing.T, body []byte, boundary string) map[string]string {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
reader := multipart.NewReader(bytes.NewReader(body), boundary)
|
||||||
|
fields := make(map[string]string)
|
||||||
|
for {
|
||||||
|
part, err := reader.NextPart()
|
||||||
|
if err == io.EOF {
|
||||||
|
return fields
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("NextPart returned error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
data, err := io.ReadAll(part)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("ReadAll returned error: %v", err)
|
||||||
|
}
|
||||||
|
fields[part.FormName()] = string(data)
|
||||||
|
}
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user