relay-gemini-native.go 5.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180
  1. package gemini
  2. import (
  3. "io"
  4. "net/http"
  5. "one-api/common"
  6. "one-api/dto"
  7. relaycommon "one-api/relay/common"
  8. "one-api/relay/helper"
  9. "one-api/service"
  10. "one-api/types"
  11. "strings"
  12. "github.com/pkg/errors"
  13. "github.com/gin-gonic/gin"
  14. )
  15. func GeminiTextGenerationHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
  16. defer common.CloseResponseBodyGracefully(resp)
  17. // 读取响应体
  18. responseBody, err := io.ReadAll(resp.Body)
  19. if err != nil {
  20. return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
  21. }
  22. if common.DebugEnabled {
  23. println(string(responseBody))
  24. }
  25. // 解析为 Gemini 原生响应格式
  26. var geminiResponse dto.GeminiChatResponse
  27. err = common.Unmarshal(responseBody, &geminiResponse)
  28. if err != nil {
  29. return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
  30. }
  31. // 计算使用量(基于 UsageMetadata)
  32. usage := dto.Usage{
  33. PromptTokens: geminiResponse.UsageMetadata.PromptTokenCount,
  34. CompletionTokens: geminiResponse.UsageMetadata.CandidatesTokenCount + geminiResponse.UsageMetadata.ThoughtsTokenCount,
  35. TotalTokens: geminiResponse.UsageMetadata.TotalTokenCount,
  36. }
  37. usage.CompletionTokenDetails.ReasoningTokens = geminiResponse.UsageMetadata.ThoughtsTokenCount
  38. for _, detail := range geminiResponse.UsageMetadata.PromptTokensDetails {
  39. if detail.Modality == "AUDIO" {
  40. usage.PromptTokensDetails.AudioTokens = detail.TokenCount
  41. } else if detail.Modality == "TEXT" {
  42. usage.PromptTokensDetails.TextTokens = detail.TokenCount
  43. }
  44. }
  45. // 直接返回 Gemini 原生格式的 JSON 响应
  46. jsonResponse, err := common.Marshal(geminiResponse)
  47. if err != nil {
  48. return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
  49. }
  50. common.IOCopyBytesGracefully(c, resp, jsonResponse)
  51. return &usage, nil
  52. }
  53. func NativeGeminiEmbeddingHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.Usage, *types.NewAPIError) {
  54. defer common.CloseResponseBodyGracefully(resp)
  55. responseBody, err := io.ReadAll(resp.Body)
  56. if err != nil {
  57. return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
  58. }
  59. if common.DebugEnabled {
  60. println(string(responseBody))
  61. }
  62. usage := &dto.Usage{
  63. PromptTokens: info.PromptTokens,
  64. TotalTokens: info.PromptTokens,
  65. }
  66. if info.IsGeminiBatchEmbedding {
  67. var geminiResponse dto.GeminiBatchEmbeddingResponse
  68. err = common.Unmarshal(responseBody, &geminiResponse)
  69. if err != nil {
  70. return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
  71. }
  72. } else {
  73. var geminiResponse dto.GeminiEmbeddingResponse
  74. err = common.Unmarshal(responseBody, &geminiResponse)
  75. if err != nil {
  76. return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
  77. }
  78. }
  79. common.IOCopyBytesGracefully(c, resp, responseBody)
  80. return usage, nil
  81. }
  82. func GeminiTextGenerationStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
  83. var usage = &dto.Usage{}
  84. var imageCount int
  85. helper.SetEventStreamHeaders(c)
  86. responseText := strings.Builder{}
  87. helper.StreamScannerHandler(c, resp, info, func(data string) bool {
  88. var geminiResponse dto.GeminiChatResponse
  89. err := common.UnmarshalJsonStr(data, &geminiResponse)
  90. if err != nil {
  91. common.LogError(c, "error unmarshalling stream response: "+err.Error())
  92. return false
  93. }
  94. // 统计图片数量
  95. for _, candidate := range geminiResponse.Candidates {
  96. for _, part := range candidate.Content.Parts {
  97. if part.InlineData != nil && part.InlineData.MimeType != "" {
  98. imageCount++
  99. }
  100. if part.Text != "" {
  101. responseText.WriteString(part.Text)
  102. }
  103. }
  104. }
  105. // 更新使用量统计
  106. if geminiResponse.UsageMetadata.TotalTokenCount != 0 {
  107. usage.PromptTokens = geminiResponse.UsageMetadata.PromptTokenCount
  108. usage.CompletionTokens = geminiResponse.UsageMetadata.CandidatesTokenCount + geminiResponse.UsageMetadata.ThoughtsTokenCount
  109. usage.TotalTokens = geminiResponse.UsageMetadata.TotalTokenCount
  110. usage.CompletionTokenDetails.ReasoningTokens = geminiResponse.UsageMetadata.ThoughtsTokenCount
  111. for _, detail := range geminiResponse.UsageMetadata.PromptTokensDetails {
  112. if detail.Modality == "AUDIO" {
  113. usage.PromptTokensDetails.AudioTokens = detail.TokenCount
  114. } else if detail.Modality == "TEXT" {
  115. usage.PromptTokensDetails.TextTokens = detail.TokenCount
  116. }
  117. }
  118. }
  119. // 直接发送 GeminiChatResponse 响应
  120. err = helper.StringData(c, data)
  121. if err != nil {
  122. common.LogError(c, err.Error())
  123. }
  124. info.SendResponseCount++
  125. return true
  126. })
  127. if info.SendResponseCount == 0 {
  128. return nil, types.NewOpenAIError(errors.New("no response received from Gemini API"), types.ErrorCodeEmptyResponse, http.StatusInternalServerError)
  129. }
  130. if imageCount != 0 {
  131. if usage.CompletionTokens == 0 {
  132. usage.CompletionTokens = imageCount * 258
  133. }
  134. }
  135. // 如果usage.CompletionTokens为0,则使用本地统计的completion tokens
  136. if usage.CompletionTokens == 0 {
  137. str := responseText.String()
  138. if len(str) > 0 {
  139. usage = service.ResponseText2Usage(responseText.String(), info.UpstreamModelName, info.PromptTokens)
  140. } else {
  141. // 空补全,不需要使用量
  142. usage = &dto.Usage{}
  143. }
  144. }
  145. // 移除流式响应结尾的[Done],因为Gemini API没有发送Done的行为
  146. //helper.Done(c)
  147. return usage, nil
  148. }