relay-gemini-native.go 4.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143
  1. package gemini
  2. import (
  3. "encoding/json"
  4. "io"
  5. "net/http"
  6. "one-api/common"
  7. "one-api/dto"
  8. relaycommon "one-api/relay/common"
  9. "one-api/relay/helper"
  10. "one-api/service"
  11. "strings"
  12. "github.com/gin-gonic/gin"
  13. )
  14. func GeminiTextGenerationHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.Usage, *dto.OpenAIErrorWithStatusCode) {
  15. // 读取响应体
  16. responseBody, err := io.ReadAll(resp.Body)
  17. if err != nil {
  18. return nil, service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError)
  19. }
  20. common.CloseResponseBodyGracefully(resp)
  21. if common.DebugEnabled {
  22. println(string(responseBody))
  23. }
  24. // 解析为 Gemini 原生响应格式
  25. var geminiResponse GeminiChatResponse
  26. err = common.DecodeJson(responseBody, &geminiResponse)
  27. if err != nil {
  28. return nil, service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError)
  29. }
  30. // 计算使用量(基于 UsageMetadata)
  31. usage := dto.Usage{
  32. PromptTokens: geminiResponse.UsageMetadata.PromptTokenCount,
  33. CompletionTokens: geminiResponse.UsageMetadata.CandidatesTokenCount + geminiResponse.UsageMetadata.ThoughtsTokenCount,
  34. TotalTokens: geminiResponse.UsageMetadata.TotalTokenCount,
  35. }
  36. usage.CompletionTokenDetails.ReasoningTokens = geminiResponse.UsageMetadata.ThoughtsTokenCount
  37. for _, detail := range geminiResponse.UsageMetadata.PromptTokensDetails {
  38. if detail.Modality == "AUDIO" {
  39. usage.PromptTokensDetails.AudioTokens = detail.TokenCount
  40. } else if detail.Modality == "TEXT" {
  41. usage.PromptTokensDetails.TextTokens = detail.TokenCount
  42. }
  43. }
  44. // 直接返回 Gemini 原生格式的 JSON 响应
  45. jsonResponse, err := json.Marshal(geminiResponse)
  46. if err != nil {
  47. return nil, service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError)
  48. }
  49. // 设置响应头并写入响应
  50. c.Writer.Header().Set("Content-Type", "application/json")
  51. c.Writer.WriteHeader(resp.StatusCode)
  52. _, err = c.Writer.Write(jsonResponse)
  53. if err != nil {
  54. return nil, service.OpenAIErrorWrapper(err, "write_response_failed", http.StatusInternalServerError)
  55. }
  56. return &usage, nil
  57. }
  58. func GeminiTextGenerationStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.Usage, *dto.OpenAIErrorWithStatusCode) {
  59. var usage = &dto.Usage{}
  60. var imageCount int
  61. helper.SetEventStreamHeaders(c)
  62. responseText := strings.Builder{}
  63. helper.StreamScannerHandler(c, resp, info, func(data string) bool {
  64. var geminiResponse GeminiChatResponse
  65. err := common.DecodeJsonStr(data, &geminiResponse)
  66. if err != nil {
  67. common.LogError(c, "error unmarshalling stream response: "+err.Error())
  68. return false
  69. }
  70. // 统计图片数量
  71. for _, candidate := range geminiResponse.Candidates {
  72. for _, part := range candidate.Content.Parts {
  73. if part.InlineData != nil && part.InlineData.MimeType != "" {
  74. imageCount++
  75. }
  76. if part.Text != "" {
  77. responseText.WriteString(part.Text)
  78. }
  79. }
  80. }
  81. // 更新使用量统计
  82. if geminiResponse.UsageMetadata.TotalTokenCount != 0 {
  83. usage.PromptTokens = geminiResponse.UsageMetadata.PromptTokenCount
  84. usage.CompletionTokens = geminiResponse.UsageMetadata.CandidatesTokenCount + geminiResponse.UsageMetadata.ThoughtsTokenCount
  85. usage.TotalTokens = geminiResponse.UsageMetadata.TotalTokenCount
  86. usage.CompletionTokenDetails.ReasoningTokens = geminiResponse.UsageMetadata.ThoughtsTokenCount
  87. for _, detail := range geminiResponse.UsageMetadata.PromptTokensDetails {
  88. if detail.Modality == "AUDIO" {
  89. usage.PromptTokensDetails.AudioTokens = detail.TokenCount
  90. } else if detail.Modality == "TEXT" {
  91. usage.PromptTokensDetails.TextTokens = detail.TokenCount
  92. }
  93. }
  94. }
  95. // 直接发送 GeminiChatResponse 响应
  96. err = helper.StringData(c, data)
  97. if err != nil {
  98. common.LogError(c, err.Error())
  99. }
  100. return true
  101. })
  102. if imageCount != 0 {
  103. if usage.CompletionTokens == 0 {
  104. usage.CompletionTokens = imageCount * 258
  105. }
  106. }
  107. // 如果usage.CompletionTokens为0,则使用本地统计的completion tokens
  108. if usage.CompletionTokens == 0 {
  109. str := responseText.String()
  110. if len(str) > 0 {
  111. usage = service.ResponseText2Usage(responseText.String(), info.UpstreamModelName, info.PromptTokens)
  112. } else {
  113. // 空补全,不需要使用量
  114. usage = &dto.Usage{}
  115. }
  116. }
  117. // 移除流式响应结尾的[Done],因为Gemini API没有发送Done的行为
  118. //helper.Done(c)
  119. return usage, nil
  120. }