|
@@ -2,12 +2,14 @@ package coze
|
|
|
|
|
|
|
|
import (
|
|
import (
|
|
|
"encoding/json"
|
|
"encoding/json"
|
|
|
|
|
+ "errors"
|
|
|
"fmt"
|
|
"fmt"
|
|
|
"io"
|
|
"io"
|
|
|
"net/http"
|
|
"net/http"
|
|
|
"one-api/dto"
|
|
"one-api/dto"
|
|
|
"one-api/relay/common"
|
|
"one-api/relay/common"
|
|
|
relaycommon "one-api/relay/common"
|
|
relaycommon "one-api/relay/common"
|
|
|
|
|
+ "one-api/relay/helper"
|
|
|
"one-api/service"
|
|
"one-api/service"
|
|
|
|
|
|
|
|
"github.com/gin-gonic/gin"
|
|
"github.com/gin-gonic/gin"
|
|
@@ -47,14 +49,47 @@ func cozeChatHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rela
|
|
|
}
|
|
}
|
|
|
// convert coze response to openai response
|
|
// convert coze response to openai response
|
|
|
var response dto.TextResponse
|
|
var response dto.TextResponse
|
|
|
- var cozeResponse CozeChatResponse
|
|
|
|
|
|
|
+ var cozeResponse CozeChatDetailResponse
|
|
|
|
|
+ response.Model = info.UpstreamModelName
|
|
|
err = json.Unmarshal(responseBody, &cozeResponse)
|
|
err = json.Unmarshal(responseBody, &cozeResponse)
|
|
|
if err != nil {
|
|
if err != nil {
|
|
|
return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
|
return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
|
|
}
|
|
}
|
|
|
- response.Model = info.UpstreamModelName
|
|
|
|
|
- // TODO: 处理 cozeResponse
|
|
|
|
|
- return nil, nil
|
|
|
|
|
|
|
+ if cozeResponse.Code != 0 {
|
|
|
|
|
+ return service.OpenAIErrorWrapper(errors.New(cozeResponse.Msg), fmt.Sprintf("%d", cozeResponse.Code), http.StatusInternalServerError), nil
|
|
|
|
|
+ }
|
|
|
|
|
+ // 从上下文获取 usage
|
|
|
|
|
+ var usage dto.Usage
|
|
|
|
|
+ usage.PromptTokens = c.GetInt("coze_input_count")
|
|
|
|
|
+ usage.CompletionTokens = c.GetInt("coze_output_count")
|
|
|
|
|
+ usage.TotalTokens = c.GetInt("coze_token_count")
|
|
|
|
|
+ response.Usage = usage
|
|
|
|
|
+ response.Id = helper.GetResponseID(c)
|
|
|
|
|
+
|
|
|
|
|
+ var responseContent json.RawMessage
|
|
|
|
|
+ for _, data := range cozeResponse.Data {
|
|
|
|
|
+ if data.Type == "answer" {
|
|
|
|
|
+ responseContent = data.Content
|
|
|
|
|
+ response.Created = data.CreatedAt
|
|
|
|
|
+ }
|
|
|
|
|
+ }
|
|
|
|
|
+ // 添加 response.Choices
|
|
|
|
|
+ response.Choices = []dto.OpenAITextResponseChoice{
|
|
|
|
|
+ {
|
|
|
|
|
+ Index: 0,
|
|
|
|
|
+ Message: dto.Message{Role: "assistant", Content: responseContent},
|
|
|
|
|
+ FinishReason: "stop",
|
|
|
|
|
+ },
|
|
|
|
|
+ }
|
|
|
|
|
+ jsonResponse, err := json.Marshal(response)
|
|
|
|
|
+ if err != nil {
|
|
|
|
|
+ return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
|
|
|
|
|
+ }
|
|
|
|
|
+ c.Writer.Header().Set("Content-Type", "application/json")
|
|
|
|
|
+ c.Writer.WriteHeader(resp.StatusCode)
|
|
|
|
|
+ _, _ = c.Writer.Write(jsonResponse)
|
|
|
|
|
+
|
|
|
|
|
+ return nil, &usage
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
func checkIfChatComplete(a *Adaptor, c *gin.Context, info *relaycommon.RelayInfo) (error, bool) {
|
|
func checkIfChatComplete(a *Adaptor, c *gin.Context, info *relaycommon.RelayInfo) (error, bool) {
|