pre_consume_quota.go 4.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124
  1. package service
  2. import (
  3. "fmt"
  4. "net/http"
  5. "time"
  6. "github.com/QuantumNous/new-api/common"
  7. "github.com/QuantumNous/new-api/logger"
  8. "github.com/QuantumNous/new-api/model"
  9. relaycommon "github.com/QuantumNous/new-api/relay/common"
  10. "github.com/QuantumNous/new-api/types"
  11. "github.com/bytedance/gopkg/util/gopool"
  12. "github.com/gin-gonic/gin"
  13. )
  14. func ReturnPreConsumedQuota(c *gin.Context, relayInfo *relaycommon.RelayInfo) {
  15. // Always refund subscription pre-consumed (can be non-zero even when FinalPreConsumedQuota is 0)
  16. needRefundSub := relayInfo.BillingSource == BillingSourceSubscription && relayInfo.SubscriptionId != 0 && relayInfo.SubscriptionPreConsumed > 0
  17. needRefundToken := relayInfo.FinalPreConsumedQuota != 0
  18. if !needRefundSub && !needRefundToken {
  19. return
  20. }
  21. logger.LogInfo(c, fmt.Sprintf("用户 %d 请求失败, 返还预扣费(token_quota=%s, subscription=%d)",
  22. relayInfo.UserId,
  23. logger.FormatQuota(relayInfo.FinalPreConsumedQuota),
  24. relayInfo.SubscriptionPreConsumed,
  25. ))
  26. gopool.Go(func() {
  27. relayInfoCopy := *relayInfo
  28. if relayInfoCopy.BillingSource == BillingSourceSubscription {
  29. if needRefundSub {
  30. if err := refundWithRetry(func() error {
  31. return model.RefundSubscriptionPreConsume(relayInfoCopy.RequestId)
  32. }); err != nil {
  33. common.SysLog("error refund subscription pre-consume: " + err.Error())
  34. }
  35. }
  36. // refund token quota only
  37. if needRefundToken && !relayInfoCopy.IsPlayground {
  38. _ = model.IncreaseTokenQuota(relayInfoCopy.TokenId, relayInfoCopy.TokenKey, relayInfoCopy.FinalPreConsumedQuota)
  39. }
  40. return
  41. }
  42. // wallet refund uses existing path (user quota + token quota)
  43. if needRefundToken {
  44. err := PostConsumeQuota(&relayInfoCopy, -relayInfoCopy.FinalPreConsumedQuota, 0, false)
  45. if err != nil {
  46. common.SysLog("error return pre-consumed quota: " + err.Error())
  47. }
  48. }
  49. })
  50. }
  51. func refundWithRetry(fn func() error) error {
  52. if fn == nil {
  53. return nil
  54. }
  55. const maxAttempts = 3
  56. var lastErr error
  57. for i := 0; i < maxAttempts; i++ {
  58. if err := fn(); err == nil {
  59. return nil
  60. } else {
  61. lastErr = err
  62. }
  63. if i < maxAttempts-1 {
  64. time.Sleep(time.Duration(200*(i+1)) * time.Millisecond)
  65. }
  66. }
  67. return lastErr
  68. }
  69. // PreConsumeQuota checks if the user has enough quota to pre-consume.
  70. // It returns the pre-consumed quota if successful, or an error if not.
  71. func PreConsumeQuota(c *gin.Context, preConsumedQuota int, relayInfo *relaycommon.RelayInfo) *types.NewAPIError {
  72. userQuota, err := model.GetUserQuota(relayInfo.UserId, false)
  73. if err != nil {
  74. return types.NewError(err, types.ErrorCodeQueryDataError, types.ErrOptionWithSkipRetry())
  75. }
  76. if userQuota <= 0 {
  77. return types.NewErrorWithStatusCode(fmt.Errorf("用户额度不足, 剩余额度: %s", logger.FormatQuota(userQuota)), types.ErrorCodeInsufficientUserQuota, http.StatusForbidden, types.ErrOptionWithSkipRetry(), types.ErrOptionWithNoRecordErrorLog())
  78. }
  79. if userQuota-preConsumedQuota < 0 {
  80. return types.NewErrorWithStatusCode(fmt.Errorf("预扣费额度失败, 用户剩余额度: %s, 需要预扣费额度: %s", logger.FormatQuota(userQuota), logger.FormatQuota(preConsumedQuota)), types.ErrorCodeInsufficientUserQuota, http.StatusForbidden, types.ErrOptionWithSkipRetry(), types.ErrOptionWithNoRecordErrorLog())
  81. }
  82. trustQuota := common.GetTrustQuota()
  83. relayInfo.UserQuota = userQuota
  84. if userQuota > trustQuota {
  85. // 用户额度充足,判断令牌额度是否充足
  86. if !relayInfo.TokenUnlimited {
  87. // 非无限令牌,判断令牌额度是否充足
  88. tokenQuota := c.GetInt("token_quota")
  89. if tokenQuota > trustQuota {
  90. // 令牌额度充足,信任令牌
  91. preConsumedQuota = 0
  92. logger.LogInfo(c, fmt.Sprintf("用户 %d 剩余额度 %s 且令牌 %d 额度 %d 充足, 信任且不需要预扣费", relayInfo.UserId, logger.FormatQuota(userQuota), relayInfo.TokenId, tokenQuota))
  93. }
  94. } else {
  95. // in this case, we do not pre-consume quota
  96. // because the user has enough quota
  97. preConsumedQuota = 0
  98. logger.LogInfo(c, fmt.Sprintf("用户 %d 额度充足且为无限额度令牌, 信任且不需要预扣费", relayInfo.UserId))
  99. }
  100. }
  101. if preConsumedQuota > 0 {
  102. err := PreConsumeTokenQuota(relayInfo, preConsumedQuota)
  103. if err != nil {
  104. return types.NewErrorWithStatusCode(err, types.ErrorCodePreConsumeTokenQuotaFailed, http.StatusForbidden, types.ErrOptionWithSkipRetry(), types.ErrOptionWithNoRecordErrorLog())
  105. }
  106. err = model.DecreaseUserQuota(relayInfo.UserId, preConsumedQuota)
  107. if err != nil {
  108. return types.NewError(err, types.ErrorCodeUpdateDataError, types.ErrOptionWithSkipRetry())
  109. }
  110. logger.LogInfo(c, fmt.Sprintf("用户 %d 预扣费 %s, 预扣费后剩余额度: %s", relayInfo.UserId, logger.FormatQuota(preConsumedQuota), logger.FormatQuota(userQuota-preConsumedQuota)))
  111. }
  112. relayInfo.FinalPreConsumedQuota = preConsumedQuota
  113. return nil
  114. }