package openai import ( "bufio" "bytes" "context" "encoding/json" "errors" "fmt" "io" "iter" "net/http" "net/url" "os" "strings" "time" "git.nix13.pw/scuroneko/slog" ) type API struct { token string model string baseUrl string logger *slog.Logger client *http.Client stream bool } func NewOpenAIAPI(baseURL, token, model string) *API { logger := slog.CreateLogger() level := slog.FATAL if os.Getenv("DEBUG") == "true" { level = slog.DEBUG } logger = logger.Prefix("AI").Level(level) // FIXME Leak here //logger = logger.AddWriter(logger.CreateJsonStdoutWriter()) proxy, err := url.Parse(os.Getenv("HTTPS_PROXY")) if err != nil { logger.Error(err) logger.Close() return nil } t := &http.Transport{} if proxy.Host != "" { t.Proxy = http.ProxyURL(proxy) } client := &http.Client{ Timeout: 5 * time.Minute, Transport: t, } return &API{ token: token, model: model, baseUrl: baseURL, logger: logger, client: client, } } func (api *API) Close() error { return api.logger.Close() } func (api *API) SetStream(stream bool) *API { api.stream = stream return api } func (api *API) GetModel() string { return api.model } func (api *API) GetBaseURL() string { return api.baseUrl } type Request[P any] struct { params P method string } func NewRequest[P any](method string, params P) *Request[P] { return &Request[P]{params, method} } func (r *Request[P]) doWithContext(ctx context.Context, api *API) (io.ReadCloser, error) { data, err := json.Marshal(r.params) if err != nil { return nil, err } u := fmt.Sprintf("%s/v1/%s", api.baseUrl, r.method) req, err := http.NewRequestWithContext(ctx, "POST", u, bytes.NewBuffer(data)) if err != nil { return nil, err } req.Header.Set("Content-Type", "application/json") if api.token != "" { req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", api.token)) } res, err := api.client.Do(req) if err != nil { return nil, err } if res.StatusCode == 504 || res.StatusCode == 400 || res.StatusCode == 502 { api.logger.Warn(fmt.Sprintf("[%d] %s", res.StatusCode, res.Status)) res.Body.Close() return nil, fmt.Errorf("[%d] %s", res.StatusCode, res.Status) } return res.Body, nil } func (r *Request[P]) do(api *API) (io.ReadCloser, error) { ctx := context.Background() return r.doWithContext(ctx, api) } func (r *Request[P]) DoWithContext(ctx context.Context, api *API) (AIResponse, error) { var zero AIResponse body, err := r.doWithContext(ctx, api) if err != nil { return zero, err } defer body.Close() data, err := io.ReadAll(body) if err != nil { return zero, err } err = api.handleAIError(data) if err != nil { return zero, err } err = json.Unmarshal(data, &zero) return zero, err } func (r *Request[P]) Do(api *API) (AIResponse, error) { ctx := context.Background() return r.DoWithContext(ctx, api) } func (r *Request[P]) DoStreamWithContext(ctx context.Context, api *API) (iter.Seq2[AIResponse, error], error) { body, err := r.doWithContext(ctx, api) if err != nil { return nil, err } reader := bufio.NewReader(body) return func(yield func(AIResponse, error) bool) { defer body.Close() var zero AIResponse for { line, err := reader.ReadString('\n') if err != nil && err != io.EOF { yield(zero, err) return } if line == "" || line == "\n" { continue } if strings.HasPrefix(line, "data: ") { line = line[len("data: "):] } line = strings.Trim(strings.Trim(strings.TrimSpace(line), "\r"), "\n") if strings.HasPrefix(line, "[DONE]") { return } var resp AIResponse err = json.Unmarshal([]byte(line), &resp) if err != nil { yield(zero, fmt.Errorf("%v\n%s", err, line)) return } if !yield(resp, nil) { return } } }, nil } func (r *Request[P]) DoStream(api *API) (iter.Seq2[AIResponse, error], error) { ctx := context.Background() return r.DoStreamWithContext(ctx, api) } func (api *API) handleAIError(body []byte) error { var tempData any err := json.Unmarshal(body, &tempData) if err != nil { return err } // {"error":{"message":"openai_error","type":"bad_response_status_code","param":"","code":"bad_response_status_code"}} if eData, ok := tempData.(map[string]any); ok { if errorData, ok := eData["error"]; ok { if errorPayload, ok := errorData.(map[string]any); ok { code, ok := errorPayload["code"] if !ok { return errors.New("unknown error code") } return errors.New(fmt.Sprintf("%v", code)) } return errors.New(string(body)) } } else if eData, ok := tempData.(string); ok { return errors.New(eData) } return nil }