some changes
This commit is contained in:
@@ -1,232 +1,11 @@
|
||||
package ai
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
"time"
|
||||
|
||||
"git.nix13.pw/scuroneko/slog"
|
||||
"ymgb/openai"
|
||||
)
|
||||
|
||||
type OpenAIResponse struct {
|
||||
ID string `json:"id"`
|
||||
Object string `json:"object"`
|
||||
Created int64 `json:"created"`
|
||||
Model string `json:"model"`
|
||||
Choices []Choice `json:"choices"`
|
||||
Usage Usage `json:"usage"`
|
||||
ServiceTier string `json:"service_tier"`
|
||||
}
|
||||
type Choice struct {
|
||||
Index int64 `json:"index"`
|
||||
Message Message `json:"message"`
|
||||
Logprobs interface{} `json:"logprobs"`
|
||||
FinishReason string `json:"finish_reason"`
|
||||
}
|
||||
type Message struct {
|
||||
Role string `json:"role"`
|
||||
Content string `json:"content"`
|
||||
Refusal interface{} `json:"refusal"`
|
||||
Annotations []interface{} `json:"annotations"`
|
||||
}
|
||||
type Usage struct {
|
||||
PromptTokens int64 `json:"prompt_tokens"`
|
||||
CompletionTokens int64 `json:"completion_tokens"`
|
||||
TotalTokens int64 `json:"total_tokens"`
|
||||
PromptTokensDetails PromptTokensDetails `json:"prompt_tokens_details"`
|
||||
CompletionTokensDetails CompletionTokensDetails `json:"completion_tokens_details"`
|
||||
}
|
||||
type CompletionTokensDetails struct {
|
||||
ReasoningTokens int64 `json:"reasoning_tokens"`
|
||||
AudioTokens int64 `json:"audio_tokens"`
|
||||
AcceptedPredictionTokens int64 `json:"accepted_prediction_tokens"`
|
||||
RejectedPredictionTokens int64 `json:"rejected_prediction_tokens"`
|
||||
}
|
||||
type PromptTokensDetails struct {
|
||||
CachedTokens int64 `json:"cached_tokens"`
|
||||
AudioTokens int64 `json:"audio_tokens"`
|
||||
}
|
||||
type OpenAIAPI struct {
|
||||
Token string
|
||||
Model string
|
||||
BaseURL string
|
||||
Logger *slog.Logger
|
||||
client *http.Client
|
||||
}
|
||||
|
||||
func NewOpenAIAPI(baseURL, token, model string) *OpenAIAPI {
|
||||
logger := slog.CreateLogger()
|
||||
level := slog.FATAL
|
||||
if os.Getenv("DEBUG") == "true" {
|
||||
level = slog.DEBUG
|
||||
}
|
||||
logger = logger.Prefix("AI").Level(level)
|
||||
// FIXME Leak here
|
||||
//logger = logger.AddWriter(logger.CreateJsonStdoutWriter())
|
||||
proxy, err := url.Parse(os.Getenv("HTTPS_PROXY"))
|
||||
if err != nil {
|
||||
logger.Error(err)
|
||||
logger.Close()
|
||||
}
|
||||
t := &http.Transport{}
|
||||
if proxy.Host != "" {
|
||||
t.Proxy = http.ProxyURL(proxy)
|
||||
}
|
||||
client := &http.Client{
|
||||
Timeout: 5 * time.Minute,
|
||||
Transport: t,
|
||||
}
|
||||
return &OpenAIAPI{
|
||||
Token: token,
|
||||
Model: model,
|
||||
BaseURL: baseURL,
|
||||
Logger: logger,
|
||||
client: client,
|
||||
}
|
||||
}
|
||||
func (o *OpenAIAPI) Close() error {
|
||||
return o.Logger.Close()
|
||||
}
|
||||
|
||||
type CreateCompletionReq struct {
|
||||
Model string `json:"model"`
|
||||
Messages []Message `json:"messages"`
|
||||
Verbosity string `json:"verbosity,omitempty"`
|
||||
Temperature float32 `json:"temperature,omitempty"`
|
||||
PresencePenalty int `json:"presence_penalty,omitempty"`
|
||||
FrequencyPenalty int `json:"frequency_penalty,omitempty"`
|
||||
TopP int `json:"top_p,omitempty"`
|
||||
MaxCompletionTokens int `json:"max_completition_tokens,omitempty"`
|
||||
}
|
||||
|
||||
var MaxRetriesErr = fmt.Errorf("max retries exceeded")
|
||||
var BadResponseErr = fmt.Errorf("bad_response_status_code")
|
||||
|
||||
func (o *OpenAIAPI) DoRequest(url string, params any, retries int) ([]byte, error) {
|
||||
responseBody := make([]byte, 0)
|
||||
data, err := json.Marshal(params)
|
||||
if err != nil {
|
||||
log.Println("json marshal failed:", err)
|
||||
return responseBody, err
|
||||
}
|
||||
|
||||
req, err := http.NewRequest("POST", url, bytes.NewBuffer(data))
|
||||
if err != nil {
|
||||
log.Println("create request failed:", err)
|
||||
return responseBody, err
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
if o.Token != "" {
|
||||
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", o.Token))
|
||||
}
|
||||
|
||||
res, err := o.client.Do(req)
|
||||
if err != nil {
|
||||
log.Println("do request failed:", err)
|
||||
return nil, err
|
||||
}
|
||||
defer res.Body.Close()
|
||||
if res.StatusCode == 504 || res.StatusCode == 400 || res.StatusCode == 502 {
|
||||
o.Logger.Warn(fmt.Sprintf("[%d] %s", res.StatusCode, res.Status))
|
||||
if retries >= 3 {
|
||||
return responseBody, MaxRetriesErr
|
||||
}
|
||||
time.Sleep(1 * time.Second)
|
||||
return o.DoRequest(url, params, retries+1)
|
||||
}
|
||||
responseBody, err = io.ReadAll(res.Body)
|
||||
if err != nil {
|
||||
log.Println("read response failed:", err)
|
||||
return responseBody, err
|
||||
}
|
||||
|
||||
var tempData any
|
||||
err = json.Unmarshal(responseBody, &tempData)
|
||||
if err != nil {
|
||||
log.Println("json unmarshal failed:", err)
|
||||
return responseBody, err
|
||||
}
|
||||
// {"error":{"message":"openai_error","type":"bad_response_status_code","param":"","code":"bad_response_status_code"}}
|
||||
if eData, ok := tempData.(map[string]any); ok {
|
||||
if errorData, ok := eData["error"]; ok {
|
||||
if errorPayload, ok := errorData.(map[string]any); ok {
|
||||
code := errorPayload["code"].(string)
|
||||
if code == "bad_response_status_code" {
|
||||
if retries >= 3 {
|
||||
return responseBody, BadResponseErr
|
||||
}
|
||||
o.Logger.Warnln("Retrying because of bad response status code")
|
||||
return o.DoRequest(url, params, retries+1)
|
||||
}
|
||||
return nil, errors.New(code)
|
||||
}
|
||||
o.Logger.Errorln("Unknown error", errorData)
|
||||
return nil, errors.New(string(responseBody))
|
||||
}
|
||||
} else if eData, ok := tempData.(string); ok {
|
||||
return responseBody, errors.New(eData)
|
||||
}
|
||||
|
||||
return responseBody, err
|
||||
}
|
||||
|
||||
func (o *OpenAIAPI) CreateCompletion(history []Message, message string, temp float32) (*OpenAIResponse, error) {
|
||||
u := fmt.Sprintf("%s/v1/chat/completions", o.BaseURL)
|
||||
request := CreateCompletionReq{
|
||||
Model: o.Model,
|
||||
Messages: append(history, Message{
|
||||
Role: "user",
|
||||
Content: message,
|
||||
}),
|
||||
Temperature: temp,
|
||||
}
|
||||
data, err := json.Marshal(request)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
o.Logger.Debug("REQ", u, string(data))
|
||||
body, err := o.DoRequest(u, request, 0)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
o.Logger.Debug("RES", u, string(body))
|
||||
|
||||
response := new(OpenAIResponse)
|
||||
err = json.Unmarshal(body, response)
|
||||
return response, err
|
||||
}
|
||||
|
||||
func (o *OpenAIAPI) CompressChat(history []Message) (*OpenAIResponse, error) {
|
||||
request := CreateCompletionReq{
|
||||
Model: o.Model,
|
||||
Messages: append(history, Message{
|
||||
Role: "user",
|
||||
Content: CompressPrompt,
|
||||
}),
|
||||
Temperature: 1.0,
|
||||
}
|
||||
u := fmt.Sprintf("%s/v1/chat/completions", o.BaseURL)
|
||||
data, err := json.Marshal(request)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
o.Logger.Debug("COMPRESS REQ", u, string(data))
|
||||
body, err := o.DoRequest(u, request, 0)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
o.Logger.Debug("COMPRESS RES", u, string(body))
|
||||
|
||||
response := new(OpenAIResponse)
|
||||
err = json.Unmarshal(body, response)
|
||||
return response, err
|
||||
//https://github.com/sashabaranov/go-openai
|
||||
|
||||
func CompressChat(api *openai.API, history []openai.Message) (openai.AIResponse, error) {
|
||||
return api.CreateCompletion(history, CompressPrompt, 0.0)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user