package gemini import ( "bytes" "fmt" "io" "mime/multipart" "net/http" "strings" "github.com/QuantumNous/new-api/common" "github.com/QuantumNous/new-api/logger" "github.com/QuantumNous/new-api/types" "github.com/gin-gonic/gin" ) const ( GeminiFileAPIBaseURL = "https://generativelanguage.googleapis.com" ) // BuildGeminiFileURL constructs the full URL for file operations func BuildGeminiFileURL(path string) string { return GeminiFileAPIBaseURL + path } // ExtractAPIKey gets API key from header or query param // Gemini supports both x-goog-api-key header and key query parameter func ExtractAPIKey(c *gin.Context) string { // First check header apiKey := c.GetHeader("x-goog-api-key") if apiKey != "" { return apiKey } // Then check query parameter return c.Query("key") } // ForwardGeminiFileRequest sends request to Gemini and streams response func ForwardGeminiFileRequest(c *gin.Context, method, url string, body io.Reader, headers map[string]string) error { req, err := http.NewRequest(method, url, body) if err != nil { logger.LogError(c, fmt.Sprintf("failed to create request: %s", err.Error())) return err } // Copy headers for key, value := range headers { req.Header.Set(key, value) } // Send request client := &http.Client{} resp, err := client.Do(req) if err != nil { logger.LogError(c, fmt.Sprintf("failed to send request: %s", err.Error())) return err } defer resp.Body.Close() // Copy response headers for key, values := range resp.Header { for _, value := range values { c.Header(key, value) } } // Set status code c.Status(resp.StatusCode) // Stream response body _, err = io.Copy(c.Writer, resp.Body) if err != nil { logger.LogError(c, fmt.Sprintf("failed to stream response: %s", err.Error())) return err } return nil } // RebuildMultipartForm rebuilds a multipart form for forwarding to upstream func RebuildMultipartForm(form *multipart.Form) (io.Reader, string, error) { body := &bytes.Buffer{} writer := multipart.NewWriter(body) // Add form fields for key, values := range form.Value { for _, value := range values { if err := writer.WriteField(key, value); err != nil { return nil, "", err } } } // Add files for key, files := range form.File { for _, fileHeader := range files { file, err := fileHeader.Open() if err != nil { return nil, "", err } defer file.Close() part, err := writer.CreateFormFile(key, fileHeader.Filename) if err != nil { return nil, "", err } if _, err := io.Copy(part, file); err != nil { return nil, "", err } } } contentType := writer.FormDataContentType() if err := writer.Close(); err != nil { return nil, "", err } return body, contentType, nil } // getAPIKeyFromToken extracts the Gemini API key from the token // This function retrieves the actual API key that should be forwarded to Gemini func getAPIKeyFromToken(c *gin.Context) (string, error) { // Get the token info from context (set by middleware) tokenId, exists := common.GetContextKey(c, "token_id") if !exists { return "", types.NewError(nil, types.ErrorCodeAccessDenied, types.ErrOptionWithSkipRetry()) } // Get the token key from context tokenKey, exists := common.GetContextKey(c, "token_key") if !exists { return "", types.NewError(nil, types.ErrorCodeAccessDenied, types.ErrOptionWithSkipRetry()) } // For now, we'll use the token key directly // In a production system, you might want to look up the actual Gemini API key // from the database using the token ID if tokenKey == nil { return "", types.NewError(nil, types.ErrorCodeAccessDenied, types.ErrOptionWithSkipRetry()) } key, ok := tokenKey.(string) if !ok { return "", types.NewError(nil, types.ErrorCodeAccessDenied, types.ErrOptionWithSkipRetry()) } // Remove "Bearer " prefix if present key = strings.TrimPrefix(key, "Bearer ") key = strings.TrimSpace(key) logger.LogDebug(c, fmt.Sprintf("token_id: %v, using API key for Gemini", tokenId)) return key, nil }