some changes
This commit is contained in:
@@ -11,6 +11,7 @@ import (
|
||||
"ymgb/database/mdb"
|
||||
"ymgb/database/psql"
|
||||
"ymgb/database/red"
|
||||
"ymgb/openai"
|
||||
"ymgb/utils"
|
||||
"ymgb/utils/ai"
|
||||
|
||||
@@ -557,12 +558,12 @@ func rpUserPromptSet(ctx *laniakea.MsgContext, db *database.Context) {
|
||||
ctx.Answer("Описание пользователя было обновлено")
|
||||
}
|
||||
|
||||
func _getChatHistory(ctx *laniakea.MsgContext, db *database.Context) ([]ai.Message, error) {
|
||||
func _getChatHistory(ctx *laniakea.MsgContext, db *database.Context) ([]openai.Message, error) {
|
||||
redRep := red.NewRPRepository(db)
|
||||
psqlRep := psql.NewRPRepository(db)
|
||||
waifuRep := psql.NewWaifuRepository(db)
|
||||
|
||||
messages := make([]ai.Message, 0)
|
||||
messages := make([]openai.Message, 0)
|
||||
waifuId := redRep.GetSelectedWaifu(ctx.FromID)
|
||||
chatId, err := redRep.GetOrCreateChatId(ctx.FromID, waifuId)
|
||||
if err != nil {
|
||||
@@ -590,7 +591,7 @@ func _getChatHistory(ctx *laniakea.MsgContext, db *database.Context) ([]ai.Messa
|
||||
//if err != nil {
|
||||
// return messages, err
|
||||
//}
|
||||
beforeHistory := ai.Message{
|
||||
beforeHistory := openai.Message{
|
||||
Role: "system",
|
||||
Content: fmt.Sprintf(
|
||||
"%s %s %s %s",
|
||||
@@ -603,7 +604,7 @@ func _getChatHistory(ctx *laniakea.MsgContext, db *database.Context) ([]ai.Messa
|
||||
userPrompt,
|
||||
),
|
||||
}
|
||||
afterHistory := ai.Message{
|
||||
afterHistory := openai.Message{
|
||||
Role: "system",
|
||||
Content: ai.FormatPrompt(preset.PostHistory, waifu.Name, ctx.From.FirstName),
|
||||
}
|
||||
@@ -615,7 +616,7 @@ func _getChatHistory(ctx *laniakea.MsgContext, db *database.Context) ([]ai.Messa
|
||||
|
||||
messages = append(messages, beforeHistory)
|
||||
for _, m := range history {
|
||||
messages = append(messages, ai.Message{Role: m.Role, Content: m.Message})
|
||||
messages = append(messages, openai.Message{Role: m.Role, Content: m.Message})
|
||||
}
|
||||
messages = append(messages, afterHistory)
|
||||
return messages, nil
|
||||
@@ -653,16 +654,41 @@ func generate(ctx *laniakea.MsgContext, db *database.Context) {
|
||||
kb.AddCallbackButtonStyle("Отменить", laniakea.ButtonStyleDanger, "rp.cancel")
|
||||
m := ctx.Keyboard("Генерация запущена...", kb)
|
||||
ctx.SendAction(tgapi.ChatActionTyping)
|
||||
api := ai.NewOpenAIAPI(ai.GPTBaseUrl, "", rpUser.Model.Key)
|
||||
api := openai.NewOpenAIAPI(ai.GPTBaseUrl, "", rpUser.Model.Key)
|
||||
defer api.Close()
|
||||
res, err := api.CreateCompletion(messages, userMessage, 0.5)
|
||||
res, err := api.CreateCompletionStream(messages, userMessage, 0.5)
|
||||
if err != nil {
|
||||
ctx.Error(err)
|
||||
return
|
||||
}
|
||||
if len(res.Choices) == 0 {
|
||||
m.Edit("Не удалось сгенерировать ответ. Попробуйте снова позже")
|
||||
return
|
||||
|
||||
answerContent := ""
|
||||
draft := ctx.NewDraft()
|
||||
for r, err := range res {
|
||||
if m != nil {
|
||||
m.Delete()
|
||||
m = nil
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
ctx.Error(err)
|
||||
return
|
||||
}
|
||||
|
||||
if len(r.Choices) == 0 {
|
||||
continue
|
||||
}
|
||||
content := r.Choices[0].Delta.Content
|
||||
if content == "" {
|
||||
continue
|
||||
}
|
||||
answerContent += content
|
||||
err = draft.Push(content)
|
||||
if err != nil {
|
||||
ctx.Error(err)
|
||||
//draft.Flush()
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
counter := redisRpRep.GetCounter(ctx.FromID, waifuId)
|
||||
@@ -671,9 +697,7 @@ func generate(ctx *laniakea.MsgContext, db *database.Context) {
|
||||
ctx.Error(err)
|
||||
return
|
||||
}
|
||||
agentAnswer := res.Choices[0].Message
|
||||
answerContent := strings.TrimSpace(agentAnswer.Content)
|
||||
err = mdb.UpdateRPChatHistory(db, chatId, agentAnswer.Role, answerContent, counter+2)
|
||||
err = mdb.UpdateRPChatHistory(db, chatId, "assistant", answerContent, counter+2)
|
||||
if err != nil {
|
||||
ctx.Error(err)
|
||||
}
|
||||
@@ -695,7 +719,6 @@ func generate(ctx *laniakea.MsgContext, db *database.Context) {
|
||||
ctx.Error(err)
|
||||
}
|
||||
|
||||
m.Delete()
|
||||
kb = laniakea.NewInlineKeyboard(1)
|
||||
kb.AddCallbackButtonStyle("🔄 Перегенерировать 🔄", laniakea.ButtonStyleSuccess, "rp.regenerate", counter+2)
|
||||
//kb.AddButton(laniakea.NewInlineKbButton("Тест").SetStyle(laniakea.ButtonStyleSuccess).SetIconCustomEmojiId("5375155835846534814").SetCallbackData("rp.test"))
|
||||
@@ -729,7 +752,7 @@ func regenerateResponse(ctx *laniakea.MsgContext, db *database.Context) {
|
||||
ctx.Error(err)
|
||||
return
|
||||
}
|
||||
var messages extypes.Slice[ai.Message]
|
||||
var messages extypes.Slice[openai.Message]
|
||||
messages, err = _getChatHistory(ctx, db)
|
||||
if err != nil {
|
||||
ctx.Error(err)
|
||||
@@ -748,7 +771,7 @@ func regenerateResponse(ctx *laniakea.MsgContext, db *database.Context) {
|
||||
}
|
||||
|
||||
//if messages.Len() == count {
|
||||
// ctx.Bot.Logger().Errorln("len(messages) == count")
|
||||
// ctx.Bot.logger().Errorln("len(messages) == count")
|
||||
// return
|
||||
//}
|
||||
|
||||
@@ -766,7 +789,7 @@ func regenerateResponse(ctx *laniakea.MsgContext, db *database.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
api := ai.NewOpenAIAPI(ai.GPTBaseUrl, "", rpUser.Model.Key)
|
||||
api := openai.NewOpenAIAPI(ai.GPTBaseUrl, "", rpUser.Model.Key)
|
||||
defer api.Close()
|
||||
|
||||
messages = messages.Pop(count - 2)
|
||||
@@ -818,9 +841,9 @@ func _compress(ctx *laniakea.MsgContext, db *database.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
messages := make([]ai.Message, 0)
|
||||
messages := make([]openai.Message, 0)
|
||||
for _, m := range history {
|
||||
messages = append(messages, ai.Message{
|
||||
messages = append(messages, openai.Message{
|
||||
Role: m.Role,
|
||||
Content: m.Message,
|
||||
})
|
||||
@@ -833,9 +856,9 @@ func _compress(ctx *laniakea.MsgContext, db *database.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
api := ai.NewOpenAIAPI(ai.GPTBaseUrl, "", user.Model.Key)
|
||||
api := openai.NewOpenAIAPI(ai.GPTBaseUrl, "", user.Model.Key)
|
||||
defer api.Close()
|
||||
res, err := api.CompressChat(messages)
|
||||
res, err := ai.CompressChat(api, messages)
|
||||
if err != nil {
|
||||
ctx.Error(err)
|
||||
return
|
||||
|
||||
Reference in New Issue
Block a user