|
@@ -13,6 +13,7 @@ import (
|
|
|
"one-api/relay"
|
|
"one-api/relay"
|
|
|
"one-api/relay/channel"
|
|
"one-api/relay/channel"
|
|
|
relaycommon "one-api/relay/common"
|
|
relaycommon "one-api/relay/common"
|
|
|
|
|
+ "one-api/setting/ratio_setting"
|
|
|
"time"
|
|
"time"
|
|
|
)
|
|
)
|
|
|
|
|
|
|
@@ -120,6 +121,89 @@ func updateVideoSingleTask(ctx context.Context, adaptor channel.TaskAdaptor, cha
|
|
|
if !(len(taskResult.Url) > 5 && taskResult.Url[:5] == "data:") {
|
|
if !(len(taskResult.Url) > 5 && taskResult.Url[:5] == "data:") {
|
|
|
task.FailReason = taskResult.Url
|
|
task.FailReason = taskResult.Url
|
|
|
}
|
|
}
|
|
|
|
|
+
|
|
|
|
|
+ // 如果返回了 total_tokens 并且配置了模型倍率(非固定价格),则重新计费
|
|
|
|
|
+ if taskResult.TotalTokens > 0 {
|
|
|
|
|
+ // 获取模型名称
|
|
|
|
|
+ var taskData map[string]interface{}
|
|
|
|
|
+ if err := json.Unmarshal(task.Data, &taskData); err == nil {
|
|
|
|
|
+ if modelName, ok := taskData["model"].(string); ok && modelName != "" {
|
|
|
|
|
+ // 获取模型价格和倍率
|
|
|
|
|
+ modelRatio, hasRatioSetting, _ := ratio_setting.GetModelRatio(modelName)
|
|
|
|
|
+
|
|
|
|
|
+ // 只有配置了倍率(非固定价格)时才按 token 重新计费
|
|
|
|
|
+ if hasRatioSetting && modelRatio > 0 {
|
|
|
|
|
+ // 获取用户和组的倍率信息
|
|
|
|
|
+ user, err := model.GetUserById(task.UserId, false)
|
|
|
|
|
+ if err == nil {
|
|
|
|
|
+ groupRatio := ratio_setting.GetGroupRatio(user.Group)
|
|
|
|
|
+ userGroupRatio, hasUserGroupRatio := ratio_setting.GetGroupGroupRatio(user.Group, user.Group)
|
|
|
|
|
+
|
|
|
|
|
+ var finalGroupRatio float64
|
|
|
|
|
+ if hasUserGroupRatio {
|
|
|
|
|
+ finalGroupRatio = userGroupRatio
|
|
|
|
|
+ } else {
|
|
|
|
|
+ finalGroupRatio = groupRatio
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ // 计算实际应扣费额度: totalTokens * modelRatio * groupRatio
|
|
|
|
|
+ actualQuota := int(float64(taskResult.TotalTokens) * modelRatio * finalGroupRatio)
|
|
|
|
|
+
|
|
|
|
|
+ // 计算差额
|
|
|
|
|
+ preConsumedQuota := task.Quota
|
|
|
|
|
+ quotaDelta := actualQuota - preConsumedQuota
|
|
|
|
|
+
|
|
|
|
|
+ if quotaDelta > 0 {
|
|
|
|
|
+ // 需要补扣费
|
|
|
|
|
+ logger.LogInfo(ctx, fmt.Sprintf("视频任务 %s 预扣费后补扣费:%s(实际消耗:%s,预扣费:%s,tokens:%d)",
|
|
|
|
|
+ task.TaskID,
|
|
|
|
|
+ logger.LogQuota(quotaDelta),
|
|
|
|
|
+ logger.LogQuota(actualQuota),
|
|
|
|
|
+ logger.LogQuota(preConsumedQuota),
|
|
|
|
|
+ taskResult.TotalTokens,
|
|
|
|
|
+ ))
|
|
|
|
|
+ if err := model.DecreaseUserQuota(task.UserId, quotaDelta); err != nil {
|
|
|
|
|
+ logger.LogError(ctx, fmt.Sprintf("补扣费失败: %s", err.Error()))
|
|
|
|
|
+ } else {
|
|
|
|
|
+ model.UpdateUserUsedQuotaAndRequestCount(task.UserId, quotaDelta)
|
|
|
|
|
+ model.UpdateChannelUsedQuota(task.ChannelId, quotaDelta)
|
|
|
|
|
+ task.Quota = actualQuota // 更新任务记录的实际扣费额度
|
|
|
|
|
+
|
|
|
|
|
+ // 记录消费日志
|
|
|
|
|
+ logContent := fmt.Sprintf("视频任务成功补扣费,模型倍率 %.2f,分组倍率 %.2f,tokens %d",
|
|
|
|
|
+ modelRatio, finalGroupRatio, taskResult.TotalTokens)
|
|
|
|
|
+ model.RecordLog(task.UserId, model.LogTypeSystem, logContent)
|
|
|
|
|
+ }
|
|
|
|
|
+ } else if quotaDelta < 0 {
|
|
|
|
|
+ // 需要退还多扣的费用
|
|
|
|
|
+ refundQuota := -quotaDelta
|
|
|
|
|
+ logger.LogInfo(ctx, fmt.Sprintf("视频任务 %s 预扣费后返还:%s(实际消耗:%s,预扣费:%s,tokens:%d)",
|
|
|
|
|
+ task.TaskID,
|
|
|
|
|
+ logger.LogQuota(refundQuota),
|
|
|
|
|
+ logger.LogQuota(actualQuota),
|
|
|
|
|
+ logger.LogQuota(preConsumedQuota),
|
|
|
|
|
+ taskResult.TotalTokens,
|
|
|
|
|
+ ))
|
|
|
|
|
+ if err := model.IncreaseUserQuota(task.UserId, refundQuota, false); err != nil {
|
|
|
|
|
+ logger.LogError(ctx, fmt.Sprintf("退还预扣费失败: %s", err.Error()))
|
|
|
|
|
+ } else {
|
|
|
|
|
+ task.Quota = actualQuota // 更新任务记录的实际扣费额度
|
|
|
|
|
+
|
|
|
|
|
+ // 记录退款日志
|
|
|
|
|
+ logContent := fmt.Sprintf("视频任务成功退还多扣费用,模型倍率 %.2f,分组倍率 %.2f,tokens %d,退还 %s",
|
|
|
|
|
+ modelRatio, finalGroupRatio, taskResult.TotalTokens, logger.LogQuota(refundQuota))
|
|
|
|
|
+ model.RecordLog(task.UserId, model.LogTypeSystem, logContent)
|
|
|
|
|
+ }
|
|
|
|
|
+ } else {
|
|
|
|
|
+ // quotaDelta == 0, 预扣费刚好准确
|
|
|
|
|
+ logger.LogInfo(ctx, fmt.Sprintf("视频任务 %s 预扣费准确(%s,tokens:%d)",
|
|
|
|
|
+ task.TaskID, logger.LogQuota(actualQuota), taskResult.TotalTokens))
|
|
|
|
|
+ }
|
|
|
|
|
+ }
|
|
|
|
|
+ }
|
|
|
|
|
+ }
|
|
|
|
|
+ }
|
|
|
|
|
+ }
|
|
|
case model.TaskStatusFailure:
|
|
case model.TaskStatusFailure:
|
|
|
task.Status = model.TaskStatusFailure
|
|
task.Status = model.TaskStatusFailure
|
|
|
task.Progress = "100%"
|
|
task.Progress = "100%"
|