task_billing.go 7.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227
  1. package service
  2. import (
  3. "context"
  4. "fmt"
  5. "strings"
  6. "github.com/QuantumNous/new-api/common"
  7. "github.com/QuantumNous/new-api/constant"
  8. "github.com/QuantumNous/new-api/logger"
  9. "github.com/QuantumNous/new-api/model"
  10. relaycommon "github.com/QuantumNous/new-api/relay/common"
  11. "github.com/QuantumNous/new-api/setting/ratio_setting"
  12. "github.com/gin-gonic/gin"
  13. )
  14. // LogTaskConsumption 记录任务消费日志和统计信息(仅记录,不涉及实际扣费)。
  15. // 实际扣费已由 BillingSession(PreConsumeBilling + SettleBilling)完成。
  16. func LogTaskConsumption(c *gin.Context, info *relaycommon.RelayInfo, modelName string) {
  17. tokenName := c.GetString("token_name")
  18. logContent := fmt.Sprintf("操作 %s", info.Action)
  19. // 支持任务仅按次计费
  20. if common.StringsContains(constant.TaskPricePatches, modelName) {
  21. logContent = fmt.Sprintf("%s,按次计费", logContent)
  22. } else {
  23. if len(info.PriceData.OtherRatios) > 0 {
  24. var contents []string
  25. for key, ra := range info.PriceData.OtherRatios {
  26. if 1.0 != ra {
  27. contents = append(contents, fmt.Sprintf("%s: %.2f", key, ra))
  28. }
  29. }
  30. if len(contents) > 0 {
  31. logContent = fmt.Sprintf("%s, 计算参数:%s", logContent, strings.Join(contents, ", "))
  32. }
  33. }
  34. }
  35. other := make(map[string]interface{})
  36. other["request_path"] = c.Request.URL.Path
  37. other["model_price"] = info.PriceData.ModelPrice
  38. other["group_ratio"] = info.PriceData.GroupRatioInfo.GroupRatio
  39. if info.PriceData.GroupRatioInfo.HasSpecialRatio {
  40. other["user_group_ratio"] = info.PriceData.GroupRatioInfo.GroupSpecialRatio
  41. }
  42. model.RecordConsumeLog(c, info.UserId, model.RecordConsumeLogParams{
  43. ChannelId: info.ChannelId,
  44. ModelName: modelName,
  45. TokenName: tokenName,
  46. Quota: info.PriceData.Quota,
  47. Content: logContent,
  48. TokenId: info.TokenId,
  49. Group: info.UsingGroup,
  50. Other: other,
  51. })
  52. model.UpdateUserUsedQuotaAndRequestCount(info.UserId, info.PriceData.Quota)
  53. model.UpdateChannelUsedQuota(info.ChannelId, info.PriceData.Quota)
  54. }
  55. // ---------------------------------------------------------------------------
  56. // 异步任务计费辅助函数
  57. // ---------------------------------------------------------------------------
  58. // resolveTokenKey 通过 TokenId 运行时获取令牌 Key(用于 Redis 缓存操作)。
  59. // 如果令牌已被删除或查询失败,返回空字符串。
  60. func resolveTokenKey(ctx context.Context, tokenId int, taskID string) string {
  61. token, err := model.GetTokenById(tokenId)
  62. if err != nil {
  63. logger.LogWarn(ctx, fmt.Sprintf("获取令牌 key 失败 (tokenId=%d, task=%s): %s", tokenId, taskID, err.Error()))
  64. return ""
  65. }
  66. return token.Key
  67. }
  68. // taskIsSubscription 判断任务是否通过订阅计费。
  69. func taskIsSubscription(task *model.Task) bool {
  70. return task.PrivateData.BillingSource == BillingSourceSubscription && task.PrivateData.SubscriptionId > 0
  71. }
  72. // taskAdjustFunding 调整任务的资金来源(钱包或订阅),delta > 0 表示扣费,delta < 0 表示退还。
  73. func taskAdjustFunding(task *model.Task, delta int) error {
  74. if taskIsSubscription(task) {
  75. return model.PostConsumeUserSubscriptionDelta(task.PrivateData.SubscriptionId, int64(delta))
  76. }
  77. if delta > 0 {
  78. return model.DecreaseUserQuota(task.UserId, delta)
  79. }
  80. return model.IncreaseUserQuota(task.UserId, -delta, false)
  81. }
  82. // taskAdjustTokenQuota 调整任务的令牌额度,delta > 0 表示扣费,delta < 0 表示退还。
  83. // 需要通过 resolveTokenKey 运行时获取 key(不从 PrivateData 中读取)。
  84. func taskAdjustTokenQuota(ctx context.Context, task *model.Task, delta int) {
  85. if task.PrivateData.TokenId <= 0 || delta == 0 {
  86. return
  87. }
  88. tokenKey := resolveTokenKey(ctx, task.PrivateData.TokenId, task.TaskID)
  89. if tokenKey == "" {
  90. return
  91. }
  92. var err error
  93. if delta > 0 {
  94. err = model.DecreaseTokenQuota(task.PrivateData.TokenId, tokenKey, delta)
  95. } else {
  96. err = model.IncreaseTokenQuota(task.PrivateData.TokenId, tokenKey, -delta)
  97. }
  98. if err != nil {
  99. logger.LogWarn(ctx, fmt.Sprintf("调整令牌额度失败 (delta=%d, task=%s): %s", delta, task.TaskID, err.Error()))
  100. }
  101. }
  102. // RefundTaskQuota 统一的任务失败退款逻辑。
  103. // 当异步任务失败时,将预扣的 quota 退还给用户(支持钱包和订阅),并退还令牌额度。
  104. func RefundTaskQuota(ctx context.Context, task *model.Task, reason string) {
  105. quota := task.Quota
  106. if quota == 0 {
  107. return
  108. }
  109. // 1. 退还资金来源(钱包或订阅)
  110. if err := taskAdjustFunding(task, -quota); err != nil {
  111. logger.LogWarn(ctx, fmt.Sprintf("退还资金来源失败 task %s: %s", task.TaskID, err.Error()))
  112. return
  113. }
  114. // 2. 退还令牌额度
  115. taskAdjustTokenQuota(ctx, task, -quota)
  116. // 3. 记录日志
  117. logContent := fmt.Sprintf("异步任务执行失败 %s,补偿 %s,原因:%s", task.TaskID, logger.LogQuota(quota), reason)
  118. model.RecordLog(task.UserId, model.LogTypeSystem, logContent)
  119. }
  120. // RecalculateTaskQuotaByTokens 根据实际 token 消耗重新计费(异步差额结算)。
  121. // 当任务成功且返回了 totalTokens 时,根据模型倍率和分组倍率重新计算实际扣费额度,
  122. // 与预扣费的差额进行补扣或退还。支持钱包和订阅计费来源。
  123. func RecalculateTaskQuotaByTokens(ctx context.Context, task *model.Task, totalTokens int) {
  124. if totalTokens <= 0 {
  125. return
  126. }
  127. // 获取模型名称
  128. var taskData map[string]interface{}
  129. if err := common.Unmarshal(task.Data, &taskData); err != nil {
  130. return
  131. }
  132. modelName, ok := taskData["model"].(string)
  133. if !ok || modelName == "" {
  134. return
  135. }
  136. // 获取模型价格和倍率
  137. modelRatio, hasRatioSetting, _ := ratio_setting.GetModelRatio(modelName)
  138. // 只有配置了倍率(非固定价格)时才按 token 重新计费
  139. if !hasRatioSetting || modelRatio <= 0 {
  140. return
  141. }
  142. // 获取用户和组的倍率信息
  143. group := task.Group
  144. if group == "" {
  145. user, err := model.GetUserById(task.UserId, false)
  146. if err == nil {
  147. group = user.Group
  148. }
  149. }
  150. if group == "" {
  151. return
  152. }
  153. groupRatio := ratio_setting.GetGroupRatio(group)
  154. userGroupRatio, hasUserGroupRatio := ratio_setting.GetGroupGroupRatio(group, group)
  155. var finalGroupRatio float64
  156. if hasUserGroupRatio {
  157. finalGroupRatio = userGroupRatio
  158. } else {
  159. finalGroupRatio = groupRatio
  160. }
  161. // 计算实际应扣费额度: totalTokens * modelRatio * groupRatio
  162. actualQuota := int(float64(totalTokens) * modelRatio * finalGroupRatio)
  163. // 计算差额(正数=需要补扣,负数=需要退还)
  164. preConsumedQuota := task.Quota
  165. quotaDelta := actualQuota - preConsumedQuota
  166. if quotaDelta == 0 {
  167. logger.LogInfo(ctx, fmt.Sprintf("视频任务 %s 预扣费准确(%s,tokens:%d)",
  168. task.TaskID, logger.LogQuota(actualQuota), totalTokens))
  169. return
  170. }
  171. logger.LogInfo(ctx, fmt.Sprintf("视频任务 %s 差额结算:delta=%s(实际:%s,预扣:%s,tokens:%d)",
  172. task.TaskID,
  173. logger.LogQuota(quotaDelta),
  174. logger.LogQuota(actualQuota),
  175. logger.LogQuota(preConsumedQuota),
  176. totalTokens,
  177. ))
  178. // 调整资金来源
  179. if err := taskAdjustFunding(task, quotaDelta); err != nil {
  180. logger.LogError(ctx, fmt.Sprintf("差额结算资金调整失败 task %s: %s", task.TaskID, err.Error()))
  181. return
  182. }
  183. // 调整令牌额度
  184. taskAdjustTokenQuota(ctx, task, quotaDelta)
  185. // 更新统计(仅补扣时更新,退还不影响已用统计)
  186. if quotaDelta > 0 {
  187. model.UpdateUserUsedQuotaAndRequestCount(task.UserId, quotaDelta)
  188. model.UpdateChannelUsedQuota(task.ChannelId, quotaDelta)
  189. }
  190. task.Quota = actualQuota
  191. var action string
  192. if quotaDelta > 0 {
  193. action = "补扣费"
  194. } else {
  195. action = "退还"
  196. }
  197. logContent := fmt.Sprintf("视频任务成功%s,模型倍率 %.2f,分组倍率 %.2f,tokens %d,预扣费 %s,实际扣费 %s",
  198. action, modelRatio, finalGroupRatio, totalTokens,
  199. logger.LogQuota(preConsumedQuota), logger.LogQuota(actualQuota))
  200. model.RecordLog(task.UserId, model.LogTypeSystem, logContent)
  201. }