relay-gemini-native.go 4.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145
  1. package gemini
  2. import (
  3. "encoding/json"
  4. "io"
  5. "net/http"
  6. "one-api/common"
  7. "one-api/dto"
  8. relaycommon "one-api/relay/common"
  9. "one-api/relay/helper"
  10. "one-api/service"
  11. "github.com/gin-gonic/gin"
  12. )
  13. func GeminiTextGenerationHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.Usage, *dto.OpenAIErrorWithStatusCode) {
  14. // 读取响应体
  15. responseBody, err := io.ReadAll(resp.Body)
  16. if err != nil {
  17. return nil, service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError)
  18. }
  19. err = resp.Body.Close()
  20. if err != nil {
  21. return nil, service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError)
  22. }
  23. if common.DebugEnabled {
  24. println(string(responseBody))
  25. }
  26. // 解析为 Gemini 原生响应格式
  27. var geminiResponse GeminiChatResponse
  28. err = common.DecodeJson(responseBody, &geminiResponse)
  29. if err != nil {
  30. return nil, service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError)
  31. }
  32. // 检查是否有候选响应
  33. if len(geminiResponse.Candidates) == 0 {
  34. return nil, &dto.OpenAIErrorWithStatusCode{
  35. Error: dto.OpenAIError{
  36. Message: "No candidates returned",
  37. Type: "server_error",
  38. Param: "",
  39. Code: 500,
  40. },
  41. StatusCode: resp.StatusCode,
  42. }
  43. }
  44. // 计算使用量(基于 UsageMetadata)
  45. usage := dto.Usage{
  46. PromptTokens: geminiResponse.UsageMetadata.PromptTokenCount,
  47. CompletionTokens: geminiResponse.UsageMetadata.CandidatesTokenCount,
  48. TotalTokens: geminiResponse.UsageMetadata.TotalTokenCount,
  49. }
  50. usage.CompletionTokenDetails.ReasoningTokens = geminiResponse.UsageMetadata.ThoughtsTokenCount
  51. for _, detail := range geminiResponse.UsageMetadata.PromptTokensDetails {
  52. if detail.Modality == "AUDIO" {
  53. usage.PromptTokensDetails.AudioTokens = detail.TokenCount
  54. } else if detail.Modality == "TEXT" {
  55. usage.PromptTokensDetails.TextTokens = detail.TokenCount
  56. }
  57. }
  58. // 直接返回 Gemini 原生格式的 JSON 响应
  59. jsonResponse, err := json.Marshal(geminiResponse)
  60. if err != nil {
  61. return nil, service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError)
  62. }
  63. // 设置响应头并写入响应
  64. c.Writer.Header().Set("Content-Type", "application/json")
  65. c.Writer.WriteHeader(resp.StatusCode)
  66. _, err = c.Writer.Write(jsonResponse)
  67. if err != nil {
  68. return nil, service.OpenAIErrorWrapper(err, "write_response_failed", http.StatusInternalServerError)
  69. }
  70. return &usage, nil
  71. }
  72. func GeminiTextGenerationStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.Usage, *dto.OpenAIErrorWithStatusCode) {
  73. var usage = &dto.Usage{}
  74. var imageCount int
  75. helper.SetEventStreamHeaders(c)
  76. helper.StreamScannerHandler(c, resp, info, func(data string) bool {
  77. var geminiResponse GeminiChatResponse
  78. err := common.DecodeJsonStr(data, &geminiResponse)
  79. if err != nil {
  80. common.LogError(c, "error unmarshalling stream response: "+err.Error())
  81. return false
  82. }
  83. // 统计图片数量
  84. for _, candidate := range geminiResponse.Candidates {
  85. for _, part := range candidate.Content.Parts {
  86. if part.InlineData != nil && part.InlineData.MimeType != "" {
  87. imageCount++
  88. }
  89. }
  90. }
  91. // 更新使用量统计
  92. if geminiResponse.UsageMetadata.TotalTokenCount != 0 {
  93. usage.PromptTokens = geminiResponse.UsageMetadata.PromptTokenCount
  94. usage.CompletionTokens = geminiResponse.UsageMetadata.CandidatesTokenCount
  95. usage.TotalTokens = geminiResponse.UsageMetadata.TotalTokenCount
  96. usage.CompletionTokenDetails.ReasoningTokens = geminiResponse.UsageMetadata.ThoughtsTokenCount
  97. for _, detail := range geminiResponse.UsageMetadata.PromptTokensDetails {
  98. if detail.Modality == "AUDIO" {
  99. usage.PromptTokensDetails.AudioTokens = detail.TokenCount
  100. } else if detail.Modality == "TEXT" {
  101. usage.PromptTokensDetails.TextTokens = detail.TokenCount
  102. }
  103. }
  104. }
  105. // 直接发送 GeminiChatResponse 响应
  106. err = helper.ObjectData(c, geminiResponse)
  107. if err != nil {
  108. common.LogError(c, err.Error())
  109. }
  110. return true
  111. })
  112. if imageCount != 0 {
  113. if usage.CompletionTokens == 0 {
  114. usage.CompletionTokens = imageCount * 258
  115. }
  116. }
  117. // 计算最终使用量
  118. usage.CompletionTokens = usage.TotalTokens - usage.PromptTokens
  119. // 移除流式响应结尾的[Done],因为Gemini API没有发送Done的行为
  120. //helper.Done(c)
  121. return usage, nil
  122. }