package middleware import ( "fmt" "net/http" "strings" "github.com/QuantumNous/new-api/common" "github.com/QuantumNous/new-api/constant" "github.com/QuantumNous/new-api/logger" "github.com/QuantumNous/new-api/model" "github.com/QuantumNous/new-api/service" "github.com/gin-gonic/gin" ) // GeminiFileAuth is a dedicated authentication middleware for Gemini File API // This is completely isolated from other authentication logic func GeminiFileAuth() func(c *gin.Context) { return func(c *gin.Context) { // Extract API key from multiple sources apiKey := extractGeminiFileAPIKey(c) if apiKey == "" { c.JSON(http.StatusUnauthorized, gin.H{ "error": gin.H{ "message": "API key is required for Gemini File API", "type": "authentication_error", "code": "missing_api_key", }, }) c.Abort() return } // Validate token key := strings.TrimPrefix(apiKey, "sk-") parts := strings.Split(key, "-") key = parts[0] token, err := model.ValidateUserToken(key) if err != nil { c.JSON(http.StatusUnauthorized, gin.H{ "error": gin.H{ "message": fmt.Sprintf("Invalid API key: %s", err.Error()), "type": "authentication_error", "code": "invalid_api_key", }, }) c.Abort() return } // Check user status userCache, err := model.GetUserCache(token.UserId) if err != nil { c.JSON(http.StatusInternalServerError, gin.H{ "error": gin.H{ "message": fmt.Sprintf("Failed to get user info: %s", err.Error()), "type": "internal_error", "code": "user_lookup_failed", }, }) c.Abort() return } if userCache.Status != common.UserStatusEnabled { c.JSON(http.StatusForbidden, gin.H{ "error": gin.H{ "message": "User account is disabled", "type": "authentication_error", "code": "account_disabled", }, }) c.Abort() return } // Set user context userCache.WriteContext(c) // Get user group userGroup := userCache.Group tokenGroup := token.Group if tokenGroup != "" { // Check if user has access to this group if _, ok := service.GetUserUsableGroups(userGroup)[tokenGroup]; !ok { c.JSON(http.StatusForbidden, gin.H{ "error": gin.H{ "message": fmt.Sprintf("No access to group: %s", tokenGroup), "type": "authorization_error", "code": "group_access_denied", }, }) c.Abort() return } userGroup = tokenGroup } common.SetContextKey(c, constant.ContextKeyUsingGroup, userGroup) // Find an available Gemini channel for file operations channel, err := findGeminiFileChannel(c, userGroup) if err != nil { c.JSON(http.StatusServiceUnavailable, gin.H{ "error": gin.H{ "message": fmt.Sprintf("No available Gemini channel: %s", err.Error()), "type": "service_unavailable_error", "code": "no_available_channel", }, }) c.Abort() return } // Setup channel context newAPIError := SetupContextForSelectedChannel(c, channel, "gemini-2.0-flash") if newAPIError != nil { c.JSON(http.StatusServiceUnavailable, gin.H{ "error": gin.H{ "message": fmt.Sprintf("Failed to setup channel: %s", newAPIError.Error()), "type": "service_unavailable_error", "code": "channel_setup_failed", }, }) c.Abort() return } // Set token context for quota tracking c.Set("id", token.UserId) c.Set("token_id", token.Id) c.Set("token_key", token.Key) c.Set("token_name", token.Name) c.Set("token_unlimited_quota", token.UnlimitedQuota) if !token.UnlimitedQuota { c.Set("token_quota", token.RemainQuota) } c.Next() } } // extractGeminiFileAPIKey extracts API key from various sources func extractGeminiFileAPIKey(c *gin.Context) string { // 1. Check Authorization header auth := c.GetHeader("Authorization") if auth != "" { if strings.HasPrefix(auth, "Bearer ") || strings.HasPrefix(auth, "bearer ") { return strings.TrimSpace(auth[7:]) } } // 2. Check x-goog-api-key header (Gemini-specific) if key := c.GetHeader("x-goog-api-key"); key != "" { return key } // 3. Check x-api-key header (Claude-style) if key := c.GetHeader("x-api-key"); key != "" { return key } // 4. Check query parameter if key := c.Query("key"); key != "" { return key } return "" } // findGeminiFileChannel finds an available Gemini channel for file operations func findGeminiFileChannel(c *gin.Context, userGroup string) (*model.Channel, error) { // Try multiple common Gemini models to find an available channel geminiModels := []string{ "gemini-2.0-flash", "gemini-1.5-flash", "gemini-1.5-pro", "gemini-2.0-flash-exp", "gemini-pro", "gemini-1.0-pro", } var lastError error for _, modelName := range geminiModels { channel, _, err := service.CacheGetRandomSatisfiedChannel(&service.RetryParam{ Ctx: c, ModelName: modelName, TokenGroup: userGroup, Retry: common.GetPointer(0), }) if err == nil && channel != nil { logger.LogDebug(c, fmt.Sprintf("Found Gemini channel for file operations using model: %s", modelName)) return channel, nil } lastError = err } if lastError != nil { return nil, fmt.Errorf("failed to find Gemini channel: %w", lastError) } return nil, fmt.Errorf("no available Gemini channel found") }