file_helper.go 3.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157
  1. package gemini
  2. import (
  3. "bytes"
  4. "fmt"
  5. "io"
  6. "mime/multipart"
  7. "net/http"
  8. "strings"
  9. "github.com/QuantumNous/new-api/common"
  10. "github.com/QuantumNous/new-api/logger"
  11. "github.com/QuantumNous/new-api/types"
  12. "github.com/gin-gonic/gin"
  13. )
  14. const (
  15. GeminiFileAPIBaseURL = "https://generativelanguage.googleapis.com"
  16. )
  17. // BuildGeminiFileURL constructs the full URL for file operations
  18. func BuildGeminiFileURL(path string) string {
  19. return GeminiFileAPIBaseURL + path
  20. }
  21. // ExtractAPIKey gets API key from header or query param
  22. // Gemini supports both x-goog-api-key header and key query parameter
  23. func ExtractAPIKey(c *gin.Context) string {
  24. // First check header
  25. apiKey := c.GetHeader("x-goog-api-key")
  26. if apiKey != "" {
  27. return apiKey
  28. }
  29. // Then check query parameter
  30. return c.Query("key")
  31. }
  32. // ForwardGeminiFileRequest sends request to Gemini and streams response
  33. func ForwardGeminiFileRequest(c *gin.Context, method, url string, body io.Reader, headers map[string]string) error {
  34. req, err := http.NewRequest(method, url, body)
  35. if err != nil {
  36. logger.LogError(c, fmt.Sprintf("failed to create request: %s", err.Error()))
  37. return err
  38. }
  39. // Copy headers
  40. for key, value := range headers {
  41. req.Header.Set(key, value)
  42. }
  43. // Send request
  44. client := &http.Client{}
  45. resp, err := client.Do(req)
  46. if err != nil {
  47. logger.LogError(c, fmt.Sprintf("failed to send request: %s", err.Error()))
  48. return err
  49. }
  50. defer resp.Body.Close()
  51. // Copy response headers
  52. for key, values := range resp.Header {
  53. for _, value := range values {
  54. c.Header(key, value)
  55. }
  56. }
  57. // Set status code
  58. c.Status(resp.StatusCode)
  59. // Stream response body
  60. _, err = io.Copy(c.Writer, resp.Body)
  61. if err != nil {
  62. logger.LogError(c, fmt.Sprintf("failed to stream response: %s", err.Error()))
  63. return err
  64. }
  65. return nil
  66. }
  67. // RebuildMultipartForm rebuilds a multipart form for forwarding to upstream
  68. func RebuildMultipartForm(form *multipart.Form) (io.Reader, string, error) {
  69. body := &bytes.Buffer{}
  70. writer := multipart.NewWriter(body)
  71. // Add form fields
  72. for key, values := range form.Value {
  73. for _, value := range values {
  74. if err := writer.WriteField(key, value); err != nil {
  75. return nil, "", err
  76. }
  77. }
  78. }
  79. // Add files
  80. for key, files := range form.File {
  81. for _, fileHeader := range files {
  82. file, err := fileHeader.Open()
  83. if err != nil {
  84. return nil, "", err
  85. }
  86. defer file.Close()
  87. part, err := writer.CreateFormFile(key, fileHeader.Filename)
  88. if err != nil {
  89. return nil, "", err
  90. }
  91. if _, err := io.Copy(part, file); err != nil {
  92. return nil, "", err
  93. }
  94. }
  95. }
  96. contentType := writer.FormDataContentType()
  97. if err := writer.Close(); err != nil {
  98. return nil, "", err
  99. }
  100. return body, contentType, nil
  101. }
  102. // getAPIKeyFromToken extracts the Gemini API key from the token
  103. // This function retrieves the actual API key that should be forwarded to Gemini
  104. func getAPIKeyFromToken(c *gin.Context) (string, error) {
  105. // Get the token info from context (set by middleware)
  106. tokenId, exists := common.GetContextKey(c, "token_id")
  107. if !exists {
  108. return "", types.NewError(nil, types.ErrorCodeAccessDenied, types.ErrOptionWithSkipRetry())
  109. }
  110. // Get the token key from context
  111. tokenKey, exists := common.GetContextKey(c, "token_key")
  112. if !exists {
  113. return "", types.NewError(nil, types.ErrorCodeAccessDenied, types.ErrOptionWithSkipRetry())
  114. }
  115. // For now, we'll use the token key directly
  116. // In a production system, you might want to look up the actual Gemini API key
  117. // from the database using the token ID
  118. if tokenKey == nil {
  119. return "", types.NewError(nil, types.ErrorCodeAccessDenied, types.ErrOptionWithSkipRetry())
  120. }
  121. key, ok := tokenKey.(string)
  122. if !ok {
  123. return "", types.NewError(nil, types.ErrorCodeAccessDenied, types.ErrOptionWithSkipRetry())
  124. }
  125. // Remove "Bearer " prefix if present
  126. key = strings.TrimPrefix(key, "Bearer ")
  127. key = strings.TrimSpace(key)
  128. logger.LogDebug(c, fmt.Sprintf("token_id: %v, using API key for Gemini", tokenId))
  129. return key, nil
  130. }