helper.go 6.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188
  1. package openai
  2. import (
  3. "encoding/json"
  4. "one-api/common"
  5. "one-api/dto"
  6. relaycommon "one-api/relay/common"
  7. relayconstant "one-api/relay/constant"
  8. "one-api/relay/helper"
  9. "one-api/service"
  10. "strings"
  11. "github.com/gin-gonic/gin"
  12. )
  13. // 辅助函数
  14. func handleStreamFormat(c *gin.Context, info *relaycommon.RelayInfo, data string, forceFormat bool, thinkToContent bool) error {
  15. info.SendResponseCount++
  16. switch info.RelayFormat {
  17. case relaycommon.RelayFormatOpenAI:
  18. return sendStreamData(c, info, data, forceFormat, thinkToContent)
  19. case relaycommon.RelayFormatClaude:
  20. return handleClaudeFormat(c, data, info)
  21. }
  22. return nil
  23. }
  24. func handleClaudeFormat(c *gin.Context, data string, info *relaycommon.RelayInfo) error {
  25. var streamResponse dto.ChatCompletionsStreamResponse
  26. if err := json.Unmarshal(common.StringToByteSlice(data), &streamResponse); err != nil {
  27. return err
  28. }
  29. claudeResponses := service.StreamResponseOpenAI2Claude(&streamResponse, info)
  30. for _, resp := range claudeResponses {
  31. helper.ClaudeData(c, *resp)
  32. }
  33. return nil
  34. }
  35. func processStreamResponse(item string, responseTextBuilder *strings.Builder, toolCount *int) error {
  36. var streamResponse dto.ChatCompletionsStreamResponse
  37. if err := json.Unmarshal(common.StringToByteSlice(item), &streamResponse); err != nil {
  38. return err
  39. }
  40. for _, choice := range streamResponse.Choices {
  41. responseTextBuilder.WriteString(choice.Delta.GetContentString())
  42. responseTextBuilder.WriteString(choice.Delta.GetReasoningContent())
  43. if choice.Delta.ToolCalls != nil {
  44. if len(choice.Delta.ToolCalls) > *toolCount {
  45. *toolCount = len(choice.Delta.ToolCalls)
  46. }
  47. for _, tool := range choice.Delta.ToolCalls {
  48. responseTextBuilder.WriteString(tool.Function.Name)
  49. responseTextBuilder.WriteString(tool.Function.Arguments)
  50. }
  51. }
  52. }
  53. return nil
  54. }
  55. func processTokens(relayMode int, streamItems []string, responseTextBuilder *strings.Builder, toolCount *int) error {
  56. streamResp := "[" + strings.Join(streamItems, ",") + "]"
  57. switch relayMode {
  58. case relayconstant.RelayModeChatCompletions:
  59. return processChatCompletions(streamResp, streamItems, responseTextBuilder, toolCount)
  60. case relayconstant.RelayModeCompletions:
  61. return processCompletions(streamResp, streamItems, responseTextBuilder)
  62. }
  63. return nil
  64. }
  65. func processChatCompletions(streamResp string, streamItems []string, responseTextBuilder *strings.Builder, toolCount *int) error {
  66. var streamResponses []dto.ChatCompletionsStreamResponse
  67. if err := json.Unmarshal(common.StringToByteSlice(streamResp), &streamResponses); err != nil {
  68. // 一次性解析失败,逐个解析
  69. common.SysError("error unmarshalling stream response: " + err.Error())
  70. for _, item := range streamItems {
  71. if err := processStreamResponse(item, responseTextBuilder, toolCount); err != nil {
  72. common.SysError("error processing stream response: " + err.Error())
  73. }
  74. }
  75. return nil
  76. }
  77. // 批量处理所有响应
  78. for _, streamResponse := range streamResponses {
  79. for _, choice := range streamResponse.Choices {
  80. responseTextBuilder.WriteString(choice.Delta.GetContentString())
  81. responseTextBuilder.WriteString(choice.Delta.GetReasoningContent())
  82. if choice.Delta.ToolCalls != nil {
  83. if len(choice.Delta.ToolCalls) > *toolCount {
  84. *toolCount = len(choice.Delta.ToolCalls)
  85. }
  86. for _, tool := range choice.Delta.ToolCalls {
  87. responseTextBuilder.WriteString(tool.Function.Name)
  88. responseTextBuilder.WriteString(tool.Function.Arguments)
  89. }
  90. }
  91. }
  92. }
  93. return nil
  94. }
  95. func processCompletions(streamResp string, streamItems []string, responseTextBuilder *strings.Builder) error {
  96. var streamResponses []dto.CompletionsStreamResponse
  97. if err := json.Unmarshal(common.StringToByteSlice(streamResp), &streamResponses); err != nil {
  98. // 一次性解析失败,逐个解析
  99. common.SysError("error unmarshalling stream response: " + err.Error())
  100. for _, item := range streamItems {
  101. var streamResponse dto.CompletionsStreamResponse
  102. if err := json.Unmarshal(common.StringToByteSlice(item), &streamResponse); err != nil {
  103. continue
  104. }
  105. for _, choice := range streamResponse.Choices {
  106. responseTextBuilder.WriteString(choice.Text)
  107. }
  108. }
  109. return nil
  110. }
  111. // 批量处理所有响应
  112. for _, streamResponse := range streamResponses {
  113. for _, choice := range streamResponse.Choices {
  114. responseTextBuilder.WriteString(choice.Text)
  115. }
  116. }
  117. return nil
  118. }
  119. func handleLastResponse(lastStreamData string, responseId *string, createAt *int64,
  120. systemFingerprint *string, model *string, usage **dto.Usage,
  121. containStreamUsage *bool, info *relaycommon.RelayInfo,
  122. shouldSendLastResp *bool) error {
  123. var lastStreamResponse dto.ChatCompletionsStreamResponse
  124. if err := json.Unmarshal(common.StringToByteSlice(lastStreamData), &lastStreamResponse); err != nil {
  125. return err
  126. }
  127. *responseId = lastStreamResponse.Id
  128. *createAt = lastStreamResponse.Created
  129. *systemFingerprint = lastStreamResponse.GetSystemFingerprint()
  130. *model = lastStreamResponse.Model
  131. if service.ValidUsage(lastStreamResponse.Usage) {
  132. *containStreamUsage = true
  133. *usage = lastStreamResponse.Usage
  134. if !info.ShouldIncludeUsage {
  135. *shouldSendLastResp = false
  136. }
  137. }
  138. return nil
  139. }
  140. func handleFinalResponse(c *gin.Context, info *relaycommon.RelayInfo, lastStreamData string,
  141. responseId string, createAt int64, model string, systemFingerprint string,
  142. usage *dto.Usage, containStreamUsage bool) {
  143. switch info.RelayFormat {
  144. case relaycommon.RelayFormatOpenAI:
  145. if info.ShouldIncludeUsage && !containStreamUsage {
  146. response := helper.GenerateFinalUsageResponse(responseId, createAt, model, *usage)
  147. response.SetSystemFingerprint(systemFingerprint)
  148. helper.ObjectData(c, response)
  149. }
  150. helper.Done(c)
  151. case relaycommon.RelayFormatClaude:
  152. var streamResponse dto.ChatCompletionsStreamResponse
  153. if err := json.Unmarshal(common.StringToByteSlice(lastStreamData), &streamResponse); err != nil {
  154. common.SysError("error unmarshalling stream response: " + err.Error())
  155. return
  156. }
  157. if !containStreamUsage {
  158. streamResponse.Usage = usage
  159. }
  160. claudeResponses := service.StreamResponseOpenAI2Claude(&streamResponse, info)
  161. for _, resp := range claudeResponses {
  162. helper.ClaudeData(c, *resp)
  163. }
  164. }
  165. }