compress and regenerate
This commit is contained in:
@@ -10,6 +10,7 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
type RPChatMessage struct {
|
type RPChatMessage struct {
|
||||||
|
Id bson.ObjectID `bson:"_id"`
|
||||||
ChatID string `bson:"chat_id"`
|
ChatID string `bson:"chat_id"`
|
||||||
Role string `bson:"role"`
|
Role string `bson:"role"`
|
||||||
Message string `bson:"message"`
|
Message string `bson:"message"`
|
||||||
@@ -32,6 +33,7 @@ func UpdateChatHistory(db *laniakea.DatabaseContext, chatId, role, message strin
|
|||||||
defer cancel()
|
defer cancel()
|
||||||
col := database.GetMongoCollection(db, "rp_chat_messages")
|
col := database.GetMongoCollection(db, "rp_chat_messages")
|
||||||
_, err := col.InsertOne(ctx, RPChatMessage{
|
_, err := col.InsertOne(ctx, RPChatMessage{
|
||||||
|
bson.NewObjectID(),
|
||||||
chatId,
|
chatId,
|
||||||
role,
|
role,
|
||||||
message,
|
message,
|
||||||
@@ -44,3 +46,10 @@ func GetChatHistorySize(db *laniakea.DatabaseContext, chatId string) (int64, err
|
|||||||
col := database.GetMongoCollection(db, "rp_chat_messages")
|
col := database.GetMongoCollection(db, "rp_chat_messages")
|
||||||
return col.CountDocuments(ctx, bson.M{"chat_id": chatId})
|
return col.CountDocuments(ctx, bson.M{"chat_id": chatId})
|
||||||
}
|
}
|
||||||
|
func DeleteChatEntry(db *laniakea.DatabaseContext, entry *RPChatMessage) error {
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
col := database.GetMongoCollection(db, "rp_chat_messages")
|
||||||
|
_, err := col.DeleteOne(ctx, bson.M{"chat_id": entry.ChatID})
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|||||||
159
plugins/rp.go
159
plugins/rp.go
@@ -35,6 +35,7 @@ func RegisterRP(bot *laniakea.Bot) {
|
|||||||
rp.Payload(newChat, "rp.new_chat")
|
rp.Payload(newChat, "rp.new_chat")
|
||||||
rp.Command(generate, "g", "gen", "г")
|
rp.Command(generate, "g", "gen", "г")
|
||||||
rp.Payload(compress, "rp.compress_chat")
|
rp.Payload(compress, "rp.compress_chat")
|
||||||
|
rp.Payload(regenerateResponse, "rp.regenerate")
|
||||||
|
|
||||||
rp.Payload(compressSettingStage1, "rp.compress_setting_s1")
|
rp.Payload(compressSettingStage1, "rp.compress_setting_s1")
|
||||||
rp.Payload(compressSettingStage2, "rp.compress_setting_s2")
|
rp.Payload(compressSettingStage2, "rp.compress_setting_s2")
|
||||||
@@ -485,35 +486,34 @@ func rpUserPromptSet(ctx *laniakea.MsgContext, db *laniakea.DatabaseContext) {
|
|||||||
ctx.Answer("Описание пользователя было обновлено")
|
ctx.Answer("Описание пользователя было обновлено")
|
||||||
}
|
}
|
||||||
|
|
||||||
func generate(ctx *laniakea.MsgContext, db *laniakea.DatabaseContext) {
|
func _getChatHistory(ctx *laniakea.MsgContext, db *laniakea.DatabaseContext) ([]ai.Message, error) {
|
||||||
redisRpRep := red.NewRPRepository(db)
|
redRep := red.NewRPRepository(db)
|
||||||
rpRep := psql.NewRPRepository(db)
|
psqlRep := psql.NewRPRepository(db)
|
||||||
waifuId := redisRpRep.GetSelectedWaifu(ctx.FromID)
|
|
||||||
if waifuId == 0 {
|
|
||||||
ctx.Answer("Не выбрана вайфу")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
waifuRep := psql.NewWaifuRepository(db)
|
waifuRep := psql.NewWaifuRepository(db)
|
||||||
waifu, err := waifuRep.GetById(waifuId)
|
|
||||||
|
messages := make([]ai.Message, 0)
|
||||||
|
waifuId := redRep.GetSelectedWaifu(ctx.FromID)
|
||||||
|
chatId, err := redRep.GetOrCreateChatId(ctx.FromID, waifuId)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
ctx.Error(err)
|
return messages, err
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
rpUser, err := rpRep.GetOrCreateUser(int64(ctx.FromID))
|
waifu, err := waifuRep.GetById(waifuId)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
ctx.Error(err)
|
return messages, err
|
||||||
return
|
|
||||||
}
|
}
|
||||||
preset, err := rpRep.GetUserPreset(rpUser)
|
user, err := psqlRep.GetUser(int64(ctx.FromID))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
ctx.Error(err)
|
return messages, err
|
||||||
return
|
}
|
||||||
|
preset, err := psqlRep.GetPreset(user.SelectedPreset)
|
||||||
|
if err != nil {
|
||||||
|
return messages, err
|
||||||
}
|
}
|
||||||
|
|
||||||
userPrompt := ""
|
userPrompt := ""
|
||||||
if rpUser.UserPrompt != "" {
|
if user.UserPrompt != "" {
|
||||||
userPrompt = fmt.Sprintf("Вот описание моего персонажа %s.", rpUser.UserPrompt)
|
userPrompt = fmt.Sprintf("Вот описание моего персонажа %s.", user.UserPrompt)
|
||||||
}
|
}
|
||||||
beforeHistory := ai.Message{
|
beforeHistory := ai.Message{
|
||||||
Role: "system",
|
Role: "system",
|
||||||
@@ -521,7 +521,7 @@ func generate(ctx *laniakea.MsgContext, db *laniakea.DatabaseContext) {
|
|||||||
"%s %s %s %s",
|
"%s %s %s %s",
|
||||||
ai.FormatPrompt(preset.PreHistory, waifu.Name, ctx.From.FirstName),
|
ai.FormatPrompt(preset.PreHistory, waifu.Name, ctx.From.FirstName),
|
||||||
fmt.Sprintf("Вот краткое описание твоего персонажа: %s.", waifu.RpPrompt),
|
fmt.Sprintf("Вот краткое описание твоего персонажа: %s.", waifu.RpPrompt),
|
||||||
redisRpRep.GetChatPrompt(ctx.FromID, waifuId),
|
redRep.GetChatPrompt(ctx.FromID, waifuId),
|
||||||
userPrompt,
|
userPrompt,
|
||||||
),
|
),
|
||||||
}
|
}
|
||||||
@@ -530,26 +530,46 @@ func generate(ctx *laniakea.MsgContext, db *laniakea.DatabaseContext) {
|
|||||||
Content: ai.FormatPrompt(preset.PostHistory, waifu.Name, ctx.From.FirstName),
|
Content: ai.FormatPrompt(preset.PostHistory, waifu.Name, ctx.From.FirstName),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
history, err := mdb.GetChatHistory(db, chatId)
|
||||||
|
if err != nil {
|
||||||
|
return messages, err
|
||||||
|
}
|
||||||
|
|
||||||
|
messages = append(messages, beforeHistory)
|
||||||
|
for _, m := range history {
|
||||||
|
messages = append(messages, ai.Message{Role: m.Role, Content: m.Message})
|
||||||
|
}
|
||||||
|
messages = append(messages, afterHistory)
|
||||||
|
return messages, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func generate(ctx *laniakea.MsgContext, db *laniakea.DatabaseContext) {
|
||||||
|
redisRpRep := red.NewRPRepository(db)
|
||||||
|
rpRep := psql.NewRPRepository(db)
|
||||||
|
waifuId := redisRpRep.GetSelectedWaifu(ctx.FromID)
|
||||||
|
if waifuId == 0 {
|
||||||
|
ctx.Answer("Не выбрана вайфу")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
rpUser, err := rpRep.GetOrCreateUser(int64(ctx.FromID))
|
||||||
|
if err != nil {
|
||||||
|
ctx.Error(err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
chatId, err := redisRpRep.GetOrCreateChatId(ctx.FromID, waifuId)
|
chatId, err := redisRpRep.GetOrCreateChatId(ctx.FromID, waifuId)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
ctx.Error(err)
|
ctx.Error(err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
history, err := mdb.GetChatHistory(db, chatId)
|
|
||||||
|
messages, err := _getChatHistory(ctx, db)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
ctx.Error(err)
|
ctx.Error(err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
messages := []ai.Message{beforeHistory}
|
|
||||||
for _, m := range history {
|
|
||||||
messages = append(messages, ai.Message{
|
|
||||||
Role: m.Role,
|
|
||||||
Content: strings.TrimSpace(m.Message),
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
userMessage := strings.TrimSpace(strings.Join(ctx.Args, " "))
|
userMessage := strings.TrimSpace(strings.Join(ctx.Args, " "))
|
||||||
messages = append(messages, afterHistory)
|
|
||||||
|
|
||||||
kb := laniakea.NewInlineKeyboard(1).AddCallbackButton("Отменить", "rp.cancel")
|
kb := laniakea.NewInlineKeyboard(1).AddCallbackButton("Отменить", "rp.cancel")
|
||||||
m := ctx.Keyboard("Генерация запущена...", kb)
|
m := ctx.Keyboard("Генерация запущена...", kb)
|
||||||
@@ -596,7 +616,9 @@ func generate(ctx *laniakea.MsgContext, db *laniakea.DatabaseContext) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
m.Delete()
|
m.Delete()
|
||||||
ctx.Answer(laniakea.EscapeMarkdown(answerContent))
|
kb = laniakea.NewInlineKeyboard(1)
|
||||||
|
kb.AddCallbackButton("Перегенерировать", "rp.regenerate", counter+2)
|
||||||
|
ctx.Keyboard(laniakea.EscapeMarkdown(answerContent), kb)
|
||||||
|
|
||||||
// Auto compress
|
// Auto compress
|
||||||
compressMethod := rpUser.CompressMethod
|
compressMethod := rpUser.CompressMethod
|
||||||
@@ -615,6 +637,68 @@ func generate(ctx *laniakea.MsgContext, db *laniakea.DatabaseContext) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func regenerateResponse(ctx *laniakea.MsgContext, db *laniakea.DatabaseContext) {
|
||||||
|
m := ctx.Answer("Запущена повторная генерация…")
|
||||||
|
|
||||||
|
redRep := red.NewRPRepository(db)
|
||||||
|
waifuId := redRep.GetSelectedWaifu(ctx.FromID)
|
||||||
|
|
||||||
|
count, err := strconv.Atoi(ctx.Args[0])
|
||||||
|
if err != nil {
|
||||||
|
ctx.Error(err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
messages, err := _getChatHistory(ctx, db)
|
||||||
|
if err != nil {
|
||||||
|
ctx.Error(err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
chatId, err := redRep.GetOrCreateChatId(ctx.FromID, waifuId)
|
||||||
|
if err != nil {
|
||||||
|
ctx.Error(err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
history, err := mdb.GetChatHistory(db, chatId)
|
||||||
|
if err != nil {
|
||||||
|
ctx.Error(err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 0(system), 1, 2, ..., N-2(user, count-3), N-1(agent, count-2), N(system, count-1)
|
||||||
|
answerToDelete := history[count-2]
|
||||||
|
err = mdb.DeleteChatEntry(db, answerToDelete)
|
||||||
|
if err != nil {
|
||||||
|
ctx.Error(err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
psqlRep := psql.NewRPRepository(db)
|
||||||
|
rpUser, err := psqlRep.GetOrCreateUser(int64(ctx.FromID))
|
||||||
|
if err != nil {
|
||||||
|
ctx.Error(err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
userReq := messages[count-2]
|
||||||
|
api := ai.NewOpenAIAPI(ai.GPTBaseUrl, "", rpUser.Model.Key)
|
||||||
|
defer api.Close()
|
||||||
|
|
||||||
|
messages = utils.PopSlice(messages, count-1)
|
||||||
|
messages = utils.PopSlice(messages, count-2)
|
||||||
|
|
||||||
|
res, err := api.CreateCompletion(messages, userReq.Content, 1.0)
|
||||||
|
if err != nil {
|
||||||
|
ctx.Error(err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
m.Delete()
|
||||||
|
kb := laniakea.NewInlineKeyboard(1)
|
||||||
|
kb.AddCallbackButton("Перегенерировать", "rp.regenerate", count)
|
||||||
|
ctx.EditCallback(laniakea.EscapeMarkdown(res.Choices[0].Message.Content), kb)
|
||||||
|
}
|
||||||
|
|
||||||
func compress(ctx *laniakea.MsgContext, db *laniakea.DatabaseContext) {
|
func compress(ctx *laniakea.MsgContext, db *laniakea.DatabaseContext) {
|
||||||
m := ctx.EditCallback("Запущено сжатие чата…", nil)
|
m := ctx.EditCallback("Запущено сжатие чата…", nil)
|
||||||
_compress(ctx, db)
|
_compress(ctx, db)
|
||||||
@@ -667,6 +751,7 @@ func _compress(ctx *laniakea.MsgContext, db *laniakea.DatabaseContext) {
|
|||||||
compressedHistory = strings.ReplaceAll(compressedHistory, "*", "")
|
compressedHistory = strings.ReplaceAll(compressedHistory, "*", "")
|
||||||
|
|
||||||
ctx.Answer(compressedHistory)
|
ctx.Answer(compressedHistory)
|
||||||
|
tokens := len(compressModel)
|
||||||
|
|
||||||
chatId = uuid.New().String()
|
chatId = uuid.New().String()
|
||||||
err = redisRpRep.SetChatId(ctx.FromID, waifuId, chatId)
|
err = redisRpRep.SetChatId(ctx.FromID, waifuId, chatId)
|
||||||
@@ -681,11 +766,21 @@ func _compress(ctx *laniakea.MsgContext, db *laniakea.DatabaseContext) {
|
|||||||
}
|
}
|
||||||
offset := utils.Min(len(history), 20)
|
offset := utils.Min(len(history), 20)
|
||||||
for _, m := range history[len(history)-offset:] {
|
for _, m := range history[len(history)-offset:] {
|
||||||
|
tokens += len(m.Message)
|
||||||
err = mdb.UpdateChatHistory(db, chatId, m.Role, m.Message)
|
err = mdb.UpdateChatHistory(db, chatId, m.Role, m.Message)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
ctx.Error(err)
|
ctx.Error(err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
err = redisRpRep.SetCounter(ctx.FromID, waifuId, offset+1)
|
||||||
|
if err != nil {
|
||||||
|
ctx.Error(err)
|
||||||
|
}
|
||||||
|
err = redisRpRep.SetChatTokens(ctx.FromID, waifuId, tokens)
|
||||||
|
if err != nil {
|
||||||
|
ctx.Error(err)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
var messagesMethodCount = []int{
|
var messagesMethodCount = []int{
|
||||||
@@ -709,7 +804,7 @@ func compressSettingStage1(ctx *laniakea.MsgContext, _ *laniakea.DatabaseContext
|
|||||||
|
|
||||||
ctx.EditCallback(strings.Join(out, "\n"), kb)
|
ctx.EditCallback(strings.Join(out, "\n"), kb)
|
||||||
}
|
}
|
||||||
func compressSettingStage2(ctx *laniakea.MsgContext, db *laniakea.DatabaseContext) {
|
func compressSettingStage2(ctx *laniakea.MsgContext, _ *laniakea.DatabaseContext) {
|
||||||
if len(ctx.Args) == 0 {
|
if len(ctx.Args) == 0 {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -17,3 +17,14 @@ func Max(a, b int) int {
|
|||||||
}
|
}
|
||||||
return b
|
return b
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func PopSlice[S any](s []S, index int) []S {
|
||||||
|
out := make([]S, 0)
|
||||||
|
for i, e := range s {
|
||||||
|
if i == index {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
out = append(out, e)
|
||||||
|
}
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user