some changes
This commit is contained in:
204
openai/api.go
Normal file
204
openai/api.go
Normal file
@@ -0,0 +1,204 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"iter"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"git.nix13.pw/scuroneko/slog"
|
||||
)
|
||||
|
||||
type API struct {
|
||||
token string
|
||||
model string
|
||||
baseUrl string
|
||||
logger *slog.Logger
|
||||
client *http.Client
|
||||
stream bool
|
||||
}
|
||||
|
||||
func NewOpenAIAPI(baseURL, token, model string) *API {
|
||||
logger := slog.CreateLogger()
|
||||
level := slog.FATAL
|
||||
if os.Getenv("DEBUG") == "true" {
|
||||
level = slog.DEBUG
|
||||
}
|
||||
logger = logger.Prefix("AI").Level(level)
|
||||
// FIXME Leak here
|
||||
//logger = logger.AddWriter(logger.CreateJsonStdoutWriter())
|
||||
proxy, err := url.Parse(os.Getenv("HTTPS_PROXY"))
|
||||
if err != nil {
|
||||
logger.Error(err)
|
||||
logger.Close()
|
||||
return nil
|
||||
}
|
||||
t := &http.Transport{}
|
||||
if proxy.Host != "" {
|
||||
t.Proxy = http.ProxyURL(proxy)
|
||||
}
|
||||
client := &http.Client{
|
||||
Timeout: 5 * time.Minute,
|
||||
Transport: t,
|
||||
}
|
||||
return &API{
|
||||
token: token,
|
||||
model: model,
|
||||
baseUrl: baseURL,
|
||||
logger: logger,
|
||||
client: client,
|
||||
}
|
||||
}
|
||||
func (api *API) Close() error {
|
||||
return api.logger.Close()
|
||||
}
|
||||
func (api *API) SetStream(stream bool) *API {
|
||||
api.stream = stream
|
||||
return api
|
||||
}
|
||||
func (api *API) GetModel() string { return api.model }
|
||||
func (api *API) GetBaseURL() string { return api.baseUrl }
|
||||
|
||||
type Request[P any] struct {
|
||||
params P
|
||||
method string
|
||||
}
|
||||
|
||||
func NewRequest[P any](method string, params P) *Request[P] {
|
||||
return &Request[P]{params, method}
|
||||
}
|
||||
|
||||
func (r *Request[P]) doWithContext(ctx context.Context, api *API) (io.ReadCloser, error) {
|
||||
data, err := json.Marshal(r.params)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
u := fmt.Sprintf("%s/v1/%s", api.baseUrl, r.method)
|
||||
req, err := http.NewRequestWithContext(ctx, "POST", u, bytes.NewBuffer(data))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
if api.token != "" {
|
||||
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", api.token))
|
||||
}
|
||||
|
||||
res, err := api.client.Do(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if res.StatusCode == 504 || res.StatusCode == 400 || res.StatusCode == 502 {
|
||||
api.logger.Warn(fmt.Sprintf("[%d] %s", res.StatusCode, res.Status))
|
||||
res.Body.Close()
|
||||
return nil, fmt.Errorf("[%d] %s", res.StatusCode, res.Status)
|
||||
}
|
||||
return res.Body, nil
|
||||
}
|
||||
func (r *Request[P]) do(api *API) (io.ReadCloser, error) {
|
||||
ctx := context.Background()
|
||||
return r.doWithContext(ctx, api)
|
||||
}
|
||||
|
||||
func (r *Request[P]) DoWithContext(ctx context.Context, api *API) (AIResponse, error) {
|
||||
var zero AIResponse
|
||||
body, err := r.doWithContext(ctx, api)
|
||||
if err != nil {
|
||||
return zero, err
|
||||
}
|
||||
defer body.Close()
|
||||
data, err := io.ReadAll(body)
|
||||
if err != nil {
|
||||
return zero, err
|
||||
}
|
||||
err = api.handleAIError(data)
|
||||
if err != nil {
|
||||
return zero, err
|
||||
}
|
||||
|
||||
err = json.Unmarshal(data, &zero)
|
||||
return zero, err
|
||||
}
|
||||
func (r *Request[P]) Do(api *API) (AIResponse, error) {
|
||||
ctx := context.Background()
|
||||
return r.DoWithContext(ctx, api)
|
||||
}
|
||||
func (r *Request[P]) DoStreamWithContext(ctx context.Context, api *API) (iter.Seq2[AIResponse, error], error) {
|
||||
body, err := r.doWithContext(ctx, api)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
reader := bufio.NewReader(body)
|
||||
return func(yield func(AIResponse, error) bool) {
|
||||
defer body.Close()
|
||||
var zero AIResponse
|
||||
for {
|
||||
line, err := reader.ReadString('\n')
|
||||
if err != nil && err != io.EOF {
|
||||
yield(zero, err)
|
||||
return
|
||||
}
|
||||
if line == "" || line == "\n" {
|
||||
continue
|
||||
}
|
||||
|
||||
if strings.HasPrefix(line, "data: ") {
|
||||
line = line[len("data: "):]
|
||||
}
|
||||
line = strings.Trim(strings.Trim(strings.TrimSpace(line), "\r"), "\n")
|
||||
if strings.HasPrefix(line, "[DONE]") {
|
||||
return
|
||||
}
|
||||
|
||||
var resp AIResponse
|
||||
err = json.Unmarshal([]byte(line), &resp)
|
||||
if err != nil {
|
||||
yield(zero, fmt.Errorf("%v\n%s", err, line))
|
||||
return
|
||||
}
|
||||
if !yield(resp, nil) {
|
||||
return
|
||||
}
|
||||
time.Sleep(time.Millisecond * 100)
|
||||
}
|
||||
}, nil
|
||||
}
|
||||
func (r *Request[P]) DoStream(api *API) (iter.Seq2[AIResponse, error], error) {
|
||||
ctx := context.Background()
|
||||
return r.DoStreamWithContext(ctx, api)
|
||||
}
|
||||
|
||||
func (api *API) handleAIError(body []byte) error {
|
||||
var tempData any
|
||||
err := json.Unmarshal(body, &tempData)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// {"error":{"message":"openai_error","type":"bad_response_status_code","param":"","code":"bad_response_status_code"}}
|
||||
if eData, ok := tempData.(map[string]any); ok {
|
||||
if errorData, ok := eData["error"]; ok {
|
||||
if errorPayload, ok := errorData.(map[string]any); ok {
|
||||
code, ok := errorPayload["code"]
|
||||
if !ok {
|
||||
return errors.New("unknown error code")
|
||||
}
|
||||
return errors.New(fmt.Sprintf("%v", code))
|
||||
}
|
||||
return errors.New(string(body))
|
||||
}
|
||||
} else if eData, ok := tempData.(string); ok {
|
||||
return errors.New(eData)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
Reference in New Issue
Block a user