155 lines
4.3 KiB
Go
155 lines
4.3 KiB
Go
package psql
|
|
|
|
import (
|
|
"database/sql"
|
|
"errors"
|
|
|
|
"git.nix13.pw/scuroneko/laniakea"
|
|
"github.com/vinovest/sqlx"
|
|
)
|
|
|
|
type RPPreset struct {
|
|
ID string
|
|
Name string
|
|
Description string
|
|
PreHistory string `db:"pre_history"`
|
|
PostHistory string `db:"post_history"`
|
|
}
|
|
type RPScenario struct {
|
|
ID int
|
|
Name string
|
|
Description string
|
|
Prompt string
|
|
}
|
|
type RPSetting struct {
|
|
ID int
|
|
Name string
|
|
Description string
|
|
Prompt string
|
|
}
|
|
type RPUser struct {
|
|
UserID int64 `db:"user_id"`
|
|
UserPrompt string `db:"user_prompt"`
|
|
UsedTokens int64 `db:"used_tokens"`
|
|
|
|
SelectedPreset string `db:"selected_preset"`
|
|
Preset *RPPreset
|
|
SelectedModel string `db:"selected_model"`
|
|
Model *AIModel
|
|
|
|
CompressMethod string `db:"compress_method"`
|
|
CompressLimit int `db:"compress_limit"`
|
|
}
|
|
|
|
type RPRepository struct {
|
|
db *sqlx.DB
|
|
}
|
|
|
|
func NewRPRepository(db *laniakea.DatabaseContext) *RPRepository {
|
|
return &RPRepository{db.PostgresSQL}
|
|
}
|
|
|
|
func (rep *RPRepository) GetOrCreateUser(id int64) (*RPUser, error) {
|
|
user, err := rep.GetUser(id)
|
|
if errors.Is(err, sql.ErrNoRows) {
|
|
return rep.CreateUser(id)
|
|
}
|
|
return user, err
|
|
}
|
|
func (rep *RPRepository) CreateUser(id int64) (*RPUser, error) {
|
|
user := new(RPUser)
|
|
err := rep.db.Get(user, "INSERT INTO rp_users(user_id) VALUES ($1) RETURNING *;", id)
|
|
if err != nil {
|
|
return user, err
|
|
}
|
|
user.Preset, err = rep.GetPreset(user.SelectedPreset)
|
|
if err != nil {
|
|
return user, err
|
|
}
|
|
aiRep := newAiRepository(rep.db)
|
|
user.Model, err = aiRep.GetModel(user.SelectedModel)
|
|
return user, err
|
|
}
|
|
func (rep *RPRepository) GetUser(id int64) (*RPUser, error) {
|
|
user := new(RPUser)
|
|
err := rep.db.Get(user, "SELECT * FROM rp_users WHERE user_id=$1", id)
|
|
if err != nil {
|
|
return user, err
|
|
}
|
|
user.Preset, err = rep.GetPreset(user.SelectedPreset)
|
|
if err != nil {
|
|
return user, err
|
|
}
|
|
aiRep := newAiRepository(rep.db)
|
|
user.Model, err = aiRep.GetModel(user.SelectedModel)
|
|
return user, err
|
|
}
|
|
func (rep *RPRepository) UpdateUser(user *RPUser) error {
|
|
_, err := rep.db.NamedExec(
|
|
"UPDATE rp_users SET selected_preset=:selected_preset, used_tokens=:used_tokens, user_prompt=:user_prompt, selected_model=:selected_model WHERE user_id=:user_id;",
|
|
user,
|
|
)
|
|
return err
|
|
}
|
|
func (rep *RPRepository) UpdateUserPreset(user *RPUser, presetId string) (*RPPreset, error) {
|
|
preset, err := rep.GetPreset(presetId)
|
|
if err != nil {
|
|
return preset, err
|
|
}
|
|
_, err = rep.db.Exec("UPDATE rp_users SET selected_preset=$1 WHERE user_id=$2;", presetId, user.UserID)
|
|
return preset, err
|
|
}
|
|
func (rep *RPRepository) GetUserPreset(user *RPUser) (*RPPreset, error) {
|
|
preset, err := rep.GetPreset(user.SelectedPreset)
|
|
if errors.Is(err, sql.ErrNoRows) {
|
|
return rep.UpdateUserPreset(user, "soft")
|
|
}
|
|
return preset, err
|
|
}
|
|
|
|
func (rep *RPRepository) GetAllPresets() ([]*RPPreset, error) {
|
|
presets := make([]*RPPreset, 0)
|
|
err := rep.db.Select(&presets, "SELECT * FROM rp_presets ORDER BY id;")
|
|
return presets, err
|
|
}
|
|
func (rep *RPRepository) GetPreset(id string) (*RPPreset, error) {
|
|
preset := new(RPPreset)
|
|
err := rep.db.Get(preset, "SELECT * FROM rp_presets WHERE id=$1;", id)
|
|
return preset, err
|
|
}
|
|
|
|
func (rep *RPRepository) GetAllScenarios() ([]*RPScenario, error) {
|
|
scenarios := make([]*RPScenario, 0)
|
|
err := rep.db.Select(&scenarios, "SELECT * FROM rp_scenarios ORDER BY id;")
|
|
return scenarios, err
|
|
}
|
|
func (rep *RPRepository) GetScenario(id int) (*RPScenario, error) {
|
|
scenario := new(RPScenario)
|
|
err := rep.db.Get(scenario, "SELECT * FROM rp_scenarios WHERE id=$1;", id)
|
|
return scenario, err
|
|
}
|
|
|
|
func (rep *RPRepository) GetAllSettings() ([]*RPSetting, error) {
|
|
settings := make([]*RPSetting, 0)
|
|
err := rep.db.Select(&settings, "SELECT * FROM rp_settings ORDER BY id;")
|
|
return settings, err
|
|
}
|
|
func (rep *RPRepository) GetSetting(id int) (*RPSetting, error) {
|
|
setting := new(RPSetting)
|
|
err := rep.db.Get(setting, "SELECT * FROM rp_settings WHERE id=$1;", id)
|
|
return setting, err
|
|
}
|
|
|
|
func (rep *RPRepository) UpdateUserCompressSettings(user *RPUser) (*RPUser, error) {
|
|
query, args, err := sqlx.In(
|
|
"UPDATE rp_users SET compress_method=?, compress_limit=? WHERE user_id=?;",
|
|
user.CompressMethod, user.CompressLimit, user.UserID,
|
|
)
|
|
if err != nil {
|
|
return user, err
|
|
}
|
|
query = rep.db.Rebind(query)
|
|
_, err = rep.db.Exec(query, args...)
|
|
return user, err
|
|
}
|