package gemini import ( "fmt" "io" "net/http" "github.com/QuantumNous/new-api/common" "github.com/QuantumNous/new-api/constant" "github.com/QuantumNous/new-api/dto" "github.com/QuantumNous/new-api/logger" relaycommon "github.com/QuantumNous/new-api/relay/common" "github.com/QuantumNous/new-api/relay/helper" "github.com/QuantumNous/new-api/service" "github.com/QuantumNous/new-api/types" "github.com/gin-gonic/gin" ) func GeminiTextGenerationHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) { defer service.CloseResponseBodyGracefully(resp) // 读取响应体 responseBody, err := io.ReadAll(resp.Body) if err != nil { return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError) } if common.DebugEnabled { println(string(responseBody)) } // 解析为 Gemini 原生响应格式 var geminiResponse dto.GeminiChatResponse err = common.Unmarshal(responseBody, &geminiResponse) if err != nil { return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError) } if len(geminiResponse.Candidates) == 0 && geminiResponse.PromptFeedback != nil && geminiResponse.PromptFeedback.BlockReason != nil { common.SetContextKey(c, constant.ContextKeyAdminRejectReason, fmt.Sprintf("gemini_block_reason=%s", *geminiResponse.PromptFeedback.BlockReason)) } // 计算使用量(基于 UsageMetadata) usage := dto.Usage{ PromptTokens: geminiResponse.UsageMetadata.PromptTokenCount, CompletionTokens: geminiResponse.UsageMetadata.CandidatesTokenCount + geminiResponse.UsageMetadata.ThoughtsTokenCount, TotalTokens: geminiResponse.UsageMetadata.TotalTokenCount, } usage.CompletionTokenDetails.ReasoningTokens = geminiResponse.UsageMetadata.ThoughtsTokenCount usage.PromptTokensDetails.CachedTokens = geminiResponse.UsageMetadata.CachedContentTokenCount for _, detail := range geminiResponse.UsageMetadata.PromptTokensDetails { if detail.Modality == "AUDIO" { usage.PromptTokensDetails.AudioTokens = detail.TokenCount } else if detail.Modality == "TEXT" { usage.PromptTokensDetails.TextTokens = detail.TokenCount } } service.IOCopyBytesGracefully(c, resp, responseBody) return &usage, nil } func NativeGeminiEmbeddingHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.Usage, *types.NewAPIError) { defer service.CloseResponseBodyGracefully(resp) responseBody, err := io.ReadAll(resp.Body) if err != nil { return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError) } if common.DebugEnabled { println(string(responseBody)) } usage := service.ResponseText2Usage(c, "", info.UpstreamModelName, info.GetEstimatePromptTokens()) if info.IsGeminiBatchEmbedding { var geminiResponse dto.GeminiBatchEmbeddingResponse err = common.Unmarshal(responseBody, &geminiResponse) if err != nil { return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError) } } else { var geminiResponse dto.GeminiEmbeddingResponse err = common.Unmarshal(responseBody, &geminiResponse) if err != nil { return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError) } } service.IOCopyBytesGracefully(c, resp, responseBody) return usage, nil } func GeminiTextGenerationStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) { helper.SetEventStreamHeaders(c) return geminiStreamHandler(c, info, resp, func(data string, geminiResponse *dto.GeminiChatResponse) bool { err := helper.StringData(c, data) if err != nil { logger.LogError(c, "failed to write stream data: "+err.Error()) return false } info.SendResponseCount++ return true }) }