Просмотр исходного кода

feat: add xAI handling and response processing

CaIon 11 месяцев назад
Родитель
Сommit
8723e3f239
6 измененных файлов с 137 добавлено и 9 удалено
  1. 4 0
      common/json.go
  2. 4 3
      dto/realtime.go
  3. 5 6
      relay/channel/xai/adaptor.go
  4. 14 0
      relay/channel/xai/dto.go
  5. 107 0
      relay/channel/xai/text.go
  6. 3 0
      relay/helper/common.go

+ 4 - 0
common/json.go

@@ -12,3 +12,7 @@ func DecodeJson(data []byte, v any) error {
 func DecodeJsonStr(data string, v any) error {
 	return DecodeJson(StringToByteSlice(data), v)
 }
+
+func EncodeJson(v any) ([]byte, error) {
+	return json.Marshal(v)
+}

+ 4 - 3
dto/realtime.go

@@ -45,15 +45,16 @@ type RealtimeUsage struct {
 
 type InputTokenDetails struct {
 	CachedTokens         int `json:"cached_tokens"`
-	CachedCreationTokens int
+	CachedCreationTokens int `json:"-"`
 	TextTokens           int `json:"text_tokens"`
 	AudioTokens          int `json:"audio_tokens"`
 	ImageTokens          int `json:"image_tokens"`
 }
 
 type OutputTokenDetails struct {
-	TextTokens  int `json:"text_tokens"`
-	AudioTokens int `json:"audio_tokens"`
+	TextTokens      int `json:"text_tokens"`
+	AudioTokens     int `json:"audio_tokens"`
+	ReasoningTokens int `json:"reasoning_tokens"`
 }
 
 type RealtimeSession struct {

+ 5 - 6
relay/channel/xai/adaptor.go

@@ -8,7 +8,6 @@ import (
 	"net/http"
 	"one-api/dto"
 	"one-api/relay/channel"
-	"one-api/relay/channel/openai"
 	relaycommon "one-api/relay/common"
 	"strings"
 )
@@ -86,13 +85,13 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
 
 func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
 	if info.IsStream {
-		err, usage = openai.OaiStreamHandler(c, resp, info)
+		err, usage = xAIStreamHandler(c, resp, info)
 	} else {
-		err, usage = openai.OpenaiHandler(c, resp, info)
-	}
-	if _, ok := usage.(*dto.Usage); ok && usage != nil {
-		usage.(*dto.Usage).CompletionTokens = usage.(*dto.Usage).TotalTokens - usage.(*dto.Usage).PromptTokens
+		err, usage = xAIHandler(c, resp, info)
 	}
+	//if _, ok := usage.(*dto.Usage); ok && usage != nil {
+	//	usage.(*dto.Usage).CompletionTokens = usage.(*dto.Usage).TotalTokens - usage.(*dto.Usage).PromptTokens
+	//}
 
 	return
 }

+ 14 - 0
relay/channel/xai/dto.go

@@ -0,0 +1,14 @@
+package xai
+
+import "one-api/dto"
+
+// ChatCompletionResponse represents the response from XAI chat completion API
+type ChatCompletionResponse struct {
+	Id                string `json:"id"`
+	Object            string `json:"object"`
+	Created           int64  `json:"created"`
+	Model             string `json:"model"`
+	Choices           []dto.ChatCompletionsStreamResponseChoice
+	Usage             *dto.Usage `json:"usage"`
+	SystemFingerprint string     `json:"system_fingerprint"`
+}

+ 107 - 0
relay/channel/xai/text.go

@@ -0,0 +1,107 @@
+package xai
+
+import (
+	"bytes"
+	"encoding/json"
+	"github.com/gin-gonic/gin"
+	"io"
+	"net/http"
+	"one-api/common"
+	"one-api/dto"
+	relaycommon "one-api/relay/common"
+	"one-api/relay/helper"
+	"one-api/service"
+)
+
+func streamResponseXAI2OpenAI(xAIResp *dto.ChatCompletionsStreamResponse, usage *dto.Usage) *dto.ChatCompletionsStreamResponse {
+	if xAIResp == nil {
+		return nil
+	}
+	if xAIResp.Usage != nil {
+		xAIResp.Usage.CompletionTokens = usage.CompletionTokens
+	}
+	openAIResp := &dto.ChatCompletionsStreamResponse{
+		Id:      xAIResp.Id,
+		Object:  xAIResp.Object,
+		Created: xAIResp.Created,
+		Model:   xAIResp.Model,
+		Choices: xAIResp.Choices,
+		Usage:   xAIResp.Usage,
+	}
+
+	return openAIResp
+}
+
+func xAIStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
+	usage := &dto.Usage{}
+
+	helper.SetEventStreamHeaders(c)
+
+	helper.StreamScannerHandler(c, resp, info, func(data string) bool {
+		var xAIResp *dto.ChatCompletionsStreamResponse
+		err := json.Unmarshal([]byte(data), &xAIResp)
+		if err != nil {
+			common.SysError("error unmarshalling stream response: " + err.Error())
+			return true
+		}
+
+		// 把 xAI 的usage转换为 OpenAI 的usage
+		if xAIResp.Usage != nil {
+			usage.PromptTokens = xAIResp.Usage.PromptTokens
+			usage.TotalTokens = xAIResp.Usage.TotalTokens
+			usage.CompletionTokens = usage.TotalTokens - usage.PromptTokens
+		}
+
+		openaiResponse := streamResponseXAI2OpenAI(xAIResp, usage)
+		err = helper.ObjectData(c, openaiResponse)
+		if err != nil {
+			common.SysError(err.Error())
+		}
+		return true
+	})
+
+	helper.Done(c)
+	err := resp.Body.Close()
+	if err != nil {
+		//return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
+		common.SysError("close_response_body_failed: " + err.Error())
+	}
+	return nil, usage
+}
+
+func xAIHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
+	responseBody, err := io.ReadAll(resp.Body)
+	var response *dto.TextResponse
+	err = common.DecodeJson(responseBody, &response)
+	if err != nil {
+		common.SysError("error unmarshalling stream response: " + err.Error())
+		return nil, nil
+	}
+	response.Usage.CompletionTokens = response.Usage.TotalTokens - response.Usage.PromptTokens
+	response.Usage.CompletionTokenDetails.TextTokens = response.Usage.CompletionTokens - response.Usage.CompletionTokenDetails.ReasoningTokens
+
+	// new body
+	encodeJson, err := common.EncodeJson(response)
+	if err != nil {
+		common.SysError("error marshalling stream response: " + err.Error())
+		return nil, nil
+	}
+
+	// set new body
+	resp.Body = io.NopCloser(bytes.NewBuffer(encodeJson))
+
+	for k, v := range resp.Header {
+		c.Writer.Header().Set(k, v[0])
+	}
+	c.Writer.WriteHeader(resp.StatusCode)
+	_, err = io.Copy(c.Writer, resp.Body)
+	if err != nil {
+		return service.OpenAIErrorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError), nil
+	}
+	err = resp.Body.Close()
+	if err != nil {
+		return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
+	}
+
+	return nil, &response.Usage
+}

+ 3 - 0
relay/helper/common.go

@@ -56,6 +56,9 @@ func StringData(c *gin.Context, str string) error {
 }
 
 func ObjectData(c *gin.Context, object interface{}) error {
+	if object == nil {
+		return errors.New("object is nil")
+	}
 	jsonData, err := json.Marshal(object)
 	if err != nil {
 		return fmt.Errorf("error marshalling object: %w", err)