diff --git a/database/red/rp_chats.go b/database/red/rp_chats.go index 0fb64e9..1525e83 100644 --- a/database/red/rp_chats.go +++ b/database/red/rp_chats.go @@ -3,6 +3,7 @@ package red import ( "context" "fmt" + "kurumibot/database/psql" "kurumibot/utils" "strings" @@ -15,10 +16,25 @@ var ctx = context.Background() type RPRepository struct { client *redis.Client + db *laniakea.DatabaseContext +} + +type RPChat struct { + ID uuid.UUID + UserID int + WaifuID int + Prompt string + Counter int + ChatTokens int64 + + SettingID int + Setting *psql.RPSetting + ScenariosIDs []int + Scenarios []psql.RPScenario } func NewRPRepository(db *laniakea.DatabaseContext) RPRepository { - return RPRepository{client: db.Redis} + return RPRepository{db.Redis, db} } func (rep RPRepository) SetSelectedWaifu(userId, waifuId int) error { @@ -55,6 +71,55 @@ func (rep RPRepository) GetOrCreateChatId(userId, waifuId int) (string, error) { err := rep.SetChatId(userId, waifuId, chatId) return chatId, err } +func (rep RPRepository) GetChat(userId, waifuId int) (RPChat, error) { + var chat RPChat + chatId, err := rep.GetOrCreateChatId(userId, waifuId) + if err != nil { + return chat, err + } + + chat.ID = uuid.MustParse(chatId) + chat.UserID = userId + chat.WaifuID = waifuId + chat.Prompt = rep.GetChatPrompt(userId, waifuId) + chat.Counter = rep.GetCounter(userId, waifuId) + chat.ChatTokens = int64(rep.GetChatTokens(userId, waifuId)) + + chat.SettingID = rep.GetChatSettingID(userId, waifuId) + psqlRep := psql.NewRPRepository(rep.db) + setting, err := psqlRep.GetSetting(chat.SettingID) + if err != nil { + return chat, err + } + chat.Setting = &setting + + chat.ScenariosIDs = rep.GetChatScenariosIDs(userId, waifuId) + chat.Scenarios = make([]psql.RPScenario, len(chat.ScenariosIDs)) + for i, id := range chat.ScenariosIDs { + chat.Scenarios[i], err = psqlRep.GetScenario(id) + if err != nil { + return chat, err + } + } + return chat, nil +} + +func (rep RPRepository) SaveChat(chat RPChat) error { + chatId := chat.ID.String() + waifuId := chat.WaifuID + userId := chat.UserID + var err error + + if err = rep.SetChatPrompt(userId, waifuId, chat.Prompt); err != nil { + return err + } + if err = rep.SetCounter(userId, waifuId, chat.Counter); err != nil { + return err + } + if err = rep.SetChatTokens(userId, waifuId, int(chat.ChatTokens)); err != nil { + return err + } +} func (rep RPRepository) SetChatPrompt(userId, waifuId int, prompt string) error { key := fmt.Sprintf("ai.chats.rp.%d.%d.prompt", userId, waifuId) diff --git a/go.mod b/go.mod index b556603..b14c8a4 100644 --- a/go.mod +++ b/go.mod @@ -9,7 +9,7 @@ require ( github.com/google/uuid v1.6.0 github.com/joho/godotenv v1.5.1 github.com/lib/pq v1.11.2 - github.com/redis/go-redis/v9 v9.17.3 + github.com/redis/go-redis/v9 v9.18.0 github.com/shopspring/decimal v1.4.0 github.com/vinovest/sqlx v1.7.1 go.mongodb.org/mongo-driver/v2 v2.5.0 diff --git a/go.sum b/go.sum index c9c92f9..2a439b2 100644 --- a/go.sum +++ b/go.sum @@ -44,6 +44,7 @@ github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZb github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/redis/go-redis/v9 v9.17.3 h1:fN29NdNrE17KttK5Ndf20buqfDZwGNgoUr9qjl1DQx4= github.com/redis/go-redis/v9 v9.17.3/go.mod h1:u410H11HMLoB+TP67dz8rL9s6QW2j76l0//kSOd3370= +github.com/redis/go-redis/v9 v9.18.0/go.mod h1:k3ufPphLU5YXwNTUcCRXGxUoF1fqxnhFQmscfkCoDA0= github.com/shopspring/decimal v1.4.0 h1:bxl37RwXBklmTi0C79JfXCEBD1cqqHt0bbgBAGFp81k= github.com/shopspring/decimal v1.4.0/go.mod h1:gawqmDU56v4yIKSwfBSFip1HdCCXN8/+DMd9qYNcwME= github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U= diff --git a/utils/utils.go b/utils/utils.go index e4f5d40..cafcf95 100644 --- a/utils/utils.go +++ b/utils/utils.go @@ -3,6 +3,7 @@ package utils import ( "fmt" "math/rand/v2" + "reflect" "strconv" ) @@ -54,3 +55,7 @@ func StringToInt(s string) int { func AnyToString[A any](a A) string { return fmt.Sprintf("%v", a) } + +func IsDirty[T any](a T, b T) bool { + return !reflect.DeepEqual(a, b) +}