Przeglądaj źródła

fix(gemini): fetch model list via native v1beta/models endpoint

Use the native Gemini Models API (/v1beta/models) instead of the OpenAI-compatible
path when listing models for Gemini channels, improving compatibility with
third-party Gemini-format providers that don't implement OpenAI routes.

- Add paginated model listing with timeout and optional proxy support
- Select an enabled key for multi-key Gemini channels
RedwindA 1 miesiąc temu
rodzic
commit
c2464fc877
2 zmienionych plików z 128 dodań i 3 usunięć
  1. 47 3
      controller/channel.go
  2. 81 0
      relay/channel/gemini/relay-gemini.go

+ 47 - 3
controller/channel.go

@@ -11,6 +11,7 @@ import (
 	"github.com/QuantumNous/new-api/constant"
 	"github.com/QuantumNous/new-api/dto"
 	"github.com/QuantumNous/new-api/model"
+	"github.com/QuantumNous/new-api/relay/channel/gemini"
 	"github.com/QuantumNous/new-api/relay/channel/ollama"
 	"github.com/QuantumNous/new-api/service"
 
@@ -260,11 +261,37 @@ func FetchUpstreamModels(c *gin.Context) {
 		return
 	}
 
+	// 对于 Gemini 渠道,使用特殊处理
+	if channel.Type == constant.ChannelTypeGemini {
+		// 获取用于请求的可用密钥(多密钥渠道优先使用启用状态的密钥)
+		key, _, apiErr := channel.GetNextEnabledKey()
+		if apiErr != nil {
+			c.JSON(http.StatusOK, gin.H{
+				"success": false,
+				"message": fmt.Sprintf("获取渠道密钥失败: %s", apiErr.Error()),
+			})
+			return
+		}
+		key = strings.TrimSpace(key)
+		models, err := gemini.FetchGeminiModels(baseURL, key, channel.GetSetting().Proxy)
+		if err != nil {
+			c.JSON(http.StatusOK, gin.H{
+				"success": false,
+				"message": fmt.Sprintf("获取Gemini模型失败: %s", err.Error()),
+			})
+			return
+		}
+
+		c.JSON(http.StatusOK, gin.H{
+			"success": true,
+			"message": "",
+			"data":    models,
+		})
+		return
+	}
+
 	var url string
 	switch channel.Type {
-	case constant.ChannelTypeGemini:
-		// curl https://example.com/v1beta/models?key=$GEMINI_API_KEY
-		url = fmt.Sprintf("%s/v1beta/openai/models", baseURL) // Remove key in url since we need to use AuthHeader
 	case constant.ChannelTypeAli:
 		url = fmt.Sprintf("%s/compatible-mode/v1/models", baseURL)
 	case constant.ChannelTypeZhipu_v4:
@@ -1072,6 +1099,23 @@ func FetchModels(c *gin.Context) {
 		return
 	}
 
+	if req.Type == constant.ChannelTypeGemini {
+		models, err := gemini.FetchGeminiModels(baseURL, key, "")
+		if err != nil {
+			c.JSON(http.StatusOK, gin.H{
+				"success": false,
+				"message": fmt.Sprintf("获取Gemini模型失败: %s", err.Error()),
+			})
+			return
+		}
+
+		c.JSON(http.StatusOK, gin.H{
+			"success": true,
+			"data":    models,
+		})
+		return
+	}
+
 	client := &http.Client{}
 	url := fmt.Sprintf("%s/v1/models", baseURL)
 

+ 81 - 0
relay/channel/gemini/relay-gemini.go

@@ -1,6 +1,7 @@
 package gemini
 
 import (
+	"context"
 	"encoding/json"
 	"errors"
 	"fmt"
@@ -8,6 +9,7 @@ import (
 	"net/http"
 	"strconv"
 	"strings"
+	"time"
 	"unicode/utf8"
 
 	"github.com/QuantumNous/new-api/common"
@@ -1363,3 +1365,82 @@ func GeminiImageHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.
 
 	return usage, nil
 }
+
+type GeminiModelInfo struct {
+	Name                       string   `json:"name"`
+	Version                    string   `json:"version"`
+	DisplayName                string   `json:"displayName"`
+	Description                string   `json:"description"`
+	InputTokenLimit            int      `json:"inputTokenLimit"`
+	OutputTokenLimit           int      `json:"outputTokenLimit"`
+	SupportedGenerationMethods []string `json:"supportedGenerationMethods"`
+}
+
+type GeminiModelsResponse struct {
+	Models        []GeminiModelInfo `json:"models"`
+	NextPageToken string            `json:"nextPageToken"`
+}
+
+func FetchGeminiModels(baseURL, apiKey, proxyURL string) ([]string, error) {
+	client, err := service.GetHttpClientWithProxy(proxyURL)
+	if err != nil {
+		return nil, fmt.Errorf("创建HTTP客户端失败: %v", err)
+	}
+
+	allModels := make([]string, 0)
+	nextPageToken := ""
+	maxPages := 100 // Safety limit to prevent infinite loops
+
+	for page := 0; page < maxPages; page++ {
+		url := fmt.Sprintf("%s/v1beta/models", baseURL)
+		if nextPageToken != "" {
+			url = fmt.Sprintf("%s?pageToken=%s", url, nextPageToken)
+		}
+
+		ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
+		request, err := http.NewRequestWithContext(ctx, "GET", url, nil)
+		if err != nil {
+			cancel()
+			return nil, fmt.Errorf("创建请求失败: %v", err)
+		}
+
+		request.Header.Set("x-goog-api-key", apiKey)
+
+		response, err := client.Do(request)
+		if err != nil {
+			cancel()
+			return nil, fmt.Errorf("请求失败: %v", err)
+		}
+
+		if response.StatusCode != http.StatusOK {
+			body, _ := io.ReadAll(response.Body)
+			response.Body.Close()
+			cancel()
+			return nil, fmt.Errorf("服务器返回错误 %d: %s", response.StatusCode, string(body))
+		}
+
+		body, err := io.ReadAll(response.Body)
+		response.Body.Close()
+		cancel()
+		if err != nil {
+			return nil, fmt.Errorf("读取响应失败: %v", err)
+		}
+
+		var modelsResponse GeminiModelsResponse
+		if err = common.Unmarshal(body, &modelsResponse); err != nil {
+			return nil, fmt.Errorf("解析响应失败: %v", err)
+		}
+
+		for _, model := range modelsResponse.Models {
+			modelName := strings.TrimPrefix(model.Name, "models/")
+			allModels = append(allModels, modelName)
+		}
+
+		nextPageToken = modelsResponse.NextPageToken
+		if nextPageToken == "" {
+			break
+		}
+	}
+
+	return allModels, nil
+}