| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204 |
- package middleware
- import (
- "fmt"
- "net/http"
- "strings"
- "github.com/QuantumNous/new-api/common"
- "github.com/QuantumNous/new-api/constant"
- "github.com/QuantumNous/new-api/logger"
- "github.com/QuantumNous/new-api/model"
- "github.com/QuantumNous/new-api/service"
- "github.com/gin-gonic/gin"
- )
- // GeminiFileAuth is a dedicated authentication middleware for Gemini File API
- // This is completely isolated from other authentication logic
- func GeminiFileAuth() func(c *gin.Context) {
- return func(c *gin.Context) {
- // Extract API key from multiple sources
- apiKey := extractGeminiFileAPIKey(c)
- if apiKey == "" {
- c.JSON(http.StatusUnauthorized, gin.H{
- "error": gin.H{
- "message": "API key is required for Gemini File API",
- "type": "authentication_error",
- "code": "missing_api_key",
- },
- })
- c.Abort()
- return
- }
- // Validate token
- key := strings.TrimPrefix(apiKey, "sk-")
- parts := strings.Split(key, "-")
- key = parts[0]
- token, err := model.ValidateUserToken(key)
- if err != nil {
- c.JSON(http.StatusUnauthorized, gin.H{
- "error": gin.H{
- "message": fmt.Sprintf("Invalid API key: %s", err.Error()),
- "type": "authentication_error",
- "code": "invalid_api_key",
- },
- })
- c.Abort()
- return
- }
- // Check user status
- userCache, err := model.GetUserCache(token.UserId)
- if err != nil {
- c.JSON(http.StatusInternalServerError, gin.H{
- "error": gin.H{
- "message": fmt.Sprintf("Failed to get user info: %s", err.Error()),
- "type": "internal_error",
- "code": "user_lookup_failed",
- },
- })
- c.Abort()
- return
- }
- if userCache.Status != common.UserStatusEnabled {
- c.JSON(http.StatusForbidden, gin.H{
- "error": gin.H{
- "message": "User account is disabled",
- "type": "authentication_error",
- "code": "account_disabled",
- },
- })
- c.Abort()
- return
- }
- // Set user context
- userCache.WriteContext(c)
- // Get user group
- userGroup := userCache.Group
- tokenGroup := token.Group
- if tokenGroup != "" {
- // Check if user has access to this group
- if _, ok := service.GetUserUsableGroups(userGroup)[tokenGroup]; !ok {
- c.JSON(http.StatusForbidden, gin.H{
- "error": gin.H{
- "message": fmt.Sprintf("No access to group: %s", tokenGroup),
- "type": "authorization_error",
- "code": "group_access_denied",
- },
- })
- c.Abort()
- return
- }
- userGroup = tokenGroup
- }
- common.SetContextKey(c, constant.ContextKeyUsingGroup, userGroup)
- // Find an available Gemini channel for file operations
- channel, err := findGeminiFileChannel(c, userGroup)
- if err != nil {
- c.JSON(http.StatusServiceUnavailable, gin.H{
- "error": gin.H{
- "message": fmt.Sprintf("No available Gemini channel: %s", err.Error()),
- "type": "service_unavailable_error",
- "code": "no_available_channel",
- },
- })
- c.Abort()
- return
- }
- // Setup channel context
- newAPIError := SetupContextForSelectedChannel(c, channel, "gemini-2.0-flash")
- if newAPIError != nil {
- c.JSON(http.StatusServiceUnavailable, gin.H{
- "error": gin.H{
- "message": fmt.Sprintf("Failed to setup channel: %s", newAPIError.Error()),
- "type": "service_unavailable_error",
- "code": "channel_setup_failed",
- },
- })
- c.Abort()
- return
- }
- // Set token context for quota tracking
- c.Set("id", token.UserId)
- c.Set("token_id", token.Id)
- c.Set("token_key", token.Key)
- c.Set("token_name", token.Name)
- c.Set("token_unlimited_quota", token.UnlimitedQuota)
- if !token.UnlimitedQuota {
- c.Set("token_quota", token.RemainQuota)
- }
- c.Next()
- }
- }
- // extractGeminiFileAPIKey extracts API key from various sources
- func extractGeminiFileAPIKey(c *gin.Context) string {
- // 1. Check Authorization header
- auth := c.GetHeader("Authorization")
- if auth != "" {
- if strings.HasPrefix(auth, "Bearer ") || strings.HasPrefix(auth, "bearer ") {
- return strings.TrimSpace(auth[7:])
- }
- }
- // 2. Check x-goog-api-key header (Gemini-specific)
- if key := c.GetHeader("x-goog-api-key"); key != "" {
- return key
- }
- // 3. Check x-api-key header (Claude-style)
- if key := c.GetHeader("x-api-key"); key != "" {
- return key
- }
- // 4. Check query parameter
- if key := c.Query("key"); key != "" {
- return key
- }
- return ""
- }
- // findGeminiFileChannel finds an available Gemini channel for file operations
- func findGeminiFileChannel(c *gin.Context, userGroup string) (*model.Channel, error) {
- // Try multiple common Gemini models to find an available channel
- geminiModels := []string{
- "gemini-2.0-flash",
- "gemini-1.5-flash",
- "gemini-1.5-pro",
- "gemini-2.0-flash-exp",
- "gemini-pro",
- "gemini-1.0-pro",
- }
- var lastError error
- for _, modelName := range geminiModels {
- channel, _, err := service.CacheGetRandomSatisfiedChannel(&service.RetryParam{
- Ctx: c,
- ModelName: modelName,
- TokenGroup: userGroup,
- Retry: common.GetPointer(0),
- })
- if err == nil && channel != nil {
- logger.LogDebug(c, fmt.Sprintf("Found Gemini channel for file operations using model: %s", modelName))
- return channel, nil
- }
- lastError = err
- }
- if lastError != nil {
- return nil, fmt.Errorf("failed to find Gemini channel: %w", lastError)
- }
- return nil, fmt.Errorf("no available Gemini channel found")
- }
|