relay-gemini-native.go 4.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138
  1. package gemini
  2. import (
  3. "io"
  4. "net/http"
  5. "one-api/common"
  6. "one-api/dto"
  7. relaycommon "one-api/relay/common"
  8. "one-api/relay/helper"
  9. "one-api/service"
  10. "one-api/types"
  11. "strings"
  12. "github.com/gin-gonic/gin"
  13. )
  14. func GeminiTextGenerationHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
  15. defer common.CloseResponseBodyGracefully(resp)
  16. // 读取响应体
  17. responseBody, err := io.ReadAll(resp.Body)
  18. if err != nil {
  19. return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
  20. }
  21. if common.DebugEnabled {
  22. println(string(responseBody))
  23. }
  24. // 解析为 Gemini 原生响应格式
  25. var geminiResponse GeminiChatResponse
  26. err = common.Unmarshal(responseBody, &geminiResponse)
  27. if err != nil {
  28. return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
  29. }
  30. // 计算使用量(基于 UsageMetadata)
  31. usage := dto.Usage{
  32. PromptTokens: geminiResponse.UsageMetadata.PromptTokenCount,
  33. CompletionTokens: geminiResponse.UsageMetadata.CandidatesTokenCount + geminiResponse.UsageMetadata.ThoughtsTokenCount,
  34. TotalTokens: geminiResponse.UsageMetadata.TotalTokenCount,
  35. }
  36. usage.CompletionTokenDetails.ReasoningTokens = geminiResponse.UsageMetadata.ThoughtsTokenCount
  37. for _, detail := range geminiResponse.UsageMetadata.PromptTokensDetails {
  38. if detail.Modality == "AUDIO" {
  39. usage.PromptTokensDetails.AudioTokens = detail.TokenCount
  40. } else if detail.Modality == "TEXT" {
  41. usage.PromptTokensDetails.TextTokens = detail.TokenCount
  42. }
  43. }
  44. // 直接返回 Gemini 原生格式的 JSON 响应
  45. jsonResponse, err := common.Marshal(geminiResponse)
  46. if err != nil {
  47. return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
  48. }
  49. common.IOCopyBytesGracefully(c, resp, jsonResponse)
  50. return &usage, nil
  51. }
  52. func GeminiTextGenerationStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
  53. var usage = &dto.Usage{}
  54. var imageCount int
  55. helper.SetEventStreamHeaders(c)
  56. responseText := strings.Builder{}
  57. helper.StreamScannerHandler(c, resp, info, func(data string) bool {
  58. var geminiResponse GeminiChatResponse
  59. err := common.UnmarshalJsonStr(data, &geminiResponse)
  60. if err != nil {
  61. common.LogError(c, "error unmarshalling stream response: "+err.Error())
  62. return false
  63. }
  64. // 统计图片数量
  65. for _, candidate := range geminiResponse.Candidates {
  66. for _, part := range candidate.Content.Parts {
  67. if part.InlineData != nil && part.InlineData.MimeType != "" {
  68. imageCount++
  69. }
  70. if part.Text != "" {
  71. responseText.WriteString(part.Text)
  72. }
  73. }
  74. }
  75. // 更新使用量统计
  76. if geminiResponse.UsageMetadata.TotalTokenCount != 0 {
  77. usage.PromptTokens = geminiResponse.UsageMetadata.PromptTokenCount
  78. usage.CompletionTokens = geminiResponse.UsageMetadata.CandidatesTokenCount + geminiResponse.UsageMetadata.ThoughtsTokenCount
  79. usage.TotalTokens = geminiResponse.UsageMetadata.TotalTokenCount
  80. usage.CompletionTokenDetails.ReasoningTokens = geminiResponse.UsageMetadata.ThoughtsTokenCount
  81. for _, detail := range geminiResponse.UsageMetadata.PromptTokensDetails {
  82. if detail.Modality == "AUDIO" {
  83. usage.PromptTokensDetails.AudioTokens = detail.TokenCount
  84. } else if detail.Modality == "TEXT" {
  85. usage.PromptTokensDetails.TextTokens = detail.TokenCount
  86. }
  87. }
  88. }
  89. // 直接发送 GeminiChatResponse 响应
  90. err = helper.StringData(c, data)
  91. if err != nil {
  92. common.LogError(c, err.Error())
  93. }
  94. return true
  95. })
  96. if imageCount != 0 {
  97. if usage.CompletionTokens == 0 {
  98. usage.CompletionTokens = imageCount * 258
  99. }
  100. }
  101. // 如果usage.CompletionTokens为0,则使用本地统计的completion tokens
  102. if usage.CompletionTokens == 0 {
  103. str := responseText.String()
  104. if len(str) > 0 {
  105. usage = service.ResponseText2Usage(responseText.String(), info.UpstreamModelName, info.PromptTokens)
  106. } else {
  107. // 空补全,不需要使用量
  108. usage = &dto.Usage{}
  109. }
  110. }
  111. // 移除流式响应结尾的[Done],因为Gemini API没有发送Done的行为
  112. //helper.Done(c)
  113. return usage, nil
  114. }