| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124 |
- package service
- import (
- "fmt"
- "net/http"
- "time"
- "github.com/QuantumNous/new-api/common"
- "github.com/QuantumNous/new-api/logger"
- "github.com/QuantumNous/new-api/model"
- relaycommon "github.com/QuantumNous/new-api/relay/common"
- "github.com/QuantumNous/new-api/types"
- "github.com/bytedance/gopkg/util/gopool"
- "github.com/gin-gonic/gin"
- )
- func ReturnPreConsumedQuota(c *gin.Context, relayInfo *relaycommon.RelayInfo) {
- // Always refund subscription pre-consumed (can be non-zero even when FinalPreConsumedQuota is 0)
- needRefundSub := relayInfo.BillingSource == BillingSourceSubscription && relayInfo.SubscriptionId != 0 && relayInfo.SubscriptionPreConsumed > 0
- needRefundToken := relayInfo.FinalPreConsumedQuota != 0
- if !needRefundSub && !needRefundToken {
- return
- }
- logger.LogInfo(c, fmt.Sprintf("用户 %d 请求失败, 返还预扣费(token_quota=%s, subscription=%d)",
- relayInfo.UserId,
- logger.FormatQuota(relayInfo.FinalPreConsumedQuota),
- relayInfo.SubscriptionPreConsumed,
- ))
- gopool.Go(func() {
- relayInfoCopy := *relayInfo
- if relayInfoCopy.BillingSource == BillingSourceSubscription {
- if needRefundSub {
- if err := refundWithRetry(func() error {
- return model.RefundSubscriptionPreConsume(relayInfoCopy.RequestId)
- }); err != nil {
- common.SysLog("error refund subscription pre-consume: " + err.Error())
- }
- }
- // refund token quota only
- if needRefundToken && !relayInfoCopy.IsPlayground {
- _ = model.IncreaseTokenQuota(relayInfoCopy.TokenId, relayInfoCopy.TokenKey, relayInfoCopy.FinalPreConsumedQuota)
- }
- return
- }
- // wallet refund uses existing path (user quota + token quota)
- if needRefundToken {
- err := PostConsumeQuota(&relayInfoCopy, -relayInfoCopy.FinalPreConsumedQuota, 0, false)
- if err != nil {
- common.SysLog("error return pre-consumed quota: " + err.Error())
- }
- }
- })
- }
- func refundWithRetry(fn func() error) error {
- if fn == nil {
- return nil
- }
- const maxAttempts = 3
- var lastErr error
- for i := 0; i < maxAttempts; i++ {
- if err := fn(); err == nil {
- return nil
- } else {
- lastErr = err
- }
- if i < maxAttempts-1 {
- time.Sleep(time.Duration(200*(i+1)) * time.Millisecond)
- }
- }
- return lastErr
- }
- // PreConsumeQuota checks if the user has enough quota to pre-consume.
- // It returns the pre-consumed quota if successful, or an error if not.
- func PreConsumeQuota(c *gin.Context, preConsumedQuota int, relayInfo *relaycommon.RelayInfo) *types.NewAPIError {
- userQuota, err := model.GetUserQuota(relayInfo.UserId, false)
- if err != nil {
- return types.NewError(err, types.ErrorCodeQueryDataError, types.ErrOptionWithSkipRetry())
- }
- if userQuota <= 0 {
- return types.NewErrorWithStatusCode(fmt.Errorf("用户额度不足, 剩余额度: %s", logger.FormatQuota(userQuota)), types.ErrorCodeInsufficientUserQuota, http.StatusForbidden, types.ErrOptionWithSkipRetry(), types.ErrOptionWithNoRecordErrorLog())
- }
- if userQuota-preConsumedQuota < 0 {
- return types.NewErrorWithStatusCode(fmt.Errorf("预扣费额度失败, 用户剩余额度: %s, 需要预扣费额度: %s", logger.FormatQuota(userQuota), logger.FormatQuota(preConsumedQuota)), types.ErrorCodeInsufficientUserQuota, http.StatusForbidden, types.ErrOptionWithSkipRetry(), types.ErrOptionWithNoRecordErrorLog())
- }
- trustQuota := common.GetTrustQuota()
- relayInfo.UserQuota = userQuota
- if userQuota > trustQuota {
- // 用户额度充足,判断令牌额度是否充足
- if !relayInfo.TokenUnlimited {
- // 非无限令牌,判断令牌额度是否充足
- tokenQuota := c.GetInt("token_quota")
- if tokenQuota > trustQuota {
- // 令牌额度充足,信任令牌
- preConsumedQuota = 0
- logger.LogInfo(c, fmt.Sprintf("用户 %d 剩余额度 %s 且令牌 %d 额度 %d 充足, 信任且不需要预扣费", relayInfo.UserId, logger.FormatQuota(userQuota), relayInfo.TokenId, tokenQuota))
- }
- } else {
- // in this case, we do not pre-consume quota
- // because the user has enough quota
- preConsumedQuota = 0
- logger.LogInfo(c, fmt.Sprintf("用户 %d 额度充足且为无限额度令牌, 信任且不需要预扣费", relayInfo.UserId))
- }
- }
- if preConsumedQuota > 0 {
- err := PreConsumeTokenQuota(relayInfo, preConsumedQuota)
- if err != nil {
- return types.NewErrorWithStatusCode(err, types.ErrorCodePreConsumeTokenQuotaFailed, http.StatusForbidden, types.ErrOptionWithSkipRetry(), types.ErrOptionWithNoRecordErrorLog())
- }
- err = model.DecreaseUserQuota(relayInfo.UserId, preConsumedQuota)
- if err != nil {
- return types.NewError(err, types.ErrorCodeUpdateDataError, types.ErrOptionWithSkipRetry())
- }
- logger.LogInfo(c, fmt.Sprintf("用户 %d 预扣费 %s, 预扣费后剩余额度: %s", relayInfo.UserId, logger.FormatQuota(preConsumedQuota), logger.FormatQuota(userQuota-preConsumedQuota)))
- }
- relayInfo.FinalPreConsumedQuota = preConsumedQuota
- return nil
- }
|