218 lines
5.8 KiB
Go
218 lines
5.8 KiB
Go
package ai
|
|
|
|
import (
|
|
"bytes"
|
|
"encoding/json"
|
|
"errors"
|
|
"fmt"
|
|
"io"
|
|
"net/http"
|
|
"net/url"
|
|
"os"
|
|
"time"
|
|
|
|
"git.nix13.pw/scuroneko/slog"
|
|
)
|
|
|
|
type OpenAIResponse struct {
|
|
ID string `json:"id"`
|
|
Object string `json:"object"`
|
|
Created int64 `json:"created"`
|
|
Model string `json:"model"`
|
|
Choices []Choice `json:"choices"`
|
|
Usage Usage `json:"usage"`
|
|
ServiceTier string `json:"service_tier"`
|
|
}
|
|
type Choice struct {
|
|
Index int64 `json:"index"`
|
|
Message Message `json:"message"`
|
|
Logprobs interface{} `json:"logprobs"`
|
|
FinishReason string `json:"finish_reason"`
|
|
}
|
|
type Message struct {
|
|
Role string `json:"role"`
|
|
Content string `json:"content"`
|
|
Refusal interface{} `json:"refusal"`
|
|
Annotations []interface{} `json:"annotations"`
|
|
}
|
|
type Usage struct {
|
|
PromptTokens int64 `json:"prompt_tokens"`
|
|
CompletionTokens int64 `json:"completion_tokens"`
|
|
TotalTokens int64 `json:"total_tokens"`
|
|
PromptTokensDetails PromptTokensDetails `json:"prompt_tokens_details"`
|
|
CompletionTokensDetails CompletionTokensDetails `json:"completion_tokens_details"`
|
|
}
|
|
type CompletionTokensDetails struct {
|
|
ReasoningTokens int64 `json:"reasoning_tokens"`
|
|
AudioTokens int64 `json:"audio_tokens"`
|
|
AcceptedPredictionTokens int64 `json:"accepted_prediction_tokens"`
|
|
RejectedPredictionTokens int64 `json:"rejected_prediction_tokens"`
|
|
}
|
|
type PromptTokensDetails struct {
|
|
CachedTokens int64 `json:"cached_tokens"`
|
|
AudioTokens int64 `json:"audio_tokens"`
|
|
}
|
|
type OpenAIAPI struct {
|
|
Token string
|
|
Model string
|
|
BaseURL string
|
|
Logger *slog.Logger
|
|
client *http.Client
|
|
}
|
|
|
|
func NewOpenAIAPI(baseURL, token, model string) *OpenAIAPI {
|
|
logger := slog.CreateLogger()
|
|
level := slog.FATAL
|
|
if os.Getenv("DEBUG") == "true" {
|
|
level = slog.DEBUG
|
|
}
|
|
logger = logger.Prefix("AI").Level(level)
|
|
proxy, err := url.Parse(os.Getenv("HTTPS_PROXY"))
|
|
if err != nil {
|
|
logger.Error(err)
|
|
}
|
|
t := &http.Transport{}
|
|
if proxy.Host != "" {
|
|
t.Proxy = http.ProxyURL(proxy)
|
|
}
|
|
client := &http.Client{
|
|
Timeout: 5 * time.Minute,
|
|
Transport: t,
|
|
}
|
|
return &OpenAIAPI{
|
|
Token: token,
|
|
Model: model,
|
|
BaseURL: baseURL,
|
|
Logger: logger,
|
|
client: client,
|
|
}
|
|
}
|
|
func (o *OpenAIAPI) Close() error {
|
|
return o.Logger.Close()
|
|
}
|
|
|
|
type CreateCompletionReq struct {
|
|
Model string `json:"model"`
|
|
Messages []Message `json:"messages"`
|
|
Verbosity string `json:"verbosity,omitempty"`
|
|
Temperature float32 `json:"temperature,omitempty"`
|
|
PresencePenalty int `json:"presence_penalty,omitempty"`
|
|
FrequencyPenalty int `json:"frequency_penalty,omitempty"`
|
|
TopP int `json:"top_p,omitempty"`
|
|
MaxCompletionTokens int `json:"max_completition_tokens,omitempty"`
|
|
}
|
|
|
|
var MaxRetriesErr = fmt.Errorf("max retries exceeded")
|
|
var BadResponseErr = fmt.Errorf("bad_response_status_code")
|
|
|
|
func (o *OpenAIAPI) DoRequest(url string, params any, retries int) ([]byte, error) {
|
|
responseBody := make([]byte, 0)
|
|
data, err := json.Marshal(params)
|
|
if err != nil {
|
|
return responseBody, err
|
|
}
|
|
|
|
req, err := http.NewRequest("POST", url, bytes.NewBuffer(data))
|
|
if err != nil {
|
|
return responseBody, err
|
|
}
|
|
req.Header.Set("Content-Type", "application/json")
|
|
if o.Token != "" {
|
|
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", o.Token))
|
|
}
|
|
|
|
res, err := o.client.Do(req)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
defer res.Body.Close()
|
|
if res.StatusCode == 504 || res.StatusCode == 400 || res.StatusCode == 502 {
|
|
o.Logger.Warn(fmt.Sprintf("[%d] %s", res.StatusCode, res.Status))
|
|
if retries >= 3 {
|
|
return responseBody, MaxRetriesErr
|
|
}
|
|
time.Sleep(1 * time.Second)
|
|
return o.DoRequest(url, params, retries+1)
|
|
}
|
|
responseBody, err = io.ReadAll(res.Body)
|
|
if err != nil {
|
|
return responseBody, err
|
|
}
|
|
|
|
tempData := make(map[string]any)
|
|
err = json.Unmarshal(responseBody, &tempData)
|
|
if err != nil {
|
|
return responseBody, err
|
|
}
|
|
// {"error":{"message":"openai_error","type":"bad_response_status_code","param":"","code":"bad_response_status_code"}}
|
|
if errorData, ok := tempData["error"]; ok {
|
|
o.Logger.Error(errorData)
|
|
errorPayload := errorData.(map[string]interface{})
|
|
code := errorPayload["code"].(string)
|
|
if code == "bad_response_status_code" {
|
|
if retries >= 3 {
|
|
return responseBody, BadResponseErr
|
|
}
|
|
o.Logger.Debug("Retrying because of bad response status code")
|
|
return o.DoRequest(url, params, retries+1)
|
|
}
|
|
return nil, errors.New(code)
|
|
}
|
|
|
|
return responseBody, err
|
|
}
|
|
|
|
func (o *OpenAIAPI) CreateCompletion(history []Message, message string, temp float32) (*OpenAIResponse, error) {
|
|
u := fmt.Sprintf("%s/v1/chat/completions", o.BaseURL)
|
|
request := CreateCompletionReq{
|
|
Model: o.Model,
|
|
Messages: append(history, Message{
|
|
Role: "user",
|
|
Content: message,
|
|
}),
|
|
Temperature: temp,
|
|
}
|
|
data, err := json.Marshal(request)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
o.Logger.Debug("REQ", u, string(data))
|
|
body, err := o.DoRequest(u, request, 0)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
o.Logger.Debug("RES", u, string(body))
|
|
|
|
response := new(OpenAIResponse)
|
|
err = json.Unmarshal(body, response)
|
|
return response, err
|
|
}
|
|
|
|
func (o *OpenAIAPI) CompressChat(history []Message) (*OpenAIResponse, error) {
|
|
request := CreateCompletionReq{
|
|
Model: o.Model,
|
|
Messages: append(history, Message{
|
|
Role: "user",
|
|
Content: CompressPrompt,
|
|
}),
|
|
Temperature: 1.0,
|
|
}
|
|
u := fmt.Sprintf("%s/v1/chat/completions", o.BaseURL)
|
|
data, err := json.Marshal(request)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
o.Logger.Debug("COMPRESS REQ", u, string(data))
|
|
body, err := o.DoRequest(u, request, 0)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
o.Logger.Debug("COMPRESS RES", u, string(body))
|
|
|
|
response := new(OpenAIResponse)
|
|
err = json.Unmarshal(body, response)
|
|
return response, err
|
|
}
|