From 3e0d3db47ec1267a5f13a8dac2c6dabe6eb637bc Mon Sep 17 00:00:00 2001 From: ScuroNeko Date: Mon, 2 Mar 2026 00:58:43 +0300 Subject: [PATCH] some changes --- Dockerfile | 1 + build.bat | 4 +- go.mod | 4 +- go.sum | 4 +- openai/api.go | 204 +++++++++++++++++++++++++++++++++++ openai/completitions.go | 49 +++++++++ openai/sse.go | 37 +++++++ openai/types.go | 41 +++++++ plugins/admin.go | 14 +++ plugins/ai.go | 122 ++++++++++++++------- plugins/ban.go | 2 +- plugins/rp.go | 65 +++++++---- utils/ai/openai.go | 231 +--------------------------------------- 13 files changed, 486 insertions(+), 292 deletions(-) create mode 100644 openai/api.go create mode 100644 openai/completitions.go create mode 100644 openai/sse.go create mode 100644 openai/types.go diff --git a/Dockerfile b/Dockerfile index d108f54..08c681d 100644 --- a/Dockerfile +++ b/Dockerfile @@ -8,6 +8,7 @@ RUN --mount=type=cache,target=/go/pkg/mod go mod download -x COPY ./database ./database COPY ./plugins ./plugins COPY ./utils ./utils +COPY ./openai ./openai COPY ./main.go ./ RUN --mount=type=cache,target=/root/.cache/go-build \ --mount=type=cache,target=/go/pkg/mod \ diff --git a/build.bat b/build.bat index d84c88e..3eaf48f 100644 --- a/build.bat +++ b/build.bat @@ -1,3 +1,3 @@ go mod tidy -docker build --build-arg GIT_COMMIT="DEV" --build-arg BUILD_TIME="DEV" -t git.nix13.pw/scuroneko/kurumibotgo:dev -f ./Dockerfile . -docker push git.nix13.pw/scuroneko/kurumibotgo:dev \ No newline at end of file +docker build --build-arg GIT_COMMIT="DEV" --build-arg BUILD_TIME="DEV" -t git.nix13.pw/scuroneko/yaemikobot:dev -f ./Dockerfile . +docker push git.nix13.pw/scuroneko/yaemikobot:dev \ No newline at end of file diff --git a/go.mod b/go.mod index 66854f9..8b06166 100644 --- a/go.mod +++ b/go.mod @@ -4,7 +4,7 @@ go 1.26.0 require ( git.nix13.pw/scuroneko/extypes v1.2.1 - git.nix13.pw/scuroneko/laniakea v1.0.0-beta.3 + git.nix13.pw/scuroneko/laniakea v1.0.0-beta.6 git.nix13.pw/scuroneko/slog v1.0.2 github.com/google/uuid v1.6.0 github.com/joho/godotenv v1.5.1 @@ -15,7 +15,7 @@ require ( go.mongodb.org/mongo-driver/v2 v2.5.0 ) -//replace git.nix13.pw/scuroneko/laniakea v1.0.0-beta.2 => ./laniakea +//replace git.nix13.pw/scuroneko/laniakea v1.0.0-beta.6 => ./laniakea //replace git.nix13.pw/scuroneko/extypes v1.2.1 => ../go-extypes //replace git.nix13.pw/scuroneko/slog v1.0.2 => ../slog diff --git a/go.sum b/go.sum index 99f30a6..b314ed8 100644 --- a/go.sum +++ b/go.sum @@ -2,8 +2,8 @@ filippo.io/edwards25519 v1.1.0 h1:FNf4tywRC1HmFuKW5xopWpigGjJKiJSV0Cqo0cJWDaA= filippo.io/edwards25519 v1.1.0/go.mod h1:BxyFTGdWcka3PhytdK4V28tE5sGfRvvvRV7EaN4VDT4= git.nix13.pw/scuroneko/extypes v1.2.1 h1:IYrOjnWKL2EAuJYtYNa+luB1vBe6paE8VY/YD+5/RpQ= git.nix13.pw/scuroneko/extypes v1.2.1/go.mod h1:uZVs8Yo3RrYAG9dMad6qR6lsYY67t+459D9c65QAYAw= -git.nix13.pw/scuroneko/laniakea v1.0.0-beta.3 h1:FtEpeJ6Hi8/RGyT3m7Ysf2AIkwVLflc75HMSQzxPvnc= -git.nix13.pw/scuroneko/laniakea v1.0.0-beta.3/go.mod h1:DZgCqOazRzoa+f/GSNuKnTB2wIZ1eJD3cGf34Qya31U= +git.nix13.pw/scuroneko/laniakea v1.0.0-beta.6 h1:4XgYsXgx68/UkXORku2245yLaT+NnKnipanKpUKsuoI= +git.nix13.pw/scuroneko/laniakea v1.0.0-beta.6/go.mod h1:DZgCqOazRzoa+f/GSNuKnTB2wIZ1eJD3cGf34Qya31U= git.nix13.pw/scuroneko/slog v1.0.2 h1:vZyUROygxC2d5FJHUQM/30xFEHY1JT/aweDZXA4rm2g= git.nix13.pw/scuroneko/slog v1.0.2/go.mod h1:3Qm2wzkR5KjwOponMfG7TcGSDjmYaFqRAmLvSPTuWJI= github.com/alitto/pond/v2 v2.6.2 h1:Sphe40g0ILeM1pA2c2K+Th0DGU+pt0A/Kprr+WB24Pw= diff --git a/openai/api.go b/openai/api.go new file mode 100644 index 0000000..4a463bc --- /dev/null +++ b/openai/api.go @@ -0,0 +1,204 @@ +package openai + +import ( + "bufio" + "bytes" + "context" + "encoding/json" + "errors" + "fmt" + "io" + "iter" + "net/http" + "net/url" + "os" + "strings" + "time" + + "git.nix13.pw/scuroneko/slog" +) + +type API struct { + token string + model string + baseUrl string + logger *slog.Logger + client *http.Client + stream bool +} + +func NewOpenAIAPI(baseURL, token, model string) *API { + logger := slog.CreateLogger() + level := slog.FATAL + if os.Getenv("DEBUG") == "true" { + level = slog.DEBUG + } + logger = logger.Prefix("AI").Level(level) + // FIXME Leak here + //logger = logger.AddWriter(logger.CreateJsonStdoutWriter()) + proxy, err := url.Parse(os.Getenv("HTTPS_PROXY")) + if err != nil { + logger.Error(err) + logger.Close() + return nil + } + t := &http.Transport{} + if proxy.Host != "" { + t.Proxy = http.ProxyURL(proxy) + } + client := &http.Client{ + Timeout: 5 * time.Minute, + Transport: t, + } + return &API{ + token: token, + model: model, + baseUrl: baseURL, + logger: logger, + client: client, + } +} +func (api *API) Close() error { + return api.logger.Close() +} +func (api *API) SetStream(stream bool) *API { + api.stream = stream + return api +} +func (api *API) GetModel() string { return api.model } +func (api *API) GetBaseURL() string { return api.baseUrl } + +type Request[P any] struct { + params P + method string +} + +func NewRequest[P any](method string, params P) *Request[P] { + return &Request[P]{params, method} +} + +func (r *Request[P]) doWithContext(ctx context.Context, api *API) (io.ReadCloser, error) { + data, err := json.Marshal(r.params) + if err != nil { + return nil, err + } + + u := fmt.Sprintf("%s/v1/%s", api.baseUrl, r.method) + req, err := http.NewRequestWithContext(ctx, "POST", u, bytes.NewBuffer(data)) + if err != nil { + return nil, err + } + req.Header.Set("Content-Type", "application/json") + if api.token != "" { + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", api.token)) + } + + res, err := api.client.Do(req) + if err != nil { + return nil, err + } + if res.StatusCode == 504 || res.StatusCode == 400 || res.StatusCode == 502 { + api.logger.Warn(fmt.Sprintf("[%d] %s", res.StatusCode, res.Status)) + res.Body.Close() + return nil, fmt.Errorf("[%d] %s", res.StatusCode, res.Status) + } + return res.Body, nil +} +func (r *Request[P]) do(api *API) (io.ReadCloser, error) { + ctx := context.Background() + return r.doWithContext(ctx, api) +} + +func (r *Request[P]) DoWithContext(ctx context.Context, api *API) (AIResponse, error) { + var zero AIResponse + body, err := r.doWithContext(ctx, api) + if err != nil { + return zero, err + } + defer body.Close() + data, err := io.ReadAll(body) + if err != nil { + return zero, err + } + err = api.handleAIError(data) + if err != nil { + return zero, err + } + + err = json.Unmarshal(data, &zero) + return zero, err +} +func (r *Request[P]) Do(api *API) (AIResponse, error) { + ctx := context.Background() + return r.DoWithContext(ctx, api) +} +func (r *Request[P]) DoStreamWithContext(ctx context.Context, api *API) (iter.Seq2[AIResponse, error], error) { + body, err := r.doWithContext(ctx, api) + if err != nil { + return nil, err + } + + reader := bufio.NewReader(body) + return func(yield func(AIResponse, error) bool) { + defer body.Close() + var zero AIResponse + for { + line, err := reader.ReadString('\n') + if err != nil && err != io.EOF { + yield(zero, err) + return + } + if line == "" || line == "\n" { + continue + } + + if strings.HasPrefix(line, "data: ") { + line = line[len("data: "):] + } + line = strings.Trim(strings.Trim(strings.TrimSpace(line), "\r"), "\n") + if strings.HasPrefix(line, "[DONE]") { + return + } + + var resp AIResponse + err = json.Unmarshal([]byte(line), &resp) + if err != nil { + yield(zero, fmt.Errorf("%v\n%s", err, line)) + return + } + if !yield(resp, nil) { + return + } + time.Sleep(time.Millisecond * 100) + } + }, nil +} +func (r *Request[P]) DoStream(api *API) (iter.Seq2[AIResponse, error], error) { + ctx := context.Background() + return r.DoStreamWithContext(ctx, api) +} + +func (api *API) handleAIError(body []byte) error { + var tempData any + err := json.Unmarshal(body, &tempData) + if err != nil { + return err + } + + // {"error":{"message":"openai_error","type":"bad_response_status_code","param":"","code":"bad_response_status_code"}} + if eData, ok := tempData.(map[string]any); ok { + if errorData, ok := eData["error"]; ok { + if errorPayload, ok := errorData.(map[string]any); ok { + code, ok := errorPayload["code"] + if !ok { + return errors.New("unknown error code") + } + return errors.New(fmt.Sprintf("%v", code)) + } + return errors.New(string(body)) + } + } else if eData, ok := tempData.(string); ok { + return errors.New(eData) + } + return nil +} diff --git a/openai/completitions.go b/openai/completitions.go new file mode 100644 index 0000000..b6effd1 --- /dev/null +++ b/openai/completitions.go @@ -0,0 +1,49 @@ +package openai + +import ( + "fmt" + "iter" +) + +var MaxRetriesErr = fmt.Errorf("max retries exceeded") +var BadResponseErr = fmt.Errorf("bad_response_status_code") + +type CreateCompletionReq struct { + Model string `json:"model"` + Messages []Message `json:"messages"` + Verbosity string `json:"verbosity,omitempty"` + Temperature float32 `json:"temperature,omitempty"` + PresencePenalty int `json:"presence_penalty,omitempty"` + FrequencyPenalty int `json:"frequency_penalty,omitempty"` + TopP int `json:"top_p,omitempty"` + MaxCompletionTokens int `json:"max_completition_tokens,omitempty"` + Stream bool `json:"stream,omitempty"` +} + +func (api *API) CreateCompletionStream(history []Message, message string, temp float32) (iter.Seq2[AIResponse, error], error) { + params := CreateCompletionReq{ + Model: api.model, + Messages: append(history, Message{ + Role: "user", + Content: message, + }), + Temperature: temp, + Stream: true, + } + req := NewRequest("chat/completions", params) + return req.DoStream(api) +} + +func (api *API) CreateCompletion(history []Message, message string, temp float32) (AIResponse, error) { + params := CreateCompletionReq{ + Model: api.model, + Messages: append(history, Message{ + Role: "user", + Content: message, + }), + Temperature: temp, + Stream: false, + } + req := NewRequest("chat/completions", params) + return req.Do(api) +} diff --git a/openai/sse.go b/openai/sse.go new file mode 100644 index 0000000..60b881f --- /dev/null +++ b/openai/sse.go @@ -0,0 +1,37 @@ +package openai + +import ( + "bufio" + "io" + "iter" + "strings" +) + +// Server-sent event + +func ReadSSE(r io.ReadCloser) iter.Seq[string] { + reader := bufio.NewReader(r) + return func(yield func(string) bool) { + for { + line, err := reader.ReadString('\n') + if err != nil { + return + } + if line == "" || line == "\n" { + continue + } + if strings.HasPrefix(line, "data: ") { + line = line[len("data: "):] + } + line = strings.TrimSpace(line) + line = strings.Trim(line, "\r") + line = strings.Trim(line, "\n") + if strings.HasPrefix(line, "[DONE]") { + return + } + if !yield(line) { + return + } + } + } +} diff --git a/openai/types.go b/openai/types.go new file mode 100644 index 0000000..414ebf5 --- /dev/null +++ b/openai/types.go @@ -0,0 +1,41 @@ +package openai + +type AIResponse struct { + ID string `json:"id"` + Object string `json:"object"` + Created int64 `json:"created"` + Model string `json:"model"` + Choices []Choice `json:"choices"` + Usage Usage `json:"usage"` + ServiceTier string `json:"service_tier"` +} +type Choice struct { + Index int64 `json:"index"` + Message Message `json:"message"` + Delta Message `json:"delta"` + Logprobs interface{} `json:"logprobs"` + FinishReason string `json:"finish_reason"` +} +type Message struct { + Role string `json:"role"` + Content string `json:"content"` + Refusal interface{} `json:"refusal"` + Annotations []interface{} `json:"annotations"` +} +type Usage struct { + PromptTokens int64 `json:"prompt_tokens"` + CompletionTokens int64 `json:"completion_tokens"` + TotalTokens int64 `json:"total_tokens"` + PromptTokensDetails PromptTokensDetails `json:"prompt_tokens_details"` + CompletionTokensDetails CompletionTokensDetails `json:"completion_tokens_details"` +} +type CompletionTokensDetails struct { + ReasoningTokens int64 `json:"reasoning_tokens"` + AudioTokens int64 `json:"audio_tokens"` + AcceptedPredictionTokens int64 `json:"accepted_prediction_tokens"` + RejectedPredictionTokens int64 `json:"rejected_prediction_tokens"` +} +type PromptTokensDetails struct { + CachedTokens int64 `json:"cached_tokens"` + AudioTokens int64 `json:"audio_tokens"` +} diff --git a/plugins/admin.go b/plugins/admin.go index 1ca20a3..e2c1af1 100644 --- a/plugins/admin.go +++ b/plugins/admin.go @@ -2,8 +2,10 @@ package plugins import ( "encoding/json" + "fmt" "path/filepath" "strings" + "time" "ymgb/database" "ymgb/database/psql" @@ -18,6 +20,7 @@ func RegisterAdmin() *laniakea.Plugin[database.Context] { p.AddCommand(laniakea.NewCommand(uploadPhoto, "uploadPhoto").SkipCommandAutoGen()) p.AddCommand(laniakea.NewCommand(emojiId, "emojiId").SkipCommandAutoGen()) p.AddCommand(laniakea.NewCommand(execSql, "sql").SkipCommandAutoGen()) + p.AddCommand(laniakea.NewCommand(test, "test").SkipCommandAutoGen()) p.AddMiddleware(AdminMiddleware()) return p @@ -33,6 +36,17 @@ func AdminMiddleware() laniakea.Middleware[database.Context] { }) return *m } +func test(ctx *laniakea.MsgContext, _ *database.Context) { + draft := ctx.NewDraft() + for i := 0; i < 10; i++ { + err := draft.Push(fmt.Sprintf("%d", i)) + if err != nil { + ctx.Error(err) + return + } + time.Sleep(1 * time.Second) + } +} func execSql(ctx *laniakea.MsgContext, db *database.Context) { stmt := strings.Join(ctx.Args, " ") diff --git a/plugins/ai.go b/plugins/ai.go index 14e807c..5e486c5 100644 --- a/plugins/ai.go +++ b/plugins/ai.go @@ -1,10 +1,10 @@ package plugins import ( + "io" "strings" "ymgb/database" - "ymgb/database/mdb" - "ymgb/database/red" + "ymgb/openai" "ymgb/utils/ai" "git.nix13.pw/scuroneko/laniakea" @@ -12,50 +12,96 @@ import ( func RegisterAi() *laniakea.Plugin[database.Context] { p := laniakea.NewPlugin[database.Context]("AI") - p.AddCommand(p.NewCommand(gpt, "gpt").SkipCommandAutoGen()) + p.AddCommand(p.NewCommand(gptTest, "gpt").SkipCommandAutoGen()) + //p.AddCommand(p.NewCommand(gptTest, "gpt2").SkipCommandAutoGen()) return p } -func gpt(ctx *laniakea.MsgContext, db *database.Context) { +func gptTest(ctx *laniakea.MsgContext, _ *database.Context) { q := strings.Join(ctx.Args, " ") - api := ai.NewOpenAIAPI(ai.GPTBaseUrl, "", "anthropic/claude-sonnet-4") - defer api.Close() - - aiRedisRep := red.NewAiRepository(db) - chatId, err := aiRedisRep.GetOrCreateChatId(ctx.FromID) + m := ctx.Answer("Генерация запущена") + api := openai.NewOpenAIAPI(ai.GPTBaseUrl, "", "anthropic/claude-sonnet-4") + resp, err := api.CreateCompletionStream([]openai.Message{}, q, 1.0) if err != nil { + m.Delete() ctx.Error(err) return } - history, err := mdb.GetGptChatHistory(db, chatId) - if err != nil { - ctx.Error(err) - return - } - aiHistory := make([]ai.Message, len(history)) - for _, m := range history { - aiHistory = append(aiHistory, ai.Message{ - Role: m.Role, - Content: m.Message, - }) - } - m := ctx.Answer("Генерация запущена...") - res, err := api.CreateCompletion(aiHistory, q, 1.0) - if err != nil { - ctx.Error(err) - return - } - answer := res.Choices[0].Message.Content - m.Delete() - err = mdb.UpdateGptChatHistory(db, chatId, "user", q) - if err != nil { - ctx.Error(err) - } - err = mdb.UpdateGptChatHistory(db, chatId, "assistant", answer) - if err != nil { - ctx.Error(err) - } + draft := ctx.NewDraft() + for r, err := range resp { + if err == io.EOF { + break + } + if m != nil { + m.Delete() + m = nil + } + if err != nil { + ctx.Error(err) + return + } - ctx.Answer(answer) + if len(r.Choices) == 0 { + continue + } + content := r.Choices[0].Delta.Content + if content == "" { + continue + } + err = draft.Push(content) + if err != nil { + ctx.Error(err) + //draft.Flush() + break + } + } + err = draft.Flush() + if err != nil { + ctx.Error(err) + } } + +//func gpt(ctx *laniakea.MsgContext, db *database.Context) { +// q := strings.Join(ctx.Args, " ") +// api := openai.NewOpenAIAPI(ai.GPTBaseUrl, "", "anthropic/claude-sonnet-4") +// defer api.Close() +// +// aiRedisRep := red.NewAiRepository(db) +// chatId, err := aiRedisRep.GetOrCreateChatId(ctx.FromID) +// if err != nil { +// ctx.Error(err) +// return +// } +// history, err := mdb.GetGptChatHistory(db, chatId) +// if err != nil { +// ctx.Error(err) +// return +// } +// aiHistory := make([]openai.Message, len(history)) +// for _, m := range history { +// aiHistory = append(aiHistory, openai.Message{ +// Role: m.Role, +// Content: m.Message, +// }) +// } +// +// m := ctx.Answer("Генерация запущена...") +// res, err := api.CreateCompletion(aiHistory, q, 1.0) +// if err != nil { +// ctx.Error(err) +// return +// } +// answer := res.Choices[0].Message.Content +// m.Delete() +// err = mdb.UpdateGptChatHistory(db, chatId, "user", q) +// if err != nil { +// ctx.Error(err) +// } +// err = mdb.UpdateGptChatHistory(db, chatId, "assistant", answer) +// if err != nil { +// ctx.Error(err) +// } +// +// ctx.Answer(answer) +//} diff --git a/plugins/ban.go b/plugins/ban.go index 5ff56d6..6089a79 100644 --- a/plugins/ban.go +++ b/plugins/ban.go @@ -16,7 +16,7 @@ func RegisterProxy() *laniakea.Plugin[database.Context] { func getH2Link(ctx *laniakea.MsgContext, db *database.Context) { api := utils.NewHysteria2API() - url, err := api.GetConnectLink(1, "K1321xt90RUS") + url, err := api.GetConnectLink(1, "") if err != nil { ctx.Error(err) return diff --git a/plugins/rp.go b/plugins/rp.go index 2032288..6a9901e 100644 --- a/plugins/rp.go +++ b/plugins/rp.go @@ -11,6 +11,7 @@ import ( "ymgb/database/mdb" "ymgb/database/psql" "ymgb/database/red" + "ymgb/openai" "ymgb/utils" "ymgb/utils/ai" @@ -557,12 +558,12 @@ func rpUserPromptSet(ctx *laniakea.MsgContext, db *database.Context) { ctx.Answer("Описание пользователя было обновлено") } -func _getChatHistory(ctx *laniakea.MsgContext, db *database.Context) ([]ai.Message, error) { +func _getChatHistory(ctx *laniakea.MsgContext, db *database.Context) ([]openai.Message, error) { redRep := red.NewRPRepository(db) psqlRep := psql.NewRPRepository(db) waifuRep := psql.NewWaifuRepository(db) - messages := make([]ai.Message, 0) + messages := make([]openai.Message, 0) waifuId := redRep.GetSelectedWaifu(ctx.FromID) chatId, err := redRep.GetOrCreateChatId(ctx.FromID, waifuId) if err != nil { @@ -590,7 +591,7 @@ func _getChatHistory(ctx *laniakea.MsgContext, db *database.Context) ([]ai.Messa //if err != nil { // return messages, err //} - beforeHistory := ai.Message{ + beforeHistory := openai.Message{ Role: "system", Content: fmt.Sprintf( "%s %s %s %s", @@ -603,7 +604,7 @@ func _getChatHistory(ctx *laniakea.MsgContext, db *database.Context) ([]ai.Messa userPrompt, ), } - afterHistory := ai.Message{ + afterHistory := openai.Message{ Role: "system", Content: ai.FormatPrompt(preset.PostHistory, waifu.Name, ctx.From.FirstName), } @@ -615,7 +616,7 @@ func _getChatHistory(ctx *laniakea.MsgContext, db *database.Context) ([]ai.Messa messages = append(messages, beforeHistory) for _, m := range history { - messages = append(messages, ai.Message{Role: m.Role, Content: m.Message}) + messages = append(messages, openai.Message{Role: m.Role, Content: m.Message}) } messages = append(messages, afterHistory) return messages, nil @@ -653,16 +654,41 @@ func generate(ctx *laniakea.MsgContext, db *database.Context) { kb.AddCallbackButtonStyle("Отменить", laniakea.ButtonStyleDanger, "rp.cancel") m := ctx.Keyboard("Генерация запущена...", kb) ctx.SendAction(tgapi.ChatActionTyping) - api := ai.NewOpenAIAPI(ai.GPTBaseUrl, "", rpUser.Model.Key) + api := openai.NewOpenAIAPI(ai.GPTBaseUrl, "", rpUser.Model.Key) defer api.Close() - res, err := api.CreateCompletion(messages, userMessage, 0.5) + res, err := api.CreateCompletionStream(messages, userMessage, 0.5) if err != nil { ctx.Error(err) return } - if len(res.Choices) == 0 { - m.Edit("Не удалось сгенерировать ответ. Попробуйте снова позже") - return + + answerContent := "" + draft := ctx.NewDraft() + for r, err := range res { + if m != nil { + m.Delete() + m = nil + } + + if err != nil { + ctx.Error(err) + return + } + + if len(r.Choices) == 0 { + continue + } + content := r.Choices[0].Delta.Content + if content == "" { + continue + } + answerContent += content + err = draft.Push(content) + if err != nil { + ctx.Error(err) + //draft.Flush() + break + } } counter := redisRpRep.GetCounter(ctx.FromID, waifuId) @@ -671,9 +697,7 @@ func generate(ctx *laniakea.MsgContext, db *database.Context) { ctx.Error(err) return } - agentAnswer := res.Choices[0].Message - answerContent := strings.TrimSpace(agentAnswer.Content) - err = mdb.UpdateRPChatHistory(db, chatId, agentAnswer.Role, answerContent, counter+2) + err = mdb.UpdateRPChatHistory(db, chatId, "assistant", answerContent, counter+2) if err != nil { ctx.Error(err) } @@ -695,7 +719,6 @@ func generate(ctx *laniakea.MsgContext, db *database.Context) { ctx.Error(err) } - m.Delete() kb = laniakea.NewInlineKeyboard(1) kb.AddCallbackButtonStyle("🔄 Перегенерировать 🔄", laniakea.ButtonStyleSuccess, "rp.regenerate", counter+2) //kb.AddButton(laniakea.NewInlineKbButton("Тест").SetStyle(laniakea.ButtonStyleSuccess).SetIconCustomEmojiId("5375155835846534814").SetCallbackData("rp.test")) @@ -729,7 +752,7 @@ func regenerateResponse(ctx *laniakea.MsgContext, db *database.Context) { ctx.Error(err) return } - var messages extypes.Slice[ai.Message] + var messages extypes.Slice[openai.Message] messages, err = _getChatHistory(ctx, db) if err != nil { ctx.Error(err) @@ -748,7 +771,7 @@ func regenerateResponse(ctx *laniakea.MsgContext, db *database.Context) { } //if messages.Len() == count { - // ctx.Bot.Logger().Errorln("len(messages) == count") + // ctx.Bot.logger().Errorln("len(messages) == count") // return //} @@ -766,7 +789,7 @@ func regenerateResponse(ctx *laniakea.MsgContext, db *database.Context) { return } - api := ai.NewOpenAIAPI(ai.GPTBaseUrl, "", rpUser.Model.Key) + api := openai.NewOpenAIAPI(ai.GPTBaseUrl, "", rpUser.Model.Key) defer api.Close() messages = messages.Pop(count - 2) @@ -818,9 +841,9 @@ func _compress(ctx *laniakea.MsgContext, db *database.Context) { return } - messages := make([]ai.Message, 0) + messages := make([]openai.Message, 0) for _, m := range history { - messages = append(messages, ai.Message{ + messages = append(messages, openai.Message{ Role: m.Role, Content: m.Message, }) @@ -833,9 +856,9 @@ func _compress(ctx *laniakea.MsgContext, db *database.Context) { return } - api := ai.NewOpenAIAPI(ai.GPTBaseUrl, "", user.Model.Key) + api := openai.NewOpenAIAPI(ai.GPTBaseUrl, "", user.Model.Key) defer api.Close() - res, err := api.CompressChat(messages) + res, err := ai.CompressChat(api, messages) if err != nil { ctx.Error(err) return diff --git a/utils/ai/openai.go b/utils/ai/openai.go index d6c3fff..7c85d64 100644 --- a/utils/ai/openai.go +++ b/utils/ai/openai.go @@ -1,232 +1,11 @@ package ai import ( - "bytes" - "encoding/json" - "errors" - "fmt" - "io" - "log" - "net/http" - "net/url" - "os" - "time" - - "git.nix13.pw/scuroneko/slog" + "ymgb/openai" ) -type OpenAIResponse struct { - ID string `json:"id"` - Object string `json:"object"` - Created int64 `json:"created"` - Model string `json:"model"` - Choices []Choice `json:"choices"` - Usage Usage `json:"usage"` - ServiceTier string `json:"service_tier"` -} -type Choice struct { - Index int64 `json:"index"` - Message Message `json:"message"` - Logprobs interface{} `json:"logprobs"` - FinishReason string `json:"finish_reason"` -} -type Message struct { - Role string `json:"role"` - Content string `json:"content"` - Refusal interface{} `json:"refusal"` - Annotations []interface{} `json:"annotations"` -} -type Usage struct { - PromptTokens int64 `json:"prompt_tokens"` - CompletionTokens int64 `json:"completion_tokens"` - TotalTokens int64 `json:"total_tokens"` - PromptTokensDetails PromptTokensDetails `json:"prompt_tokens_details"` - CompletionTokensDetails CompletionTokensDetails `json:"completion_tokens_details"` -} -type CompletionTokensDetails struct { - ReasoningTokens int64 `json:"reasoning_tokens"` - AudioTokens int64 `json:"audio_tokens"` - AcceptedPredictionTokens int64 `json:"accepted_prediction_tokens"` - RejectedPredictionTokens int64 `json:"rejected_prediction_tokens"` -} -type PromptTokensDetails struct { - CachedTokens int64 `json:"cached_tokens"` - AudioTokens int64 `json:"audio_tokens"` -} -type OpenAIAPI struct { - Token string - Model string - BaseURL string - Logger *slog.Logger - client *http.Client -} - -func NewOpenAIAPI(baseURL, token, model string) *OpenAIAPI { - logger := slog.CreateLogger() - level := slog.FATAL - if os.Getenv("DEBUG") == "true" { - level = slog.DEBUG - } - logger = logger.Prefix("AI").Level(level) - // FIXME Leak here - //logger = logger.AddWriter(logger.CreateJsonStdoutWriter()) - proxy, err := url.Parse(os.Getenv("HTTPS_PROXY")) - if err != nil { - logger.Error(err) - logger.Close() - } - t := &http.Transport{} - if proxy.Host != "" { - t.Proxy = http.ProxyURL(proxy) - } - client := &http.Client{ - Timeout: 5 * time.Minute, - Transport: t, - } - return &OpenAIAPI{ - Token: token, - Model: model, - BaseURL: baseURL, - Logger: logger, - client: client, - } -} -func (o *OpenAIAPI) Close() error { - return o.Logger.Close() -} - -type CreateCompletionReq struct { - Model string `json:"model"` - Messages []Message `json:"messages"` - Verbosity string `json:"verbosity,omitempty"` - Temperature float32 `json:"temperature,omitempty"` - PresencePenalty int `json:"presence_penalty,omitempty"` - FrequencyPenalty int `json:"frequency_penalty,omitempty"` - TopP int `json:"top_p,omitempty"` - MaxCompletionTokens int `json:"max_completition_tokens,omitempty"` -} - -var MaxRetriesErr = fmt.Errorf("max retries exceeded") -var BadResponseErr = fmt.Errorf("bad_response_status_code") - -func (o *OpenAIAPI) DoRequest(url string, params any, retries int) ([]byte, error) { - responseBody := make([]byte, 0) - data, err := json.Marshal(params) - if err != nil { - log.Println("json marshal failed:", err) - return responseBody, err - } - - req, err := http.NewRequest("POST", url, bytes.NewBuffer(data)) - if err != nil { - log.Println("create request failed:", err) - return responseBody, err - } - req.Header.Set("Content-Type", "application/json") - if o.Token != "" { - req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", o.Token)) - } - - res, err := o.client.Do(req) - if err != nil { - log.Println("do request failed:", err) - return nil, err - } - defer res.Body.Close() - if res.StatusCode == 504 || res.StatusCode == 400 || res.StatusCode == 502 { - o.Logger.Warn(fmt.Sprintf("[%d] %s", res.StatusCode, res.Status)) - if retries >= 3 { - return responseBody, MaxRetriesErr - } - time.Sleep(1 * time.Second) - return o.DoRequest(url, params, retries+1) - } - responseBody, err = io.ReadAll(res.Body) - if err != nil { - log.Println("read response failed:", err) - return responseBody, err - } - - var tempData any - err = json.Unmarshal(responseBody, &tempData) - if err != nil { - log.Println("json unmarshal failed:", err) - return responseBody, err - } - // {"error":{"message":"openai_error","type":"bad_response_status_code","param":"","code":"bad_response_status_code"}} - if eData, ok := tempData.(map[string]any); ok { - if errorData, ok := eData["error"]; ok { - if errorPayload, ok := errorData.(map[string]any); ok { - code := errorPayload["code"].(string) - if code == "bad_response_status_code" { - if retries >= 3 { - return responseBody, BadResponseErr - } - o.Logger.Warnln("Retrying because of bad response status code") - return o.DoRequest(url, params, retries+1) - } - return nil, errors.New(code) - } - o.Logger.Errorln("Unknown error", errorData) - return nil, errors.New(string(responseBody)) - } - } else if eData, ok := tempData.(string); ok { - return responseBody, errors.New(eData) - } - - return responseBody, err -} - -func (o *OpenAIAPI) CreateCompletion(history []Message, message string, temp float32) (*OpenAIResponse, error) { - u := fmt.Sprintf("%s/v1/chat/completions", o.BaseURL) - request := CreateCompletionReq{ - Model: o.Model, - Messages: append(history, Message{ - Role: "user", - Content: message, - }), - Temperature: temp, - } - data, err := json.Marshal(request) - if err != nil { - return nil, err - } - - o.Logger.Debug("REQ", u, string(data)) - body, err := o.DoRequest(u, request, 0) - if err != nil { - return nil, err - } - o.Logger.Debug("RES", u, string(body)) - - response := new(OpenAIResponse) - err = json.Unmarshal(body, response) - return response, err -} - -func (o *OpenAIAPI) CompressChat(history []Message) (*OpenAIResponse, error) { - request := CreateCompletionReq{ - Model: o.Model, - Messages: append(history, Message{ - Role: "user", - Content: CompressPrompt, - }), - Temperature: 1.0, - } - u := fmt.Sprintf("%s/v1/chat/completions", o.BaseURL) - data, err := json.Marshal(request) - if err != nil { - return nil, err - } - - o.Logger.Debug("COMPRESS REQ", u, string(data)) - body, err := o.DoRequest(u, request, 0) - if err != nil { - return nil, err - } - o.Logger.Debug("COMPRESS RES", u, string(body)) - - response := new(OpenAIResponse) - err = json.Unmarshal(body, response) - return response, err +//https://github.com/sashabaranov/go-openai + +func CompressChat(api *openai.API, history []openai.Message) (openai.AIResponse, error) { + return api.CreateCompletion(history, CompressPrompt, 0.0) }