gemini_file_auth.go 5.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204
  1. package middleware
  2. import (
  3. "fmt"
  4. "net/http"
  5. "strings"
  6. "github.com/QuantumNous/new-api/common"
  7. "github.com/QuantumNous/new-api/constant"
  8. "github.com/QuantumNous/new-api/logger"
  9. "github.com/QuantumNous/new-api/model"
  10. "github.com/QuantumNous/new-api/service"
  11. "github.com/gin-gonic/gin"
  12. )
  13. // GeminiFileAuth is a dedicated authentication middleware for Gemini File API
  14. // This is completely isolated from other authentication logic
  15. func GeminiFileAuth() func(c *gin.Context) {
  16. return func(c *gin.Context) {
  17. // Extract API key from multiple sources
  18. apiKey := extractGeminiFileAPIKey(c)
  19. if apiKey == "" {
  20. c.JSON(http.StatusUnauthorized, gin.H{
  21. "error": gin.H{
  22. "message": "API key is required for Gemini File API",
  23. "type": "authentication_error",
  24. "code": "missing_api_key",
  25. },
  26. })
  27. c.Abort()
  28. return
  29. }
  30. // Validate token
  31. key := strings.TrimPrefix(apiKey, "sk-")
  32. parts := strings.Split(key, "-")
  33. key = parts[0]
  34. token, err := model.ValidateUserToken(key)
  35. if err != nil {
  36. c.JSON(http.StatusUnauthorized, gin.H{
  37. "error": gin.H{
  38. "message": fmt.Sprintf("Invalid API key: %s", err.Error()),
  39. "type": "authentication_error",
  40. "code": "invalid_api_key",
  41. },
  42. })
  43. c.Abort()
  44. return
  45. }
  46. // Check user status
  47. userCache, err := model.GetUserCache(token.UserId)
  48. if err != nil {
  49. c.JSON(http.StatusInternalServerError, gin.H{
  50. "error": gin.H{
  51. "message": fmt.Sprintf("Failed to get user info: %s", err.Error()),
  52. "type": "internal_error",
  53. "code": "user_lookup_failed",
  54. },
  55. })
  56. c.Abort()
  57. return
  58. }
  59. if userCache.Status != common.UserStatusEnabled {
  60. c.JSON(http.StatusForbidden, gin.H{
  61. "error": gin.H{
  62. "message": "User account is disabled",
  63. "type": "authentication_error",
  64. "code": "account_disabled",
  65. },
  66. })
  67. c.Abort()
  68. return
  69. }
  70. // Set user context
  71. userCache.WriteContext(c)
  72. // Get user group
  73. userGroup := userCache.Group
  74. tokenGroup := token.Group
  75. if tokenGroup != "" {
  76. // Check if user has access to this group
  77. if _, ok := service.GetUserUsableGroups(userGroup)[tokenGroup]; !ok {
  78. c.JSON(http.StatusForbidden, gin.H{
  79. "error": gin.H{
  80. "message": fmt.Sprintf("No access to group: %s", tokenGroup),
  81. "type": "authorization_error",
  82. "code": "group_access_denied",
  83. },
  84. })
  85. c.Abort()
  86. return
  87. }
  88. userGroup = tokenGroup
  89. }
  90. common.SetContextKey(c, constant.ContextKeyUsingGroup, userGroup)
  91. // Find an available Gemini channel for file operations
  92. channel, err := findGeminiFileChannel(c, userGroup)
  93. if err != nil {
  94. c.JSON(http.StatusServiceUnavailable, gin.H{
  95. "error": gin.H{
  96. "message": fmt.Sprintf("No available Gemini channel: %s", err.Error()),
  97. "type": "service_unavailable_error",
  98. "code": "no_available_channel",
  99. },
  100. })
  101. c.Abort()
  102. return
  103. }
  104. // Setup channel context
  105. newAPIError := SetupContextForSelectedChannel(c, channel, "gemini-2.0-flash")
  106. if newAPIError != nil {
  107. c.JSON(http.StatusServiceUnavailable, gin.H{
  108. "error": gin.H{
  109. "message": fmt.Sprintf("Failed to setup channel: %s", newAPIError.Error()),
  110. "type": "service_unavailable_error",
  111. "code": "channel_setup_failed",
  112. },
  113. })
  114. c.Abort()
  115. return
  116. }
  117. // Set token context for quota tracking
  118. c.Set("id", token.UserId)
  119. c.Set("token_id", token.Id)
  120. c.Set("token_key", token.Key)
  121. c.Set("token_name", token.Name)
  122. c.Set("token_unlimited_quota", token.UnlimitedQuota)
  123. if !token.UnlimitedQuota {
  124. c.Set("token_quota", token.RemainQuota)
  125. }
  126. c.Next()
  127. }
  128. }
  129. // extractGeminiFileAPIKey extracts API key from various sources
  130. func extractGeminiFileAPIKey(c *gin.Context) string {
  131. // 1. Check Authorization header
  132. auth := c.GetHeader("Authorization")
  133. if auth != "" {
  134. if strings.HasPrefix(auth, "Bearer ") || strings.HasPrefix(auth, "bearer ") {
  135. return strings.TrimSpace(auth[7:])
  136. }
  137. }
  138. // 2. Check x-goog-api-key header (Gemini-specific)
  139. if key := c.GetHeader("x-goog-api-key"); key != "" {
  140. return key
  141. }
  142. // 3. Check x-api-key header (Claude-style)
  143. if key := c.GetHeader("x-api-key"); key != "" {
  144. return key
  145. }
  146. // 4. Check query parameter
  147. if key := c.Query("key"); key != "" {
  148. return key
  149. }
  150. return ""
  151. }
  152. // findGeminiFileChannel finds an available Gemini channel for file operations
  153. func findGeminiFileChannel(c *gin.Context, userGroup string) (*model.Channel, error) {
  154. // Try multiple common Gemini models to find an available channel
  155. geminiModels := []string{
  156. "gemini-2.0-flash",
  157. "gemini-1.5-flash",
  158. "gemini-1.5-pro",
  159. "gemini-2.0-flash-exp",
  160. "gemini-pro",
  161. "gemini-1.0-pro",
  162. }
  163. var lastError error
  164. for _, modelName := range geminiModels {
  165. channel, _, err := service.CacheGetRandomSatisfiedChannel(&service.RetryParam{
  166. Ctx: c,
  167. ModelName: modelName,
  168. TokenGroup: userGroup,
  169. Retry: common.GetPointer(0),
  170. })
  171. if err == nil && channel != nil {
  172. logger.LogDebug(c, fmt.Sprintf("Found Gemini channel for file operations using model: %s", modelName))
  173. return channel, nil
  174. }
  175. lastError = err
  176. }
  177. if lastError != nil {
  178. return nil, fmt.Errorf("failed to find Gemini channel: %w", lastError)
  179. }
  180. return nil, fmt.Errorf("no available Gemini channel found")
  181. }