Browse Source

feat: add doubao video use quota by total token

feitianbubu 5 months ago
parent
commit
b244a06ca1
3 changed files with 95 additions and 6 deletions
  1. 84 0
      controller/task_video.go
  2. 3 0
      relay/channel/task/doubao/adaptor.go
  3. 8 6
      relay/common/relay_info.go

+ 84 - 0
controller/task_video.go

@@ -13,6 +13,7 @@ import (
 	"one-api/relay"
 	"one-api/relay/channel"
 	relaycommon "one-api/relay/common"
+	"one-api/setting/ratio_setting"
 	"time"
 )
 
@@ -120,6 +121,89 @@ func updateVideoSingleTask(ctx context.Context, adaptor channel.TaskAdaptor, cha
 		if !(len(taskResult.Url) > 5 && taskResult.Url[:5] == "data:") {
 			task.FailReason = taskResult.Url
 		}
+
+		// 如果返回了 total_tokens 并且配置了模型倍率(非固定价格),则重新计费
+		if taskResult.TotalTokens > 0 {
+			// 获取模型名称
+			var taskData map[string]interface{}
+			if err := json.Unmarshal(task.Data, &taskData); err == nil {
+				if modelName, ok := taskData["model"].(string); ok && modelName != "" {
+					// 获取模型价格和倍率
+					modelRatio, hasRatioSetting, _ := ratio_setting.GetModelRatio(modelName)
+
+					// 只有配置了倍率(非固定价格)时才按 token 重新计费
+					if hasRatioSetting && modelRatio > 0 {
+						// 获取用户和组的倍率信息
+						user, err := model.GetUserById(task.UserId, false)
+						if err == nil {
+							groupRatio := ratio_setting.GetGroupRatio(user.Group)
+							userGroupRatio, hasUserGroupRatio := ratio_setting.GetGroupGroupRatio(user.Group, user.Group)
+
+							var finalGroupRatio float64
+							if hasUserGroupRatio {
+								finalGroupRatio = userGroupRatio
+							} else {
+								finalGroupRatio = groupRatio
+							}
+
+							// 计算实际应扣费额度: totalTokens * modelRatio * groupRatio
+							actualQuota := int(float64(taskResult.TotalTokens) * modelRatio * finalGroupRatio)
+
+							// 计算差额
+							preConsumedQuota := task.Quota
+							quotaDelta := actualQuota - preConsumedQuota
+
+							if quotaDelta > 0 {
+								// 需要补扣费
+								logger.LogInfo(ctx, fmt.Sprintf("视频任务 %s 预扣费后补扣费:%s(实际消耗:%s,预扣费:%s,tokens:%d)",
+									task.TaskID,
+									logger.LogQuota(quotaDelta),
+									logger.LogQuota(actualQuota),
+									logger.LogQuota(preConsumedQuota),
+									taskResult.TotalTokens,
+								))
+								if err := model.DecreaseUserQuota(task.UserId, quotaDelta); err != nil {
+									logger.LogError(ctx, fmt.Sprintf("补扣费失败: %s", err.Error()))
+								} else {
+									model.UpdateUserUsedQuotaAndRequestCount(task.UserId, quotaDelta)
+									model.UpdateChannelUsedQuota(task.ChannelId, quotaDelta)
+									task.Quota = actualQuota // 更新任务记录的实际扣费额度
+
+									// 记录消费日志
+									logContent := fmt.Sprintf("视频任务成功补扣费,模型倍率 %.2f,分组倍率 %.2f,tokens %d",
+										modelRatio, finalGroupRatio, taskResult.TotalTokens)
+									model.RecordLog(task.UserId, model.LogTypeSystem, logContent)
+								}
+							} else if quotaDelta < 0 {
+								// 需要退还多扣的费用
+								refundQuota := -quotaDelta
+								logger.LogInfo(ctx, fmt.Sprintf("视频任务 %s 预扣费后返还:%s(实际消耗:%s,预扣费:%s,tokens:%d)",
+									task.TaskID,
+									logger.LogQuota(refundQuota),
+									logger.LogQuota(actualQuota),
+									logger.LogQuota(preConsumedQuota),
+									taskResult.TotalTokens,
+								))
+								if err := model.IncreaseUserQuota(task.UserId, refundQuota, false); err != nil {
+									logger.LogError(ctx, fmt.Sprintf("退还预扣费失败: %s", err.Error()))
+								} else {
+									task.Quota = actualQuota // 更新任务记录的实际扣费额度
+
+									// 记录退款日志
+									logContent := fmt.Sprintf("视频任务成功退还多扣费用,模型倍率 %.2f,分组倍率 %.2f,tokens %d,退还 %s",
+										modelRatio, finalGroupRatio, taskResult.TotalTokens, logger.LogQuota(refundQuota))
+									model.RecordLog(task.UserId, model.LogTypeSystem, logContent)
+								}
+							} else {
+								// quotaDelta == 0, 预扣费刚好准确
+								logger.LogInfo(ctx, fmt.Sprintf("视频任务 %s 预扣费准确(%s,tokens:%d)",
+									task.TaskID, logger.LogQuota(actualQuota), taskResult.TotalTokens))
+							}
+						}
+					}
+				}
+			}
+		}
 	case model.TaskStatusFailure:
 		task.Status = model.TaskStatusFailure
 		task.Progress = "100%"

+ 3 - 0
relay/channel/task/doubao/adaptor.go

@@ -231,6 +231,9 @@ func (a *TaskAdaptor) ParseTaskResult(respBody []byte) (*relaycommon.TaskInfo, e
 		taskResult.Status = model.TaskStatusSuccess
 		taskResult.Progress = "100%"
 		taskResult.Url = resTask.Content.VideoURL
+		// 解析 usage 信息用于按倍率计费
+		taskResult.CompletionTokens = resTask.Usage.CompletionTokens
+		taskResult.TotalTokens = resTask.Usage.TotalTokens
 	case "failed":
 		taskResult.Status = model.TaskStatusFailure
 		taskResult.Progress = "100%"

+ 8 - 6
relay/common/relay_info.go

@@ -500,10 +500,12 @@ func (t TaskSubmitReq) HasImage() bool {
 }
 
 type TaskInfo struct {
-	Code     int    `json:"code"`
-	TaskID   string `json:"task_id"`
-	Status   string `json:"status"`
-	Reason   string `json:"reason,omitempty"`
-	Url      string `json:"url,omitempty"`
-	Progress string `json:"progress,omitempty"`
+	Code             int    `json:"code"`
+	TaskID           string `json:"task_id"`
+	Status           string `json:"status"`
+	Reason           string `json:"reason,omitempty"`
+	Url              string `json:"url,omitempty"`
+	Progress         string `json:"progress,omitempty"`
+	CompletionTokens int    `json:"completion_tokens,omitempty"` // 用于按倍率计费
+	TotalTokens      int    `json:"total_tokens,omitempty"`      // 用于按倍率计费
 }