relay-gemini-native.go 4.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146
  1. package gemini
  2. import (
  3. "encoding/json"
  4. "io"
  5. "net/http"
  6. "one-api/common"
  7. "one-api/dto"
  8. relaycommon "one-api/relay/common"
  9. "one-api/relay/helper"
  10. "one-api/service"
  11. "strings"
  12. "github.com/gin-gonic/gin"
  13. )
  14. func GeminiTextGenerationHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.Usage, *dto.OpenAIErrorWithStatusCode) {
  15. // 读取响应体
  16. responseBody, err := io.ReadAll(resp.Body)
  17. if err != nil {
  18. return nil, service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError)
  19. }
  20. err = resp.Body.Close()
  21. if err != nil {
  22. return nil, service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError)
  23. }
  24. if common.DebugEnabled {
  25. println(string(responseBody))
  26. }
  27. // 解析为 Gemini 原生响应格式
  28. var geminiResponse GeminiChatResponse
  29. err = common.DecodeJson(responseBody, &geminiResponse)
  30. if err != nil {
  31. return nil, service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError)
  32. }
  33. // 计算使用量(基于 UsageMetadata)
  34. usage := dto.Usage{
  35. PromptTokens: geminiResponse.UsageMetadata.PromptTokenCount,
  36. CompletionTokens: geminiResponse.UsageMetadata.CandidatesTokenCount + geminiResponse.UsageMetadata.ThoughtsTokenCount,
  37. TotalTokens: geminiResponse.UsageMetadata.TotalTokenCount,
  38. }
  39. usage.CompletionTokenDetails.ReasoningTokens = geminiResponse.UsageMetadata.ThoughtsTokenCount
  40. for _, detail := range geminiResponse.UsageMetadata.PromptTokensDetails {
  41. if detail.Modality == "AUDIO" {
  42. usage.PromptTokensDetails.AudioTokens = detail.TokenCount
  43. } else if detail.Modality == "TEXT" {
  44. usage.PromptTokensDetails.TextTokens = detail.TokenCount
  45. }
  46. }
  47. // 直接返回 Gemini 原生格式的 JSON 响应
  48. jsonResponse, err := json.Marshal(geminiResponse)
  49. if err != nil {
  50. return nil, service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError)
  51. }
  52. // 设置响应头并写入响应
  53. c.Writer.Header().Set("Content-Type", "application/json")
  54. c.Writer.WriteHeader(resp.StatusCode)
  55. _, err = c.Writer.Write(jsonResponse)
  56. if err != nil {
  57. return nil, service.OpenAIErrorWrapper(err, "write_response_failed", http.StatusInternalServerError)
  58. }
  59. return &usage, nil
  60. }
  61. func GeminiTextGenerationStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.Usage, *dto.OpenAIErrorWithStatusCode) {
  62. var usage = &dto.Usage{}
  63. var imageCount int
  64. helper.SetEventStreamHeaders(c)
  65. responseText := strings.Builder{}
  66. helper.StreamScannerHandler(c, resp, info, func(data string) bool {
  67. var geminiResponse GeminiChatResponse
  68. err := common.DecodeJsonStr(data, &geminiResponse)
  69. if err != nil {
  70. common.LogError(c, "error unmarshalling stream response: "+err.Error())
  71. return false
  72. }
  73. // 统计图片数量
  74. for _, candidate := range geminiResponse.Candidates {
  75. for _, part := range candidate.Content.Parts {
  76. if part.InlineData != nil && part.InlineData.MimeType != "" {
  77. imageCount++
  78. }
  79. if part.Text != "" {
  80. responseText.WriteString(part.Text)
  81. }
  82. }
  83. }
  84. // 更新使用量统计
  85. if geminiResponse.UsageMetadata.TotalTokenCount != 0 {
  86. usage.PromptTokens = geminiResponse.UsageMetadata.PromptTokenCount
  87. usage.CompletionTokens = geminiResponse.UsageMetadata.CandidatesTokenCount + geminiResponse.UsageMetadata.ThoughtsTokenCount
  88. usage.TotalTokens = geminiResponse.UsageMetadata.TotalTokenCount
  89. usage.CompletionTokenDetails.ReasoningTokens = geminiResponse.UsageMetadata.ThoughtsTokenCount
  90. for _, detail := range geminiResponse.UsageMetadata.PromptTokensDetails {
  91. if detail.Modality == "AUDIO" {
  92. usage.PromptTokensDetails.AudioTokens = detail.TokenCount
  93. } else if detail.Modality == "TEXT" {
  94. usage.PromptTokensDetails.TextTokens = detail.TokenCount
  95. }
  96. }
  97. }
  98. // 直接发送 GeminiChatResponse 响应
  99. err = helper.StringData(c, data)
  100. if err != nil {
  101. common.LogError(c, err.Error())
  102. }
  103. return true
  104. })
  105. if imageCount != 0 {
  106. if usage.CompletionTokens == 0 {
  107. usage.CompletionTokens = imageCount * 258
  108. }
  109. }
  110. // 如果usage.CompletionTokens为0,则使用本地统计的completion tokens
  111. if usage.CompletionTokens == 0 {
  112. str := responseText.String()
  113. if len(str) > 0 {
  114. usage = service.ResponseText2Usage(responseText.String(), info.UpstreamModelName, info.PromptTokens)
  115. } else {
  116. // 空补全,不需要使用量
  117. usage = &dto.Usage{}
  118. }
  119. }
  120. // 移除流式响应结尾的[Done],因为Gemini API没有发送Done的行为
  121. //helper.Done(c)
  122. return usage, nil
  123. }