relay-gemini-native.go 4.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143
  1. package gemini
  2. import (
  3. "github.com/pkg/errors"
  4. "io"
  5. "net/http"
  6. "one-api/common"
  7. "one-api/dto"
  8. relaycommon "one-api/relay/common"
  9. "one-api/relay/helper"
  10. "one-api/service"
  11. "one-api/types"
  12. "strings"
  13. "github.com/gin-gonic/gin"
  14. )
  15. func GeminiTextGenerationHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
  16. defer common.CloseResponseBodyGracefully(resp)
  17. // 读取响应体
  18. responseBody, err := io.ReadAll(resp.Body)
  19. if err != nil {
  20. return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
  21. }
  22. if common.DebugEnabled {
  23. println(string(responseBody))
  24. }
  25. // 解析为 Gemini 原生响应格式
  26. var geminiResponse GeminiChatResponse
  27. err = common.Unmarshal(responseBody, &geminiResponse)
  28. if err != nil {
  29. return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
  30. }
  31. // 计算使用量(基于 UsageMetadata)
  32. usage := dto.Usage{
  33. PromptTokens: geminiResponse.UsageMetadata.PromptTokenCount,
  34. CompletionTokens: geminiResponse.UsageMetadata.CandidatesTokenCount + geminiResponse.UsageMetadata.ThoughtsTokenCount,
  35. TotalTokens: geminiResponse.UsageMetadata.TotalTokenCount,
  36. }
  37. usage.CompletionTokenDetails.ReasoningTokens = geminiResponse.UsageMetadata.ThoughtsTokenCount
  38. for _, detail := range geminiResponse.UsageMetadata.PromptTokensDetails {
  39. if detail.Modality == "AUDIO" {
  40. usage.PromptTokensDetails.AudioTokens = detail.TokenCount
  41. } else if detail.Modality == "TEXT" {
  42. usage.PromptTokensDetails.TextTokens = detail.TokenCount
  43. }
  44. }
  45. // 直接返回 Gemini 原生格式的 JSON 响应
  46. jsonResponse, err := common.Marshal(geminiResponse)
  47. if err != nil {
  48. return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
  49. }
  50. common.IOCopyBytesGracefully(c, resp, jsonResponse)
  51. return &usage, nil
  52. }
  53. func GeminiTextGenerationStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
  54. var usage = &dto.Usage{}
  55. var imageCount int
  56. helper.SetEventStreamHeaders(c)
  57. responseText := strings.Builder{}
  58. helper.StreamScannerHandler(c, resp, info, func(data string) bool {
  59. var geminiResponse GeminiChatResponse
  60. err := common.UnmarshalJsonStr(data, &geminiResponse)
  61. if err != nil {
  62. common.LogError(c, "error unmarshalling stream response: "+err.Error())
  63. return false
  64. }
  65. // 统计图片数量
  66. for _, candidate := range geminiResponse.Candidates {
  67. for _, part := range candidate.Content.Parts {
  68. if part.InlineData != nil && part.InlineData.MimeType != "" {
  69. imageCount++
  70. }
  71. if part.Text != "" {
  72. responseText.WriteString(part.Text)
  73. }
  74. }
  75. }
  76. // 更新使用量统计
  77. if geminiResponse.UsageMetadata.TotalTokenCount != 0 {
  78. usage.PromptTokens = geminiResponse.UsageMetadata.PromptTokenCount
  79. usage.CompletionTokens = geminiResponse.UsageMetadata.CandidatesTokenCount + geminiResponse.UsageMetadata.ThoughtsTokenCount
  80. usage.TotalTokens = geminiResponse.UsageMetadata.TotalTokenCount
  81. usage.CompletionTokenDetails.ReasoningTokens = geminiResponse.UsageMetadata.ThoughtsTokenCount
  82. for _, detail := range geminiResponse.UsageMetadata.PromptTokensDetails {
  83. if detail.Modality == "AUDIO" {
  84. usage.PromptTokensDetails.AudioTokens = detail.TokenCount
  85. } else if detail.Modality == "TEXT" {
  86. usage.PromptTokensDetails.TextTokens = detail.TokenCount
  87. }
  88. }
  89. }
  90. // 直接发送 GeminiChatResponse 响应
  91. err = helper.StringData(c, data)
  92. if err != nil {
  93. common.LogError(c, err.Error())
  94. }
  95. info.SendResponseCount++
  96. return true
  97. })
  98. if info.SendResponseCount == 0 {
  99. return nil, types.NewOpenAIError(errors.New("no response received from Gemini API"), types.ErrorCodeEmptyResponse, http.StatusInternalServerError)
  100. }
  101. if imageCount != 0 {
  102. if usage.CompletionTokens == 0 {
  103. usage.CompletionTokens = imageCount * 258
  104. }
  105. }
  106. // 如果usage.CompletionTokens为0,则使用本地统计的completion tokens
  107. if usage.CompletionTokens == 0 {
  108. str := responseText.String()
  109. if len(str) > 0 {
  110. usage = service.ResponseText2Usage(responseText.String(), info.UpstreamModelName, info.PromptTokens)
  111. } else {
  112. // 空补全,不需要使用量
  113. usage = &dto.Usage{}
  114. }
  115. }
  116. // 移除流式响应结尾的[Done],因为Gemini API没有发送Done的行为
  117. //helper.Done(c)
  118. return usage, nil
  119. }