| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296 |
- package coze
- import (
- "bufio"
- "encoding/json"
- "errors"
- "fmt"
- "io"
- "net/http"
- "one-api/common"
- "one-api/dto"
- relaycommon "one-api/relay/common"
- "one-api/relay/helper"
- "one-api/service"
- "one-api/types"
- "strings"
- "github.com/gin-gonic/gin"
- )
- func convertCozeChatRequest(c *gin.Context, request dto.GeneralOpenAIRequest) *CozeChatRequest {
- var messages []CozeEnterMessage
- // 将 request的messages的role为user的content转换为CozeMessage
- for _, message := range request.Messages {
- if message.Role == "user" {
- messages = append(messages, CozeEnterMessage{
- Role: "user",
- Content: message.Content,
- // TODO: support more content type
- ContentType: "text",
- })
- }
- }
- user := request.User
- if user == "" {
- user = helper.GetResponseID(c)
- }
- cozeRequest := &CozeChatRequest{
- BotId: c.GetString("bot_id"),
- UserId: user,
- AdditionalMessages: messages,
- Stream: request.Stream,
- }
- return cozeRequest
- }
- func cozeChatHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*types.NewAPIError, *dto.Usage) {
- responseBody, err := io.ReadAll(resp.Body)
- if err != nil {
- return types.NewError(err, types.ErrorCodeBadResponseBody), nil
- }
- common.CloseResponseBodyGracefully(resp)
- // convert coze response to openai response
- var response dto.TextResponse
- var cozeResponse CozeChatDetailResponse
- response.Model = info.UpstreamModelName
- err = json.Unmarshal(responseBody, &cozeResponse)
- if err != nil {
- return types.NewError(err, types.ErrorCodeBadResponseBody), nil
- }
- if cozeResponse.Code != 0 {
- return types.NewError(errors.New(cozeResponse.Msg), types.ErrorCodeBadResponseBody), nil
- }
- // 从上下文获取 usage
- var usage dto.Usage
- usage.PromptTokens = c.GetInt("coze_input_count")
- usage.CompletionTokens = c.GetInt("coze_output_count")
- usage.TotalTokens = c.GetInt("coze_token_count")
- response.Usage = usage
- response.Id = helper.GetResponseID(c)
- var responseContent json.RawMessage
- for _, data := range cozeResponse.Data {
- if data.Type == "answer" {
- responseContent = data.Content
- response.Created = data.CreatedAt
- }
- }
- // 添加 response.Choices
- response.Choices = []dto.OpenAITextResponseChoice{
- {
- Index: 0,
- Message: dto.Message{Role: "assistant", Content: responseContent},
- FinishReason: "stop",
- },
- }
- jsonResponse, err := json.Marshal(response)
- if err != nil {
- return types.NewError(err, types.ErrorCodeBadResponseBody), nil
- }
- c.Writer.Header().Set("Content-Type", "application/json")
- c.Writer.WriteHeader(resp.StatusCode)
- _, _ = c.Writer.Write(jsonResponse)
- return nil, &usage
- }
- func cozeChatStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*types.NewAPIError, *dto.Usage) {
- scanner := bufio.NewScanner(resp.Body)
- scanner.Split(bufio.ScanLines)
- helper.SetEventStreamHeaders(c)
- id := helper.GetResponseID(c)
- var responseText string
- var currentEvent string
- var currentData string
- var usage = &dto.Usage{}
- for scanner.Scan() {
- line := scanner.Text()
- if line == "" {
- if currentEvent != "" && currentData != "" {
- // handle last event
- handleCozeEvent(c, currentEvent, currentData, &responseText, usage, id, info)
- currentEvent = ""
- currentData = ""
- }
- continue
- }
- if strings.HasPrefix(line, "event:") {
- currentEvent = strings.TrimSpace(line[6:])
- continue
- }
- if strings.HasPrefix(line, "data:") {
- currentData = strings.TrimSpace(line[5:])
- continue
- }
- }
- // Last event
- if currentEvent != "" && currentData != "" {
- handleCozeEvent(c, currentEvent, currentData, &responseText, usage, id, info)
- }
- if err := scanner.Err(); err != nil {
- return types.NewError(err, types.ErrorCodeBadResponseBody), nil
- }
- helper.Done(c)
- if usage.TotalTokens == 0 {
- usage = service.ResponseText2Usage(responseText, info.UpstreamModelName, c.GetInt("coze_input_count"))
- }
- return nil, usage
- }
- func handleCozeEvent(c *gin.Context, event string, data string, responseText *string, usage *dto.Usage, id string, info *relaycommon.RelayInfo) {
- switch event {
- case "conversation.chat.completed":
- // 将 data 解析为 CozeChatResponseData
- var chatData CozeChatResponseData
- err := json.Unmarshal([]byte(data), &chatData)
- if err != nil {
- common.SysError("error_unmarshalling_stream_response: " + err.Error())
- return
- }
- usage.PromptTokens = chatData.Usage.InputCount
- usage.CompletionTokens = chatData.Usage.OutputCount
- usage.TotalTokens = chatData.Usage.TokenCount
- finishReason := "stop"
- stopResponse := helper.GenerateStopResponse(id, common.GetTimestamp(), info.UpstreamModelName, finishReason)
- helper.ObjectData(c, stopResponse)
- case "conversation.message.delta":
- // 将 data 解析为 CozeChatV3MessageDetail
- var messageData CozeChatV3MessageDetail
- err := json.Unmarshal([]byte(data), &messageData)
- if err != nil {
- common.SysError("error_unmarshalling_stream_response: " + err.Error())
- return
- }
- var content string
- err = json.Unmarshal(messageData.Content, &content)
- if err != nil {
- common.SysError("error_unmarshalling_stream_response: " + err.Error())
- return
- }
- *responseText += content
- openaiResponse := dto.ChatCompletionsStreamResponse{
- Id: id,
- Object: "chat.completion.chunk",
- Created: common.GetTimestamp(),
- Model: info.UpstreamModelName,
- }
- choice := dto.ChatCompletionsStreamResponseChoice{
- Index: 0,
- }
- choice.Delta.SetContentString(content)
- openaiResponse.Choices = append(openaiResponse.Choices, choice)
- helper.ObjectData(c, openaiResponse)
- case "error":
- var errorData CozeError
- err := json.Unmarshal([]byte(data), &errorData)
- if err != nil {
- common.SysError("error_unmarshalling_stream_response: " + err.Error())
- return
- }
- common.SysError(fmt.Sprintf("stream event error: ", errorData.Code, errorData.Message))
- }
- }
- func checkIfChatComplete(a *Adaptor, c *gin.Context, info *relaycommon.RelayInfo) (error, bool) {
- requestURL := fmt.Sprintf("%s/v3/chat/retrieve", info.BaseUrl)
- requestURL = requestURL + "?conversation_id=" + c.GetString("coze_conversation_id") + "&chat_id=" + c.GetString("coze_chat_id")
- // 将 conversationId和chatId作为参数发送get请求
- req, err := http.NewRequest("GET", requestURL, nil)
- if err != nil {
- return err, false
- }
- err = a.SetupRequestHeader(c, &req.Header, info)
- if err != nil {
- return err, false
- }
- resp, err := doRequest(req, info) // 调用 doRequest
- if err != nil {
- return err, false
- }
- if resp == nil { // 确保在 doRequest 失败时 resp 不为 nil 导致 panic
- return fmt.Errorf("resp is nil"), false
- }
- defer resp.Body.Close() // 确保响应体被关闭
- // 解析 resp 到 CozeChatResponse
- var cozeResponse CozeChatResponse
- responseBody, err := io.ReadAll(resp.Body)
- if err != nil {
- return fmt.Errorf("read response body failed: %w", err), false
- }
- err = json.Unmarshal(responseBody, &cozeResponse)
- if err != nil {
- return fmt.Errorf("unmarshal response body failed: %w", err), false
- }
- if cozeResponse.Data.Status == "completed" {
- // 在上下文设置 usage
- c.Set("coze_token_count", cozeResponse.Data.Usage.TokenCount)
- c.Set("coze_output_count", cozeResponse.Data.Usage.OutputCount)
- c.Set("coze_input_count", cozeResponse.Data.Usage.InputCount)
- return nil, true
- } else if cozeResponse.Data.Status == "failed" || cozeResponse.Data.Status == "canceled" || cozeResponse.Data.Status == "requires_action" {
- return fmt.Errorf("chat status: %s", cozeResponse.Data.Status), false
- } else {
- return nil, false
- }
- }
- func getChatDetail(a *Adaptor, c *gin.Context, info *relaycommon.RelayInfo) (*http.Response, error) {
- requestURL := fmt.Sprintf("%s/v3/chat/message/list", info.BaseUrl)
- requestURL = requestURL + "?conversation_id=" + c.GetString("coze_conversation_id") + "&chat_id=" + c.GetString("coze_chat_id")
- req, err := http.NewRequest("GET", requestURL, nil)
- if err != nil {
- return nil, fmt.Errorf("new request failed: %w", err)
- }
- err = a.SetupRequestHeader(c, &req.Header, info)
- if err != nil {
- return nil, fmt.Errorf("setup request header failed: %w", err)
- }
- resp, err := doRequest(req, info)
- if err != nil {
- return nil, fmt.Errorf("do request failed: %w", err)
- }
- return resp, nil
- }
- func doRequest(req *http.Request, info *relaycommon.RelayInfo) (*http.Response, error) {
- var client *http.Client
- var err error // 声明 err 变量
- if info.ChannelSetting.Proxy != "" {
- client, err = service.NewProxyHttpClient(info.ChannelSetting.Proxy)
- if err != nil {
- return nil, fmt.Errorf("new proxy http client failed: %w", err)
- }
- } else {
- client = service.GetHttpClient()
- }
- resp, err := client.Do(req)
- if err != nil { // 增加对 client.Do(req) 返回错误的检查
- return nil, fmt.Errorf("client.Do failed: %w", err)
- }
- // _ = resp.Body.Close()
- return resp, nil
- }
|