gemini_file_channel.go 2.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172
  1. package middleware
  2. import (
  3. "fmt"
  4. "net/http"
  5. "github.com/QuantumNous/new-api/common"
  6. "github.com/QuantumNous/new-api/constant"
  7. "github.com/QuantumNous/new-api/model"
  8. "github.com/QuantumNous/new-api/service"
  9. "github.com/gin-gonic/gin"
  10. )
  11. // SetupGeminiFileChannel selects a Gemini channel for File API operations
  12. // This middleware is used instead of Distribute() for File API endpoints
  13. // since they don't require model-based channel selection
  14. func SetupGeminiFileChannel() func(c *gin.Context) {
  15. return func(c *gin.Context) {
  16. // Get user's group
  17. usingGroup := common.GetContextKeyString(c, constant.ContextKeyUsingGroup)
  18. if usingGroup == "" {
  19. usingGroup = common.GetContextKeyString(c, constant.ContextKeyUserGroup)
  20. }
  21. // Try multiple common Gemini models to find an available channel
  22. // The actual File API doesn't require a model, but we need one to select a channel
  23. geminiModels := []string{
  24. "gemini-2.0-flash",
  25. "gemini-1.5-flash",
  26. "gemini-1.5-pro",
  27. "gemini-2.0-flash-exp",
  28. "gemini-pro",
  29. "gemini-1.0-pro",
  30. }
  31. var channel *model.Channel
  32. var err error
  33. var lastError error
  34. // Try each model until we find an available channel
  35. for _, modelName := range geminiModels {
  36. channel, _, err = service.CacheGetRandomSatisfiedChannel(&service.RetryParam{
  37. Ctx: c,
  38. ModelName: modelName,
  39. TokenGroup: usingGroup,
  40. Retry: common.GetPointer(0),
  41. })
  42. if err == nil && channel != nil {
  43. // Found a channel, setup context and continue
  44. newAPIError := SetupContextForSelectedChannel(c, channel, modelName)
  45. if newAPIError != nil {
  46. abortWithOpenAiMessage(c, http.StatusServiceUnavailable,
  47. fmt.Sprintf("设置 Gemini 渠道失败: %s", newAPIError.Error()))
  48. return
  49. }
  50. c.Next()
  51. return
  52. }
  53. lastError = err
  54. }
  55. // No channel found with any of the models
  56. errorMsg := "没有可用的 Gemini 文件 API 渠道"
  57. if lastError != nil {
  58. errorMsg = fmt.Sprintf("获取 Gemini 文件 API 渠道失败: %s", lastError.Error())
  59. }
  60. abortWithOpenAiMessage(c, http.StatusServiceUnavailable, errorMsg)
  61. }
  62. }