some changes
This commit is contained in:
@@ -8,6 +8,7 @@ RUN --mount=type=cache,target=/go/pkg/mod go mod download -x
|
||||
COPY ./database ./database
|
||||
COPY ./plugins ./plugins
|
||||
COPY ./utils ./utils
|
||||
COPY ./openai ./openai
|
||||
COPY ./main.go ./
|
||||
RUN --mount=type=cache,target=/root/.cache/go-build \
|
||||
--mount=type=cache,target=/go/pkg/mod \
|
||||
|
||||
@@ -1,3 +1,3 @@
|
||||
go mod tidy
|
||||
docker build --build-arg GIT_COMMIT="DEV" --build-arg BUILD_TIME="DEV" -t git.nix13.pw/scuroneko/kurumibotgo:dev -f ./Dockerfile .
|
||||
docker push git.nix13.pw/scuroneko/kurumibotgo:dev
|
||||
docker build --build-arg GIT_COMMIT="DEV" --build-arg BUILD_TIME="DEV" -t git.nix13.pw/scuroneko/yaemikobot:dev -f ./Dockerfile .
|
||||
docker push git.nix13.pw/scuroneko/yaemikobot:dev
|
||||
4
go.mod
4
go.mod
@@ -4,7 +4,7 @@ go 1.26.0
|
||||
|
||||
require (
|
||||
git.nix13.pw/scuroneko/extypes v1.2.1
|
||||
git.nix13.pw/scuroneko/laniakea v1.0.0-beta.3
|
||||
git.nix13.pw/scuroneko/laniakea v1.0.0-beta.6
|
||||
git.nix13.pw/scuroneko/slog v1.0.2
|
||||
github.com/google/uuid v1.6.0
|
||||
github.com/joho/godotenv v1.5.1
|
||||
@@ -15,7 +15,7 @@ require (
|
||||
go.mongodb.org/mongo-driver/v2 v2.5.0
|
||||
)
|
||||
|
||||
//replace git.nix13.pw/scuroneko/laniakea v1.0.0-beta.2 => ./laniakea
|
||||
//replace git.nix13.pw/scuroneko/laniakea v1.0.0-beta.6 => ./laniakea
|
||||
//replace git.nix13.pw/scuroneko/extypes v1.2.1 => ../go-extypes
|
||||
//replace git.nix13.pw/scuroneko/slog v1.0.2 => ../slog
|
||||
|
||||
|
||||
4
go.sum
4
go.sum
@@ -2,8 +2,8 @@ filippo.io/edwards25519 v1.1.0 h1:FNf4tywRC1HmFuKW5xopWpigGjJKiJSV0Cqo0cJWDaA=
|
||||
filippo.io/edwards25519 v1.1.0/go.mod h1:BxyFTGdWcka3PhytdK4V28tE5sGfRvvvRV7EaN4VDT4=
|
||||
git.nix13.pw/scuroneko/extypes v1.2.1 h1:IYrOjnWKL2EAuJYtYNa+luB1vBe6paE8VY/YD+5/RpQ=
|
||||
git.nix13.pw/scuroneko/extypes v1.2.1/go.mod h1:uZVs8Yo3RrYAG9dMad6qR6lsYY67t+459D9c65QAYAw=
|
||||
git.nix13.pw/scuroneko/laniakea v1.0.0-beta.3 h1:FtEpeJ6Hi8/RGyT3m7Ysf2AIkwVLflc75HMSQzxPvnc=
|
||||
git.nix13.pw/scuroneko/laniakea v1.0.0-beta.3/go.mod h1:DZgCqOazRzoa+f/GSNuKnTB2wIZ1eJD3cGf34Qya31U=
|
||||
git.nix13.pw/scuroneko/laniakea v1.0.0-beta.6 h1:4XgYsXgx68/UkXORku2245yLaT+NnKnipanKpUKsuoI=
|
||||
git.nix13.pw/scuroneko/laniakea v1.0.0-beta.6/go.mod h1:DZgCqOazRzoa+f/GSNuKnTB2wIZ1eJD3cGf34Qya31U=
|
||||
git.nix13.pw/scuroneko/slog v1.0.2 h1:vZyUROygxC2d5FJHUQM/30xFEHY1JT/aweDZXA4rm2g=
|
||||
git.nix13.pw/scuroneko/slog v1.0.2/go.mod h1:3Qm2wzkR5KjwOponMfG7TcGSDjmYaFqRAmLvSPTuWJI=
|
||||
github.com/alitto/pond/v2 v2.6.2 h1:Sphe40g0ILeM1pA2c2K+Th0DGU+pt0A/Kprr+WB24Pw=
|
||||
|
||||
204
openai/api.go
Normal file
204
openai/api.go
Normal file
@@ -0,0 +1,204 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"iter"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"git.nix13.pw/scuroneko/slog"
|
||||
)
|
||||
|
||||
type API struct {
|
||||
token string
|
||||
model string
|
||||
baseUrl string
|
||||
logger *slog.Logger
|
||||
client *http.Client
|
||||
stream bool
|
||||
}
|
||||
|
||||
func NewOpenAIAPI(baseURL, token, model string) *API {
|
||||
logger := slog.CreateLogger()
|
||||
level := slog.FATAL
|
||||
if os.Getenv("DEBUG") == "true" {
|
||||
level = slog.DEBUG
|
||||
}
|
||||
logger = logger.Prefix("AI").Level(level)
|
||||
// FIXME Leak here
|
||||
//logger = logger.AddWriter(logger.CreateJsonStdoutWriter())
|
||||
proxy, err := url.Parse(os.Getenv("HTTPS_PROXY"))
|
||||
if err != nil {
|
||||
logger.Error(err)
|
||||
logger.Close()
|
||||
return nil
|
||||
}
|
||||
t := &http.Transport{}
|
||||
if proxy.Host != "" {
|
||||
t.Proxy = http.ProxyURL(proxy)
|
||||
}
|
||||
client := &http.Client{
|
||||
Timeout: 5 * time.Minute,
|
||||
Transport: t,
|
||||
}
|
||||
return &API{
|
||||
token: token,
|
||||
model: model,
|
||||
baseUrl: baseURL,
|
||||
logger: logger,
|
||||
client: client,
|
||||
}
|
||||
}
|
||||
func (api *API) Close() error {
|
||||
return api.logger.Close()
|
||||
}
|
||||
func (api *API) SetStream(stream bool) *API {
|
||||
api.stream = stream
|
||||
return api
|
||||
}
|
||||
func (api *API) GetModel() string { return api.model }
|
||||
func (api *API) GetBaseURL() string { return api.baseUrl }
|
||||
|
||||
type Request[P any] struct {
|
||||
params P
|
||||
method string
|
||||
}
|
||||
|
||||
func NewRequest[P any](method string, params P) *Request[P] {
|
||||
return &Request[P]{params, method}
|
||||
}
|
||||
|
||||
func (r *Request[P]) doWithContext(ctx context.Context, api *API) (io.ReadCloser, error) {
|
||||
data, err := json.Marshal(r.params)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
u := fmt.Sprintf("%s/v1/%s", api.baseUrl, r.method)
|
||||
req, err := http.NewRequestWithContext(ctx, "POST", u, bytes.NewBuffer(data))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
if api.token != "" {
|
||||
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", api.token))
|
||||
}
|
||||
|
||||
res, err := api.client.Do(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if res.StatusCode == 504 || res.StatusCode == 400 || res.StatusCode == 502 {
|
||||
api.logger.Warn(fmt.Sprintf("[%d] %s", res.StatusCode, res.Status))
|
||||
res.Body.Close()
|
||||
return nil, fmt.Errorf("[%d] %s", res.StatusCode, res.Status)
|
||||
}
|
||||
return res.Body, nil
|
||||
}
|
||||
func (r *Request[P]) do(api *API) (io.ReadCloser, error) {
|
||||
ctx := context.Background()
|
||||
return r.doWithContext(ctx, api)
|
||||
}
|
||||
|
||||
func (r *Request[P]) DoWithContext(ctx context.Context, api *API) (AIResponse, error) {
|
||||
var zero AIResponse
|
||||
body, err := r.doWithContext(ctx, api)
|
||||
if err != nil {
|
||||
return zero, err
|
||||
}
|
||||
defer body.Close()
|
||||
data, err := io.ReadAll(body)
|
||||
if err != nil {
|
||||
return zero, err
|
||||
}
|
||||
err = api.handleAIError(data)
|
||||
if err != nil {
|
||||
return zero, err
|
||||
}
|
||||
|
||||
err = json.Unmarshal(data, &zero)
|
||||
return zero, err
|
||||
}
|
||||
func (r *Request[P]) Do(api *API) (AIResponse, error) {
|
||||
ctx := context.Background()
|
||||
return r.DoWithContext(ctx, api)
|
||||
}
|
||||
func (r *Request[P]) DoStreamWithContext(ctx context.Context, api *API) (iter.Seq2[AIResponse, error], error) {
|
||||
body, err := r.doWithContext(ctx, api)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
reader := bufio.NewReader(body)
|
||||
return func(yield func(AIResponse, error) bool) {
|
||||
defer body.Close()
|
||||
var zero AIResponse
|
||||
for {
|
||||
line, err := reader.ReadString('\n')
|
||||
if err != nil && err != io.EOF {
|
||||
yield(zero, err)
|
||||
return
|
||||
}
|
||||
if line == "" || line == "\n" {
|
||||
continue
|
||||
}
|
||||
|
||||
if strings.HasPrefix(line, "data: ") {
|
||||
line = line[len("data: "):]
|
||||
}
|
||||
line = strings.Trim(strings.Trim(strings.TrimSpace(line), "\r"), "\n")
|
||||
if strings.HasPrefix(line, "[DONE]") {
|
||||
return
|
||||
}
|
||||
|
||||
var resp AIResponse
|
||||
err = json.Unmarshal([]byte(line), &resp)
|
||||
if err != nil {
|
||||
yield(zero, fmt.Errorf("%v\n%s", err, line))
|
||||
return
|
||||
}
|
||||
if !yield(resp, nil) {
|
||||
return
|
||||
}
|
||||
time.Sleep(time.Millisecond * 100)
|
||||
}
|
||||
}, nil
|
||||
}
|
||||
func (r *Request[P]) DoStream(api *API) (iter.Seq2[AIResponse, error], error) {
|
||||
ctx := context.Background()
|
||||
return r.DoStreamWithContext(ctx, api)
|
||||
}
|
||||
|
||||
func (api *API) handleAIError(body []byte) error {
|
||||
var tempData any
|
||||
err := json.Unmarshal(body, &tempData)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// {"error":{"message":"openai_error","type":"bad_response_status_code","param":"","code":"bad_response_status_code"}}
|
||||
if eData, ok := tempData.(map[string]any); ok {
|
||||
if errorData, ok := eData["error"]; ok {
|
||||
if errorPayload, ok := errorData.(map[string]any); ok {
|
||||
code, ok := errorPayload["code"]
|
||||
if !ok {
|
||||
return errors.New("unknown error code")
|
||||
}
|
||||
return errors.New(fmt.Sprintf("%v", code))
|
||||
}
|
||||
return errors.New(string(body))
|
||||
}
|
||||
} else if eData, ok := tempData.(string); ok {
|
||||
return errors.New(eData)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
49
openai/completitions.go
Normal file
49
openai/completitions.go
Normal file
@@ -0,0 +1,49 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"iter"
|
||||
)
|
||||
|
||||
var MaxRetriesErr = fmt.Errorf("max retries exceeded")
|
||||
var BadResponseErr = fmt.Errorf("bad_response_status_code")
|
||||
|
||||
type CreateCompletionReq struct {
|
||||
Model string `json:"model"`
|
||||
Messages []Message `json:"messages"`
|
||||
Verbosity string `json:"verbosity,omitempty"`
|
||||
Temperature float32 `json:"temperature,omitempty"`
|
||||
PresencePenalty int `json:"presence_penalty,omitempty"`
|
||||
FrequencyPenalty int `json:"frequency_penalty,omitempty"`
|
||||
TopP int `json:"top_p,omitempty"`
|
||||
MaxCompletionTokens int `json:"max_completition_tokens,omitempty"`
|
||||
Stream bool `json:"stream,omitempty"`
|
||||
}
|
||||
|
||||
func (api *API) CreateCompletionStream(history []Message, message string, temp float32) (iter.Seq2[AIResponse, error], error) {
|
||||
params := CreateCompletionReq{
|
||||
Model: api.model,
|
||||
Messages: append(history, Message{
|
||||
Role: "user",
|
||||
Content: message,
|
||||
}),
|
||||
Temperature: temp,
|
||||
Stream: true,
|
||||
}
|
||||
req := NewRequest("chat/completions", params)
|
||||
return req.DoStream(api)
|
||||
}
|
||||
|
||||
func (api *API) CreateCompletion(history []Message, message string, temp float32) (AIResponse, error) {
|
||||
params := CreateCompletionReq{
|
||||
Model: api.model,
|
||||
Messages: append(history, Message{
|
||||
Role: "user",
|
||||
Content: message,
|
||||
}),
|
||||
Temperature: temp,
|
||||
Stream: false,
|
||||
}
|
||||
req := NewRequest("chat/completions", params)
|
||||
return req.Do(api)
|
||||
}
|
||||
37
openai/sse.go
Normal file
37
openai/sse.go
Normal file
@@ -0,0 +1,37 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"io"
|
||||
"iter"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// Server-sent event
|
||||
|
||||
func ReadSSE(r io.ReadCloser) iter.Seq[string] {
|
||||
reader := bufio.NewReader(r)
|
||||
return func(yield func(string) bool) {
|
||||
for {
|
||||
line, err := reader.ReadString('\n')
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
if line == "" || line == "\n" {
|
||||
continue
|
||||
}
|
||||
if strings.HasPrefix(line, "data: ") {
|
||||
line = line[len("data: "):]
|
||||
}
|
||||
line = strings.TrimSpace(line)
|
||||
line = strings.Trim(line, "\r")
|
||||
line = strings.Trim(line, "\n")
|
||||
if strings.HasPrefix(line, "[DONE]") {
|
||||
return
|
||||
}
|
||||
if !yield(line) {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
41
openai/types.go
Normal file
41
openai/types.go
Normal file
@@ -0,0 +1,41 @@
|
||||
package openai
|
||||
|
||||
type AIResponse struct {
|
||||
ID string `json:"id"`
|
||||
Object string `json:"object"`
|
||||
Created int64 `json:"created"`
|
||||
Model string `json:"model"`
|
||||
Choices []Choice `json:"choices"`
|
||||
Usage Usage `json:"usage"`
|
||||
ServiceTier string `json:"service_tier"`
|
||||
}
|
||||
type Choice struct {
|
||||
Index int64 `json:"index"`
|
||||
Message Message `json:"message"`
|
||||
Delta Message `json:"delta"`
|
||||
Logprobs interface{} `json:"logprobs"`
|
||||
FinishReason string `json:"finish_reason"`
|
||||
}
|
||||
type Message struct {
|
||||
Role string `json:"role"`
|
||||
Content string `json:"content"`
|
||||
Refusal interface{} `json:"refusal"`
|
||||
Annotations []interface{} `json:"annotations"`
|
||||
}
|
||||
type Usage struct {
|
||||
PromptTokens int64 `json:"prompt_tokens"`
|
||||
CompletionTokens int64 `json:"completion_tokens"`
|
||||
TotalTokens int64 `json:"total_tokens"`
|
||||
PromptTokensDetails PromptTokensDetails `json:"prompt_tokens_details"`
|
||||
CompletionTokensDetails CompletionTokensDetails `json:"completion_tokens_details"`
|
||||
}
|
||||
type CompletionTokensDetails struct {
|
||||
ReasoningTokens int64 `json:"reasoning_tokens"`
|
||||
AudioTokens int64 `json:"audio_tokens"`
|
||||
AcceptedPredictionTokens int64 `json:"accepted_prediction_tokens"`
|
||||
RejectedPredictionTokens int64 `json:"rejected_prediction_tokens"`
|
||||
}
|
||||
type PromptTokensDetails struct {
|
||||
CachedTokens int64 `json:"cached_tokens"`
|
||||
AudioTokens int64 `json:"audio_tokens"`
|
||||
}
|
||||
@@ -2,8 +2,10 @@ package plugins
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"time"
|
||||
"ymgb/database"
|
||||
"ymgb/database/psql"
|
||||
|
||||
@@ -18,6 +20,7 @@ func RegisterAdmin() *laniakea.Plugin[database.Context] {
|
||||
p.AddCommand(laniakea.NewCommand(uploadPhoto, "uploadPhoto").SkipCommandAutoGen())
|
||||
p.AddCommand(laniakea.NewCommand(emojiId, "emojiId").SkipCommandAutoGen())
|
||||
p.AddCommand(laniakea.NewCommand(execSql, "sql").SkipCommandAutoGen())
|
||||
p.AddCommand(laniakea.NewCommand(test, "test").SkipCommandAutoGen())
|
||||
|
||||
p.AddMiddleware(AdminMiddleware())
|
||||
return p
|
||||
@@ -33,6 +36,17 @@ func AdminMiddleware() laniakea.Middleware[database.Context] {
|
||||
})
|
||||
return *m
|
||||
}
|
||||
func test(ctx *laniakea.MsgContext, _ *database.Context) {
|
||||
draft := ctx.NewDraft()
|
||||
for i := 0; i < 10; i++ {
|
||||
err := draft.Push(fmt.Sprintf("%d", i))
|
||||
if err != nil {
|
||||
ctx.Error(err)
|
||||
return
|
||||
}
|
||||
time.Sleep(1 * time.Second)
|
||||
}
|
||||
}
|
||||
|
||||
func execSql(ctx *laniakea.MsgContext, db *database.Context) {
|
||||
stmt := strings.Join(ctx.Args, " ")
|
||||
|
||||
122
plugins/ai.go
122
plugins/ai.go
@@ -1,10 +1,10 @@
|
||||
package plugins
|
||||
|
||||
import (
|
||||
"io"
|
||||
"strings"
|
||||
"ymgb/database"
|
||||
"ymgb/database/mdb"
|
||||
"ymgb/database/red"
|
||||
"ymgb/openai"
|
||||
"ymgb/utils/ai"
|
||||
|
||||
"git.nix13.pw/scuroneko/laniakea"
|
||||
@@ -12,50 +12,96 @@ import (
|
||||
|
||||
func RegisterAi() *laniakea.Plugin[database.Context] {
|
||||
p := laniakea.NewPlugin[database.Context]("AI")
|
||||
p.AddCommand(p.NewCommand(gpt, "gpt").SkipCommandAutoGen())
|
||||
p.AddCommand(p.NewCommand(gptTest, "gpt").SkipCommandAutoGen())
|
||||
//p.AddCommand(p.NewCommand(gptTest, "gpt2").SkipCommandAutoGen())
|
||||
return p
|
||||
}
|
||||
|
||||
func gpt(ctx *laniakea.MsgContext, db *database.Context) {
|
||||
func gptTest(ctx *laniakea.MsgContext, _ *database.Context) {
|
||||
q := strings.Join(ctx.Args, " ")
|
||||
api := ai.NewOpenAIAPI(ai.GPTBaseUrl, "", "anthropic/claude-sonnet-4")
|
||||
defer api.Close()
|
||||
|
||||
aiRedisRep := red.NewAiRepository(db)
|
||||
chatId, err := aiRedisRep.GetOrCreateChatId(ctx.FromID)
|
||||
m := ctx.Answer("Генерация запущена")
|
||||
api := openai.NewOpenAIAPI(ai.GPTBaseUrl, "", "anthropic/claude-sonnet-4")
|
||||
resp, err := api.CreateCompletionStream([]openai.Message{}, q, 1.0)
|
||||
if err != nil {
|
||||
m.Delete()
|
||||
ctx.Error(err)
|
||||
return
|
||||
}
|
||||
history, err := mdb.GetGptChatHistory(db, chatId)
|
||||
if err != nil {
|
||||
ctx.Error(err)
|
||||
return
|
||||
}
|
||||
aiHistory := make([]ai.Message, len(history))
|
||||
for _, m := range history {
|
||||
aiHistory = append(aiHistory, ai.Message{
|
||||
Role: m.Role,
|
||||
Content: m.Message,
|
||||
})
|
||||
}
|
||||
|
||||
m := ctx.Answer("Генерация запущена...")
|
||||
res, err := api.CreateCompletion(aiHistory, q, 1.0)
|
||||
if err != nil {
|
||||
ctx.Error(err)
|
||||
return
|
||||
}
|
||||
answer := res.Choices[0].Message.Content
|
||||
m.Delete()
|
||||
err = mdb.UpdateGptChatHistory(db, chatId, "user", q)
|
||||
if err != nil {
|
||||
ctx.Error(err)
|
||||
}
|
||||
err = mdb.UpdateGptChatHistory(db, chatId, "assistant", answer)
|
||||
if err != nil {
|
||||
ctx.Error(err)
|
||||
}
|
||||
draft := ctx.NewDraft()
|
||||
for r, err := range resp {
|
||||
if err == io.EOF {
|
||||
break
|
||||
}
|
||||
if m != nil {
|
||||
m.Delete()
|
||||
m = nil
|
||||
}
|
||||
if err != nil {
|
||||
ctx.Error(err)
|
||||
return
|
||||
}
|
||||
|
||||
ctx.Answer(answer)
|
||||
if len(r.Choices) == 0 {
|
||||
continue
|
||||
}
|
||||
content := r.Choices[0].Delta.Content
|
||||
if content == "" {
|
||||
continue
|
||||
}
|
||||
err = draft.Push(content)
|
||||
if err != nil {
|
||||
ctx.Error(err)
|
||||
//draft.Flush()
|
||||
break
|
||||
}
|
||||
}
|
||||
err = draft.Flush()
|
||||
if err != nil {
|
||||
ctx.Error(err)
|
||||
}
|
||||
}
|
||||
|
||||
//func gpt(ctx *laniakea.MsgContext, db *database.Context) {
|
||||
// q := strings.Join(ctx.Args, " ")
|
||||
// api := openai.NewOpenAIAPI(ai.GPTBaseUrl, "", "anthropic/claude-sonnet-4")
|
||||
// defer api.Close()
|
||||
//
|
||||
// aiRedisRep := red.NewAiRepository(db)
|
||||
// chatId, err := aiRedisRep.GetOrCreateChatId(ctx.FromID)
|
||||
// if err != nil {
|
||||
// ctx.Error(err)
|
||||
// return
|
||||
// }
|
||||
// history, err := mdb.GetGptChatHistory(db, chatId)
|
||||
// if err != nil {
|
||||
// ctx.Error(err)
|
||||
// return
|
||||
// }
|
||||
// aiHistory := make([]openai.Message, len(history))
|
||||
// for _, m := range history {
|
||||
// aiHistory = append(aiHistory, openai.Message{
|
||||
// Role: m.Role,
|
||||
// Content: m.Message,
|
||||
// })
|
||||
// }
|
||||
//
|
||||
// m := ctx.Answer("Генерация запущена...")
|
||||
// res, err := api.CreateCompletion(aiHistory, q, 1.0)
|
||||
// if err != nil {
|
||||
// ctx.Error(err)
|
||||
// return
|
||||
// }
|
||||
// answer := res.Choices[0].Message.Content
|
||||
// m.Delete()
|
||||
// err = mdb.UpdateGptChatHistory(db, chatId, "user", q)
|
||||
// if err != nil {
|
||||
// ctx.Error(err)
|
||||
// }
|
||||
// err = mdb.UpdateGptChatHistory(db, chatId, "assistant", answer)
|
||||
// if err != nil {
|
||||
// ctx.Error(err)
|
||||
// }
|
||||
//
|
||||
// ctx.Answer(answer)
|
||||
//}
|
||||
|
||||
@@ -16,7 +16,7 @@ func RegisterProxy() *laniakea.Plugin[database.Context] {
|
||||
|
||||
func getH2Link(ctx *laniakea.MsgContext, db *database.Context) {
|
||||
api := utils.NewHysteria2API()
|
||||
url, err := api.GetConnectLink(1, "K1321xt90RUS")
|
||||
url, err := api.GetConnectLink(1, "")
|
||||
if err != nil {
|
||||
ctx.Error(err)
|
||||
return
|
||||
|
||||
@@ -11,6 +11,7 @@ import (
|
||||
"ymgb/database/mdb"
|
||||
"ymgb/database/psql"
|
||||
"ymgb/database/red"
|
||||
"ymgb/openai"
|
||||
"ymgb/utils"
|
||||
"ymgb/utils/ai"
|
||||
|
||||
@@ -557,12 +558,12 @@ func rpUserPromptSet(ctx *laniakea.MsgContext, db *database.Context) {
|
||||
ctx.Answer("Описание пользователя было обновлено")
|
||||
}
|
||||
|
||||
func _getChatHistory(ctx *laniakea.MsgContext, db *database.Context) ([]ai.Message, error) {
|
||||
func _getChatHistory(ctx *laniakea.MsgContext, db *database.Context) ([]openai.Message, error) {
|
||||
redRep := red.NewRPRepository(db)
|
||||
psqlRep := psql.NewRPRepository(db)
|
||||
waifuRep := psql.NewWaifuRepository(db)
|
||||
|
||||
messages := make([]ai.Message, 0)
|
||||
messages := make([]openai.Message, 0)
|
||||
waifuId := redRep.GetSelectedWaifu(ctx.FromID)
|
||||
chatId, err := redRep.GetOrCreateChatId(ctx.FromID, waifuId)
|
||||
if err != nil {
|
||||
@@ -590,7 +591,7 @@ func _getChatHistory(ctx *laniakea.MsgContext, db *database.Context) ([]ai.Messa
|
||||
//if err != nil {
|
||||
// return messages, err
|
||||
//}
|
||||
beforeHistory := ai.Message{
|
||||
beforeHistory := openai.Message{
|
||||
Role: "system",
|
||||
Content: fmt.Sprintf(
|
||||
"%s %s %s %s",
|
||||
@@ -603,7 +604,7 @@ func _getChatHistory(ctx *laniakea.MsgContext, db *database.Context) ([]ai.Messa
|
||||
userPrompt,
|
||||
),
|
||||
}
|
||||
afterHistory := ai.Message{
|
||||
afterHistory := openai.Message{
|
||||
Role: "system",
|
||||
Content: ai.FormatPrompt(preset.PostHistory, waifu.Name, ctx.From.FirstName),
|
||||
}
|
||||
@@ -615,7 +616,7 @@ func _getChatHistory(ctx *laniakea.MsgContext, db *database.Context) ([]ai.Messa
|
||||
|
||||
messages = append(messages, beforeHistory)
|
||||
for _, m := range history {
|
||||
messages = append(messages, ai.Message{Role: m.Role, Content: m.Message})
|
||||
messages = append(messages, openai.Message{Role: m.Role, Content: m.Message})
|
||||
}
|
||||
messages = append(messages, afterHistory)
|
||||
return messages, nil
|
||||
@@ -653,16 +654,41 @@ func generate(ctx *laniakea.MsgContext, db *database.Context) {
|
||||
kb.AddCallbackButtonStyle("Отменить", laniakea.ButtonStyleDanger, "rp.cancel")
|
||||
m := ctx.Keyboard("Генерация запущена...", kb)
|
||||
ctx.SendAction(tgapi.ChatActionTyping)
|
||||
api := ai.NewOpenAIAPI(ai.GPTBaseUrl, "", rpUser.Model.Key)
|
||||
api := openai.NewOpenAIAPI(ai.GPTBaseUrl, "", rpUser.Model.Key)
|
||||
defer api.Close()
|
||||
res, err := api.CreateCompletion(messages, userMessage, 0.5)
|
||||
res, err := api.CreateCompletionStream(messages, userMessage, 0.5)
|
||||
if err != nil {
|
||||
ctx.Error(err)
|
||||
return
|
||||
}
|
||||
if len(res.Choices) == 0 {
|
||||
m.Edit("Не удалось сгенерировать ответ. Попробуйте снова позже")
|
||||
return
|
||||
|
||||
answerContent := ""
|
||||
draft := ctx.NewDraft()
|
||||
for r, err := range res {
|
||||
if m != nil {
|
||||
m.Delete()
|
||||
m = nil
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
ctx.Error(err)
|
||||
return
|
||||
}
|
||||
|
||||
if len(r.Choices) == 0 {
|
||||
continue
|
||||
}
|
||||
content := r.Choices[0].Delta.Content
|
||||
if content == "" {
|
||||
continue
|
||||
}
|
||||
answerContent += content
|
||||
err = draft.Push(content)
|
||||
if err != nil {
|
||||
ctx.Error(err)
|
||||
//draft.Flush()
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
counter := redisRpRep.GetCounter(ctx.FromID, waifuId)
|
||||
@@ -671,9 +697,7 @@ func generate(ctx *laniakea.MsgContext, db *database.Context) {
|
||||
ctx.Error(err)
|
||||
return
|
||||
}
|
||||
agentAnswer := res.Choices[0].Message
|
||||
answerContent := strings.TrimSpace(agentAnswer.Content)
|
||||
err = mdb.UpdateRPChatHistory(db, chatId, agentAnswer.Role, answerContent, counter+2)
|
||||
err = mdb.UpdateRPChatHistory(db, chatId, "assistant", answerContent, counter+2)
|
||||
if err != nil {
|
||||
ctx.Error(err)
|
||||
}
|
||||
@@ -695,7 +719,6 @@ func generate(ctx *laniakea.MsgContext, db *database.Context) {
|
||||
ctx.Error(err)
|
||||
}
|
||||
|
||||
m.Delete()
|
||||
kb = laniakea.NewInlineKeyboard(1)
|
||||
kb.AddCallbackButtonStyle("🔄 Перегенерировать 🔄", laniakea.ButtonStyleSuccess, "rp.regenerate", counter+2)
|
||||
//kb.AddButton(laniakea.NewInlineKbButton("Тест").SetStyle(laniakea.ButtonStyleSuccess).SetIconCustomEmojiId("5375155835846534814").SetCallbackData("rp.test"))
|
||||
@@ -729,7 +752,7 @@ func regenerateResponse(ctx *laniakea.MsgContext, db *database.Context) {
|
||||
ctx.Error(err)
|
||||
return
|
||||
}
|
||||
var messages extypes.Slice[ai.Message]
|
||||
var messages extypes.Slice[openai.Message]
|
||||
messages, err = _getChatHistory(ctx, db)
|
||||
if err != nil {
|
||||
ctx.Error(err)
|
||||
@@ -748,7 +771,7 @@ func regenerateResponse(ctx *laniakea.MsgContext, db *database.Context) {
|
||||
}
|
||||
|
||||
//if messages.Len() == count {
|
||||
// ctx.Bot.Logger().Errorln("len(messages) == count")
|
||||
// ctx.Bot.logger().Errorln("len(messages) == count")
|
||||
// return
|
||||
//}
|
||||
|
||||
@@ -766,7 +789,7 @@ func regenerateResponse(ctx *laniakea.MsgContext, db *database.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
api := ai.NewOpenAIAPI(ai.GPTBaseUrl, "", rpUser.Model.Key)
|
||||
api := openai.NewOpenAIAPI(ai.GPTBaseUrl, "", rpUser.Model.Key)
|
||||
defer api.Close()
|
||||
|
||||
messages = messages.Pop(count - 2)
|
||||
@@ -818,9 +841,9 @@ func _compress(ctx *laniakea.MsgContext, db *database.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
messages := make([]ai.Message, 0)
|
||||
messages := make([]openai.Message, 0)
|
||||
for _, m := range history {
|
||||
messages = append(messages, ai.Message{
|
||||
messages = append(messages, openai.Message{
|
||||
Role: m.Role,
|
||||
Content: m.Message,
|
||||
})
|
||||
@@ -833,9 +856,9 @@ func _compress(ctx *laniakea.MsgContext, db *database.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
api := ai.NewOpenAIAPI(ai.GPTBaseUrl, "", user.Model.Key)
|
||||
api := openai.NewOpenAIAPI(ai.GPTBaseUrl, "", user.Model.Key)
|
||||
defer api.Close()
|
||||
res, err := api.CompressChat(messages)
|
||||
res, err := ai.CompressChat(api, messages)
|
||||
if err != nil {
|
||||
ctx.Error(err)
|
||||
return
|
||||
|
||||
@@ -1,232 +1,11 @@
|
||||
package ai
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
"time"
|
||||
|
||||
"git.nix13.pw/scuroneko/slog"
|
||||
"ymgb/openai"
|
||||
)
|
||||
|
||||
type OpenAIResponse struct {
|
||||
ID string `json:"id"`
|
||||
Object string `json:"object"`
|
||||
Created int64 `json:"created"`
|
||||
Model string `json:"model"`
|
||||
Choices []Choice `json:"choices"`
|
||||
Usage Usage `json:"usage"`
|
||||
ServiceTier string `json:"service_tier"`
|
||||
}
|
||||
type Choice struct {
|
||||
Index int64 `json:"index"`
|
||||
Message Message `json:"message"`
|
||||
Logprobs interface{} `json:"logprobs"`
|
||||
FinishReason string `json:"finish_reason"`
|
||||
}
|
||||
type Message struct {
|
||||
Role string `json:"role"`
|
||||
Content string `json:"content"`
|
||||
Refusal interface{} `json:"refusal"`
|
||||
Annotations []interface{} `json:"annotations"`
|
||||
}
|
||||
type Usage struct {
|
||||
PromptTokens int64 `json:"prompt_tokens"`
|
||||
CompletionTokens int64 `json:"completion_tokens"`
|
||||
TotalTokens int64 `json:"total_tokens"`
|
||||
PromptTokensDetails PromptTokensDetails `json:"prompt_tokens_details"`
|
||||
CompletionTokensDetails CompletionTokensDetails `json:"completion_tokens_details"`
|
||||
}
|
||||
type CompletionTokensDetails struct {
|
||||
ReasoningTokens int64 `json:"reasoning_tokens"`
|
||||
AudioTokens int64 `json:"audio_tokens"`
|
||||
AcceptedPredictionTokens int64 `json:"accepted_prediction_tokens"`
|
||||
RejectedPredictionTokens int64 `json:"rejected_prediction_tokens"`
|
||||
}
|
||||
type PromptTokensDetails struct {
|
||||
CachedTokens int64 `json:"cached_tokens"`
|
||||
AudioTokens int64 `json:"audio_tokens"`
|
||||
}
|
||||
type OpenAIAPI struct {
|
||||
Token string
|
||||
Model string
|
||||
BaseURL string
|
||||
Logger *slog.Logger
|
||||
client *http.Client
|
||||
}
|
||||
|
||||
func NewOpenAIAPI(baseURL, token, model string) *OpenAIAPI {
|
||||
logger := slog.CreateLogger()
|
||||
level := slog.FATAL
|
||||
if os.Getenv("DEBUG") == "true" {
|
||||
level = slog.DEBUG
|
||||
}
|
||||
logger = logger.Prefix("AI").Level(level)
|
||||
// FIXME Leak here
|
||||
//logger = logger.AddWriter(logger.CreateJsonStdoutWriter())
|
||||
proxy, err := url.Parse(os.Getenv("HTTPS_PROXY"))
|
||||
if err != nil {
|
||||
logger.Error(err)
|
||||
logger.Close()
|
||||
}
|
||||
t := &http.Transport{}
|
||||
if proxy.Host != "" {
|
||||
t.Proxy = http.ProxyURL(proxy)
|
||||
}
|
||||
client := &http.Client{
|
||||
Timeout: 5 * time.Minute,
|
||||
Transport: t,
|
||||
}
|
||||
return &OpenAIAPI{
|
||||
Token: token,
|
||||
Model: model,
|
||||
BaseURL: baseURL,
|
||||
Logger: logger,
|
||||
client: client,
|
||||
}
|
||||
}
|
||||
func (o *OpenAIAPI) Close() error {
|
||||
return o.Logger.Close()
|
||||
}
|
||||
|
||||
type CreateCompletionReq struct {
|
||||
Model string `json:"model"`
|
||||
Messages []Message `json:"messages"`
|
||||
Verbosity string `json:"verbosity,omitempty"`
|
||||
Temperature float32 `json:"temperature,omitempty"`
|
||||
PresencePenalty int `json:"presence_penalty,omitempty"`
|
||||
FrequencyPenalty int `json:"frequency_penalty,omitempty"`
|
||||
TopP int `json:"top_p,omitempty"`
|
||||
MaxCompletionTokens int `json:"max_completition_tokens,omitempty"`
|
||||
}
|
||||
|
||||
var MaxRetriesErr = fmt.Errorf("max retries exceeded")
|
||||
var BadResponseErr = fmt.Errorf("bad_response_status_code")
|
||||
|
||||
func (o *OpenAIAPI) DoRequest(url string, params any, retries int) ([]byte, error) {
|
||||
responseBody := make([]byte, 0)
|
||||
data, err := json.Marshal(params)
|
||||
if err != nil {
|
||||
log.Println("json marshal failed:", err)
|
||||
return responseBody, err
|
||||
}
|
||||
|
||||
req, err := http.NewRequest("POST", url, bytes.NewBuffer(data))
|
||||
if err != nil {
|
||||
log.Println("create request failed:", err)
|
||||
return responseBody, err
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
if o.Token != "" {
|
||||
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", o.Token))
|
||||
}
|
||||
|
||||
res, err := o.client.Do(req)
|
||||
if err != nil {
|
||||
log.Println("do request failed:", err)
|
||||
return nil, err
|
||||
}
|
||||
defer res.Body.Close()
|
||||
if res.StatusCode == 504 || res.StatusCode == 400 || res.StatusCode == 502 {
|
||||
o.Logger.Warn(fmt.Sprintf("[%d] %s", res.StatusCode, res.Status))
|
||||
if retries >= 3 {
|
||||
return responseBody, MaxRetriesErr
|
||||
}
|
||||
time.Sleep(1 * time.Second)
|
||||
return o.DoRequest(url, params, retries+1)
|
||||
}
|
||||
responseBody, err = io.ReadAll(res.Body)
|
||||
if err != nil {
|
||||
log.Println("read response failed:", err)
|
||||
return responseBody, err
|
||||
}
|
||||
|
||||
var tempData any
|
||||
err = json.Unmarshal(responseBody, &tempData)
|
||||
if err != nil {
|
||||
log.Println("json unmarshal failed:", err)
|
||||
return responseBody, err
|
||||
}
|
||||
// {"error":{"message":"openai_error","type":"bad_response_status_code","param":"","code":"bad_response_status_code"}}
|
||||
if eData, ok := tempData.(map[string]any); ok {
|
||||
if errorData, ok := eData["error"]; ok {
|
||||
if errorPayload, ok := errorData.(map[string]any); ok {
|
||||
code := errorPayload["code"].(string)
|
||||
if code == "bad_response_status_code" {
|
||||
if retries >= 3 {
|
||||
return responseBody, BadResponseErr
|
||||
}
|
||||
o.Logger.Warnln("Retrying because of bad response status code")
|
||||
return o.DoRequest(url, params, retries+1)
|
||||
}
|
||||
return nil, errors.New(code)
|
||||
}
|
||||
o.Logger.Errorln("Unknown error", errorData)
|
||||
return nil, errors.New(string(responseBody))
|
||||
}
|
||||
} else if eData, ok := tempData.(string); ok {
|
||||
return responseBody, errors.New(eData)
|
||||
}
|
||||
|
||||
return responseBody, err
|
||||
}
|
||||
|
||||
func (o *OpenAIAPI) CreateCompletion(history []Message, message string, temp float32) (*OpenAIResponse, error) {
|
||||
u := fmt.Sprintf("%s/v1/chat/completions", o.BaseURL)
|
||||
request := CreateCompletionReq{
|
||||
Model: o.Model,
|
||||
Messages: append(history, Message{
|
||||
Role: "user",
|
||||
Content: message,
|
||||
}),
|
||||
Temperature: temp,
|
||||
}
|
||||
data, err := json.Marshal(request)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
o.Logger.Debug("REQ", u, string(data))
|
||||
body, err := o.DoRequest(u, request, 0)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
o.Logger.Debug("RES", u, string(body))
|
||||
|
||||
response := new(OpenAIResponse)
|
||||
err = json.Unmarshal(body, response)
|
||||
return response, err
|
||||
}
|
||||
|
||||
func (o *OpenAIAPI) CompressChat(history []Message) (*OpenAIResponse, error) {
|
||||
request := CreateCompletionReq{
|
||||
Model: o.Model,
|
||||
Messages: append(history, Message{
|
||||
Role: "user",
|
||||
Content: CompressPrompt,
|
||||
}),
|
||||
Temperature: 1.0,
|
||||
}
|
||||
u := fmt.Sprintf("%s/v1/chat/completions", o.BaseURL)
|
||||
data, err := json.Marshal(request)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
o.Logger.Debug("COMPRESS REQ", u, string(data))
|
||||
body, err := o.DoRequest(u, request, 0)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
o.Logger.Debug("COMPRESS RES", u, string(body))
|
||||
|
||||
response := new(OpenAIResponse)
|
||||
err = json.Unmarshal(body, response)
|
||||
return response, err
|
||||
//https://github.com/sashabaranov/go-openai
|
||||
|
||||
func CompressChat(api *openai.API, history []openai.Message) (openai.AIResponse, error) {
|
||||
return api.CreateCompletion(history, CompressPrompt, 0.0)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user