task_billing.go 9.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285
  1. package service
  2. import (
  3. "context"
  4. "fmt"
  5. "strings"
  6. "github.com/QuantumNous/new-api/common"
  7. "github.com/QuantumNous/new-api/constant"
  8. "github.com/QuantumNous/new-api/logger"
  9. "github.com/QuantumNous/new-api/model"
  10. relaycommon "github.com/QuantumNous/new-api/relay/common"
  11. "github.com/QuantumNous/new-api/setting/ratio_setting"
  12. "github.com/gin-gonic/gin"
  13. )
  14. // LogTaskConsumption 记录任务消费日志和统计信息(仅记录,不涉及实际扣费)。
  15. // 实际扣费已由 BillingSession(PreConsumeBilling + SettleBilling)完成。
  16. func LogTaskConsumption(c *gin.Context, info *relaycommon.RelayInfo) {
  17. tokenName := c.GetString("token_name")
  18. logContent := fmt.Sprintf("操作 %s", info.Action)
  19. // 支持任务仅按次计费
  20. if common.StringsContains(constant.TaskPricePatches, info.OriginModelName) {
  21. logContent = fmt.Sprintf("%s,按次计费", logContent)
  22. } else {
  23. if len(info.PriceData.OtherRatios) > 0 {
  24. var contents []string
  25. for key, ra := range info.PriceData.OtherRatios {
  26. if 1.0 != ra {
  27. contents = append(contents, fmt.Sprintf("%s: %.2f", key, ra))
  28. }
  29. }
  30. if len(contents) > 0 {
  31. logContent = fmt.Sprintf("%s, 计算参数:%s", logContent, strings.Join(contents, ", "))
  32. }
  33. }
  34. }
  35. other := make(map[string]interface{})
  36. other["request_path"] = c.Request.URL.Path
  37. other["model_price"] = info.PriceData.ModelPrice
  38. other["group_ratio"] = info.PriceData.GroupRatioInfo.GroupRatio
  39. if info.PriceData.GroupRatioInfo.HasSpecialRatio {
  40. other["user_group_ratio"] = info.PriceData.GroupRatioInfo.GroupSpecialRatio
  41. }
  42. if info.IsModelMapped {
  43. other["is_model_mapped"] = true
  44. other["upstream_model_name"] = info.UpstreamModelName
  45. }
  46. model.RecordConsumeLog(c, info.UserId, model.RecordConsumeLogParams{
  47. ChannelId: info.ChannelId,
  48. ModelName: info.OriginModelName,
  49. TokenName: tokenName,
  50. Quota: info.PriceData.Quota,
  51. Content: logContent,
  52. TokenId: info.TokenId,
  53. Group: info.UsingGroup,
  54. Other: other,
  55. })
  56. model.UpdateUserUsedQuotaAndRequestCount(info.UserId, info.PriceData.Quota)
  57. model.UpdateChannelUsedQuota(info.ChannelId, info.PriceData.Quota)
  58. }
  59. // ---------------------------------------------------------------------------
  60. // 异步任务计费辅助函数
  61. // ---------------------------------------------------------------------------
  62. // resolveTokenKey 通过 TokenId 运行时获取令牌 Key(用于 Redis 缓存操作)。
  63. // 如果令牌已被删除或查询失败,返回空字符串。
  64. func resolveTokenKey(ctx context.Context, tokenId int, taskID string) string {
  65. token, err := model.GetTokenById(tokenId)
  66. if err != nil {
  67. logger.LogWarn(ctx, fmt.Sprintf("获取令牌 key 失败 (tokenId=%d, task=%s): %s", tokenId, taskID, err.Error()))
  68. return ""
  69. }
  70. return token.Key
  71. }
  72. // taskIsSubscription 判断任务是否通过订阅计费。
  73. func taskIsSubscription(task *model.Task) bool {
  74. return task.PrivateData.BillingSource == BillingSourceSubscription && task.PrivateData.SubscriptionId > 0
  75. }
  76. // taskAdjustFunding 调整任务的资金来源(钱包或订阅),delta > 0 表示扣费,delta < 0 表示退还。
  77. func taskAdjustFunding(task *model.Task, delta int) error {
  78. if taskIsSubscription(task) {
  79. return model.PostConsumeUserSubscriptionDelta(task.PrivateData.SubscriptionId, int64(delta))
  80. }
  81. if delta > 0 {
  82. return model.DecreaseUserQuota(task.UserId, delta)
  83. }
  84. return model.IncreaseUserQuota(task.UserId, -delta, false)
  85. }
  86. // taskAdjustTokenQuota 调整任务的令牌额度,delta > 0 表示扣费,delta < 0 表示退还。
  87. // 需要通过 resolveTokenKey 运行时获取 key(不从 PrivateData 中读取)。
  88. func taskAdjustTokenQuota(ctx context.Context, task *model.Task, delta int) {
  89. if task.PrivateData.TokenId <= 0 || delta == 0 {
  90. return
  91. }
  92. tokenKey := resolveTokenKey(ctx, task.PrivateData.TokenId, task.TaskID)
  93. if tokenKey == "" {
  94. return
  95. }
  96. var err error
  97. if delta > 0 {
  98. err = model.DecreaseTokenQuota(task.PrivateData.TokenId, tokenKey, delta)
  99. } else {
  100. err = model.IncreaseTokenQuota(task.PrivateData.TokenId, tokenKey, -delta)
  101. }
  102. if err != nil {
  103. logger.LogWarn(ctx, fmt.Sprintf("调整令牌额度失败 (delta=%d, task=%s): %s", delta, task.TaskID, err.Error()))
  104. }
  105. }
  106. // taskBillingOther 从 task 的 BillingContext 构建日志 Other 字段。
  107. func taskBillingOther(task *model.Task) map[string]interface{} {
  108. other := make(map[string]interface{})
  109. if bc := task.PrivateData.BillingContext; bc != nil {
  110. other["model_price"] = bc.ModelPrice
  111. other["group_ratio"] = bc.GroupRatio
  112. if len(bc.OtherRatios) > 0 {
  113. for k, v := range bc.OtherRatios {
  114. other[k] = v
  115. }
  116. }
  117. }
  118. props := task.Properties
  119. if props.UpstreamModelName != "" && props.UpstreamModelName != props.OriginModelName {
  120. other["is_model_mapped"] = true
  121. other["upstream_model_name"] = props.UpstreamModelName
  122. }
  123. return other
  124. }
  125. // taskModelName 从 BillingContext 或 Properties 中获取模型名称。
  126. func taskModelName(task *model.Task) string {
  127. if bc := task.PrivateData.BillingContext; bc != nil && bc.OriginModelName != "" {
  128. return bc.OriginModelName
  129. }
  130. return task.Properties.OriginModelName
  131. }
  132. // RefundTaskQuota 统一的任务失败退款逻辑。
  133. // 当异步任务失败时,将预扣的 quota 退还给用户(支持钱包和订阅),并退还令牌额度。
  134. func RefundTaskQuota(ctx context.Context, task *model.Task, reason string) {
  135. quota := task.Quota
  136. if quota == 0 {
  137. return
  138. }
  139. // 1. 退还资金来源(钱包或订阅)
  140. if err := taskAdjustFunding(task, -quota); err != nil {
  141. logger.LogWarn(ctx, fmt.Sprintf("退还资金来源失败 task %s: %s", task.TaskID, err.Error()))
  142. return
  143. }
  144. // 2. 退还令牌额度
  145. taskAdjustTokenQuota(ctx, task, -quota)
  146. // 3. 记录日志
  147. other := taskBillingOther(task)
  148. other["task_id"] = task.TaskID
  149. other["reason"] = reason
  150. model.RecordTaskBillingLog(model.RecordTaskBillingLogParams{
  151. UserId: task.UserId,
  152. LogType: model.LogTypeRefund,
  153. Content: "",
  154. ChannelId: task.ChannelId,
  155. ModelName: taskModelName(task),
  156. Quota: quota,
  157. TokenId: task.PrivateData.TokenId,
  158. Group: task.Group,
  159. Other: other,
  160. })
  161. }
  162. // RecalculateTaskQuota 通用的异步差额结算。
  163. // actualQuota 是任务完成后的实际应扣额度,与预扣额度 (task.Quota) 做差额结算。
  164. // reason 用于日志记录(例如 "token重算" 或 "adaptor调整")。
  165. func RecalculateTaskQuota(ctx context.Context, task *model.Task, actualQuota int, reason string) {
  166. if actualQuota <= 0 {
  167. return
  168. }
  169. preConsumedQuota := task.Quota
  170. quotaDelta := actualQuota - preConsumedQuota
  171. if quotaDelta == 0 {
  172. logger.LogInfo(ctx, fmt.Sprintf("任务 %s 预扣费准确(%s,%s)",
  173. task.TaskID, logger.LogQuota(actualQuota), reason))
  174. return
  175. }
  176. logger.LogInfo(ctx, fmt.Sprintf("任务 %s 差额结算:delta=%s(实际:%s,预扣:%s,%s)",
  177. task.TaskID,
  178. logger.LogQuota(quotaDelta),
  179. logger.LogQuota(actualQuota),
  180. logger.LogQuota(preConsumedQuota),
  181. reason,
  182. ))
  183. // 调整资金来源
  184. if err := taskAdjustFunding(task, quotaDelta); err != nil {
  185. logger.LogError(ctx, fmt.Sprintf("差额结算资金调整失败 task %s: %s", task.TaskID, err.Error()))
  186. return
  187. }
  188. // 调整令牌额度
  189. taskAdjustTokenQuota(ctx, task, quotaDelta)
  190. task.Quota = actualQuota
  191. var logType int
  192. var logQuota int
  193. if quotaDelta > 0 {
  194. logType = model.LogTypeConsume
  195. logQuota = quotaDelta
  196. model.UpdateUserUsedQuotaAndRequestCount(task.UserId, quotaDelta)
  197. model.UpdateChannelUsedQuota(task.ChannelId, quotaDelta)
  198. } else {
  199. logType = model.LogTypeRefund
  200. logQuota = -quotaDelta
  201. }
  202. other := taskBillingOther(task)
  203. other["task_id"] = task.TaskID
  204. other["reason"] = reason
  205. other["pre_consumed_quota"] = preConsumedQuota
  206. other["actual_quota"] = actualQuota
  207. model.RecordTaskBillingLog(model.RecordTaskBillingLogParams{
  208. UserId: task.UserId,
  209. LogType: logType,
  210. Content: "",
  211. ChannelId: task.ChannelId,
  212. ModelName: taskModelName(task),
  213. Quota: logQuota,
  214. TokenId: task.PrivateData.TokenId,
  215. Group: task.Group,
  216. Other: other,
  217. })
  218. }
  219. // RecalculateTaskQuotaByTokens 根据实际 token 消耗重新计费(异步差额结算)。
  220. // 当任务成功且返回了 totalTokens 时,根据模型倍率和分组倍率重新计算实际扣费额度,
  221. // 与预扣费的差额进行补扣或退还。支持钱包和订阅计费来源。
  222. func RecalculateTaskQuotaByTokens(ctx context.Context, task *model.Task, totalTokens int) {
  223. if totalTokens <= 0 {
  224. return
  225. }
  226. modelName := taskModelName(task)
  227. // 获取模型价格和倍率
  228. modelRatio, hasRatioSetting, _ := ratio_setting.GetModelRatio(modelName)
  229. // 只有配置了倍率(非固定价格)时才按 token 重新计费
  230. if !hasRatioSetting || modelRatio <= 0 {
  231. return
  232. }
  233. // 获取用户和组的倍率信息
  234. group := task.Group
  235. if group == "" {
  236. user, err := model.GetUserById(task.UserId, false)
  237. if err == nil {
  238. group = user.Group
  239. }
  240. }
  241. if group == "" {
  242. return
  243. }
  244. groupRatio := ratio_setting.GetGroupRatio(group)
  245. userGroupRatio, hasUserGroupRatio := ratio_setting.GetGroupGroupRatio(group, group)
  246. var finalGroupRatio float64
  247. if hasUserGroupRatio {
  248. finalGroupRatio = userGroupRatio
  249. } else {
  250. finalGroupRatio = groupRatio
  251. }
  252. // 计算实际应扣费额度: totalTokens * modelRatio * groupRatio
  253. actualQuota := int(float64(totalTokens) * modelRatio * finalGroupRatio)
  254. reason := fmt.Sprintf("token重算:tokens=%d, modelRatio=%.2f, groupRatio=%.2f", totalTokens, modelRatio, finalGroupRatio)
  255. RecalculateTaskQuota(ctx, task, actualQuota, reason)
  256. }