relay-coze.go 8.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298
  1. package coze
  2. import (
  3. "bufio"
  4. "encoding/json"
  5. "errors"
  6. "fmt"
  7. "io"
  8. "net/http"
  9. "strings"
  10. "github.com/QuantumNous/new-api/common"
  11. "github.com/QuantumNous/new-api/dto"
  12. relaycommon "github.com/QuantumNous/new-api/relay/common"
  13. "github.com/QuantumNous/new-api/relay/helper"
  14. "github.com/QuantumNous/new-api/service"
  15. "github.com/QuantumNous/new-api/types"
  16. "github.com/samber/lo"
  17. "github.com/gin-gonic/gin"
  18. )
  19. func convertCozeChatRequest(c *gin.Context, request dto.GeneralOpenAIRequest) *CozeChatRequest {
  20. var messages []CozeEnterMessage
  21. // 将 request的messages的role为user的content转换为CozeMessage
  22. for _, message := range request.Messages {
  23. if message.Role == "user" {
  24. messages = append(messages, CozeEnterMessage{
  25. Role: "user",
  26. Content: message.Content,
  27. // TODO: support more content type
  28. ContentType: "text",
  29. })
  30. }
  31. }
  32. user := request.User
  33. if user == "" {
  34. user = helper.GetResponseID(c)
  35. }
  36. cozeRequest := &CozeChatRequest{
  37. BotId: c.GetString("bot_id"),
  38. UserId: user,
  39. AdditionalMessages: messages,
  40. Stream: lo.FromPtrOr(request.Stream, false),
  41. }
  42. return cozeRequest
  43. }
  44. func cozeChatHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
  45. responseBody, err := io.ReadAll(resp.Body)
  46. if err != nil {
  47. return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
  48. }
  49. service.CloseResponseBodyGracefully(resp)
  50. // convert coze response to openai response
  51. var response dto.TextResponse
  52. var cozeResponse CozeChatDetailResponse
  53. response.Model = info.UpstreamModelName
  54. err = json.Unmarshal(responseBody, &cozeResponse)
  55. if err != nil {
  56. return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
  57. }
  58. if cozeResponse.Code != 0 {
  59. return nil, types.NewError(errors.New(cozeResponse.Msg), types.ErrorCodeBadResponseBody)
  60. }
  61. // 从上下文获取 usage
  62. var usage dto.Usage
  63. usage.PromptTokens = c.GetInt("coze_input_count")
  64. usage.CompletionTokens = c.GetInt("coze_output_count")
  65. usage.TotalTokens = c.GetInt("coze_token_count")
  66. response.Usage = usage
  67. response.Id = helper.GetResponseID(c)
  68. var responseContent json.RawMessage
  69. for _, data := range cozeResponse.Data {
  70. if data.Type == "answer" {
  71. responseContent = data.Content
  72. response.Created = data.CreatedAt
  73. }
  74. }
  75. // 添加 response.Choices
  76. response.Choices = []dto.OpenAITextResponseChoice{
  77. {
  78. Index: 0,
  79. Message: dto.Message{Role: "assistant", Content: responseContent},
  80. FinishReason: "stop",
  81. },
  82. }
  83. jsonResponse, err := json.Marshal(response)
  84. if err != nil {
  85. return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
  86. }
  87. c.Writer.Header().Set("Content-Type", "application/json")
  88. c.Writer.WriteHeader(resp.StatusCode)
  89. _, _ = c.Writer.Write(jsonResponse)
  90. return &usage, nil
  91. }
  92. func cozeChatStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
  93. scanner := bufio.NewScanner(resp.Body)
  94. scanner.Split(bufio.ScanLines)
  95. helper.SetEventStreamHeaders(c)
  96. id := helper.GetResponseID(c)
  97. var responseText string
  98. var currentEvent string
  99. var currentData string
  100. var usage = &dto.Usage{}
  101. for scanner.Scan() {
  102. line := scanner.Text()
  103. if line == "" {
  104. if currentEvent != "" && currentData != "" {
  105. // handle last event
  106. handleCozeEvent(c, currentEvent, currentData, &responseText, usage, id, info)
  107. currentEvent = ""
  108. currentData = ""
  109. }
  110. continue
  111. }
  112. if strings.HasPrefix(line, "event:") {
  113. currentEvent = strings.TrimSpace(line[6:])
  114. continue
  115. }
  116. if strings.HasPrefix(line, "data:") {
  117. currentData = strings.TrimSpace(line[5:])
  118. continue
  119. }
  120. }
  121. // Last event
  122. if currentEvent != "" && currentData != "" {
  123. handleCozeEvent(c, currentEvent, currentData, &responseText, usage, id, info)
  124. }
  125. if err := scanner.Err(); err != nil {
  126. return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
  127. }
  128. helper.Done(c)
  129. if usage.TotalTokens == 0 {
  130. usage = service.ResponseText2Usage(c, responseText, info.UpstreamModelName, c.GetInt("coze_input_count"))
  131. }
  132. return usage, nil
  133. }
  134. func handleCozeEvent(c *gin.Context, event string, data string, responseText *string, usage *dto.Usage, id string, info *relaycommon.RelayInfo) {
  135. switch event {
  136. case "conversation.chat.completed":
  137. // 将 data 解析为 CozeChatResponseData
  138. var chatData CozeChatResponseData
  139. err := json.Unmarshal([]byte(data), &chatData)
  140. if err != nil {
  141. common.SysLog("error_unmarshalling_stream_response: " + err.Error())
  142. return
  143. }
  144. usage.PromptTokens = chatData.Usage.InputCount
  145. usage.CompletionTokens = chatData.Usage.OutputCount
  146. usage.TotalTokens = chatData.Usage.TokenCount
  147. finishReason := "stop"
  148. stopResponse := helper.GenerateStopResponse(id, common.GetTimestamp(), info.UpstreamModelName, finishReason)
  149. helper.ObjectData(c, stopResponse)
  150. case "conversation.message.delta":
  151. // 将 data 解析为 CozeChatV3MessageDetail
  152. var messageData CozeChatV3MessageDetail
  153. err := json.Unmarshal([]byte(data), &messageData)
  154. if err != nil {
  155. common.SysLog("error_unmarshalling_stream_response: " + err.Error())
  156. return
  157. }
  158. var content string
  159. err = json.Unmarshal(messageData.Content, &content)
  160. if err != nil {
  161. common.SysLog("error_unmarshalling_stream_response: " + err.Error())
  162. return
  163. }
  164. *responseText += content
  165. openaiResponse := dto.ChatCompletionsStreamResponse{
  166. Id: id,
  167. Object: "chat.completion.chunk",
  168. Created: common.GetTimestamp(),
  169. Model: info.UpstreamModelName,
  170. }
  171. choice := dto.ChatCompletionsStreamResponseChoice{
  172. Index: 0,
  173. }
  174. choice.Delta.SetContentString(content)
  175. openaiResponse.Choices = append(openaiResponse.Choices, choice)
  176. helper.ObjectData(c, openaiResponse)
  177. case "error":
  178. var errorData CozeError
  179. err := json.Unmarshal([]byte(data), &errorData)
  180. if err != nil {
  181. common.SysLog("error_unmarshalling_stream_response: " + err.Error())
  182. return
  183. }
  184. common.SysLog(fmt.Sprintf("stream event error: %v %v", errorData.Code, errorData.Message))
  185. }
  186. }
  187. func checkIfChatComplete(a *Adaptor, c *gin.Context, info *relaycommon.RelayInfo) (error, bool) {
  188. requestURL := fmt.Sprintf("%s/v3/chat/retrieve", info.ChannelBaseUrl)
  189. requestURL = requestURL + "?conversation_id=" + c.GetString("coze_conversation_id") + "&chat_id=" + c.GetString("coze_chat_id")
  190. // 将 conversationId和chatId作为参数发送get请求
  191. req, err := http.NewRequest("GET", requestURL, nil)
  192. if err != nil {
  193. return err, false
  194. }
  195. err = a.SetupRequestHeader(c, &req.Header, info)
  196. if err != nil {
  197. return err, false
  198. }
  199. resp, err := doRequest(req, info) // 调用 doRequest
  200. if err != nil {
  201. return err, false
  202. }
  203. if resp == nil { // 确保在 doRequest 失败时 resp 不为 nil 导致 panic
  204. return fmt.Errorf("resp is nil"), false
  205. }
  206. defer resp.Body.Close() // 确保响应体被关闭
  207. // 解析 resp 到 CozeChatResponse
  208. var cozeResponse CozeChatResponse
  209. responseBody, err := io.ReadAll(resp.Body)
  210. if err != nil {
  211. return fmt.Errorf("read response body failed: %w", err), false
  212. }
  213. err = json.Unmarshal(responseBody, &cozeResponse)
  214. if err != nil {
  215. return fmt.Errorf("unmarshal response body failed: %w", err), false
  216. }
  217. if cozeResponse.Data.Status == "completed" {
  218. // 在上下文设置 usage
  219. c.Set("coze_token_count", cozeResponse.Data.Usage.TokenCount)
  220. c.Set("coze_output_count", cozeResponse.Data.Usage.OutputCount)
  221. c.Set("coze_input_count", cozeResponse.Data.Usage.InputCount)
  222. return nil, true
  223. } else if cozeResponse.Data.Status == "failed" || cozeResponse.Data.Status == "canceled" || cozeResponse.Data.Status == "requires_action" {
  224. return fmt.Errorf("chat status: %s", cozeResponse.Data.Status), false
  225. } else {
  226. return nil, false
  227. }
  228. }
  229. func getChatDetail(a *Adaptor, c *gin.Context, info *relaycommon.RelayInfo) (*http.Response, error) {
  230. requestURL := fmt.Sprintf("%s/v3/chat/message/list", info.ChannelBaseUrl)
  231. requestURL = requestURL + "?conversation_id=" + c.GetString("coze_conversation_id") + "&chat_id=" + c.GetString("coze_chat_id")
  232. req, err := http.NewRequest("GET", requestURL, nil)
  233. if err != nil {
  234. return nil, fmt.Errorf("new request failed: %w", err)
  235. }
  236. err = a.SetupRequestHeader(c, &req.Header, info)
  237. if err != nil {
  238. return nil, fmt.Errorf("setup request header failed: %w", err)
  239. }
  240. resp, err := doRequest(req, info)
  241. if err != nil {
  242. return nil, fmt.Errorf("do request failed: %w", err)
  243. }
  244. return resp, nil
  245. }
  246. func doRequest(req *http.Request, info *relaycommon.RelayInfo) (*http.Response, error) {
  247. var client *http.Client
  248. var err error // 声明 err 变量
  249. if info.ChannelSetting.Proxy != "" {
  250. client, err = service.NewProxyHttpClient(info.ChannelSetting.Proxy)
  251. if err != nil {
  252. return nil, fmt.Errorf("new proxy http client failed: %w", err)
  253. }
  254. } else {
  255. client = service.GetHttpClient()
  256. }
  257. resp, err := client.Do(req)
  258. if err != nil { // 增加对 client.Do(req) 返回错误的检查
  259. return nil, fmt.Errorf("client.Do failed: %w", err)
  260. }
  261. // _ = resp.Body.Close()
  262. return resp, nil
  263. }