some changes

This commit is contained in:
2026-03-02 00:58:43 +03:00
parent b394c0be68
commit 3e0d3db47e
13 changed files with 486 additions and 292 deletions

View File

@@ -8,6 +8,7 @@ RUN --mount=type=cache,target=/go/pkg/mod go mod download -x
COPY ./database ./database
COPY ./plugins ./plugins
COPY ./utils ./utils
COPY ./openai ./openai
COPY ./main.go ./
RUN --mount=type=cache,target=/root/.cache/go-build \
--mount=type=cache,target=/go/pkg/mod \

View File

@@ -1,3 +1,3 @@
go mod tidy
docker build --build-arg GIT_COMMIT="DEV" --build-arg BUILD_TIME="DEV" -t git.nix13.pw/scuroneko/kurumibotgo:dev -f ./Dockerfile .
docker push git.nix13.pw/scuroneko/kurumibotgo:dev
docker build --build-arg GIT_COMMIT="DEV" --build-arg BUILD_TIME="DEV" -t git.nix13.pw/scuroneko/yaemikobot:dev -f ./Dockerfile .
docker push git.nix13.pw/scuroneko/yaemikobot:dev

4
go.mod
View File

@@ -4,7 +4,7 @@ go 1.26.0
require (
git.nix13.pw/scuroneko/extypes v1.2.1
git.nix13.pw/scuroneko/laniakea v1.0.0-beta.3
git.nix13.pw/scuroneko/laniakea v1.0.0-beta.6
git.nix13.pw/scuroneko/slog v1.0.2
github.com/google/uuid v1.6.0
github.com/joho/godotenv v1.5.1
@@ -15,7 +15,7 @@ require (
go.mongodb.org/mongo-driver/v2 v2.5.0
)
//replace git.nix13.pw/scuroneko/laniakea v1.0.0-beta.2 => ./laniakea
//replace git.nix13.pw/scuroneko/laniakea v1.0.0-beta.6 => ./laniakea
//replace git.nix13.pw/scuroneko/extypes v1.2.1 => ../go-extypes
//replace git.nix13.pw/scuroneko/slog v1.0.2 => ../slog

4
go.sum
View File

@@ -2,8 +2,8 @@ filippo.io/edwards25519 v1.1.0 h1:FNf4tywRC1HmFuKW5xopWpigGjJKiJSV0Cqo0cJWDaA=
filippo.io/edwards25519 v1.1.0/go.mod h1:BxyFTGdWcka3PhytdK4V28tE5sGfRvvvRV7EaN4VDT4=
git.nix13.pw/scuroneko/extypes v1.2.1 h1:IYrOjnWKL2EAuJYtYNa+luB1vBe6paE8VY/YD+5/RpQ=
git.nix13.pw/scuroneko/extypes v1.2.1/go.mod h1:uZVs8Yo3RrYAG9dMad6qR6lsYY67t+459D9c65QAYAw=
git.nix13.pw/scuroneko/laniakea v1.0.0-beta.3 h1:FtEpeJ6Hi8/RGyT3m7Ysf2AIkwVLflc75HMSQzxPvnc=
git.nix13.pw/scuroneko/laniakea v1.0.0-beta.3/go.mod h1:DZgCqOazRzoa+f/GSNuKnTB2wIZ1eJD3cGf34Qya31U=
git.nix13.pw/scuroneko/laniakea v1.0.0-beta.6 h1:4XgYsXgx68/UkXORku2245yLaT+NnKnipanKpUKsuoI=
git.nix13.pw/scuroneko/laniakea v1.0.0-beta.6/go.mod h1:DZgCqOazRzoa+f/GSNuKnTB2wIZ1eJD3cGf34Qya31U=
git.nix13.pw/scuroneko/slog v1.0.2 h1:vZyUROygxC2d5FJHUQM/30xFEHY1JT/aweDZXA4rm2g=
git.nix13.pw/scuroneko/slog v1.0.2/go.mod h1:3Qm2wzkR5KjwOponMfG7TcGSDjmYaFqRAmLvSPTuWJI=
github.com/alitto/pond/v2 v2.6.2 h1:Sphe40g0ILeM1pA2c2K+Th0DGU+pt0A/Kprr+WB24Pw=

204
openai/api.go Normal file
View File

@@ -0,0 +1,204 @@
package openai
import (
"bufio"
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"io"
"iter"
"net/http"
"net/url"
"os"
"strings"
"time"
"git.nix13.pw/scuroneko/slog"
)
type API struct {
token string
model string
baseUrl string
logger *slog.Logger
client *http.Client
stream bool
}
func NewOpenAIAPI(baseURL, token, model string) *API {
logger := slog.CreateLogger()
level := slog.FATAL
if os.Getenv("DEBUG") == "true" {
level = slog.DEBUG
}
logger = logger.Prefix("AI").Level(level)
// FIXME Leak here
//logger = logger.AddWriter(logger.CreateJsonStdoutWriter())
proxy, err := url.Parse(os.Getenv("HTTPS_PROXY"))
if err != nil {
logger.Error(err)
logger.Close()
return nil
}
t := &http.Transport{}
if proxy.Host != "" {
t.Proxy = http.ProxyURL(proxy)
}
client := &http.Client{
Timeout: 5 * time.Minute,
Transport: t,
}
return &API{
token: token,
model: model,
baseUrl: baseURL,
logger: logger,
client: client,
}
}
func (api *API) Close() error {
return api.logger.Close()
}
func (api *API) SetStream(stream bool) *API {
api.stream = stream
return api
}
func (api *API) GetModel() string { return api.model }
func (api *API) GetBaseURL() string { return api.baseUrl }
type Request[P any] struct {
params P
method string
}
func NewRequest[P any](method string, params P) *Request[P] {
return &Request[P]{params, method}
}
func (r *Request[P]) doWithContext(ctx context.Context, api *API) (io.ReadCloser, error) {
data, err := json.Marshal(r.params)
if err != nil {
return nil, err
}
u := fmt.Sprintf("%s/v1/%s", api.baseUrl, r.method)
req, err := http.NewRequestWithContext(ctx, "POST", u, bytes.NewBuffer(data))
if err != nil {
return nil, err
}
req.Header.Set("Content-Type", "application/json")
if api.token != "" {
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", api.token))
}
res, err := api.client.Do(req)
if err != nil {
return nil, err
}
if res.StatusCode == 504 || res.StatusCode == 400 || res.StatusCode == 502 {
api.logger.Warn(fmt.Sprintf("[%d] %s", res.StatusCode, res.Status))
res.Body.Close()
return nil, fmt.Errorf("[%d] %s", res.StatusCode, res.Status)
}
return res.Body, nil
}
func (r *Request[P]) do(api *API) (io.ReadCloser, error) {
ctx := context.Background()
return r.doWithContext(ctx, api)
}
func (r *Request[P]) DoWithContext(ctx context.Context, api *API) (AIResponse, error) {
var zero AIResponse
body, err := r.doWithContext(ctx, api)
if err != nil {
return zero, err
}
defer body.Close()
data, err := io.ReadAll(body)
if err != nil {
return zero, err
}
err = api.handleAIError(data)
if err != nil {
return zero, err
}
err = json.Unmarshal(data, &zero)
return zero, err
}
func (r *Request[P]) Do(api *API) (AIResponse, error) {
ctx := context.Background()
return r.DoWithContext(ctx, api)
}
func (r *Request[P]) DoStreamWithContext(ctx context.Context, api *API) (iter.Seq2[AIResponse, error], error) {
body, err := r.doWithContext(ctx, api)
if err != nil {
return nil, err
}
reader := bufio.NewReader(body)
return func(yield func(AIResponse, error) bool) {
defer body.Close()
var zero AIResponse
for {
line, err := reader.ReadString('\n')
if err != nil && err != io.EOF {
yield(zero, err)
return
}
if line == "" || line == "\n" {
continue
}
if strings.HasPrefix(line, "data: ") {
line = line[len("data: "):]
}
line = strings.Trim(strings.Trim(strings.TrimSpace(line), "\r"), "\n")
if strings.HasPrefix(line, "[DONE]") {
return
}
var resp AIResponse
err = json.Unmarshal([]byte(line), &resp)
if err != nil {
yield(zero, fmt.Errorf("%v\n%s", err, line))
return
}
if !yield(resp, nil) {
return
}
time.Sleep(time.Millisecond * 100)
}
}, nil
}
func (r *Request[P]) DoStream(api *API) (iter.Seq2[AIResponse, error], error) {
ctx := context.Background()
return r.DoStreamWithContext(ctx, api)
}
func (api *API) handleAIError(body []byte) error {
var tempData any
err := json.Unmarshal(body, &tempData)
if err != nil {
return err
}
// {"error":{"message":"openai_error","type":"bad_response_status_code","param":"","code":"bad_response_status_code"}}
if eData, ok := tempData.(map[string]any); ok {
if errorData, ok := eData["error"]; ok {
if errorPayload, ok := errorData.(map[string]any); ok {
code, ok := errorPayload["code"]
if !ok {
return errors.New("unknown error code")
}
return errors.New(fmt.Sprintf("%v", code))
}
return errors.New(string(body))
}
} else if eData, ok := tempData.(string); ok {
return errors.New(eData)
}
return nil
}

49
openai/completitions.go Normal file
View File

@@ -0,0 +1,49 @@
package openai
import (
"fmt"
"iter"
)
var MaxRetriesErr = fmt.Errorf("max retries exceeded")
var BadResponseErr = fmt.Errorf("bad_response_status_code")
type CreateCompletionReq struct {
Model string `json:"model"`
Messages []Message `json:"messages"`
Verbosity string `json:"verbosity,omitempty"`
Temperature float32 `json:"temperature,omitempty"`
PresencePenalty int `json:"presence_penalty,omitempty"`
FrequencyPenalty int `json:"frequency_penalty,omitempty"`
TopP int `json:"top_p,omitempty"`
MaxCompletionTokens int `json:"max_completition_tokens,omitempty"`
Stream bool `json:"stream,omitempty"`
}
func (api *API) CreateCompletionStream(history []Message, message string, temp float32) (iter.Seq2[AIResponse, error], error) {
params := CreateCompletionReq{
Model: api.model,
Messages: append(history, Message{
Role: "user",
Content: message,
}),
Temperature: temp,
Stream: true,
}
req := NewRequest("chat/completions", params)
return req.DoStream(api)
}
func (api *API) CreateCompletion(history []Message, message string, temp float32) (AIResponse, error) {
params := CreateCompletionReq{
Model: api.model,
Messages: append(history, Message{
Role: "user",
Content: message,
}),
Temperature: temp,
Stream: false,
}
req := NewRequest("chat/completions", params)
return req.Do(api)
}

37
openai/sse.go Normal file
View File

@@ -0,0 +1,37 @@
package openai
import (
"bufio"
"io"
"iter"
"strings"
)
// Server-sent event
func ReadSSE(r io.ReadCloser) iter.Seq[string] {
reader := bufio.NewReader(r)
return func(yield func(string) bool) {
for {
line, err := reader.ReadString('\n')
if err != nil {
return
}
if line == "" || line == "\n" {
continue
}
if strings.HasPrefix(line, "data: ") {
line = line[len("data: "):]
}
line = strings.TrimSpace(line)
line = strings.Trim(line, "\r")
line = strings.Trim(line, "\n")
if strings.HasPrefix(line, "[DONE]") {
return
}
if !yield(line) {
return
}
}
}
}

41
openai/types.go Normal file
View File

@@ -0,0 +1,41 @@
package openai
type AIResponse struct {
ID string `json:"id"`
Object string `json:"object"`
Created int64 `json:"created"`
Model string `json:"model"`
Choices []Choice `json:"choices"`
Usage Usage `json:"usage"`
ServiceTier string `json:"service_tier"`
}
type Choice struct {
Index int64 `json:"index"`
Message Message `json:"message"`
Delta Message `json:"delta"`
Logprobs interface{} `json:"logprobs"`
FinishReason string `json:"finish_reason"`
}
type Message struct {
Role string `json:"role"`
Content string `json:"content"`
Refusal interface{} `json:"refusal"`
Annotations []interface{} `json:"annotations"`
}
type Usage struct {
PromptTokens int64 `json:"prompt_tokens"`
CompletionTokens int64 `json:"completion_tokens"`
TotalTokens int64 `json:"total_tokens"`
PromptTokensDetails PromptTokensDetails `json:"prompt_tokens_details"`
CompletionTokensDetails CompletionTokensDetails `json:"completion_tokens_details"`
}
type CompletionTokensDetails struct {
ReasoningTokens int64 `json:"reasoning_tokens"`
AudioTokens int64 `json:"audio_tokens"`
AcceptedPredictionTokens int64 `json:"accepted_prediction_tokens"`
RejectedPredictionTokens int64 `json:"rejected_prediction_tokens"`
}
type PromptTokensDetails struct {
CachedTokens int64 `json:"cached_tokens"`
AudioTokens int64 `json:"audio_tokens"`
}

View File

@@ -2,8 +2,10 @@ package plugins
import (
"encoding/json"
"fmt"
"path/filepath"
"strings"
"time"
"ymgb/database"
"ymgb/database/psql"
@@ -18,6 +20,7 @@ func RegisterAdmin() *laniakea.Plugin[database.Context] {
p.AddCommand(laniakea.NewCommand(uploadPhoto, "uploadPhoto").SkipCommandAutoGen())
p.AddCommand(laniakea.NewCommand(emojiId, "emojiId").SkipCommandAutoGen())
p.AddCommand(laniakea.NewCommand(execSql, "sql").SkipCommandAutoGen())
p.AddCommand(laniakea.NewCommand(test, "test").SkipCommandAutoGen())
p.AddMiddleware(AdminMiddleware())
return p
@@ -33,6 +36,17 @@ func AdminMiddleware() laniakea.Middleware[database.Context] {
})
return *m
}
func test(ctx *laniakea.MsgContext, _ *database.Context) {
draft := ctx.NewDraft()
for i := 0; i < 10; i++ {
err := draft.Push(fmt.Sprintf("%d", i))
if err != nil {
ctx.Error(err)
return
}
time.Sleep(1 * time.Second)
}
}
func execSql(ctx *laniakea.MsgContext, db *database.Context) {
stmt := strings.Join(ctx.Args, " ")

View File

@@ -1,10 +1,10 @@
package plugins
import (
"io"
"strings"
"ymgb/database"
"ymgb/database/mdb"
"ymgb/database/red"
"ymgb/openai"
"ymgb/utils/ai"
"git.nix13.pw/scuroneko/laniakea"
@@ -12,50 +12,96 @@ import (
func RegisterAi() *laniakea.Plugin[database.Context] {
p := laniakea.NewPlugin[database.Context]("AI")
p.AddCommand(p.NewCommand(gpt, "gpt").SkipCommandAutoGen())
p.AddCommand(p.NewCommand(gptTest, "gpt").SkipCommandAutoGen())
//p.AddCommand(p.NewCommand(gptTest, "gpt2").SkipCommandAutoGen())
return p
}
func gpt(ctx *laniakea.MsgContext, db *database.Context) {
func gptTest(ctx *laniakea.MsgContext, _ *database.Context) {
q := strings.Join(ctx.Args, " ")
api := ai.NewOpenAIAPI(ai.GPTBaseUrl, "", "anthropic/claude-sonnet-4")
defer api.Close()
aiRedisRep := red.NewAiRepository(db)
chatId, err := aiRedisRep.GetOrCreateChatId(ctx.FromID)
m := ctx.Answer("Генерация запущена")
api := openai.NewOpenAIAPI(ai.GPTBaseUrl, "", "anthropic/claude-sonnet-4")
resp, err := api.CreateCompletionStream([]openai.Message{}, q, 1.0)
if err != nil {
m.Delete()
ctx.Error(err)
return
}
history, err := mdb.GetGptChatHistory(db, chatId)
if err != nil {
ctx.Error(err)
return
}
aiHistory := make([]ai.Message, len(history))
for _, m := range history {
aiHistory = append(aiHistory, ai.Message{
Role: m.Role,
Content: m.Message,
})
}
m := ctx.Answer("Генерация запущена...")
res, err := api.CreateCompletion(aiHistory, q, 1.0)
if err != nil {
ctx.Error(err)
return
}
answer := res.Choices[0].Message.Content
m.Delete()
err = mdb.UpdateGptChatHistory(db, chatId, "user", q)
if err != nil {
ctx.Error(err)
}
err = mdb.UpdateGptChatHistory(db, chatId, "assistant", answer)
if err != nil {
ctx.Error(err)
}
draft := ctx.NewDraft()
for r, err := range resp {
if err == io.EOF {
break
}
if m != nil {
m.Delete()
m = nil
}
if err != nil {
ctx.Error(err)
return
}
ctx.Answer(answer)
if len(r.Choices) == 0 {
continue
}
content := r.Choices[0].Delta.Content
if content == "" {
continue
}
err = draft.Push(content)
if err != nil {
ctx.Error(err)
//draft.Flush()
break
}
}
err = draft.Flush()
if err != nil {
ctx.Error(err)
}
}
//func gpt(ctx *laniakea.MsgContext, db *database.Context) {
// q := strings.Join(ctx.Args, " ")
// api := openai.NewOpenAIAPI(ai.GPTBaseUrl, "", "anthropic/claude-sonnet-4")
// defer api.Close()
//
// aiRedisRep := red.NewAiRepository(db)
// chatId, err := aiRedisRep.GetOrCreateChatId(ctx.FromID)
// if err != nil {
// ctx.Error(err)
// return
// }
// history, err := mdb.GetGptChatHistory(db, chatId)
// if err != nil {
// ctx.Error(err)
// return
// }
// aiHistory := make([]openai.Message, len(history))
// for _, m := range history {
// aiHistory = append(aiHistory, openai.Message{
// Role: m.Role,
// Content: m.Message,
// })
// }
//
// m := ctx.Answer("Генерация запущена...")
// res, err := api.CreateCompletion(aiHistory, q, 1.0)
// if err != nil {
// ctx.Error(err)
// return
// }
// answer := res.Choices[0].Message.Content
// m.Delete()
// err = mdb.UpdateGptChatHistory(db, chatId, "user", q)
// if err != nil {
// ctx.Error(err)
// }
// err = mdb.UpdateGptChatHistory(db, chatId, "assistant", answer)
// if err != nil {
// ctx.Error(err)
// }
//
// ctx.Answer(answer)
//}

View File

@@ -16,7 +16,7 @@ func RegisterProxy() *laniakea.Plugin[database.Context] {
func getH2Link(ctx *laniakea.MsgContext, db *database.Context) {
api := utils.NewHysteria2API()
url, err := api.GetConnectLink(1, "K1321xt90RUS")
url, err := api.GetConnectLink(1, "")
if err != nil {
ctx.Error(err)
return

View File

@@ -11,6 +11,7 @@ import (
"ymgb/database/mdb"
"ymgb/database/psql"
"ymgb/database/red"
"ymgb/openai"
"ymgb/utils"
"ymgb/utils/ai"
@@ -557,12 +558,12 @@ func rpUserPromptSet(ctx *laniakea.MsgContext, db *database.Context) {
ctx.Answer("Описание пользователя было обновлено")
}
func _getChatHistory(ctx *laniakea.MsgContext, db *database.Context) ([]ai.Message, error) {
func _getChatHistory(ctx *laniakea.MsgContext, db *database.Context) ([]openai.Message, error) {
redRep := red.NewRPRepository(db)
psqlRep := psql.NewRPRepository(db)
waifuRep := psql.NewWaifuRepository(db)
messages := make([]ai.Message, 0)
messages := make([]openai.Message, 0)
waifuId := redRep.GetSelectedWaifu(ctx.FromID)
chatId, err := redRep.GetOrCreateChatId(ctx.FromID, waifuId)
if err != nil {
@@ -590,7 +591,7 @@ func _getChatHistory(ctx *laniakea.MsgContext, db *database.Context) ([]ai.Messa
//if err != nil {
// return messages, err
//}
beforeHistory := ai.Message{
beforeHistory := openai.Message{
Role: "system",
Content: fmt.Sprintf(
"%s %s %s %s",
@@ -603,7 +604,7 @@ func _getChatHistory(ctx *laniakea.MsgContext, db *database.Context) ([]ai.Messa
userPrompt,
),
}
afterHistory := ai.Message{
afterHistory := openai.Message{
Role: "system",
Content: ai.FormatPrompt(preset.PostHistory, waifu.Name, ctx.From.FirstName),
}
@@ -615,7 +616,7 @@ func _getChatHistory(ctx *laniakea.MsgContext, db *database.Context) ([]ai.Messa
messages = append(messages, beforeHistory)
for _, m := range history {
messages = append(messages, ai.Message{Role: m.Role, Content: m.Message})
messages = append(messages, openai.Message{Role: m.Role, Content: m.Message})
}
messages = append(messages, afterHistory)
return messages, nil
@@ -653,16 +654,41 @@ func generate(ctx *laniakea.MsgContext, db *database.Context) {
kb.AddCallbackButtonStyle("Отменить", laniakea.ButtonStyleDanger, "rp.cancel")
m := ctx.Keyboard("Генерация запущена...", kb)
ctx.SendAction(tgapi.ChatActionTyping)
api := ai.NewOpenAIAPI(ai.GPTBaseUrl, "", rpUser.Model.Key)
api := openai.NewOpenAIAPI(ai.GPTBaseUrl, "", rpUser.Model.Key)
defer api.Close()
res, err := api.CreateCompletion(messages, userMessage, 0.5)
res, err := api.CreateCompletionStream(messages, userMessage, 0.5)
if err != nil {
ctx.Error(err)
return
}
if len(res.Choices) == 0 {
m.Edit("Не удалось сгенерировать ответ. Попробуйте снова позже")
return
answerContent := ""
draft := ctx.NewDraft()
for r, err := range res {
if m != nil {
m.Delete()
m = nil
}
if err != nil {
ctx.Error(err)
return
}
if len(r.Choices) == 0 {
continue
}
content := r.Choices[0].Delta.Content
if content == "" {
continue
}
answerContent += content
err = draft.Push(content)
if err != nil {
ctx.Error(err)
//draft.Flush()
break
}
}
counter := redisRpRep.GetCounter(ctx.FromID, waifuId)
@@ -671,9 +697,7 @@ func generate(ctx *laniakea.MsgContext, db *database.Context) {
ctx.Error(err)
return
}
agentAnswer := res.Choices[0].Message
answerContent := strings.TrimSpace(agentAnswer.Content)
err = mdb.UpdateRPChatHistory(db, chatId, agentAnswer.Role, answerContent, counter+2)
err = mdb.UpdateRPChatHistory(db, chatId, "assistant", answerContent, counter+2)
if err != nil {
ctx.Error(err)
}
@@ -695,7 +719,6 @@ func generate(ctx *laniakea.MsgContext, db *database.Context) {
ctx.Error(err)
}
m.Delete()
kb = laniakea.NewInlineKeyboard(1)
kb.AddCallbackButtonStyle("🔄 Перегенерировать 🔄", laniakea.ButtonStyleSuccess, "rp.regenerate", counter+2)
//kb.AddButton(laniakea.NewInlineKbButton("Тест").SetStyle(laniakea.ButtonStyleSuccess).SetIconCustomEmojiId("5375155835846534814").SetCallbackData("rp.test"))
@@ -729,7 +752,7 @@ func regenerateResponse(ctx *laniakea.MsgContext, db *database.Context) {
ctx.Error(err)
return
}
var messages extypes.Slice[ai.Message]
var messages extypes.Slice[openai.Message]
messages, err = _getChatHistory(ctx, db)
if err != nil {
ctx.Error(err)
@@ -748,7 +771,7 @@ func regenerateResponse(ctx *laniakea.MsgContext, db *database.Context) {
}
//if messages.Len() == count {
// ctx.Bot.Logger().Errorln("len(messages) == count")
// ctx.Bot.logger().Errorln("len(messages) == count")
// return
//}
@@ -766,7 +789,7 @@ func regenerateResponse(ctx *laniakea.MsgContext, db *database.Context) {
return
}
api := ai.NewOpenAIAPI(ai.GPTBaseUrl, "", rpUser.Model.Key)
api := openai.NewOpenAIAPI(ai.GPTBaseUrl, "", rpUser.Model.Key)
defer api.Close()
messages = messages.Pop(count - 2)
@@ -818,9 +841,9 @@ func _compress(ctx *laniakea.MsgContext, db *database.Context) {
return
}
messages := make([]ai.Message, 0)
messages := make([]openai.Message, 0)
for _, m := range history {
messages = append(messages, ai.Message{
messages = append(messages, openai.Message{
Role: m.Role,
Content: m.Message,
})
@@ -833,9 +856,9 @@ func _compress(ctx *laniakea.MsgContext, db *database.Context) {
return
}
api := ai.NewOpenAIAPI(ai.GPTBaseUrl, "", user.Model.Key)
api := openai.NewOpenAIAPI(ai.GPTBaseUrl, "", user.Model.Key)
defer api.Close()
res, err := api.CompressChat(messages)
res, err := ai.CompressChat(api, messages)
if err != nil {
ctx.Error(err)
return

View File

@@ -1,232 +1,11 @@
package ai
import (
"bytes"
"encoding/json"
"errors"
"fmt"
"io"
"log"
"net/http"
"net/url"
"os"
"time"
"git.nix13.pw/scuroneko/slog"
"ymgb/openai"
)
type OpenAIResponse struct {
ID string `json:"id"`
Object string `json:"object"`
Created int64 `json:"created"`
Model string `json:"model"`
Choices []Choice `json:"choices"`
Usage Usage `json:"usage"`
ServiceTier string `json:"service_tier"`
}
type Choice struct {
Index int64 `json:"index"`
Message Message `json:"message"`
Logprobs interface{} `json:"logprobs"`
FinishReason string `json:"finish_reason"`
}
type Message struct {
Role string `json:"role"`
Content string `json:"content"`
Refusal interface{} `json:"refusal"`
Annotations []interface{} `json:"annotations"`
}
type Usage struct {
PromptTokens int64 `json:"prompt_tokens"`
CompletionTokens int64 `json:"completion_tokens"`
TotalTokens int64 `json:"total_tokens"`
PromptTokensDetails PromptTokensDetails `json:"prompt_tokens_details"`
CompletionTokensDetails CompletionTokensDetails `json:"completion_tokens_details"`
}
type CompletionTokensDetails struct {
ReasoningTokens int64 `json:"reasoning_tokens"`
AudioTokens int64 `json:"audio_tokens"`
AcceptedPredictionTokens int64 `json:"accepted_prediction_tokens"`
RejectedPredictionTokens int64 `json:"rejected_prediction_tokens"`
}
type PromptTokensDetails struct {
CachedTokens int64 `json:"cached_tokens"`
AudioTokens int64 `json:"audio_tokens"`
}
type OpenAIAPI struct {
Token string
Model string
BaseURL string
Logger *slog.Logger
client *http.Client
}
func NewOpenAIAPI(baseURL, token, model string) *OpenAIAPI {
logger := slog.CreateLogger()
level := slog.FATAL
if os.Getenv("DEBUG") == "true" {
level = slog.DEBUG
}
logger = logger.Prefix("AI").Level(level)
// FIXME Leak here
//logger = logger.AddWriter(logger.CreateJsonStdoutWriter())
proxy, err := url.Parse(os.Getenv("HTTPS_PROXY"))
if err != nil {
logger.Error(err)
logger.Close()
}
t := &http.Transport{}
if proxy.Host != "" {
t.Proxy = http.ProxyURL(proxy)
}
client := &http.Client{
Timeout: 5 * time.Minute,
Transport: t,
}
return &OpenAIAPI{
Token: token,
Model: model,
BaseURL: baseURL,
Logger: logger,
client: client,
}
}
func (o *OpenAIAPI) Close() error {
return o.Logger.Close()
}
type CreateCompletionReq struct {
Model string `json:"model"`
Messages []Message `json:"messages"`
Verbosity string `json:"verbosity,omitempty"`
Temperature float32 `json:"temperature,omitempty"`
PresencePenalty int `json:"presence_penalty,omitempty"`
FrequencyPenalty int `json:"frequency_penalty,omitempty"`
TopP int `json:"top_p,omitempty"`
MaxCompletionTokens int `json:"max_completition_tokens,omitempty"`
}
var MaxRetriesErr = fmt.Errorf("max retries exceeded")
var BadResponseErr = fmt.Errorf("bad_response_status_code")
func (o *OpenAIAPI) DoRequest(url string, params any, retries int) ([]byte, error) {
responseBody := make([]byte, 0)
data, err := json.Marshal(params)
if err != nil {
log.Println("json marshal failed:", err)
return responseBody, err
}
req, err := http.NewRequest("POST", url, bytes.NewBuffer(data))
if err != nil {
log.Println("create request failed:", err)
return responseBody, err
}
req.Header.Set("Content-Type", "application/json")
if o.Token != "" {
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", o.Token))
}
res, err := o.client.Do(req)
if err != nil {
log.Println("do request failed:", err)
return nil, err
}
defer res.Body.Close()
if res.StatusCode == 504 || res.StatusCode == 400 || res.StatusCode == 502 {
o.Logger.Warn(fmt.Sprintf("[%d] %s", res.StatusCode, res.Status))
if retries >= 3 {
return responseBody, MaxRetriesErr
}
time.Sleep(1 * time.Second)
return o.DoRequest(url, params, retries+1)
}
responseBody, err = io.ReadAll(res.Body)
if err != nil {
log.Println("read response failed:", err)
return responseBody, err
}
var tempData any
err = json.Unmarshal(responseBody, &tempData)
if err != nil {
log.Println("json unmarshal failed:", err)
return responseBody, err
}
// {"error":{"message":"openai_error","type":"bad_response_status_code","param":"","code":"bad_response_status_code"}}
if eData, ok := tempData.(map[string]any); ok {
if errorData, ok := eData["error"]; ok {
if errorPayload, ok := errorData.(map[string]any); ok {
code := errorPayload["code"].(string)
if code == "bad_response_status_code" {
if retries >= 3 {
return responseBody, BadResponseErr
}
o.Logger.Warnln("Retrying because of bad response status code")
return o.DoRequest(url, params, retries+1)
}
return nil, errors.New(code)
}
o.Logger.Errorln("Unknown error", errorData)
return nil, errors.New(string(responseBody))
}
} else if eData, ok := tempData.(string); ok {
return responseBody, errors.New(eData)
}
return responseBody, err
}
func (o *OpenAIAPI) CreateCompletion(history []Message, message string, temp float32) (*OpenAIResponse, error) {
u := fmt.Sprintf("%s/v1/chat/completions", o.BaseURL)
request := CreateCompletionReq{
Model: o.Model,
Messages: append(history, Message{
Role: "user",
Content: message,
}),
Temperature: temp,
}
data, err := json.Marshal(request)
if err != nil {
return nil, err
}
o.Logger.Debug("REQ", u, string(data))
body, err := o.DoRequest(u, request, 0)
if err != nil {
return nil, err
}
o.Logger.Debug("RES", u, string(body))
response := new(OpenAIResponse)
err = json.Unmarshal(body, response)
return response, err
}
func (o *OpenAIAPI) CompressChat(history []Message) (*OpenAIResponse, error) {
request := CreateCompletionReq{
Model: o.Model,
Messages: append(history, Message{
Role: "user",
Content: CompressPrompt,
}),
Temperature: 1.0,
}
u := fmt.Sprintf("%s/v1/chat/completions", o.BaseURL)
data, err := json.Marshal(request)
if err != nil {
return nil, err
}
o.Logger.Debug("COMPRESS REQ", u, string(data))
body, err := o.DoRequest(u, request, 0)
if err != nil {
return nil, err
}
o.Logger.Debug("COMPRESS RES", u, string(body))
response := new(OpenAIResponse)
err = json.Unmarshal(body, response)
return response, err
//https://github.com/sashabaranov/go-openai
func CompressChat(api *openai.API, history []openai.Message) (openai.AIResponse, error) {
return api.CreateCompletion(history, CompressPrompt, 0.0)
}