Procházet zdrojové kódy

refactor(relay): improve channel locking and retry logic in RelayTask

- Enhanced the RelayTask function to utilize a locked channel when available, allowing for better reuse during retries.
- Updated error handling to ensure proper context setup for the selected channel.
- Clarified comments in ResolveOriginTask regarding channel locking and retry behavior.
- Introduced a new field in TaskRelayInfo to store the locked channel object, improving type safety and reducing import cycles.
CaIon před 1 týdnem
rodič
revize
cda540180b
3 změnil soubory, kde provedl 36 přidání a 18 odebrání
  1. 18 5
      controller/relay.go
  2. 5 0
      relay/common/relay_info.go
  3. 13 13
      relay/relay_task.go

+ 18 - 5
controller/relay.go

@@ -497,11 +497,24 @@ func RelayTask(c *gin.Context) {
 	}
 
 	for ; retryParam.GetRetry() <= common.RetryTimes; retryParam.IncreaseRetry() {
-		channel, channelErr := getChannel(c, relayInfo, retryParam)
-		if channelErr != nil {
-			logger.LogError(c, channelErr.Error())
-			taskErr = service.TaskErrorWrapperLocal(channelErr.Err, "get_channel_failed", http.StatusInternalServerError)
-			break
+		var channel *model.Channel
+
+		if lockedCh, ok := relayInfo.LockedChannel.(*model.Channel); ok && lockedCh != nil {
+			channel = lockedCh
+			if retryParam.GetRetry() > 0 {
+				if setupErr := middleware.SetupContextForSelectedChannel(c, channel, relayInfo.OriginModelName); setupErr != nil {
+					taskErr = service.TaskErrorWrapperLocal(setupErr.Err, "setup_locked_channel_failed", http.StatusInternalServerError)
+					break
+				}
+			}
+		} else {
+			var channelErr *types.NewAPIError
+			channel, channelErr = getChannel(c, relayInfo, retryParam)
+			if channelErr != nil {
+				logger.LogError(c, channelErr.Error())
+				taskErr = service.TaskErrorWrapperLocal(channelErr.Err, "get_channel_failed", http.StatusInternalServerError)
+				break
+			}
 		}
 
 		addUsedChannel(c, channel.Id)

+ 5 - 0
relay/common/relay_info.go

@@ -619,6 +619,11 @@ type TaskRelayInfo struct {
 	PublicTaskID string
 
 	ConsumeQuota bool
+
+	// LockedChannel holds the full channel object when the request is bound to
+	// a specific channel (e.g., remix on origin task's channel). Stored as any
+	// to avoid an import cycle with model; callers type-assert to *model.Channel.
+	LockedChannel any
 }
 
 type TaskSubmitReq struct {

+ 13 - 13
relay/relay_task.go

@@ -32,8 +32,9 @@ type TaskSubmitResult struct {
 }
 
 // ResolveOriginTask 处理基于已有任务的提交(remix / continuation):
-// 查找原始任务、从中提取模型名称、将渠道锁定到原始任务的渠道(并通过
-// specific_channel_id 禁止重试),以及提取 OtherRatios(时长、分辨率)。
+// 查找原始任务、从中提取模型名称、将渠道锁定到原始任务的渠道
+// (通过 info.LockedChannel,重试时复用同一渠道并轮换 key),
+// 以及提取 OtherRatios(时长、分辨率)。
 // 该函数在控制器的重试循环之前调用一次,其结果通过 info 字段和上下文持久化。
 func ResolveOriginTask(c *gin.Context, info *relaycommon.RelayInfo) *dto.TaskError {
 	// 检测 remix action
@@ -77,15 +78,17 @@ func ResolveOriginTask(c *gin.Context, info *relaycommon.RelayInfo) *dto.TaskErr
 		}
 	}
 
-	// 锁定到原始任务的渠道(如果与当前选中的不同)
+	// 锁定到原始任务的渠道(重试时复用同一渠道,轮换 key)
+	ch, err := model.GetChannelById(originTask.ChannelId, true)
+	if err != nil {
+		return service.TaskErrorWrapperLocal(err, "channel_not_found", http.StatusBadRequest)
+	}
+	if ch.Status != common.ChannelStatusEnabled {
+		return service.TaskErrorWrapperLocal(errors.New("the channel of the origin task is disabled"), "task_channel_disable", http.StatusBadRequest)
+	}
+	info.LockedChannel = ch
+
 	if originTask.ChannelId != info.ChannelId {
-		ch, err := model.GetChannelById(originTask.ChannelId, true)
-		if err != nil {
-			return service.TaskErrorWrapperLocal(err, "channel_not_found", http.StatusBadRequest)
-		}
-		if ch.Status != common.ChannelStatusEnabled {
-			return service.TaskErrorWrapperLocal(errors.New("the channel of the origin task is disabled"), "task_channel_disable", http.StatusBadRequest)
-		}
 		key, _, newAPIError := ch.GetNextEnabledKey()
 		if newAPIError != nil {
 			return service.TaskErrorWrapper(newAPIError, "channel_no_available_key", newAPIError.StatusCode)
@@ -101,9 +104,6 @@ func ResolveOriginTask(c *gin.Context, info *relaycommon.RelayInfo) *dto.TaskErr
 		info.ApiKey = key
 	}
 
-	// 渠道已锁定到原始任务 → 禁止重试切换到其他渠道
-	c.Set("specific_channel_id", fmt.Sprintf("%d", originTask.ChannelId))
-
 	// 提取 remix 参数(时长、分辨率 → OtherRatios)
 	if info.Action == constant.TaskActionRemix {
 		if originTask.PrivateData.BillingContext != nil {