| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157 |
- package gemini
- import (
- "bytes"
- "fmt"
- "io"
- "mime/multipart"
- "net/http"
- "strings"
- "github.com/QuantumNous/new-api/common"
- "github.com/QuantumNous/new-api/logger"
- "github.com/QuantumNous/new-api/types"
- "github.com/gin-gonic/gin"
- )
- const (
- GeminiFileAPIBaseURL = "https://generativelanguage.googleapis.com"
- )
- // BuildGeminiFileURL constructs the full URL for file operations
- func BuildGeminiFileURL(path string) string {
- return GeminiFileAPIBaseURL + path
- }
- // ExtractAPIKey gets API key from header or query param
- // Gemini supports both x-goog-api-key header and key query parameter
- func ExtractAPIKey(c *gin.Context) string {
- // First check header
- apiKey := c.GetHeader("x-goog-api-key")
- if apiKey != "" {
- return apiKey
- }
- // Then check query parameter
- return c.Query("key")
- }
- // ForwardGeminiFileRequest sends request to Gemini and streams response
- func ForwardGeminiFileRequest(c *gin.Context, method, url string, body io.Reader, headers map[string]string) error {
- req, err := http.NewRequest(method, url, body)
- if err != nil {
- logger.LogError(c, fmt.Sprintf("failed to create request: %s", err.Error()))
- return err
- }
- // Copy headers
- for key, value := range headers {
- req.Header.Set(key, value)
- }
- // Send request
- client := &http.Client{}
- resp, err := client.Do(req)
- if err != nil {
- logger.LogError(c, fmt.Sprintf("failed to send request: %s", err.Error()))
- return err
- }
- defer resp.Body.Close()
- // Copy response headers
- for key, values := range resp.Header {
- for _, value := range values {
- c.Header(key, value)
- }
- }
- // Set status code
- c.Status(resp.StatusCode)
- // Stream response body
- _, err = io.Copy(c.Writer, resp.Body)
- if err != nil {
- logger.LogError(c, fmt.Sprintf("failed to stream response: %s", err.Error()))
- return err
- }
- return nil
- }
- // RebuildMultipartForm rebuilds a multipart form for forwarding to upstream
- func RebuildMultipartForm(form *multipart.Form) (io.Reader, string, error) {
- body := &bytes.Buffer{}
- writer := multipart.NewWriter(body)
- // Add form fields
- for key, values := range form.Value {
- for _, value := range values {
- if err := writer.WriteField(key, value); err != nil {
- return nil, "", err
- }
- }
- }
- // Add files
- for key, files := range form.File {
- for _, fileHeader := range files {
- file, err := fileHeader.Open()
- if err != nil {
- return nil, "", err
- }
- defer file.Close()
- part, err := writer.CreateFormFile(key, fileHeader.Filename)
- if err != nil {
- return nil, "", err
- }
- if _, err := io.Copy(part, file); err != nil {
- return nil, "", err
- }
- }
- }
- contentType := writer.FormDataContentType()
- if err := writer.Close(); err != nil {
- return nil, "", err
- }
- return body, contentType, nil
- }
- // getAPIKeyFromToken extracts the Gemini API key from the token
- // This function retrieves the actual API key that should be forwarded to Gemini
- func getAPIKeyFromToken(c *gin.Context) (string, error) {
- // Get the token info from context (set by middleware)
- tokenId, exists := common.GetContextKey(c, "token_id")
- if !exists {
- return "", types.NewError(nil, types.ErrorCodeAccessDenied, types.ErrOptionWithSkipRetry())
- }
- // Get the token key from context
- tokenKey, exists := common.GetContextKey(c, "token_key")
- if !exists {
- return "", types.NewError(nil, types.ErrorCodeAccessDenied, types.ErrOptionWithSkipRetry())
- }
- // For now, we'll use the token key directly
- // In a production system, you might want to look up the actual Gemini API key
- // from the database using the token ID
- if tokenKey == nil {
- return "", types.NewError(nil, types.ErrorCodeAccessDenied, types.ErrOptionWithSkipRetry())
- }
- key, ok := tokenKey.(string)
- if !ok {
- return "", types.NewError(nil, types.ErrorCodeAccessDenied, types.ErrOptionWithSkipRetry())
- }
- // Remove "Bearer " prefix if present
- key = strings.TrimPrefix(key, "Bearer ")
- key = strings.TrimSpace(key)
- logger.LogDebug(c, fmt.Sprintf("token_id: %v, using API key for Gemini", tokenId))
- return key, nil
- }
|