helper.go 8.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265
  1. package openai
  2. import (
  3. "encoding/json"
  4. "errors"
  5. "github.com/samber/lo"
  6. "net/http"
  7. "one-api/common"
  8. "one-api/dto"
  9. relaycommon "one-api/relay/common"
  10. relayconstant "one-api/relay/constant"
  11. "one-api/relay/helper"
  12. "one-api/service"
  13. "strings"
  14. "github.com/gin-gonic/gin"
  15. )
  16. // 辅助函数
  17. func HandleStreamFormat(c *gin.Context, info *relaycommon.RelayInfo, data string, forceFormat bool, thinkToContent bool) error {
  18. info.SendResponseCount++
  19. switch info.RelayFormat {
  20. case relaycommon.RelayFormatOpenAI:
  21. return sendStreamData(c, info, data, forceFormat, thinkToContent)
  22. case relaycommon.RelayFormatClaude:
  23. return handleClaudeFormat(c, data, info)
  24. case relaycommon.RelayFormatGemini:
  25. return handleGeminiFormat(c, data, info)
  26. }
  27. return nil
  28. }
  29. func handleClaudeFormat(c *gin.Context, data string, info *relaycommon.RelayInfo) error {
  30. var streamResponse dto.ChatCompletionsStreamResponse
  31. if err := common.Unmarshal(common.StringToByteSlice(data), &streamResponse); err != nil {
  32. return err
  33. }
  34. if streamResponse.Usage != nil {
  35. info.ClaudeConvertInfo.Usage = streamResponse.Usage
  36. }
  37. claudeResponses := service.StreamResponseOpenAI2Claude(&streamResponse, info)
  38. for _, resp := range claudeResponses {
  39. helper.ClaudeData(c, *resp)
  40. }
  41. return nil
  42. }
  43. func handleGeminiFormat(c *gin.Context, data string, info *relaycommon.RelayInfo) error {
  44. var streamResponse dto.ChatCompletionsStreamResponse
  45. if err := common.Unmarshal(common.StringToByteSlice(data), &streamResponse); err != nil {
  46. common.LogError(c, "failed to unmarshal stream response: "+err.Error())
  47. return err
  48. }
  49. geminiResponse := service.StreamResponseOpenAI2Gemini(&streamResponse, info)
  50. // 如果返回 nil,表示没有实际内容,跳过发送
  51. if geminiResponse == nil {
  52. return nil
  53. }
  54. geminiResponseStr, err := common.Marshal(geminiResponse)
  55. if err != nil {
  56. common.LogError(c, "failed to marshal gemini response: "+err.Error())
  57. return err
  58. }
  59. // send gemini format response
  60. c.Render(-1, common.CustomEvent{Data: "data: " + string(geminiResponseStr)})
  61. if flusher, ok := c.Writer.(http.Flusher); ok {
  62. flusher.Flush()
  63. } else {
  64. return errors.New("streaming error: flusher not found")
  65. }
  66. return nil
  67. }
  68. func ProcessStreamResponse(streamResponse dto.ChatCompletionsStreamResponse, responseTextBuilder *strings.Builder, toolCount *int) error {
  69. for _, choice := range streamResponse.Choices {
  70. responseTextBuilder.WriteString(choice.Delta.GetContentString())
  71. responseTextBuilder.WriteString(choice.Delta.GetReasoningContent())
  72. if choice.Delta.ToolCalls != nil {
  73. if len(choice.Delta.ToolCalls) > *toolCount {
  74. *toolCount = len(choice.Delta.ToolCalls)
  75. }
  76. for _, tool := range choice.Delta.ToolCalls {
  77. responseTextBuilder.WriteString(tool.Function.Name)
  78. responseTextBuilder.WriteString(tool.Function.Arguments)
  79. }
  80. }
  81. }
  82. return nil
  83. }
  84. func processTokens(relayMode int, streamItems []string, responseTextBuilder *strings.Builder, toolCount *int) error {
  85. streamResp := "[" + strings.Join(streamItems, ",") + "]"
  86. switch relayMode {
  87. case relayconstant.RelayModeChatCompletions:
  88. return processChatCompletions(streamResp, streamItems, responseTextBuilder, toolCount)
  89. case relayconstant.RelayModeCompletions:
  90. return processCompletions(streamResp, streamItems, responseTextBuilder)
  91. }
  92. return nil
  93. }
  94. func processChatCompletions(streamResp string, streamItems []string, responseTextBuilder *strings.Builder, toolCount *int) error {
  95. var streamResponses []dto.ChatCompletionsStreamResponse
  96. if err := json.Unmarshal(common.StringToByteSlice(streamResp), &streamResponses); err != nil {
  97. // 一次性解析失败,逐个解析
  98. common.SysError("error unmarshalling stream response: " + err.Error())
  99. for _, item := range streamItems {
  100. var streamResponse dto.ChatCompletionsStreamResponse
  101. if err := json.Unmarshal(common.StringToByteSlice(item), &streamResponse); err != nil {
  102. return err
  103. }
  104. if err := ProcessStreamResponse(streamResponse, responseTextBuilder, toolCount); err != nil {
  105. common.SysError("error processing stream response: " + err.Error())
  106. }
  107. }
  108. return nil
  109. }
  110. // 批量处理所有响应
  111. for _, streamResponse := range streamResponses {
  112. for _, choice := range streamResponse.Choices {
  113. responseTextBuilder.WriteString(choice.Delta.GetContentString())
  114. responseTextBuilder.WriteString(choice.Delta.GetReasoningContent())
  115. if choice.Delta.ToolCalls != nil {
  116. if len(choice.Delta.ToolCalls) > *toolCount {
  117. *toolCount = len(choice.Delta.ToolCalls)
  118. }
  119. for _, tool := range choice.Delta.ToolCalls {
  120. responseTextBuilder.WriteString(tool.Function.Name)
  121. responseTextBuilder.WriteString(tool.Function.Arguments)
  122. }
  123. }
  124. }
  125. }
  126. return nil
  127. }
  128. func processCompletions(streamResp string, streamItems []string, responseTextBuilder *strings.Builder) error {
  129. var streamResponses []dto.CompletionsStreamResponse
  130. if err := json.Unmarshal(common.StringToByteSlice(streamResp), &streamResponses); err != nil {
  131. // 一次性解析失败,逐个解析
  132. common.SysError("error unmarshalling stream response: " + err.Error())
  133. for _, item := range streamItems {
  134. var streamResponse dto.CompletionsStreamResponse
  135. if err := json.Unmarshal(common.StringToByteSlice(item), &streamResponse); err != nil {
  136. continue
  137. }
  138. for _, choice := range streamResponse.Choices {
  139. responseTextBuilder.WriteString(choice.Text)
  140. }
  141. }
  142. return nil
  143. }
  144. // 批量处理所有响应
  145. for _, streamResponse := range streamResponses {
  146. for _, choice := range streamResponse.Choices {
  147. responseTextBuilder.WriteString(choice.Text)
  148. }
  149. }
  150. return nil
  151. }
  152. func handleLastResponse(lastStreamData string, responseId *string, createAt *int64,
  153. systemFingerprint *string, model *string, usage **dto.Usage,
  154. containStreamUsage *bool, info *relaycommon.RelayInfo,
  155. shouldSendLastResp *bool) error {
  156. var lastStreamResponse dto.ChatCompletionsStreamResponse
  157. if err := json.Unmarshal(common.StringToByteSlice(lastStreamData), &lastStreamResponse); err != nil {
  158. return err
  159. }
  160. *responseId = lastStreamResponse.Id
  161. *createAt = lastStreamResponse.Created
  162. *systemFingerprint = lastStreamResponse.GetSystemFingerprint()
  163. *model = lastStreamResponse.Model
  164. if service.ValidUsage(lastStreamResponse.Usage) {
  165. *containStreamUsage = true
  166. *usage = lastStreamResponse.Usage
  167. if !info.ShouldIncludeUsage {
  168. *shouldSendLastResp = lo.SomeBy(lastStreamResponse.Choices, func(choice dto.ChatCompletionsStreamResponseChoice) bool {
  169. return choice.Delta.GetContentString() != "" || choice.Delta.GetReasoningContent() != ""
  170. })
  171. }
  172. }
  173. return nil
  174. }
  175. func HandleFinalResponse(c *gin.Context, info *relaycommon.RelayInfo, lastStreamData string,
  176. responseId string, createAt int64, model string, systemFingerprint string,
  177. usage *dto.Usage, containStreamUsage bool) {
  178. switch info.RelayFormat {
  179. case relaycommon.RelayFormatOpenAI:
  180. if info.ShouldIncludeUsage && !containStreamUsage {
  181. response := helper.GenerateFinalUsageResponse(responseId, createAt, model, *usage)
  182. response.SetSystemFingerprint(systemFingerprint)
  183. helper.ObjectData(c, response)
  184. }
  185. helper.Done(c)
  186. case relaycommon.RelayFormatClaude:
  187. info.ClaudeConvertInfo.Done = true
  188. var streamResponse dto.ChatCompletionsStreamResponse
  189. if err := common.Unmarshal(common.StringToByteSlice(lastStreamData), &streamResponse); err != nil {
  190. common.SysError("error unmarshalling stream response: " + err.Error())
  191. return
  192. }
  193. info.ClaudeConvertInfo.Usage = usage
  194. claudeResponses := service.StreamResponseOpenAI2Claude(&streamResponse, info)
  195. for _, resp := range claudeResponses {
  196. _ = helper.ClaudeData(c, *resp)
  197. }
  198. case relaycommon.RelayFormatGemini:
  199. var streamResponse dto.ChatCompletionsStreamResponse
  200. if err := common.Unmarshal(common.StringToByteSlice(lastStreamData), &streamResponse); err != nil {
  201. common.SysError("error unmarshalling stream response: " + err.Error())
  202. return
  203. }
  204. // 这里处理的是 openai 最后一个流响应,其 delta 为空,有 finish_reason 字段
  205. // 因此相比较于 google 官方的流响应,由 openai 转换而来会多一个 parts 为空,finishReason 为 STOP 的响应
  206. // 而包含最后一段文本输出的响应(倒数第二个)的 finishReason 为 null
  207. // 暂不知是否有程序会不兼容。
  208. geminiResponse := service.StreamResponseOpenAI2Gemini(&streamResponse, info)
  209. // openai 流响应开头的空数据
  210. if geminiResponse == nil {
  211. return
  212. }
  213. geminiResponseStr, err := common.Marshal(geminiResponse)
  214. if err != nil {
  215. common.SysError("error marshalling gemini response: " + err.Error())
  216. return
  217. }
  218. // 发送最终的 Gemini 响应
  219. c.Render(-1, common.CustomEvent{Data: "data: " + string(geminiResponseStr)})
  220. if flusher, ok := c.Writer.(http.Flusher); ok {
  221. flusher.Flush()
  222. }
  223. }
  224. }
  225. func sendResponsesStreamData(c *gin.Context, streamResponse dto.ResponsesStreamResponse, data string) {
  226. if data == "" {
  227. return
  228. }
  229. helper.ResponseChunkData(c, streamResponse, data)
  230. }