|
|
@@ -451,8 +451,6 @@ func RelayNotFound(c *gin.Context) {
|
|
|
}
|
|
|
|
|
|
func RelayTask(c *gin.Context) {
|
|
|
- channelId := c.GetInt("channel_id")
|
|
|
- c.Set("use_channel", []string{fmt.Sprintf("%d", channelId)})
|
|
|
relayInfo, err := relaycommon.GenRelayInfo(c, types.RelayFormatTask, nil, nil)
|
|
|
if err != nil {
|
|
|
c.JSON(http.StatusInternalServerError, &dto.TaskError{
|
|
|
@@ -463,8 +461,7 @@ func RelayTask(c *gin.Context) {
|
|
|
return
|
|
|
}
|
|
|
|
|
|
- // Fetch 操作是纯 DB 查询(或 task 自带 channelId 的上游查询),不依赖上下文 channel,无需重试
|
|
|
- // TODO: 在video-route层面优化,避免无谓的 channel 选择和上下文设置,也没必要吧代码放到这里来写这么多屎山
|
|
|
+ // Fetch 路径:纯 DB 查询,不依赖上下文 channel,无需重试
|
|
|
switch relayInfo.RelayMode {
|
|
|
case relayconstant.RelayModeSunoFetch, relayconstant.RelayModeSunoFetchByID, relayconstant.RelayModeVideoFetchByID:
|
|
|
if taskErr := relay.RelayTaskFetch(c, relayInfo.RelayMode); taskErr != nil {
|
|
|
@@ -475,13 +472,11 @@ func RelayTask(c *gin.Context) {
|
|
|
|
|
|
// ── Submit 路径 ─────────────────────────────────────────────────
|
|
|
|
|
|
- // 1. 解析原始任务(remix / continuation),一次性,可能锁定渠道并禁止重试
|
|
|
if taskErr := relay.ResolveOriginTask(c, relayInfo); taskErr != nil {
|
|
|
respondTaskError(c, taskErr)
|
|
|
return
|
|
|
}
|
|
|
|
|
|
- // 2. defer Refund(全部失败时回滚预扣费)
|
|
|
var result *relay.TaskSubmitResult
|
|
|
var taskErr *dto.TaskError
|
|
|
defer func() {
|
|
|
@@ -490,14 +485,57 @@ func RelayTask(c *gin.Context) {
|
|
|
}
|
|
|
}()
|
|
|
|
|
|
- // 3. 执行 + 重试(RelayTaskSubmit 内部在首次调用时自动预扣费)
|
|
|
- taskErr = taskSubmitWithRetry(c, relayInfo, channelId, common.RetryTimes, func() *dto.TaskError {
|
|
|
- var te *dto.TaskError
|
|
|
- result, te = relay.RelayTaskSubmit(c, relayInfo)
|
|
|
- return te
|
|
|
- })
|
|
|
+ retryParam := &service.RetryParam{
|
|
|
+ Ctx: c,
|
|
|
+ TokenGroup: relayInfo.TokenGroup,
|
|
|
+ ModelName: relayInfo.OriginModelName,
|
|
|
+ Retry: common.GetPointer(0),
|
|
|
+ }
|
|
|
+
|
|
|
+ for ; retryParam.GetRetry() <= common.RetryTimes; retryParam.IncreaseRetry() {
|
|
|
+ channel, channelErr := getChannel(c, relayInfo, retryParam)
|
|
|
+ if channelErr != nil {
|
|
|
+ logger.LogError(c, channelErr.Error())
|
|
|
+ taskErr = service.TaskErrorWrapperLocal(channelErr.Err, "get_channel_failed", http.StatusInternalServerError)
|
|
|
+ break
|
|
|
+ }
|
|
|
+
|
|
|
+ addUsedChannel(c, channel.Id)
|
|
|
+ requestBody, bodyErr := common.GetRequestBody(c)
|
|
|
+ if bodyErr != nil {
|
|
|
+ if common.IsRequestBodyTooLargeError(bodyErr) || errors.Is(bodyErr, common.ErrRequestBodyTooLarge) {
|
|
|
+ taskErr = service.TaskErrorWrapperLocal(bodyErr, "read_request_body_failed", http.StatusRequestEntityTooLarge)
|
|
|
+ } else {
|
|
|
+ taskErr = service.TaskErrorWrapperLocal(bodyErr, "read_request_body_failed", http.StatusBadRequest)
|
|
|
+ }
|
|
|
+ break
|
|
|
+ }
|
|
|
+ c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody))
|
|
|
+
|
|
|
+ result, taskErr = relay.RelayTaskSubmit(c, relayInfo)
|
|
|
+ if taskErr == nil {
|
|
|
+ break
|
|
|
+ }
|
|
|
|
|
|
- // 4. 成功:结算 + 日志 + 插入任务
|
|
|
+ if !taskErr.LocalError {
|
|
|
+ processChannelError(c,
|
|
|
+ *types.NewChannelError(channel.Id, channel.Type, channel.Name, channel.ChannelInfo.IsMultiKey,
|
|
|
+ common.GetContextKeyString(c, constant.ContextKeyChannelKey), channel.GetAutoBan()),
|
|
|
+ types.NewOpenAIError(taskErr.Error, types.ErrorCodeBadResponseStatusCode, taskErr.StatusCode))
|
|
|
+ }
|
|
|
+
|
|
|
+ if !shouldRetryTaskRelay(c, channel.Id, taskErr, common.RetryTimes-retryParam.GetRetry()) {
|
|
|
+ break
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ useChannel := c.GetStringSlice("use_channel")
|
|
|
+ if len(useChannel) > 1 {
|
|
|
+ retryLogStr := fmt.Sprintf("重试:%s", strings.Trim(strings.Join(strings.Fields(fmt.Sprint(useChannel)), "->"), "[]"))
|
|
|
+ logger.LogInfo(c, retryLogStr)
|
|
|
+ }
|
|
|
+
|
|
|
+ // ── 成功:结算 + 日志 + 插入任务 ──
|
|
|
if taskErr == nil {
|
|
|
if settleErr := service.SettleBilling(c, relayInfo, result.Quota); settleErr != nil {
|
|
|
common.SysError("settle task billing error: " + settleErr.Error())
|
|
|
@@ -520,7 +558,6 @@ func RelayTask(c *gin.Context) {
|
|
|
task.Data = result.TaskData
|
|
|
task.Action = relayInfo.Action
|
|
|
if insertErr := task.Insert(); insertErr != nil {
|
|
|
- //taskErr = service.TaskErrorWrapper(insertErr, "insert_task_failed", http.StatusInternalServerError)
|
|
|
common.SysError("insert task error: " + insertErr.Error())
|
|
|
}
|
|
|
}
|
|
|
@@ -538,69 +575,6 @@ func respondTaskError(c *gin.Context, taskErr *dto.TaskError) {
|
|
|
c.JSON(taskErr.StatusCode, taskErr)
|
|
|
}
|
|
|
|
|
|
-// taskSubmitWithRetry 执行首次尝试并在失败时切换渠道重试,返回最终的 taskErr。
|
|
|
-// attempt 闭包负责实际的上游请求,不涉及计费。
|
|
|
-func taskSubmitWithRetry(c *gin.Context, relayInfo *relaycommon.RelayInfo,
|
|
|
- channelId int, retryTimes int, attempt func() *dto.TaskError) *dto.TaskError {
|
|
|
-
|
|
|
- taskErr := attempt()
|
|
|
- if taskErr == nil {
|
|
|
- return nil
|
|
|
- }
|
|
|
- if !taskErr.LocalError {
|
|
|
- processChannelError(c,
|
|
|
- *types.NewChannelError(channelId, c.GetInt("channel_type"), c.GetString("channel_name"), common.GetContextKeyBool(c, constant.ContextKeyChannelIsMultiKey),
|
|
|
- common.GetContextKeyString(c, constant.ContextKeyChannelKey), common.GetContextKeyBool(c, constant.ContextKeyChannelAutoBan)),
|
|
|
- types.NewOpenAIError(taskErr.Error, types.ErrorCodeBadResponseStatusCode, taskErr.StatusCode))
|
|
|
- }
|
|
|
-
|
|
|
- retryParam := &service.RetryParam{
|
|
|
- Ctx: c,
|
|
|
- TokenGroup: relayInfo.TokenGroup,
|
|
|
- ModelName: relayInfo.OriginModelName,
|
|
|
- Retry: common.GetPointer(0),
|
|
|
- }
|
|
|
- for ; shouldRetryTaskRelay(c, channelId, taskErr, retryTimes) && retryParam.GetRetry() < retryTimes; retryParam.IncreaseRetry() {
|
|
|
- channel, newAPIError := getChannel(c, relayInfo, retryParam)
|
|
|
- if newAPIError != nil {
|
|
|
- logger.LogError(c, fmt.Sprintf("CacheGetRandomSatisfiedChannel failed: %s", newAPIError.Error()))
|
|
|
- taskErr = service.TaskErrorWrapperLocal(newAPIError.Err, "get_channel_failed", http.StatusInternalServerError)
|
|
|
- break
|
|
|
- }
|
|
|
- channelId = channel.Id
|
|
|
- useChannel := c.GetStringSlice("use_channel")
|
|
|
- useChannel = append(useChannel, fmt.Sprintf("%d", channelId))
|
|
|
- c.Set("use_channel", useChannel)
|
|
|
- logger.LogInfo(c, fmt.Sprintf("using channel #%d to retry (remain times %d)", channel.Id, retryParam.GetRetry()))
|
|
|
- middleware.SetupContextForSelectedChannel(c, channel, c.GetString("original_model"))
|
|
|
-
|
|
|
- bodyStorage, err := common.GetBodyStorage(c)
|
|
|
- if err != nil {
|
|
|
- if common.IsRequestBodyTooLargeError(err) || errors.Is(err, common.ErrRequestBodyTooLarge) {
|
|
|
- taskErr = service.TaskErrorWrapperLocal(err, "read_request_body_failed", http.StatusRequestEntityTooLarge)
|
|
|
- } else {
|
|
|
- taskErr = service.TaskErrorWrapperLocal(err, "read_request_body_failed", http.StatusBadRequest)
|
|
|
- }
|
|
|
- break
|
|
|
- }
|
|
|
- c.Request.Body = io.NopCloser(bodyStorage)
|
|
|
- taskErr = attempt()
|
|
|
- if taskErr != nil && !taskErr.LocalError {
|
|
|
- processChannelError(c,
|
|
|
- *types.NewChannelError(channel.Id, channel.Type, channel.Name, channel.ChannelInfo.IsMultiKey,
|
|
|
- common.GetContextKeyString(c, constant.ContextKeyChannelKey), channel.GetAutoBan()),
|
|
|
- types.NewOpenAIError(taskErr.Error, types.ErrorCodeBadResponseStatusCode, taskErr.StatusCode))
|
|
|
- }
|
|
|
- }
|
|
|
-
|
|
|
- useChannel := c.GetStringSlice("use_channel")
|
|
|
- if len(useChannel) > 1 {
|
|
|
- retryLogStr := fmt.Sprintf("重试:%s", strings.Trim(strings.Join(strings.Fields(fmt.Sprint(useChannel)), "->"), "[]"))
|
|
|
- logger.LogInfo(c, retryLogStr)
|
|
|
- }
|
|
|
- return taskErr
|
|
|
-}
|
|
|
-
|
|
|
func shouldRetryTaskRelay(c *gin.Context, channelId int, taskErr *dto.TaskError, retryTimes int) bool {
|
|
|
if taskErr == nil {
|
|
|
return false
|