some changes

This commit is contained in:
2026-03-02 00:58:43 +03:00
parent b394c0be68
commit 3e0d3db47e
13 changed files with 486 additions and 292 deletions

204
openai/api.go Normal file
View File

@@ -0,0 +1,204 @@
package openai
import (
"bufio"
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"io"
"iter"
"net/http"
"net/url"
"os"
"strings"
"time"
"git.nix13.pw/scuroneko/slog"
)
type API struct {
token string
model string
baseUrl string
logger *slog.Logger
client *http.Client
stream bool
}
func NewOpenAIAPI(baseURL, token, model string) *API {
logger := slog.CreateLogger()
level := slog.FATAL
if os.Getenv("DEBUG") == "true" {
level = slog.DEBUG
}
logger = logger.Prefix("AI").Level(level)
// FIXME Leak here
//logger = logger.AddWriter(logger.CreateJsonStdoutWriter())
proxy, err := url.Parse(os.Getenv("HTTPS_PROXY"))
if err != nil {
logger.Error(err)
logger.Close()
return nil
}
t := &http.Transport{}
if proxy.Host != "" {
t.Proxy = http.ProxyURL(proxy)
}
client := &http.Client{
Timeout: 5 * time.Minute,
Transport: t,
}
return &API{
token: token,
model: model,
baseUrl: baseURL,
logger: logger,
client: client,
}
}
func (api *API) Close() error {
return api.logger.Close()
}
func (api *API) SetStream(stream bool) *API {
api.stream = stream
return api
}
func (api *API) GetModel() string { return api.model }
func (api *API) GetBaseURL() string { return api.baseUrl }
type Request[P any] struct {
params P
method string
}
func NewRequest[P any](method string, params P) *Request[P] {
return &Request[P]{params, method}
}
func (r *Request[P]) doWithContext(ctx context.Context, api *API) (io.ReadCloser, error) {
data, err := json.Marshal(r.params)
if err != nil {
return nil, err
}
u := fmt.Sprintf("%s/v1/%s", api.baseUrl, r.method)
req, err := http.NewRequestWithContext(ctx, "POST", u, bytes.NewBuffer(data))
if err != nil {
return nil, err
}
req.Header.Set("Content-Type", "application/json")
if api.token != "" {
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", api.token))
}
res, err := api.client.Do(req)
if err != nil {
return nil, err
}
if res.StatusCode == 504 || res.StatusCode == 400 || res.StatusCode == 502 {
api.logger.Warn(fmt.Sprintf("[%d] %s", res.StatusCode, res.Status))
res.Body.Close()
return nil, fmt.Errorf("[%d] %s", res.StatusCode, res.Status)
}
return res.Body, nil
}
func (r *Request[P]) do(api *API) (io.ReadCloser, error) {
ctx := context.Background()
return r.doWithContext(ctx, api)
}
func (r *Request[P]) DoWithContext(ctx context.Context, api *API) (AIResponse, error) {
var zero AIResponse
body, err := r.doWithContext(ctx, api)
if err != nil {
return zero, err
}
defer body.Close()
data, err := io.ReadAll(body)
if err != nil {
return zero, err
}
err = api.handleAIError(data)
if err != nil {
return zero, err
}
err = json.Unmarshal(data, &zero)
return zero, err
}
func (r *Request[P]) Do(api *API) (AIResponse, error) {
ctx := context.Background()
return r.DoWithContext(ctx, api)
}
func (r *Request[P]) DoStreamWithContext(ctx context.Context, api *API) (iter.Seq2[AIResponse, error], error) {
body, err := r.doWithContext(ctx, api)
if err != nil {
return nil, err
}
reader := bufio.NewReader(body)
return func(yield func(AIResponse, error) bool) {
defer body.Close()
var zero AIResponse
for {
line, err := reader.ReadString('\n')
if err != nil && err != io.EOF {
yield(zero, err)
return
}
if line == "" || line == "\n" {
continue
}
if strings.HasPrefix(line, "data: ") {
line = line[len("data: "):]
}
line = strings.Trim(strings.Trim(strings.TrimSpace(line), "\r"), "\n")
if strings.HasPrefix(line, "[DONE]") {
return
}
var resp AIResponse
err = json.Unmarshal([]byte(line), &resp)
if err != nil {
yield(zero, fmt.Errorf("%v\n%s", err, line))
return
}
if !yield(resp, nil) {
return
}
time.Sleep(time.Millisecond * 100)
}
}, nil
}
func (r *Request[P]) DoStream(api *API) (iter.Seq2[AIResponse, error], error) {
ctx := context.Background()
return r.DoStreamWithContext(ctx, api)
}
func (api *API) handleAIError(body []byte) error {
var tempData any
err := json.Unmarshal(body, &tempData)
if err != nil {
return err
}
// {"error":{"message":"openai_error","type":"bad_response_status_code","param":"","code":"bad_response_status_code"}}
if eData, ok := tempData.(map[string]any); ok {
if errorData, ok := eData["error"]; ok {
if errorPayload, ok := errorData.(map[string]any); ok {
code, ok := errorPayload["code"]
if !ok {
return errors.New("unknown error code")
}
return errors.New(fmt.Sprintf("%v", code))
}
return errors.New(string(body))
}
} else if eData, ok := tempData.(string); ok {
return errors.New(eData)
}
return nil
}

49
openai/completitions.go Normal file
View File

@@ -0,0 +1,49 @@
package openai
import (
"fmt"
"iter"
)
var MaxRetriesErr = fmt.Errorf("max retries exceeded")
var BadResponseErr = fmt.Errorf("bad_response_status_code")
type CreateCompletionReq struct {
Model string `json:"model"`
Messages []Message `json:"messages"`
Verbosity string `json:"verbosity,omitempty"`
Temperature float32 `json:"temperature,omitempty"`
PresencePenalty int `json:"presence_penalty,omitempty"`
FrequencyPenalty int `json:"frequency_penalty,omitempty"`
TopP int `json:"top_p,omitempty"`
MaxCompletionTokens int `json:"max_completition_tokens,omitempty"`
Stream bool `json:"stream,omitempty"`
}
func (api *API) CreateCompletionStream(history []Message, message string, temp float32) (iter.Seq2[AIResponse, error], error) {
params := CreateCompletionReq{
Model: api.model,
Messages: append(history, Message{
Role: "user",
Content: message,
}),
Temperature: temp,
Stream: true,
}
req := NewRequest("chat/completions", params)
return req.DoStream(api)
}
func (api *API) CreateCompletion(history []Message, message string, temp float32) (AIResponse, error) {
params := CreateCompletionReq{
Model: api.model,
Messages: append(history, Message{
Role: "user",
Content: message,
}),
Temperature: temp,
Stream: false,
}
req := NewRequest("chat/completions", params)
return req.Do(api)
}

37
openai/sse.go Normal file
View File

@@ -0,0 +1,37 @@
package openai
import (
"bufio"
"io"
"iter"
"strings"
)
// Server-sent event
func ReadSSE(r io.ReadCloser) iter.Seq[string] {
reader := bufio.NewReader(r)
return func(yield func(string) bool) {
for {
line, err := reader.ReadString('\n')
if err != nil {
return
}
if line == "" || line == "\n" {
continue
}
if strings.HasPrefix(line, "data: ") {
line = line[len("data: "):]
}
line = strings.TrimSpace(line)
line = strings.Trim(line, "\r")
line = strings.Trim(line, "\n")
if strings.HasPrefix(line, "[DONE]") {
return
}
if !yield(line) {
return
}
}
}
}

41
openai/types.go Normal file
View File

@@ -0,0 +1,41 @@
package openai
type AIResponse struct {
ID string `json:"id"`
Object string `json:"object"`
Created int64 `json:"created"`
Model string `json:"model"`
Choices []Choice `json:"choices"`
Usage Usage `json:"usage"`
ServiceTier string `json:"service_tier"`
}
type Choice struct {
Index int64 `json:"index"`
Message Message `json:"message"`
Delta Message `json:"delta"`
Logprobs interface{} `json:"logprobs"`
FinishReason string `json:"finish_reason"`
}
type Message struct {
Role string `json:"role"`
Content string `json:"content"`
Refusal interface{} `json:"refusal"`
Annotations []interface{} `json:"annotations"`
}
type Usage struct {
PromptTokens int64 `json:"prompt_tokens"`
CompletionTokens int64 `json:"completion_tokens"`
TotalTokens int64 `json:"total_tokens"`
PromptTokensDetails PromptTokensDetails `json:"prompt_tokens_details"`
CompletionTokensDetails CompletionTokensDetails `json:"completion_tokens_details"`
}
type CompletionTokensDetails struct {
ReasoningTokens int64 `json:"reasoning_tokens"`
AudioTokens int64 `json:"audio_tokens"`
AcceptedPredictionTokens int64 `json:"accepted_prediction_tokens"`
RejectedPredictionTokens int64 `json:"rejected_prediction_tokens"`
}
type PromptTokensDetails struct {
CachedTokens int64 `json:"cached_tokens"`
AudioTokens int64 `json:"audio_tokens"`
}