diff --git a/database/mdb/rp_chats.go b/database/mdb/rp_chats.go index bf326d6..4770b27 100644 --- a/database/mdb/rp_chats.go +++ b/database/mdb/rp_chats.go @@ -11,16 +11,17 @@ import ( type RPChatMessage struct { Id bson.ObjectID `bson:"_id"` - ChatID string `bson:"chat_id"` + ChatID string `bson:"chatId"` Role string `bson:"role"` Message string `bson:"message"` + Index int `bson:"index"` } func GetChatHistory(db *laniakea.DatabaseContext, chatId string) ([]*RPChatMessage, error) { ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() col := database.GetMongoCollection(db, "rp_chat_messages") - cursor, err := col.Find(ctx, bson.M{"chat_id": chatId}) + cursor, err := col.Find(ctx, bson.M{"chatId": chatId}) if err != nil { return nil, err } @@ -28,7 +29,7 @@ func GetChatHistory(db *laniakea.DatabaseContext, chatId string) ([]*RPChatMessa err = cursor.All(ctx, &result) return result, err } -func UpdateChatHistory(db *laniakea.DatabaseContext, chatId, role, message string) error { +func UpdateChatHistory(db *laniakea.DatabaseContext, chatId, role, message string, index int) error { ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() col := database.GetMongoCollection(db, "rp_chat_messages") @@ -36,7 +37,7 @@ func UpdateChatHistory(db *laniakea.DatabaseContext, chatId, role, message strin bson.NewObjectID(), chatId, role, - message, + message, index, }) return err } @@ -44,12 +45,12 @@ func GetChatHistorySize(db *laniakea.DatabaseContext, chatId string) (int64, err ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() col := database.GetMongoCollection(db, "rp_chat_messages") - return col.CountDocuments(ctx, bson.M{"chat_id": chatId}) + return col.CountDocuments(ctx, bson.M{"chatId": chatId}) } func DeleteChatEntry(db *laniakea.DatabaseContext, entry *RPChatMessage) error { ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() col := database.GetMongoCollection(db, "rp_chat_messages") - _, err := col.DeleteOne(ctx, bson.M{"chat_id": entry.ChatID}) + _, err := col.DeleteOne(ctx, bson.M{"chatId": entry.ChatID}) return err } diff --git a/plugins/rp.go b/plugins/rp.go index a0180f0..9878e8e 100644 --- a/plugins/rp.go +++ b/plugins/rp.go @@ -605,14 +605,15 @@ func generate(ctx *laniakea.MsgContext, db *laniakea.DatabaseContext) { return } - err = mdb.UpdateChatHistory(db, chatId, "user", userMessage) + counter := redisRpRep.GetCounter(ctx.FromID, waifuId) + err = mdb.UpdateChatHistory(db, chatId, "user", userMessage, counter+1) if err != nil { ctx.Error(err) return } agentAnswer := res.Choices[0].Message answerContent := strings.TrimSpace(agentAnswer.Content) - err = mdb.UpdateChatHistory(db, chatId, agentAnswer.Role, answerContent) + err = mdb.UpdateChatHistory(db, chatId, agentAnswer.Role, answerContent, counter+2) if err != nil { ctx.Error(err) } @@ -629,7 +630,6 @@ func generate(ctx *laniakea.MsgContext, db *laniakea.DatabaseContext) { if err != nil { ctx.Error(err) } - counter := redisRpRep.GetCounter(ctx.FromID, waifuId) err = redisRpRep.SetCounter(ctx.FromID, waifuId, counter+2) if err != nil { ctx.Error(err) @@ -786,14 +786,14 @@ func _compress(ctx *laniakea.MsgContext, db *laniakea.DatabaseContext) { return } - err = mdb.UpdateChatHistory(db, chatId, "assistant", compressedHistory) + err = mdb.UpdateChatHistory(db, chatId, "assistant", compressedHistory, 0) if err != nil { ctx.Error(err) } offset := utils.Min(len(history), 20) - for _, m := range history[len(history)-offset:] { + for i, m := range history[len(history)-offset:] { tokens += len(m.Message) - err = mdb.UpdateChatHistory(db, chatId, m.Role, m.Message) + err = mdb.UpdateChatHistory(db, chatId, m.Role, m.Message, i+1) if err != nil { ctx.Error(err) }