Files
YaeMikoBot/utils/ai/openai.go
2026-01-19 21:56:41 +03:00

184 lines
4.7 KiB
Go

package ai
import (
"bytes"
"encoding/json"
"fmt"
"io"
"kurumibot/laniakea"
"net/http"
"net/url"
"os"
"time"
)
type OpenAIResponse struct {
ID string `json:"id"`
Object string `json:"object"`
Created int64 `json:"created"`
Model string `json:"model"`
Choices []Choice `json:"choices"`
Usage Usage `json:"usage"`
ServiceTier string `json:"service_tier"`
}
type Choice struct {
Index int64 `json:"index"`
Message Message `json:"message"`
Logprobs interface{} `json:"logprobs"`
FinishReason string `json:"finish_reason"`
}
type Message struct {
Role string `json:"role"`
Content string `json:"content"`
Refusal interface{} `json:"refusal"`
Annotations []interface{} `json:"annotations"`
}
type Usage struct {
PromptTokens int64 `json:"prompt_tokens"`
CompletionTokens int64 `json:"completion_tokens"`
TotalTokens int64 `json:"total_tokens"`
PromptTokensDetails PromptTokensDetails `json:"prompt_tokens_details"`
CompletionTokensDetails CompletionTokensDetails `json:"completion_tokens_details"`
}
type CompletionTokensDetails struct {
ReasoningTokens int64 `json:"reasoning_tokens"`
AudioTokens int64 `json:"audio_tokens"`
AcceptedPredictionTokens int64 `json:"accepted_prediction_tokens"`
RejectedPredictionTokens int64 `json:"rejected_prediction_tokens"`
}
type PromptTokensDetails struct {
CachedTokens int64 `json:"cached_tokens"`
AudioTokens int64 `json:"audio_tokens"`
}
type OpenAIAPI struct {
Token string
Model string
BaseURL string
Logger *laniakea.Logger
client *http.Client
}
func NewOpenAIAPI(baseURL, token, model string) *OpenAIAPI {
logger := laniakea.CreateLogger()
logger = logger.Prefix("AI").Level(laniakea.DEBUG)
proxy, err := url.Parse(os.Getenv("HTTPS_PROXY"))
if err != nil {
logger.Error(err)
}
client := &http.Client{
Timeout: 15 * time.Second,
Transport: &http.Transport{
Proxy: http.ProxyURL(proxy),
},
}
return &OpenAIAPI{
Token: token,
Model: model,
BaseURL: baseURL,
Logger: logger,
client: client,
}
}
type CreateCompletionReq struct {
Model string `json:"model"`
Messages []Message `json:"messages"`
Verbosity string `json:"verbosity,omitempty"`
Temperature float32 `json:"temperature,omitempty"`
PresencePenalty int `json:"presence_penalty,omitempty"`
FrequencyPenalty int `json:"frequency_penalty,omitempty"`
TopP int `json:"top_p,omitempty"`
MaxCompletionTokens int `json:"max_completition_tokens,omitempty"`
}
func (o *OpenAIAPI) CreateCompletion(request CreateCompletionReq) (*OpenAIResponse, error) {
u := fmt.Sprintf("%s/v1/chat/completions", o.BaseURL)
request.Model = o.Model
data, err := json.Marshal(request)
o.Logger.Debug("REQ", u, string(data))
if err != nil {
return nil, err
}
buf := bytes.NewBuffer(data)
req, err := http.NewRequest("POST", u, buf)
if err != nil {
return nil, err
}
req.Header.Set("Content-Type", "application/json")
if o.Token != "" {
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", o.Token))
}
resp, err := http.DefaultClient.Do(req)
if err != nil {
return nil, err
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, err
}
o.Logger.Debug("RES", u, string(body))
response := new(OpenAIResponse)
err = json.Unmarshal(body, response)
return response, err
}
func (o *OpenAIAPI) CompressChat(history []Message) (*OpenAIResponse, error) {
request := CreateCompletionReq{
Model: o.Model,
Messages: append(history, Message{
Role: "user",
Content: CompressPrompt,
}),
Verbosity: "low",
Temperature: 1.0,
}
u := fmt.Sprintf("%s/v1/chat/completions", o.BaseURL)
data, err := json.Marshal(request)
o.Logger.Debug("COMPRESS REQ", u, string(data))
if err != nil {
return nil, err
}
buf := bytes.NewBuffer(data)
req, err := http.NewRequest("POST", u, buf)
if err != nil {
return nil, err
}
req.Header.Set("Content-Type", "application/json")
if o.Token != "" {
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", o.Token))
}
resp, err := http.DefaultClient.Do(req)
if err != nil {
return nil, err
}
if resp.StatusCode == 504 || resp.StatusCode == 400 {
time.Sleep(5 * time.Second)
resp, err = http.DefaultClient.Do(req)
if err != nil {
return nil, err
}
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, err
}
o.Logger.Debug("COMPRESS RES", u, string(body))
response := new(OpenAIResponse)
err = json.Unmarshal(body, response)
return response, err
}