helper.go 6.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190
  1. package openai
  2. import (
  3. "encoding/json"
  4. "one-api/common"
  5. "one-api/dto"
  6. relaycommon "one-api/relay/common"
  7. relayconstant "one-api/relay/constant"
  8. "one-api/relay/helper"
  9. "one-api/service"
  10. "strings"
  11. "github.com/gin-gonic/gin"
  12. )
  13. // 辅助函数
  14. func handleStreamFormat(c *gin.Context, info *relaycommon.RelayInfo, data string, forceFormat bool, thinkToContent bool) error {
  15. info.SendResponseCount++
  16. switch info.RelayFormat {
  17. case relaycommon.RelayFormatOpenAI:
  18. return sendStreamData(c, info, data, forceFormat, thinkToContent)
  19. case relaycommon.RelayFormatClaude:
  20. return handleClaudeFormat(c, data, info)
  21. }
  22. return nil
  23. }
  24. func handleClaudeFormat(c *gin.Context, data string, info *relaycommon.RelayInfo) error {
  25. var streamResponse dto.ChatCompletionsStreamResponse
  26. if err := json.Unmarshal(common.StringToByteSlice(data), &streamResponse); err != nil {
  27. return err
  28. }
  29. if streamResponse.Usage != nil {
  30. info.ClaudeConvertInfo.Usage = streamResponse.Usage
  31. }
  32. claudeResponses := service.StreamResponseOpenAI2Claude(&streamResponse, info)
  33. for _, resp := range claudeResponses {
  34. helper.ClaudeData(c, *resp)
  35. }
  36. return nil
  37. }
  38. func processStreamResponse(item string, responseTextBuilder *strings.Builder, toolCount *int) error {
  39. var streamResponse dto.ChatCompletionsStreamResponse
  40. if err := json.Unmarshal(common.StringToByteSlice(item), &streamResponse); err != nil {
  41. return err
  42. }
  43. for _, choice := range streamResponse.Choices {
  44. responseTextBuilder.WriteString(choice.Delta.GetContentString())
  45. responseTextBuilder.WriteString(choice.Delta.GetReasoningContent())
  46. if choice.Delta.ToolCalls != nil {
  47. if len(choice.Delta.ToolCalls) > *toolCount {
  48. *toolCount = len(choice.Delta.ToolCalls)
  49. }
  50. for _, tool := range choice.Delta.ToolCalls {
  51. responseTextBuilder.WriteString(tool.Function.Name)
  52. responseTextBuilder.WriteString(tool.Function.Arguments)
  53. }
  54. }
  55. }
  56. return nil
  57. }
  58. func processTokens(relayMode int, streamItems []string, responseTextBuilder *strings.Builder, toolCount *int) error {
  59. streamResp := "[" + strings.Join(streamItems, ",") + "]"
  60. switch relayMode {
  61. case relayconstant.RelayModeChatCompletions:
  62. return processChatCompletions(streamResp, streamItems, responseTextBuilder, toolCount)
  63. case relayconstant.RelayModeCompletions:
  64. return processCompletions(streamResp, streamItems, responseTextBuilder)
  65. }
  66. return nil
  67. }
  68. func processChatCompletions(streamResp string, streamItems []string, responseTextBuilder *strings.Builder, toolCount *int) error {
  69. var streamResponses []dto.ChatCompletionsStreamResponse
  70. if err := json.Unmarshal(common.StringToByteSlice(streamResp), &streamResponses); err != nil {
  71. // 一次性解析失败,逐个解析
  72. common.SysError("error unmarshalling stream response: " + err.Error())
  73. for _, item := range streamItems {
  74. if err := processStreamResponse(item, responseTextBuilder, toolCount); err != nil {
  75. common.SysError("error processing stream response: " + err.Error())
  76. }
  77. }
  78. return nil
  79. }
  80. // 批量处理所有响应
  81. for _, streamResponse := range streamResponses {
  82. for _, choice := range streamResponse.Choices {
  83. responseTextBuilder.WriteString(choice.Delta.GetContentString())
  84. responseTextBuilder.WriteString(choice.Delta.GetReasoningContent())
  85. if choice.Delta.ToolCalls != nil {
  86. if len(choice.Delta.ToolCalls) > *toolCount {
  87. *toolCount = len(choice.Delta.ToolCalls)
  88. }
  89. for _, tool := range choice.Delta.ToolCalls {
  90. responseTextBuilder.WriteString(tool.Function.Name)
  91. responseTextBuilder.WriteString(tool.Function.Arguments)
  92. }
  93. }
  94. }
  95. }
  96. return nil
  97. }
  98. func processCompletions(streamResp string, streamItems []string, responseTextBuilder *strings.Builder) error {
  99. var streamResponses []dto.CompletionsStreamResponse
  100. if err := json.Unmarshal(common.StringToByteSlice(streamResp), &streamResponses); err != nil {
  101. // 一次性解析失败,逐个解析
  102. common.SysError("error unmarshalling stream response: " + err.Error())
  103. for _, item := range streamItems {
  104. var streamResponse dto.CompletionsStreamResponse
  105. if err := json.Unmarshal(common.StringToByteSlice(item), &streamResponse); err != nil {
  106. continue
  107. }
  108. for _, choice := range streamResponse.Choices {
  109. responseTextBuilder.WriteString(choice.Text)
  110. }
  111. }
  112. return nil
  113. }
  114. // 批量处理所有响应
  115. for _, streamResponse := range streamResponses {
  116. for _, choice := range streamResponse.Choices {
  117. responseTextBuilder.WriteString(choice.Text)
  118. }
  119. }
  120. return nil
  121. }
  122. func handleLastResponse(lastStreamData string, responseId *string, createAt *int64,
  123. systemFingerprint *string, model *string, usage **dto.Usage,
  124. containStreamUsage *bool, info *relaycommon.RelayInfo,
  125. shouldSendLastResp *bool) error {
  126. var lastStreamResponse dto.ChatCompletionsStreamResponse
  127. if err := json.Unmarshal(common.StringToByteSlice(lastStreamData), &lastStreamResponse); err != nil {
  128. return err
  129. }
  130. *responseId = lastStreamResponse.Id
  131. *createAt = lastStreamResponse.Created
  132. *systemFingerprint = lastStreamResponse.GetSystemFingerprint()
  133. *model = lastStreamResponse.Model
  134. if service.ValidUsage(lastStreamResponse.Usage) {
  135. *containStreamUsage = true
  136. *usage = lastStreamResponse.Usage
  137. if !info.ShouldIncludeUsage {
  138. *shouldSendLastResp = false
  139. }
  140. }
  141. return nil
  142. }
  143. func handleFinalResponse(c *gin.Context, info *relaycommon.RelayInfo, lastStreamData string,
  144. responseId string, createAt int64, model string, systemFingerprint string,
  145. usage *dto.Usage, containStreamUsage bool) {
  146. switch info.RelayFormat {
  147. case relaycommon.RelayFormatOpenAI:
  148. if info.ShouldIncludeUsage && !containStreamUsage {
  149. response := helper.GenerateFinalUsageResponse(responseId, createAt, model, *usage)
  150. response.SetSystemFingerprint(systemFingerprint)
  151. helper.ObjectData(c, response)
  152. }
  153. helper.Done(c)
  154. case relaycommon.RelayFormatClaude:
  155. info.ClaudeConvertInfo.Done = true
  156. var streamResponse dto.ChatCompletionsStreamResponse
  157. if err := json.Unmarshal(common.StringToByteSlice(lastStreamData), &streamResponse); err != nil {
  158. common.SysError("error unmarshalling stream response: " + err.Error())
  159. return
  160. }
  161. info.ClaudeConvertInfo.Usage = usage
  162. claudeResponses := service.StreamResponseOpenAI2Claude(&streamResponse, info)
  163. for _, resp := range claudeResponses {
  164. helper.ClaudeData(c, *resp)
  165. }
  166. }
  167. }