relay-coze.go 8.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296
  1. package coze
  2. import (
  3. "bufio"
  4. "encoding/json"
  5. "errors"
  6. "fmt"
  7. "io"
  8. "net/http"
  9. "one-api/common"
  10. "one-api/dto"
  11. relaycommon "one-api/relay/common"
  12. "one-api/relay/helper"
  13. "one-api/service"
  14. "one-api/types"
  15. "strings"
  16. "github.com/gin-gonic/gin"
  17. )
  18. func convertCozeChatRequest(c *gin.Context, request dto.GeneralOpenAIRequest) *CozeChatRequest {
  19. var messages []CozeEnterMessage
  20. // 将 request的messages的role为user的content转换为CozeMessage
  21. for _, message := range request.Messages {
  22. if message.Role == "user" {
  23. messages = append(messages, CozeEnterMessage{
  24. Role: "user",
  25. Content: message.Content,
  26. // TODO: support more content type
  27. ContentType: "text",
  28. })
  29. }
  30. }
  31. user := request.User
  32. if user == "" {
  33. user = helper.GetResponseID(c)
  34. }
  35. cozeRequest := &CozeChatRequest{
  36. BotId: c.GetString("bot_id"),
  37. UserId: user,
  38. AdditionalMessages: messages,
  39. Stream: request.Stream,
  40. }
  41. return cozeRequest
  42. }
  43. func cozeChatHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*types.NewAPIError, *dto.Usage) {
  44. responseBody, err := io.ReadAll(resp.Body)
  45. if err != nil {
  46. return types.NewError(err, types.ErrorCodeBadResponseBody), nil
  47. }
  48. common.CloseResponseBodyGracefully(resp)
  49. // convert coze response to openai response
  50. var response dto.TextResponse
  51. var cozeResponse CozeChatDetailResponse
  52. response.Model = info.UpstreamModelName
  53. err = json.Unmarshal(responseBody, &cozeResponse)
  54. if err != nil {
  55. return types.NewError(err, types.ErrorCodeBadResponseBody), nil
  56. }
  57. if cozeResponse.Code != 0 {
  58. return types.NewError(errors.New(cozeResponse.Msg), types.ErrorCodeBadResponseBody), nil
  59. }
  60. // 从上下文获取 usage
  61. var usage dto.Usage
  62. usage.PromptTokens = c.GetInt("coze_input_count")
  63. usage.CompletionTokens = c.GetInt("coze_output_count")
  64. usage.TotalTokens = c.GetInt("coze_token_count")
  65. response.Usage = usage
  66. response.Id = helper.GetResponseID(c)
  67. var responseContent json.RawMessage
  68. for _, data := range cozeResponse.Data {
  69. if data.Type == "answer" {
  70. responseContent = data.Content
  71. response.Created = data.CreatedAt
  72. }
  73. }
  74. // 添加 response.Choices
  75. response.Choices = []dto.OpenAITextResponseChoice{
  76. {
  77. Index: 0,
  78. Message: dto.Message{Role: "assistant", Content: responseContent},
  79. FinishReason: "stop",
  80. },
  81. }
  82. jsonResponse, err := json.Marshal(response)
  83. if err != nil {
  84. return types.NewError(err, types.ErrorCodeBadResponseBody), nil
  85. }
  86. c.Writer.Header().Set("Content-Type", "application/json")
  87. c.Writer.WriteHeader(resp.StatusCode)
  88. _, _ = c.Writer.Write(jsonResponse)
  89. return nil, &usage
  90. }
  91. func cozeChatStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*types.NewAPIError, *dto.Usage) {
  92. scanner := bufio.NewScanner(resp.Body)
  93. scanner.Split(bufio.ScanLines)
  94. helper.SetEventStreamHeaders(c)
  95. id := helper.GetResponseID(c)
  96. var responseText string
  97. var currentEvent string
  98. var currentData string
  99. var usage = &dto.Usage{}
  100. for scanner.Scan() {
  101. line := scanner.Text()
  102. if line == "" {
  103. if currentEvent != "" && currentData != "" {
  104. // handle last event
  105. handleCozeEvent(c, currentEvent, currentData, &responseText, usage, id, info)
  106. currentEvent = ""
  107. currentData = ""
  108. }
  109. continue
  110. }
  111. if strings.HasPrefix(line, "event:") {
  112. currentEvent = strings.TrimSpace(line[6:])
  113. continue
  114. }
  115. if strings.HasPrefix(line, "data:") {
  116. currentData = strings.TrimSpace(line[5:])
  117. continue
  118. }
  119. }
  120. // Last event
  121. if currentEvent != "" && currentData != "" {
  122. handleCozeEvent(c, currentEvent, currentData, &responseText, usage, id, info)
  123. }
  124. if err := scanner.Err(); err != nil {
  125. return types.NewError(err, types.ErrorCodeBadResponseBody), nil
  126. }
  127. helper.Done(c)
  128. if usage.TotalTokens == 0 {
  129. usage = service.ResponseText2Usage(responseText, info.UpstreamModelName, c.GetInt("coze_input_count"))
  130. }
  131. return nil, usage
  132. }
  133. func handleCozeEvent(c *gin.Context, event string, data string, responseText *string, usage *dto.Usage, id string, info *relaycommon.RelayInfo) {
  134. switch event {
  135. case "conversation.chat.completed":
  136. // 将 data 解析为 CozeChatResponseData
  137. var chatData CozeChatResponseData
  138. err := json.Unmarshal([]byte(data), &chatData)
  139. if err != nil {
  140. common.SysError("error_unmarshalling_stream_response: " + err.Error())
  141. return
  142. }
  143. usage.PromptTokens = chatData.Usage.InputCount
  144. usage.CompletionTokens = chatData.Usage.OutputCount
  145. usage.TotalTokens = chatData.Usage.TokenCount
  146. finishReason := "stop"
  147. stopResponse := helper.GenerateStopResponse(id, common.GetTimestamp(), info.UpstreamModelName, finishReason)
  148. helper.ObjectData(c, stopResponse)
  149. case "conversation.message.delta":
  150. // 将 data 解析为 CozeChatV3MessageDetail
  151. var messageData CozeChatV3MessageDetail
  152. err := json.Unmarshal([]byte(data), &messageData)
  153. if err != nil {
  154. common.SysError("error_unmarshalling_stream_response: " + err.Error())
  155. return
  156. }
  157. var content string
  158. err = json.Unmarshal(messageData.Content, &content)
  159. if err != nil {
  160. common.SysError("error_unmarshalling_stream_response: " + err.Error())
  161. return
  162. }
  163. *responseText += content
  164. openaiResponse := dto.ChatCompletionsStreamResponse{
  165. Id: id,
  166. Object: "chat.completion.chunk",
  167. Created: common.GetTimestamp(),
  168. Model: info.UpstreamModelName,
  169. }
  170. choice := dto.ChatCompletionsStreamResponseChoice{
  171. Index: 0,
  172. }
  173. choice.Delta.SetContentString(content)
  174. openaiResponse.Choices = append(openaiResponse.Choices, choice)
  175. helper.ObjectData(c, openaiResponse)
  176. case "error":
  177. var errorData CozeError
  178. err := json.Unmarshal([]byte(data), &errorData)
  179. if err != nil {
  180. common.SysError("error_unmarshalling_stream_response: " + err.Error())
  181. return
  182. }
  183. common.SysError(fmt.Sprintf("stream event error: ", errorData.Code, errorData.Message))
  184. }
  185. }
  186. func checkIfChatComplete(a *Adaptor, c *gin.Context, info *relaycommon.RelayInfo) (error, bool) {
  187. requestURL := fmt.Sprintf("%s/v3/chat/retrieve", info.BaseUrl)
  188. requestURL = requestURL + "?conversation_id=" + c.GetString("coze_conversation_id") + "&chat_id=" + c.GetString("coze_chat_id")
  189. // 将 conversationId和chatId作为参数发送get请求
  190. req, err := http.NewRequest("GET", requestURL, nil)
  191. if err != nil {
  192. return err, false
  193. }
  194. err = a.SetupRequestHeader(c, &req.Header, info)
  195. if err != nil {
  196. return err, false
  197. }
  198. resp, err := doRequest(req, info) // 调用 doRequest
  199. if err != nil {
  200. return err, false
  201. }
  202. if resp == nil { // 确保在 doRequest 失败时 resp 不为 nil 导致 panic
  203. return fmt.Errorf("resp is nil"), false
  204. }
  205. defer resp.Body.Close() // 确保响应体被关闭
  206. // 解析 resp 到 CozeChatResponse
  207. var cozeResponse CozeChatResponse
  208. responseBody, err := io.ReadAll(resp.Body)
  209. if err != nil {
  210. return fmt.Errorf("read response body failed: %w", err), false
  211. }
  212. err = json.Unmarshal(responseBody, &cozeResponse)
  213. if err != nil {
  214. return fmt.Errorf("unmarshal response body failed: %w", err), false
  215. }
  216. if cozeResponse.Data.Status == "completed" {
  217. // 在上下文设置 usage
  218. c.Set("coze_token_count", cozeResponse.Data.Usage.TokenCount)
  219. c.Set("coze_output_count", cozeResponse.Data.Usage.OutputCount)
  220. c.Set("coze_input_count", cozeResponse.Data.Usage.InputCount)
  221. return nil, true
  222. } else if cozeResponse.Data.Status == "failed" || cozeResponse.Data.Status == "canceled" || cozeResponse.Data.Status == "requires_action" {
  223. return fmt.Errorf("chat status: %s", cozeResponse.Data.Status), false
  224. } else {
  225. return nil, false
  226. }
  227. }
  228. func getChatDetail(a *Adaptor, c *gin.Context, info *relaycommon.RelayInfo) (*http.Response, error) {
  229. requestURL := fmt.Sprintf("%s/v3/chat/message/list", info.BaseUrl)
  230. requestURL = requestURL + "?conversation_id=" + c.GetString("coze_conversation_id") + "&chat_id=" + c.GetString("coze_chat_id")
  231. req, err := http.NewRequest("GET", requestURL, nil)
  232. if err != nil {
  233. return nil, fmt.Errorf("new request failed: %w", err)
  234. }
  235. err = a.SetupRequestHeader(c, &req.Header, info)
  236. if err != nil {
  237. return nil, fmt.Errorf("setup request header failed: %w", err)
  238. }
  239. resp, err := doRequest(req, info)
  240. if err != nil {
  241. return nil, fmt.Errorf("do request failed: %w", err)
  242. }
  243. return resp, nil
  244. }
  245. func doRequest(req *http.Request, info *relaycommon.RelayInfo) (*http.Response, error) {
  246. var client *http.Client
  247. var err error // 声明 err 变量
  248. if info.ChannelSetting.Proxy != "" {
  249. client, err = service.NewProxyHttpClient(info.ChannelSetting.Proxy)
  250. if err != nil {
  251. return nil, fmt.Errorf("new proxy http client failed: %w", err)
  252. }
  253. } else {
  254. client = service.GetHttpClient()
  255. }
  256. resp, err := client.Do(req)
  257. if err != nil { // 增加对 client.Do(req) 返回错误的检查
  258. return nil, fmt.Errorf("client.Do failed: %w", err)
  259. }
  260. // _ = resp.Body.Close()
  261. return resp, nil
  262. }