diff --git a/database/mdb/rp_chats.go b/database/mdb/rp_chats.go index 21df1f5..bf326d6 100644 --- a/database/mdb/rp_chats.go +++ b/database/mdb/rp_chats.go @@ -10,9 +10,10 @@ import ( ) type RPChatMessage struct { - ChatID string `bson:"chat_id"` - Role string `bson:"role"` - Message string `bson:"message"` + Id bson.ObjectID `bson:"_id"` + ChatID string `bson:"chat_id"` + Role string `bson:"role"` + Message string `bson:"message"` } func GetChatHistory(db *laniakea.DatabaseContext, chatId string) ([]*RPChatMessage, error) { @@ -32,6 +33,7 @@ func UpdateChatHistory(db *laniakea.DatabaseContext, chatId, role, message strin defer cancel() col := database.GetMongoCollection(db, "rp_chat_messages") _, err := col.InsertOne(ctx, RPChatMessage{ + bson.NewObjectID(), chatId, role, message, @@ -44,3 +46,10 @@ func GetChatHistorySize(db *laniakea.DatabaseContext, chatId string) (int64, err col := database.GetMongoCollection(db, "rp_chat_messages") return col.CountDocuments(ctx, bson.M{"chat_id": chatId}) } +func DeleteChatEntry(db *laniakea.DatabaseContext, entry *RPChatMessage) error { + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + col := database.GetMongoCollection(db, "rp_chat_messages") + _, err := col.DeleteOne(ctx, bson.M{"chat_id": entry.ChatID}) + return err +} diff --git a/plugins/rp.go b/plugins/rp.go index 769f8ac..116466e 100644 --- a/plugins/rp.go +++ b/plugins/rp.go @@ -35,6 +35,7 @@ func RegisterRP(bot *laniakea.Bot) { rp.Payload(newChat, "rp.new_chat") rp.Command(generate, "g", "gen", "г") rp.Payload(compress, "rp.compress_chat") + rp.Payload(regenerateResponse, "rp.regenerate") rp.Payload(compressSettingStage1, "rp.compress_setting_s1") rp.Payload(compressSettingStage2, "rp.compress_setting_s2") @@ -485,35 +486,34 @@ func rpUserPromptSet(ctx *laniakea.MsgContext, db *laniakea.DatabaseContext) { ctx.Answer("Описание пользователя было обновлено") } -func generate(ctx *laniakea.MsgContext, db *laniakea.DatabaseContext) { - redisRpRep := red.NewRPRepository(db) - rpRep := psql.NewRPRepository(db) - waifuId := redisRpRep.GetSelectedWaifu(ctx.FromID) - if waifuId == 0 { - ctx.Answer("Не выбрана вайфу") - return - } +func _getChatHistory(ctx *laniakea.MsgContext, db *laniakea.DatabaseContext) ([]ai.Message, error) { + redRep := red.NewRPRepository(db) + psqlRep := psql.NewRPRepository(db) waifuRep := psql.NewWaifuRepository(db) - waifu, err := waifuRep.GetById(waifuId) + + messages := make([]ai.Message, 0) + waifuId := redRep.GetSelectedWaifu(ctx.FromID) + chatId, err := redRep.GetOrCreateChatId(ctx.FromID, waifuId) if err != nil { - ctx.Error(err) - return + return messages, err } - rpUser, err := rpRep.GetOrCreateUser(int64(ctx.FromID)) + waifu, err := waifuRep.GetById(waifuId) if err != nil { - ctx.Error(err) - return + return messages, err } - preset, err := rpRep.GetUserPreset(rpUser) + user, err := psqlRep.GetUser(int64(ctx.FromID)) if err != nil { - ctx.Error(err) - return + return messages, err + } + preset, err := psqlRep.GetPreset(user.SelectedPreset) + if err != nil { + return messages, err } userPrompt := "" - if rpUser.UserPrompt != "" { - userPrompt = fmt.Sprintf("Вот описание моего персонажа %s.", rpUser.UserPrompt) + if user.UserPrompt != "" { + userPrompt = fmt.Sprintf("Вот описание моего персонажа %s.", user.UserPrompt) } beforeHistory := ai.Message{ Role: "system", @@ -521,7 +521,7 @@ func generate(ctx *laniakea.MsgContext, db *laniakea.DatabaseContext) { "%s %s %s %s", ai.FormatPrompt(preset.PreHistory, waifu.Name, ctx.From.FirstName), fmt.Sprintf("Вот краткое описание твоего персонажа: %s.", waifu.RpPrompt), - redisRpRep.GetChatPrompt(ctx.FromID, waifuId), + redRep.GetChatPrompt(ctx.FromID, waifuId), userPrompt, ), } @@ -530,26 +530,46 @@ func generate(ctx *laniakea.MsgContext, db *laniakea.DatabaseContext) { Content: ai.FormatPrompt(preset.PostHistory, waifu.Name, ctx.From.FirstName), } + history, err := mdb.GetChatHistory(db, chatId) + if err != nil { + return messages, err + } + + messages = append(messages, beforeHistory) + for _, m := range history { + messages = append(messages, ai.Message{Role: m.Role, Content: m.Message}) + } + messages = append(messages, afterHistory) + return messages, nil +} + +func generate(ctx *laniakea.MsgContext, db *laniakea.DatabaseContext) { + redisRpRep := red.NewRPRepository(db) + rpRep := psql.NewRPRepository(db) + waifuId := redisRpRep.GetSelectedWaifu(ctx.FromID) + if waifuId == 0 { + ctx.Answer("Не выбрана вайфу") + return + } + + rpUser, err := rpRep.GetOrCreateUser(int64(ctx.FromID)) + if err != nil { + ctx.Error(err) + return + } + chatId, err := redisRpRep.GetOrCreateChatId(ctx.FromID, waifuId) if err != nil { ctx.Error(err) return } - history, err := mdb.GetChatHistory(db, chatId) + + messages, err := _getChatHistory(ctx, db) if err != nil { ctx.Error(err) return } - messages := []ai.Message{beforeHistory} - for _, m := range history { - messages = append(messages, ai.Message{ - Role: m.Role, - Content: strings.TrimSpace(m.Message), - }) - } - userMessage := strings.TrimSpace(strings.Join(ctx.Args, " ")) - messages = append(messages, afterHistory) kb := laniakea.NewInlineKeyboard(1).AddCallbackButton("Отменить", "rp.cancel") m := ctx.Keyboard("Генерация запущена...", kb) @@ -596,7 +616,9 @@ func generate(ctx *laniakea.MsgContext, db *laniakea.DatabaseContext) { } m.Delete() - ctx.Answer(laniakea.EscapeMarkdown(answerContent)) + kb = laniakea.NewInlineKeyboard(1) + kb.AddCallbackButton("Перегенерировать", "rp.regenerate", counter+2) + ctx.Keyboard(laniakea.EscapeMarkdown(answerContent), kb) // Auto compress compressMethod := rpUser.CompressMethod @@ -615,6 +637,68 @@ func generate(ctx *laniakea.MsgContext, db *laniakea.DatabaseContext) { } } +func regenerateResponse(ctx *laniakea.MsgContext, db *laniakea.DatabaseContext) { + m := ctx.Answer("Запущена повторная генерация…") + + redRep := red.NewRPRepository(db) + waifuId := redRep.GetSelectedWaifu(ctx.FromID) + + count, err := strconv.Atoi(ctx.Args[0]) + if err != nil { + ctx.Error(err) + return + } + messages, err := _getChatHistory(ctx, db) + if err != nil { + ctx.Error(err) + return + } + + chatId, err := redRep.GetOrCreateChatId(ctx.FromID, waifuId) + if err != nil { + ctx.Error(err) + return + } + history, err := mdb.GetChatHistory(db, chatId) + if err != nil { + ctx.Error(err) + return + } + + // 0(system), 1, 2, ..., N-2(user, count-3), N-1(agent, count-2), N(system, count-1) + answerToDelete := history[count-2] + err = mdb.DeleteChatEntry(db, answerToDelete) + if err != nil { + ctx.Error(err) + return + } + + psqlRep := psql.NewRPRepository(db) + rpUser, err := psqlRep.GetOrCreateUser(int64(ctx.FromID)) + if err != nil { + ctx.Error(err) + return + } + + userReq := messages[count-2] + api := ai.NewOpenAIAPI(ai.GPTBaseUrl, "", rpUser.Model.Key) + defer api.Close() + + messages = utils.PopSlice(messages, count-1) + messages = utils.PopSlice(messages, count-2) + + res, err := api.CreateCompletion(messages, userReq.Content, 1.0) + if err != nil { + ctx.Error(err) + return + } + + m.Delete() + kb := laniakea.NewInlineKeyboard(1) + kb.AddCallbackButton("Перегенерировать", "rp.regenerate", count) + ctx.EditCallback(laniakea.EscapeMarkdown(res.Choices[0].Message.Content), kb) +} + func compress(ctx *laniakea.MsgContext, db *laniakea.DatabaseContext) { m := ctx.EditCallback("Запущено сжатие чата…", nil) _compress(ctx, db) @@ -667,6 +751,7 @@ func _compress(ctx *laniakea.MsgContext, db *laniakea.DatabaseContext) { compressedHistory = strings.ReplaceAll(compressedHistory, "*", "") ctx.Answer(compressedHistory) + tokens := len(compressModel) chatId = uuid.New().String() err = redisRpRep.SetChatId(ctx.FromID, waifuId, chatId) @@ -681,11 +766,21 @@ func _compress(ctx *laniakea.MsgContext, db *laniakea.DatabaseContext) { } offset := utils.Min(len(history), 20) for _, m := range history[len(history)-offset:] { + tokens += len(m.Message) err = mdb.UpdateChatHistory(db, chatId, m.Role, m.Message) if err != nil { ctx.Error(err) } } + + err = redisRpRep.SetCounter(ctx.FromID, waifuId, offset+1) + if err != nil { + ctx.Error(err) + } + err = redisRpRep.SetChatTokens(ctx.FromID, waifuId, tokens) + if err != nil { + ctx.Error(err) + } } var messagesMethodCount = []int{ @@ -709,7 +804,7 @@ func compressSettingStage1(ctx *laniakea.MsgContext, _ *laniakea.DatabaseContext ctx.EditCallback(strings.Join(out, "\n"), kb) } -func compressSettingStage2(ctx *laniakea.MsgContext, db *laniakea.DatabaseContext) { +func compressSettingStage2(ctx *laniakea.MsgContext, _ *laniakea.DatabaseContext) { if len(ctx.Args) == 0 { return } diff --git a/utils/utils.go b/utils/utils.go index a02a96f..0728365 100644 --- a/utils/utils.go +++ b/utils/utils.go @@ -17,3 +17,14 @@ func Max(a, b int) int { } return b } + +func PopSlice[S any](s []S, index int) []S { + out := make([]S, 0) + for i, e := range s { + if i == index { + continue + } + out = append(out, e) + } + return out +}