package ai import ( "bytes" "encoding/json" "errors" "fmt" "git.nix13.pw/scuroneko/laniakea" "io" "net/http" "net/url" "os" "time" ) type OpenAIResponse struct { ID string `json:"id"` Object string `json:"object"` Created int64 `json:"created"` Model string `json:"model"` Choices []Choice `json:"choices"` Usage Usage `json:"usage"` ServiceTier string `json:"service_tier"` } type Choice struct { Index int64 `json:"index"` Message Message `json:"message"` Logprobs interface{} `json:"logprobs"` FinishReason string `json:"finish_reason"` } type Message struct { Role string `json:"role"` Content string `json:"content"` Refusal interface{} `json:"refusal"` Annotations []interface{} `json:"annotations"` } type Usage struct { PromptTokens int64 `json:"prompt_tokens"` CompletionTokens int64 `json:"completion_tokens"` TotalTokens int64 `json:"total_tokens"` PromptTokensDetails PromptTokensDetails `json:"prompt_tokens_details"` CompletionTokensDetails CompletionTokensDetails `json:"completion_tokens_details"` } type CompletionTokensDetails struct { ReasoningTokens int64 `json:"reasoning_tokens"` AudioTokens int64 `json:"audio_tokens"` AcceptedPredictionTokens int64 `json:"accepted_prediction_tokens"` RejectedPredictionTokens int64 `json:"rejected_prediction_tokens"` } type PromptTokensDetails struct { CachedTokens int64 `json:"cached_tokens"` AudioTokens int64 `json:"audio_tokens"` } type OpenAIAPI struct { Token string Model string BaseURL string Logger *laniakea.Logger client *http.Client } func NewOpenAIAPI(baseURL, token, model string) *OpenAIAPI { logger := laniakea.CreateLogger() level := laniakea.FATAL if os.Getenv("DEBUG") == "true" { level = laniakea.DEBUG } logger = logger.Prefix("AI").Level(level) proxy, err := url.Parse(os.Getenv("HTTPS_PROXY")) if err != nil { logger.Error(err) } t := &http.Transport{} if proxy.Host != "" { t.Proxy = http.ProxyURL(proxy) } client := &http.Client{ Timeout: 5 * time.Minute, Transport: t, } return &OpenAIAPI{ Token: token, Model: model, BaseURL: baseURL, Logger: logger, client: client, } } type CreateCompletionReq struct { Model string `json:"model"` Messages []Message `json:"messages"` Verbosity string `json:"verbosity,omitempty"` Temperature float32 `json:"temperature,omitempty"` PresencePenalty int `json:"presence_penalty,omitempty"` FrequencyPenalty int `json:"frequency_penalty,omitempty"` TopP int `json:"top_p,omitempty"` MaxCompletionTokens int `json:"max_completition_tokens,omitempty"` } var MaxRetriesErr = fmt.Errorf("max retries exceeded") var BadResponseErr = fmt.Errorf("bad_response_status_code") func (o *OpenAIAPI) DoRequest(url string, params any, retries int) ([]byte, error) { responseBody := make([]byte, 0) data, err := json.Marshal(params) if err != nil { return responseBody, err } req, err := http.NewRequest("POST", url, bytes.NewBuffer(data)) if err != nil { return responseBody, err } req.Header.Set("Content-Type", "application/json") if o.Token != "" { req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", o.Token)) } res, err := o.client.Do(req) if err != nil { return nil, err } defer res.Body.Close() if res.StatusCode == 504 || res.StatusCode == 400 || res.StatusCode == 502 { o.Logger.Warn(fmt.Sprintf("[%d] %s", res.StatusCode, res.Status)) if retries >= 3 { return responseBody, MaxRetriesErr } time.Sleep(1 * time.Second) return o.DoRequest(url, params, retries+1) } responseBody, err = io.ReadAll(res.Body) if err != nil { return responseBody, err } tempData := make(map[string]any) err = json.Unmarshal(responseBody, &tempData) if err != nil { return responseBody, err } // {"error":{"message":"openai_error","type":"bad_response_status_code","param":"","code":"bad_response_status_code"}} if errorData, ok := tempData["error"]; ok { o.Logger.Error(errorData) errorPayload := errorData.(map[string]interface{}) code := errorPayload["code"].(string) if code == "bad_response_status_code" { if retries >= 3 { return responseBody, BadResponseErr } o.Logger.Debug("Retrying because of bad response status code") return o.DoRequest(url, params, retries+1) } return nil, errors.New(code) } return responseBody, err } func (o *OpenAIAPI) CreateCompletion(history []Message, message string, temp float32) (*OpenAIResponse, error) { u := fmt.Sprintf("%s/v1/chat/completions", o.BaseURL) request := CreateCompletionReq{ Model: o.Model, Messages: append(history, Message{ Role: "user", Content: message, }), Temperature: temp, } data, err := json.Marshal(request) if err != nil { return nil, err } o.Logger.Debug("REQ", u, string(data)) body, err := o.DoRequest(u, request, 0) if err != nil { return nil, err } o.Logger.Debug("RES", u, string(body)) response := new(OpenAIResponse) err = json.Unmarshal(body, response) return response, err } func (o *OpenAIAPI) CompressChat(history []Message) (*OpenAIResponse, error) { request := CreateCompletionReq{ Model: o.Model, Messages: append(history, Message{ Role: "user", Content: CompressPrompt, }), Temperature: 1.0, } u := fmt.Sprintf("%s/v1/chat/completions", o.BaseURL) data, err := json.Marshal(request) if err != nil { return nil, err } o.Logger.Debug("COMPRESS REQ", u, string(data)) body, err := o.DoRequest(u, request, 0) if err != nil { return nil, err } o.Logger.Debug("COMPRESS RES", u, string(body)) response := new(OpenAIResponse) err = json.Unmarshal(body, response) return response, err }