compress and regenerate

This commit is contained in:
2026-02-02 13:43:27 +03:00
parent 2ed1fc9f80
commit 2a71754c6f
3 changed files with 150 additions and 35 deletions

View File

@@ -10,9 +10,10 @@ import (
)
type RPChatMessage struct {
ChatID string `bson:"chat_id"`
Role string `bson:"role"`
Message string `bson:"message"`
Id bson.ObjectID `bson:"_id"`
ChatID string `bson:"chat_id"`
Role string `bson:"role"`
Message string `bson:"message"`
}
func GetChatHistory(db *laniakea.DatabaseContext, chatId string) ([]*RPChatMessage, error) {
@@ -32,6 +33,7 @@ func UpdateChatHistory(db *laniakea.DatabaseContext, chatId, role, message strin
defer cancel()
col := database.GetMongoCollection(db, "rp_chat_messages")
_, err := col.InsertOne(ctx, RPChatMessage{
bson.NewObjectID(),
chatId,
role,
message,
@@ -44,3 +46,10 @@ func GetChatHistorySize(db *laniakea.DatabaseContext, chatId string) (int64, err
col := database.GetMongoCollection(db, "rp_chat_messages")
return col.CountDocuments(ctx, bson.M{"chat_id": chatId})
}
func DeleteChatEntry(db *laniakea.DatabaseContext, entry *RPChatMessage) error {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
col := database.GetMongoCollection(db, "rp_chat_messages")
_, err := col.DeleteOne(ctx, bson.M{"chat_id": entry.ChatID})
return err
}

View File

@@ -35,6 +35,7 @@ func RegisterRP(bot *laniakea.Bot) {
rp.Payload(newChat, "rp.new_chat")
rp.Command(generate, "g", "gen", "г")
rp.Payload(compress, "rp.compress_chat")
rp.Payload(regenerateResponse, "rp.regenerate")
rp.Payload(compressSettingStage1, "rp.compress_setting_s1")
rp.Payload(compressSettingStage2, "rp.compress_setting_s2")
@@ -485,35 +486,34 @@ func rpUserPromptSet(ctx *laniakea.MsgContext, db *laniakea.DatabaseContext) {
ctx.Answer("Описание пользователя было обновлено")
}
func generate(ctx *laniakea.MsgContext, db *laniakea.DatabaseContext) {
redisRpRep := red.NewRPRepository(db)
rpRep := psql.NewRPRepository(db)
waifuId := redisRpRep.GetSelectedWaifu(ctx.FromID)
if waifuId == 0 {
ctx.Answer("Не выбрана вайфу")
return
}
func _getChatHistory(ctx *laniakea.MsgContext, db *laniakea.DatabaseContext) ([]ai.Message, error) {
redRep := red.NewRPRepository(db)
psqlRep := psql.NewRPRepository(db)
waifuRep := psql.NewWaifuRepository(db)
waifu, err := waifuRep.GetById(waifuId)
messages := make([]ai.Message, 0)
waifuId := redRep.GetSelectedWaifu(ctx.FromID)
chatId, err := redRep.GetOrCreateChatId(ctx.FromID, waifuId)
if err != nil {
ctx.Error(err)
return
return messages, err
}
rpUser, err := rpRep.GetOrCreateUser(int64(ctx.FromID))
waifu, err := waifuRep.GetById(waifuId)
if err != nil {
ctx.Error(err)
return
return messages, err
}
preset, err := rpRep.GetUserPreset(rpUser)
user, err := psqlRep.GetUser(int64(ctx.FromID))
if err != nil {
ctx.Error(err)
return
return messages, err
}
preset, err := psqlRep.GetPreset(user.SelectedPreset)
if err != nil {
return messages, err
}
userPrompt := ""
if rpUser.UserPrompt != "" {
userPrompt = fmt.Sprintf("Вот описание моего персонажа %s.", rpUser.UserPrompt)
if user.UserPrompt != "" {
userPrompt = fmt.Sprintf("Вот описание моего персонажа %s.", user.UserPrompt)
}
beforeHistory := ai.Message{
Role: "system",
@@ -521,7 +521,7 @@ func generate(ctx *laniakea.MsgContext, db *laniakea.DatabaseContext) {
"%s %s %s %s",
ai.FormatPrompt(preset.PreHistory, waifu.Name, ctx.From.FirstName),
fmt.Sprintf("Вот краткое описание твоего персонажа: %s.", waifu.RpPrompt),
redisRpRep.GetChatPrompt(ctx.FromID, waifuId),
redRep.GetChatPrompt(ctx.FromID, waifuId),
userPrompt,
),
}
@@ -530,26 +530,46 @@ func generate(ctx *laniakea.MsgContext, db *laniakea.DatabaseContext) {
Content: ai.FormatPrompt(preset.PostHistory, waifu.Name, ctx.From.FirstName),
}
history, err := mdb.GetChatHistory(db, chatId)
if err != nil {
return messages, err
}
messages = append(messages, beforeHistory)
for _, m := range history {
messages = append(messages, ai.Message{Role: m.Role, Content: m.Message})
}
messages = append(messages, afterHistory)
return messages, nil
}
func generate(ctx *laniakea.MsgContext, db *laniakea.DatabaseContext) {
redisRpRep := red.NewRPRepository(db)
rpRep := psql.NewRPRepository(db)
waifuId := redisRpRep.GetSelectedWaifu(ctx.FromID)
if waifuId == 0 {
ctx.Answer("Не выбрана вайфу")
return
}
rpUser, err := rpRep.GetOrCreateUser(int64(ctx.FromID))
if err != nil {
ctx.Error(err)
return
}
chatId, err := redisRpRep.GetOrCreateChatId(ctx.FromID, waifuId)
if err != nil {
ctx.Error(err)
return
}
history, err := mdb.GetChatHistory(db, chatId)
messages, err := _getChatHistory(ctx, db)
if err != nil {
ctx.Error(err)
return
}
messages := []ai.Message{beforeHistory}
for _, m := range history {
messages = append(messages, ai.Message{
Role: m.Role,
Content: strings.TrimSpace(m.Message),
})
}
userMessage := strings.TrimSpace(strings.Join(ctx.Args, " "))
messages = append(messages, afterHistory)
kb := laniakea.NewInlineKeyboard(1).AddCallbackButton("Отменить", "rp.cancel")
m := ctx.Keyboard("Генерация запущена...", kb)
@@ -596,7 +616,9 @@ func generate(ctx *laniakea.MsgContext, db *laniakea.DatabaseContext) {
}
m.Delete()
ctx.Answer(laniakea.EscapeMarkdown(answerContent))
kb = laniakea.NewInlineKeyboard(1)
kb.AddCallbackButton("Перегенерировать", "rp.regenerate", counter+2)
ctx.Keyboard(laniakea.EscapeMarkdown(answerContent), kb)
// Auto compress
compressMethod := rpUser.CompressMethod
@@ -615,6 +637,68 @@ func generate(ctx *laniakea.MsgContext, db *laniakea.DatabaseContext) {
}
}
func regenerateResponse(ctx *laniakea.MsgContext, db *laniakea.DatabaseContext) {
m := ctx.Answer("Запущена повторная генерация…")
redRep := red.NewRPRepository(db)
waifuId := redRep.GetSelectedWaifu(ctx.FromID)
count, err := strconv.Atoi(ctx.Args[0])
if err != nil {
ctx.Error(err)
return
}
messages, err := _getChatHistory(ctx, db)
if err != nil {
ctx.Error(err)
return
}
chatId, err := redRep.GetOrCreateChatId(ctx.FromID, waifuId)
if err != nil {
ctx.Error(err)
return
}
history, err := mdb.GetChatHistory(db, chatId)
if err != nil {
ctx.Error(err)
return
}
// 0(system), 1, 2, ..., N-2(user, count-3), N-1(agent, count-2), N(system, count-1)
answerToDelete := history[count-2]
err = mdb.DeleteChatEntry(db, answerToDelete)
if err != nil {
ctx.Error(err)
return
}
psqlRep := psql.NewRPRepository(db)
rpUser, err := psqlRep.GetOrCreateUser(int64(ctx.FromID))
if err != nil {
ctx.Error(err)
return
}
userReq := messages[count-2]
api := ai.NewOpenAIAPI(ai.GPTBaseUrl, "", rpUser.Model.Key)
defer api.Close()
messages = utils.PopSlice(messages, count-1)
messages = utils.PopSlice(messages, count-2)
res, err := api.CreateCompletion(messages, userReq.Content, 1.0)
if err != nil {
ctx.Error(err)
return
}
m.Delete()
kb := laniakea.NewInlineKeyboard(1)
kb.AddCallbackButton("Перегенерировать", "rp.regenerate", count)
ctx.EditCallback(laniakea.EscapeMarkdown(res.Choices[0].Message.Content), kb)
}
func compress(ctx *laniakea.MsgContext, db *laniakea.DatabaseContext) {
m := ctx.EditCallback("Запущено сжатие чата…", nil)
_compress(ctx, db)
@@ -667,6 +751,7 @@ func _compress(ctx *laniakea.MsgContext, db *laniakea.DatabaseContext) {
compressedHistory = strings.ReplaceAll(compressedHistory, "*", "")
ctx.Answer(compressedHistory)
tokens := len(compressModel)
chatId = uuid.New().String()
err = redisRpRep.SetChatId(ctx.FromID, waifuId, chatId)
@@ -681,11 +766,21 @@ func _compress(ctx *laniakea.MsgContext, db *laniakea.DatabaseContext) {
}
offset := utils.Min(len(history), 20)
for _, m := range history[len(history)-offset:] {
tokens += len(m.Message)
err = mdb.UpdateChatHistory(db, chatId, m.Role, m.Message)
if err != nil {
ctx.Error(err)
}
}
err = redisRpRep.SetCounter(ctx.FromID, waifuId, offset+1)
if err != nil {
ctx.Error(err)
}
err = redisRpRep.SetChatTokens(ctx.FromID, waifuId, tokens)
if err != nil {
ctx.Error(err)
}
}
var messagesMethodCount = []int{
@@ -709,7 +804,7 @@ func compressSettingStage1(ctx *laniakea.MsgContext, _ *laniakea.DatabaseContext
ctx.EditCallback(strings.Join(out, "\n"), kb)
}
func compressSettingStage2(ctx *laniakea.MsgContext, db *laniakea.DatabaseContext) {
func compressSettingStage2(ctx *laniakea.MsgContext, _ *laniakea.DatabaseContext) {
if len(ctx.Args) == 0 {
return
}

View File

@@ -17,3 +17,14 @@ func Max(a, b int) int {
}
return b
}
func PopSlice[S any](s []S, index int) []S {
out := make([]S, 0)
for i, e := range s {
if i == index {
continue
}
out = append(out, e)
}
return out
}