some changes

This commit is contained in:
2026-03-02 00:58:43 +03:00
parent b394c0be68
commit 3e0d3db47e
13 changed files with 486 additions and 292 deletions

View File

@@ -11,6 +11,7 @@ import (
"ymgb/database/mdb"
"ymgb/database/psql"
"ymgb/database/red"
"ymgb/openai"
"ymgb/utils"
"ymgb/utils/ai"
@@ -557,12 +558,12 @@ func rpUserPromptSet(ctx *laniakea.MsgContext, db *database.Context) {
ctx.Answer("Описание пользователя было обновлено")
}
func _getChatHistory(ctx *laniakea.MsgContext, db *database.Context) ([]ai.Message, error) {
func _getChatHistory(ctx *laniakea.MsgContext, db *database.Context) ([]openai.Message, error) {
redRep := red.NewRPRepository(db)
psqlRep := psql.NewRPRepository(db)
waifuRep := psql.NewWaifuRepository(db)
messages := make([]ai.Message, 0)
messages := make([]openai.Message, 0)
waifuId := redRep.GetSelectedWaifu(ctx.FromID)
chatId, err := redRep.GetOrCreateChatId(ctx.FromID, waifuId)
if err != nil {
@@ -590,7 +591,7 @@ func _getChatHistory(ctx *laniakea.MsgContext, db *database.Context) ([]ai.Messa
//if err != nil {
// return messages, err
//}
beforeHistory := ai.Message{
beforeHistory := openai.Message{
Role: "system",
Content: fmt.Sprintf(
"%s %s %s %s",
@@ -603,7 +604,7 @@ func _getChatHistory(ctx *laniakea.MsgContext, db *database.Context) ([]ai.Messa
userPrompt,
),
}
afterHistory := ai.Message{
afterHistory := openai.Message{
Role: "system",
Content: ai.FormatPrompt(preset.PostHistory, waifu.Name, ctx.From.FirstName),
}
@@ -615,7 +616,7 @@ func _getChatHistory(ctx *laniakea.MsgContext, db *database.Context) ([]ai.Messa
messages = append(messages, beforeHistory)
for _, m := range history {
messages = append(messages, ai.Message{Role: m.Role, Content: m.Message})
messages = append(messages, openai.Message{Role: m.Role, Content: m.Message})
}
messages = append(messages, afterHistory)
return messages, nil
@@ -653,16 +654,41 @@ func generate(ctx *laniakea.MsgContext, db *database.Context) {
kb.AddCallbackButtonStyle("Отменить", laniakea.ButtonStyleDanger, "rp.cancel")
m := ctx.Keyboard("Генерация запущена...", kb)
ctx.SendAction(tgapi.ChatActionTyping)
api := ai.NewOpenAIAPI(ai.GPTBaseUrl, "", rpUser.Model.Key)
api := openai.NewOpenAIAPI(ai.GPTBaseUrl, "", rpUser.Model.Key)
defer api.Close()
res, err := api.CreateCompletion(messages, userMessage, 0.5)
res, err := api.CreateCompletionStream(messages, userMessage, 0.5)
if err != nil {
ctx.Error(err)
return
}
if len(res.Choices) == 0 {
m.Edit("Не удалось сгенерировать ответ. Попробуйте снова позже")
return
answerContent := ""
draft := ctx.NewDraft()
for r, err := range res {
if m != nil {
m.Delete()
m = nil
}
if err != nil {
ctx.Error(err)
return
}
if len(r.Choices) == 0 {
continue
}
content := r.Choices[0].Delta.Content
if content == "" {
continue
}
answerContent += content
err = draft.Push(content)
if err != nil {
ctx.Error(err)
//draft.Flush()
break
}
}
counter := redisRpRep.GetCounter(ctx.FromID, waifuId)
@@ -671,9 +697,7 @@ func generate(ctx *laniakea.MsgContext, db *database.Context) {
ctx.Error(err)
return
}
agentAnswer := res.Choices[0].Message
answerContent := strings.TrimSpace(agentAnswer.Content)
err = mdb.UpdateRPChatHistory(db, chatId, agentAnswer.Role, answerContent, counter+2)
err = mdb.UpdateRPChatHistory(db, chatId, "assistant", answerContent, counter+2)
if err != nil {
ctx.Error(err)
}
@@ -695,7 +719,6 @@ func generate(ctx *laniakea.MsgContext, db *database.Context) {
ctx.Error(err)
}
m.Delete()
kb = laniakea.NewInlineKeyboard(1)
kb.AddCallbackButtonStyle("🔄 Перегенерировать 🔄", laniakea.ButtonStyleSuccess, "rp.regenerate", counter+2)
//kb.AddButton(laniakea.NewInlineKbButton("Тест").SetStyle(laniakea.ButtonStyleSuccess).SetIconCustomEmojiId("5375155835846534814").SetCallbackData("rp.test"))
@@ -729,7 +752,7 @@ func regenerateResponse(ctx *laniakea.MsgContext, db *database.Context) {
ctx.Error(err)
return
}
var messages extypes.Slice[ai.Message]
var messages extypes.Slice[openai.Message]
messages, err = _getChatHistory(ctx, db)
if err != nil {
ctx.Error(err)
@@ -748,7 +771,7 @@ func regenerateResponse(ctx *laniakea.MsgContext, db *database.Context) {
}
//if messages.Len() == count {
// ctx.Bot.Logger().Errorln("len(messages) == count")
// ctx.Bot.logger().Errorln("len(messages) == count")
// return
//}
@@ -766,7 +789,7 @@ func regenerateResponse(ctx *laniakea.MsgContext, db *database.Context) {
return
}
api := ai.NewOpenAIAPI(ai.GPTBaseUrl, "", rpUser.Model.Key)
api := openai.NewOpenAIAPI(ai.GPTBaseUrl, "", rpUser.Model.Key)
defer api.Close()
messages = messages.Pop(count - 2)
@@ -818,9 +841,9 @@ func _compress(ctx *laniakea.MsgContext, db *database.Context) {
return
}
messages := make([]ai.Message, 0)
messages := make([]openai.Message, 0)
for _, m := range history {
messages = append(messages, ai.Message{
messages = append(messages, openai.Message{
Role: m.Role,
Content: m.Message,
})
@@ -833,9 +856,9 @@ func _compress(ctx *laniakea.MsgContext, db *database.Context) {
return
}
api := ai.NewOpenAIAPI(ai.GPTBaseUrl, "", user.Model.Key)
api := openai.NewOpenAIAPI(ai.GPTBaseUrl, "", user.Model.Key)
defer api.Close()
res, err := api.CompressChat(messages)
res, err := ai.CompressChat(api, messages)
if err != nil {
ctx.Error(err)
return