Jelajahi Sumber

refactor(channel_select): enhance retry logic and context key usage for channel selection

CaIon 2 bulan lalu
induk
melakukan
c51936e068

+ 4 - 1
constant/context_key.go

@@ -21,7 +21,6 @@ const (
 	ContextKeyTokenCrossGroupRetry   ContextKey = "token_cross_group_retry"
 
 	/* channel related keys */
-	ContextKeyAutoGroupIndex           ContextKey = "auto_group_index"
 	ContextKeyChannelId                ContextKey = "channel_id"
 	ContextKeyChannelName              ContextKey = "channel_name"
 	ContextKeyChannelCreateTime        ContextKey = "channel_create_time"
@@ -39,6 +38,10 @@ const (
 	ContextKeyChannelMultiKeyIndex     ContextKey = "channel_multi_key_index"
 	ContextKeyChannelKey               ContextKey = "channel_key"
 
+	ContextKeyAutoGroup           ContextKey = "auto_group"
+	ContextKeyAutoGroupIndex      ContextKey = "auto_group_index"
+	ContextKeyAutoGroupRetryIndex ContextKey = "auto_group_retry_index"
+
 	/* user related keys */
 	ContextKeyUserId      ContextKey = "id"
 	ContextKeyUserSetting ContextKey = "user_setting"

+ 0 - 9
controller/playground.go

@@ -3,10 +3,7 @@ package controller
 import (
 	"errors"
 	"fmt"
-	"time"
 
-	"github.com/QuantumNous/new-api/common"
-	"github.com/QuantumNous/new-api/constant"
 	"github.com/QuantumNous/new-api/middleware"
 	"github.com/QuantumNous/new-api/model"
 	relaycommon "github.com/QuantumNous/new-api/relay/common"
@@ -54,12 +51,6 @@ func Playground(c *gin.Context) {
 		Group:  relayInfo.UsingGroup,
 	}
 	_ = middleware.SetupContextForToken(c, tempToken)
-	_, newAPIError = getChannel(c, relayInfo, 0)
-	if newAPIError != nil {
-		return
-	}
-	//middleware.SetupContextForSelectedChannel(c, channel, playgroundRequest.Model)
-	common.SetContextKey(c, constant.ContextKeyRequestStartTime, time.Now())
 
 	Relay(c, types.RelayFormatOpenAI)
 }

+ 24 - 11
controller/relay.go

@@ -157,8 +157,15 @@ func Relay(c *gin.Context, relayFormat types.RelayFormat) {
 		}
 	}()
 
-	for i := 0; i <= common.RetryTimes; i++ {
-		channel, err := getChannel(c, relayInfo, i)
+	retryParam := &service.RetryParam{
+		Ctx:        c,
+		TokenGroup: relayInfo.TokenGroup,
+		ModelName:  relayInfo.OriginModelName,
+		Retry:      common.GetPointer(0),
+	}
+
+	for ; retryParam.GetRetry() <= common.RetryTimes; retryParam.IncreaseRetry() {
+		channel, err := getChannel(c, relayInfo, retryParam)
 		if err != nil {
 			logger.LogError(c, err.Error())
 			newAPIError = err
@@ -186,7 +193,7 @@ func Relay(c *gin.Context, relayFormat types.RelayFormat) {
 
 		processChannelError(c, *types.NewChannelError(channel.Id, channel.Type, channel.Name, channel.ChannelInfo.IsMultiKey, common.GetContextKeyString(c, constant.ContextKeyChannelKey), channel.GetAutoBan()), newAPIError)
 
-		if !shouldRetry(c, newAPIError, common.RetryTimes-i) {
+		if !shouldRetry(c, newAPIError, common.RetryTimes-retryParam.GetRetry()) {
 			break
 		}
 	}
@@ -211,8 +218,8 @@ func addUsedChannel(c *gin.Context, channelId int) {
 	c.Set("use_channel", useChannel)
 }
 
-func getChannel(c *gin.Context, info *relaycommon.RelayInfo, retryCount int) (*model.Channel, *types.NewAPIError) {
-	if retryCount == 0 {
+func getChannel(c *gin.Context, info *relaycommon.RelayInfo, retryParam *service.RetryParam) (*model.Channel, *types.NewAPIError) {
+	if info.ChannelMeta == nil {
 		autoBan := c.GetBool("auto_ban")
 		autoBanInt := 1
 		if !autoBan {
@@ -225,7 +232,7 @@ func getChannel(c *gin.Context, info *relaycommon.RelayInfo, retryCount int) (*m
 			AutoBan: &autoBanInt,
 		}, nil
 	}
-	channel, selectGroup, err := service.CacheGetRandomSatisfiedChannel(c, info.TokenGroup, info.OriginModelName, retryCount)
+	channel, selectGroup, err := service.CacheGetRandomSatisfiedChannel(retryParam)
 
 	info.PriceData.GroupRatioInfo = helper.HandleGroupRatio(c, info)
 
@@ -370,7 +377,7 @@ func RelayMidjourney(c *gin.Context) {
 }
 
 func RelayNotImplemented(c *gin.Context) {
-	err := dto.OpenAIError{
+	err := types.OpenAIError{
 		Message: "API not implemented",
 		Type:    "new_api_error",
 		Param:   "",
@@ -382,7 +389,7 @@ func RelayNotImplemented(c *gin.Context) {
 }
 
 func RelayNotFound(c *gin.Context) {
-	err := dto.OpenAIError{
+	err := types.OpenAIError{
 		Message: fmt.Sprintf("Invalid URL (%s %s)", c.Request.Method, c.Request.URL.Path),
 		Type:    "invalid_request_error",
 		Param:   "",
@@ -405,8 +412,14 @@ func RelayTask(c *gin.Context) {
 	if taskErr == nil {
 		retryTimes = 0
 	}
-	for i := 0; shouldRetryTaskRelay(c, channelId, taskErr, retryTimes) && i < retryTimes; i++ {
-		channel, newAPIError := getChannel(c, relayInfo, i)
+	retryParam := &service.RetryParam{
+		Ctx:        c,
+		TokenGroup: relayInfo.TokenGroup,
+		ModelName:  relayInfo.OriginModelName,
+		Retry:      common.GetPointer(0),
+	}
+	for ; shouldRetryTaskRelay(c, channelId, taskErr, retryTimes) && retryParam.GetRetry() < retryTimes; retryParam.IncreaseRetry() {
+		channel, newAPIError := getChannel(c, relayInfo, retryParam)
 		if newAPIError != nil {
 			logger.LogError(c, fmt.Sprintf("CacheGetRandomSatisfiedChannel failed: %s", newAPIError.Error()))
 			taskErr = service.TaskErrorWrapperLocal(newAPIError.Err, "get_channel_failed", http.StatusInternalServerError)
@@ -416,7 +429,7 @@ func RelayTask(c *gin.Context) {
 		useChannel := c.GetStringSlice("use_channel")
 		useChannel = append(useChannel, fmt.Sprintf("%d", channelId))
 		c.Set("use_channel", useChannel)
-		logger.LogInfo(c, fmt.Sprintf("using channel #%d to retry (remain times %d)", channel.Id, i))
+		logger.LogInfo(c, fmt.Sprintf("using channel #%d to retry (remain times %d)", channel.Id, retryParam.GetRetry()))
 		//middleware.SetupContextForSelectedChannel(c, channel, originalModel)
 
 		requestBody, _ := common.GetRequestBody(c)

+ 1 - 1
middleware/auth.go

@@ -308,7 +308,7 @@ func SetupContextForToken(c *gin.Context, token *model.Token, parts ...string) e
 		c.Set("token_model_limit_enabled", false)
 	}
 	common.SetContextKey(c, constant.ContextKeyTokenGroup, token.Group)
-	c.Set("token_cross_group_retry", token.CrossGroupRetry)
+	common.SetContextKey(c, constant.ContextKeyTokenCrossGroupRetry, token.CrossGroupRetry)
 	if len(parts) > 1 {
 		if model.IsAdmin(token.UserId) {
 			c.Set("specific_channel_id", parts[1])

+ 6 - 1
middleware/distributor.go

@@ -97,7 +97,12 @@ func Distribute() func(c *gin.Context) {
 						common.SetContextKey(c, constant.ContextKeyUsingGroup, usingGroup)
 					}
 				}
-				channel, selectGroup, err = service.CacheGetRandomSatisfiedChannel(c, usingGroup, modelRequest.Model, 0)
+				channel, selectGroup, err = service.CacheGetRandomSatisfiedChannel(&service.RetryParam{
+					Ctx:        c,
+					ModelName:  modelRequest.Model,
+					TokenGroup: usingGroup,
+					Retry:      common.GetPointer(0),
+				})
 				if err != nil {
 					showGroup := usingGroup
 					if usingGroup == "auto" {

+ 117 - 26
service/channel_select.go

@@ -11,50 +11,141 @@ import (
 	"github.com/gin-gonic/gin"
 )
 
+type RetryParam struct {
+	Ctx        *gin.Context
+	TokenGroup string
+	ModelName  string
+	Retry      *int
+}
+
+func (p *RetryParam) GetRetry() int {
+	if p.Retry == nil {
+		return 0
+	}
+	return *p.Retry
+}
+
+func (p *RetryParam) SetRetry(retry int) {
+	p.Retry = &retry
+}
+
+func (p *RetryParam) IncreaseRetry() {
+	if p.Retry == nil {
+		p.Retry = new(int)
+	}
+	*p.Retry++
+}
+
 // CacheGetRandomSatisfiedChannel tries to get a random channel that satisfies the requirements.
-func CacheGetRandomSatisfiedChannel(c *gin.Context, tokenGroup string, modelName string, retry int) (*model.Channel, string, error) {
+// 尝试获取一个满足要求的随机渠道。
+//
+// For "auto" tokenGroup with cross-group Retry enabled:
+// 对于启用了跨分组重试的 "auto" tokenGroup:
+//
+//   - Each group will exhaust all its priorities before moving to the next group.
+//     每个分组会用完所有优先级后才会切换到下一个分组。
+//
+//   - Uses ContextKeyAutoGroupIndex to track current group index.
+//     使用 ContextKeyAutoGroupIndex 跟踪当前分组索引。
+//
+//   - Uses ContextKeyAutoGroupRetryIndex to track the global Retry count when current group started.
+//     使用 ContextKeyAutoGroupRetryIndex 跟踪当前分组开始时的全局重试次数。
+//
+//   - priorityRetry = Retry - startRetryIndex, represents the priority level within current group.
+//     priorityRetry = Retry - startRetryIndex,表示当前分组内的优先级级别。
+//
+//   - When GetRandomSatisfiedChannel returns nil (priorities exhausted), moves to next group.
+//     当 GetRandomSatisfiedChannel 返回 nil(优先级用完)时,切换到下一个分组。
+//
+// Example flow (2 groups, each with 2 priorities, RetryTimes=3):
+// 示例流程(2个分组,每个有2个优先级,RetryTimes=3):
+//
+//	Retry=0: GroupA, priority0 (startRetryIndex=0, priorityRetry=0)
+//	         分组A, 优先级0
+//
+//	Retry=1: GroupA, priority1 (startRetryIndex=0, priorityRetry=1)
+//	         分组A, 优先级1
+//
+//	Retry=2: GroupA exhausted → GroupB, priority0 (startRetryIndex=2, priorityRetry=0)
+//	         分组A用完 → 分组B, 优先级0
+//
+//	Retry=3: GroupB, priority1 (startRetryIndex=2, priorityRetry=1)
+//	         分组B, 优先级1
+func CacheGetRandomSatisfiedChannel(param *RetryParam) (*model.Channel, string, error) {
 	var channel *model.Channel
 	var err error
-	selectGroup := tokenGroup
-	userGroup := common.GetContextKeyString(c, constant.ContextKeyUserGroup)
-	if tokenGroup == "auto" {
+	selectGroup := param.TokenGroup
+	userGroup := common.GetContextKeyString(param.Ctx, constant.ContextKeyUserGroup)
+
+	if param.TokenGroup == "auto" {
 		if len(setting.GetAutoGroups()) == 0 {
 			return nil, selectGroup, errors.New("auto groups is not enabled")
 		}
 		autoGroups := GetUserAutoGroup(userGroup)
-		startIndex := 0
-		priorityRetry := retry
-		crossGroupRetry := common.GetContextKeyBool(c, constant.ContextKeyTokenCrossGroupRetry)
-		if crossGroupRetry && retry > 0 {
-			logger.LogDebug(c, "Auto group retry cross group, retry: %d", retry)
-			if lastIndex, exists := common.GetContextKey(c, constant.ContextKeyAutoGroupIndex); exists {
-				if idx, ok := lastIndex.(int); ok {
-					startIndex = idx + 1
-					priorityRetry = 0
-				}
+
+		// startGroupIndex: the group index to start searching from
+		// startGroupIndex: 开始搜索的分组索引
+		startGroupIndex := 0
+		crossGroupRetry := common.GetContextKeyBool(param.Ctx, constant.ContextKeyTokenCrossGroupRetry)
+
+		if lastGroupIndex, exists := common.GetContextKey(param.Ctx, constant.ContextKeyAutoGroupIndex); exists {
+			if idx, ok := lastGroupIndex.(int); ok {
+				startGroupIndex = idx
 			}
-			logger.LogDebug(c, "Auto group retry cross group, start index: %d", startIndex)
 		}
 
-		for i := startIndex; i < len(autoGroups); i++ {
+		for i := startGroupIndex; i < len(autoGroups); i++ {
 			autoGroup := autoGroups[i]
-			logger.LogDebug(c, "Auto selecting group: %s", autoGroup)
-			channel, _ = model.GetRandomSatisfiedChannel(autoGroup, modelName, priorityRetry)
-			if channel == nil {
+			// Calculate priorityRetry for current group
+			// 计算当前分组的 priorityRetry
+			priorityRetry := param.GetRetry()
+			// If moved to a new group, reset priorityRetry and update startRetryIndex
+			// 如果切换到新分组,重置 priorityRetry 并更新 startRetryIndex
+			if i > startGroupIndex {
 				priorityRetry = 0
+			}
+			logger.LogDebug(param.Ctx, "Auto selecting group: %s, priorityRetry: %d", autoGroup, priorityRetry)
+
+			channel, _ = model.GetRandomSatisfiedChannel(autoGroup, param.ModelName, priorityRetry)
+			if channel == nil {
+				// Current group has no available channel for this model, try next group
+				// 当前分组没有该模型的可用渠道,尝试下一个分组
+				logger.LogDebug(param.Ctx, "No available channel in group %s for model %s at priorityRetry %d, trying next group", autoGroup, param.ModelName, priorityRetry)
+				// 重置状态以尝试下一个分组
+				common.SetContextKey(param.Ctx, constant.ContextKeyAutoGroupIndex, i+1)
+				common.SetContextKey(param.Ctx, constant.ContextKeyAutoGroupRetryIndex, 0)
+				// Reset retry counter so outer loop can continue for next group
+				// 重置重试计数器,以便外层循环可以为下一个分组继续
+				param.SetRetry(0)
 				continue
+			}
+			common.SetContextKey(param.Ctx, constant.ContextKeyAutoGroup, autoGroup)
+			selectGroup = autoGroup
+			logger.LogDebug(param.Ctx, "Auto selected group: %s", autoGroup)
+
+			// Prepare state for next retry
+			// 为下一次重试准备状态
+			if crossGroupRetry && priorityRetry >= common.RetryTimes {
+				// Current group has exhausted all retries, prepare to switch to next group
+				// This request still uses current group, but next retry will use next group
+				// 当前分组已用完所有重试次数,准备切换到下一个分组
+				// 本次请求仍使用当前分组,但下次重试将使用下一个分组
+				logger.LogDebug(param.Ctx, "Current group %s retries exhausted (priorityRetry=%d >= RetryTimes=%d), preparing switch to next group for next retry", autoGroup, priorityRetry, common.RetryTimes)
+				common.SetContextKey(param.Ctx, constant.ContextKeyAutoGroupIndex, i+1)
+				// Reset retry counter so outer loop can continue for next group
+				// 重置重试计数器,以便外层循环可以为下一个分组继续
+				param.SetRetry(-1)
 			} else {
-				c.Set("auto_group", autoGroup)
-				common.SetContextKey(c, constant.ContextKeyAutoGroupIndex, i)
-				selectGroup = autoGroup
-				logger.LogDebug(c, "Auto selected group: %s", autoGroup)
-				break
+				// Stay in current group, save current state
+				// 保持在当前分组,保存当前状态
+				common.SetContextKey(param.Ctx, constant.ContextKeyAutoGroupIndex, i)
 			}
+			break
 		}
 	} else {
-		channel, err = model.GetRandomSatisfiedChannel(tokenGroup, modelName, retry)
+		channel, err = model.GetRandomSatisfiedChannel(param.TokenGroup, param.ModelName, param.GetRetry())
 		if err != nil {
-			return nil, tokenGroup, err
+			return nil, param.TokenGroup, err
 		}
 	}
 	return channel, selectGroup, nil

+ 1 - 1
service/quota.go

@@ -108,7 +108,7 @@ func PreWssConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usag
 	groupRatio := ratio_setting.GetGroupRatio(relayInfo.UsingGroup)
 	modelRatio, _, _ := ratio_setting.GetModelRatio(modelName)
 
-	autoGroup, exists := ctx.Get("auto_group")
+	autoGroup, exists := common.GetContextKey(ctx, constant.ContextKeyAutoGroup)
 	if exists {
 		groupRatio = ratio_setting.GetGroupRatio(autoGroup.(string))
 		log.Printf("final group ratio: %f", groupRatio)