|
@@ -0,0 +1,204 @@
|
|
|
|
|
+package middleware
|
|
|
|
|
+
|
|
|
|
|
+import (
|
|
|
|
|
+ "fmt"
|
|
|
|
|
+ "net/http"
|
|
|
|
|
+ "strings"
|
|
|
|
|
+
|
|
|
|
|
+ "github.com/QuantumNous/new-api/common"
|
|
|
|
|
+ "github.com/QuantumNous/new-api/constant"
|
|
|
|
|
+ "github.com/QuantumNous/new-api/logger"
|
|
|
|
|
+ "github.com/QuantumNous/new-api/model"
|
|
|
|
|
+ "github.com/QuantumNous/new-api/service"
|
|
|
|
|
+
|
|
|
|
|
+ "github.com/gin-gonic/gin"
|
|
|
|
|
+)
|
|
|
|
|
+
|
|
|
|
|
+// GeminiFileAuth is a dedicated authentication middleware for Gemini File API
|
|
|
|
|
+// This is completely isolated from other authentication logic
|
|
|
|
|
+func GeminiFileAuth() func(c *gin.Context) {
|
|
|
|
|
+ return func(c *gin.Context) {
|
|
|
|
|
+ // Extract API key from multiple sources
|
|
|
|
|
+ apiKey := extractGeminiFileAPIKey(c)
|
|
|
|
|
+ if apiKey == "" {
|
|
|
|
|
+ c.JSON(http.StatusUnauthorized, gin.H{
|
|
|
|
|
+ "error": gin.H{
|
|
|
|
|
+ "message": "API key is required for Gemini File API",
|
|
|
|
|
+ "type": "authentication_error",
|
|
|
|
|
+ "code": "missing_api_key",
|
|
|
|
|
+ },
|
|
|
|
|
+ })
|
|
|
|
|
+ c.Abort()
|
|
|
|
|
+ return
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ // Validate token
|
|
|
|
|
+ key := strings.TrimPrefix(apiKey, "sk-")
|
|
|
|
|
+ parts := strings.Split(key, "-")
|
|
|
|
|
+ key = parts[0]
|
|
|
|
|
+
|
|
|
|
|
+ token, err := model.ValidateUserToken(key)
|
|
|
|
|
+ if err != nil {
|
|
|
|
|
+ c.JSON(http.StatusUnauthorized, gin.H{
|
|
|
|
|
+ "error": gin.H{
|
|
|
|
|
+ "message": fmt.Sprintf("Invalid API key: %s", err.Error()),
|
|
|
|
|
+ "type": "authentication_error",
|
|
|
|
|
+ "code": "invalid_api_key",
|
|
|
|
|
+ },
|
|
|
|
|
+ })
|
|
|
|
|
+ c.Abort()
|
|
|
|
|
+ return
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ // Check user status
|
|
|
|
|
+ userCache, err := model.GetUserCache(token.UserId)
|
|
|
|
|
+ if err != nil {
|
|
|
|
|
+ c.JSON(http.StatusInternalServerError, gin.H{
|
|
|
|
|
+ "error": gin.H{
|
|
|
|
|
+ "message": fmt.Sprintf("Failed to get user info: %s", err.Error()),
|
|
|
|
|
+ "type": "internal_error",
|
|
|
|
|
+ "code": "user_lookup_failed",
|
|
|
|
|
+ },
|
|
|
|
|
+ })
|
|
|
|
|
+ c.Abort()
|
|
|
|
|
+ return
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ if userCache.Status != common.UserStatusEnabled {
|
|
|
|
|
+ c.JSON(http.StatusForbidden, gin.H{
|
|
|
|
|
+ "error": gin.H{
|
|
|
|
|
+ "message": "User account is disabled",
|
|
|
|
|
+ "type": "authentication_error",
|
|
|
|
|
+ "code": "account_disabled",
|
|
|
|
|
+ },
|
|
|
|
|
+ })
|
|
|
|
|
+ c.Abort()
|
|
|
|
|
+ return
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ // Set user context
|
|
|
|
|
+ userCache.WriteContext(c)
|
|
|
|
|
+
|
|
|
|
|
+ // Get user group
|
|
|
|
|
+ userGroup := userCache.Group
|
|
|
|
|
+ tokenGroup := token.Group
|
|
|
|
|
+ if tokenGroup != "" {
|
|
|
|
|
+ // Check if user has access to this group
|
|
|
|
|
+ if _, ok := service.GetUserUsableGroups(userGroup)[tokenGroup]; !ok {
|
|
|
|
|
+ c.JSON(http.StatusForbidden, gin.H{
|
|
|
|
|
+ "error": gin.H{
|
|
|
|
|
+ "message": fmt.Sprintf("No access to group: %s", tokenGroup),
|
|
|
|
|
+ "type": "authorization_error",
|
|
|
|
|
+ "code": "group_access_denied",
|
|
|
|
|
+ },
|
|
|
|
|
+ })
|
|
|
|
|
+ c.Abort()
|
|
|
|
|
+ return
|
|
|
|
|
+ }
|
|
|
|
|
+ userGroup = tokenGroup
|
|
|
|
|
+ }
|
|
|
|
|
+ common.SetContextKey(c, constant.ContextKeyUsingGroup, userGroup)
|
|
|
|
|
+
|
|
|
|
|
+ // Find an available Gemini channel for file operations
|
|
|
|
|
+ channel, err := findGeminiFileChannel(c, userGroup)
|
|
|
|
|
+ if err != nil {
|
|
|
|
|
+ c.JSON(http.StatusServiceUnavailable, gin.H{
|
|
|
|
|
+ "error": gin.H{
|
|
|
|
|
+ "message": fmt.Sprintf("No available Gemini channel: %s", err.Error()),
|
|
|
|
|
+ "type": "service_unavailable_error",
|
|
|
|
|
+ "code": "no_available_channel",
|
|
|
|
|
+ },
|
|
|
|
|
+ })
|
|
|
|
|
+ c.Abort()
|
|
|
|
|
+ return
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ // Setup channel context
|
|
|
|
|
+ newAPIError := SetupContextForSelectedChannel(c, channel, "gemini-2.0-flash")
|
|
|
|
|
+ if newAPIError != nil {
|
|
|
|
|
+ c.JSON(http.StatusServiceUnavailable, gin.H{
|
|
|
|
|
+ "error": gin.H{
|
|
|
|
|
+ "message": fmt.Sprintf("Failed to setup channel: %s", newAPIError.Error()),
|
|
|
|
|
+ "type": "service_unavailable_error",
|
|
|
|
|
+ "code": "channel_setup_failed",
|
|
|
|
|
+ },
|
|
|
|
|
+ })
|
|
|
|
|
+ c.Abort()
|
|
|
|
|
+ return
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ // Set token context for quota tracking
|
|
|
|
|
+ c.Set("id", token.UserId)
|
|
|
|
|
+ c.Set("token_id", token.Id)
|
|
|
|
|
+ c.Set("token_key", token.Key)
|
|
|
|
|
+ c.Set("token_name", token.Name)
|
|
|
|
|
+ c.Set("token_unlimited_quota", token.UnlimitedQuota)
|
|
|
|
|
+ if !token.UnlimitedQuota {
|
|
|
|
|
+ c.Set("token_quota", token.RemainQuota)
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ c.Next()
|
|
|
|
|
+ }
|
|
|
|
|
+}
|
|
|
|
|
+
|
|
|
|
|
+// extractGeminiFileAPIKey extracts API key from various sources
|
|
|
|
|
+func extractGeminiFileAPIKey(c *gin.Context) string {
|
|
|
|
|
+ // 1. Check Authorization header
|
|
|
|
|
+ auth := c.GetHeader("Authorization")
|
|
|
|
|
+ if auth != "" {
|
|
|
|
|
+ if strings.HasPrefix(auth, "Bearer ") || strings.HasPrefix(auth, "bearer ") {
|
|
|
|
|
+ return strings.TrimSpace(auth[7:])
|
|
|
|
|
+ }
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ // 2. Check x-goog-api-key header (Gemini-specific)
|
|
|
|
|
+ if key := c.GetHeader("x-goog-api-key"); key != "" {
|
|
|
|
|
+ return key
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ // 3. Check x-api-key header (Claude-style)
|
|
|
|
|
+ if key := c.GetHeader("x-api-key"); key != "" {
|
|
|
|
|
+ return key
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ // 4. Check query parameter
|
|
|
|
|
+ if key := c.Query("key"); key != "" {
|
|
|
|
|
+ return key
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ return ""
|
|
|
|
|
+}
|
|
|
|
|
+
|
|
|
|
|
+// findGeminiFileChannel finds an available Gemini channel for file operations
|
|
|
|
|
+func findGeminiFileChannel(c *gin.Context, userGroup string) (*model.Channel, error) {
|
|
|
|
|
+ // Try multiple common Gemini models to find an available channel
|
|
|
|
|
+ geminiModels := []string{
|
|
|
|
|
+ "gemini-2.0-flash",
|
|
|
|
|
+ "gemini-1.5-flash",
|
|
|
|
|
+ "gemini-1.5-pro",
|
|
|
|
|
+ "gemini-2.0-flash-exp",
|
|
|
|
|
+ "gemini-pro",
|
|
|
|
|
+ "gemini-1.0-pro",
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ var lastError error
|
|
|
|
|
+ for _, modelName := range geminiModels {
|
|
|
|
|
+ channel, _, err := service.CacheGetRandomSatisfiedChannel(&service.RetryParam{
|
|
|
|
|
+ Ctx: c,
|
|
|
|
|
+ ModelName: modelName,
|
|
|
|
|
+ TokenGroup: userGroup,
|
|
|
|
|
+ Retry: common.GetPointer(0),
|
|
|
|
|
+ })
|
|
|
|
|
+
|
|
|
|
|
+ if err == nil && channel != nil {
|
|
|
|
|
+ logger.LogDebug(c, fmt.Sprintf("Found Gemini channel for file operations using model: %s", modelName))
|
|
|
|
|
+ return channel, nil
|
|
|
|
|
+ }
|
|
|
|
|
+ lastError = err
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ if lastError != nil {
|
|
|
|
|
+ return nil, fmt.Errorf("failed to find Gemini channel: %w", lastError)
|
|
|
|
|
+ }
|
|
|
|
|
+ return nil, fmt.Errorf("no available Gemini channel found")
|
|
|
|
|
+}
|