package service import ( "fmt" "net/http" "time" "github.com/QuantumNous/new-api/common" "github.com/QuantumNous/new-api/logger" "github.com/QuantumNous/new-api/model" relaycommon "github.com/QuantumNous/new-api/relay/common" "github.com/QuantumNous/new-api/types" "github.com/bytedance/gopkg/util/gopool" "github.com/gin-gonic/gin" ) func ReturnPreConsumedQuota(c *gin.Context, relayInfo *relaycommon.RelayInfo) { // Always refund subscription pre-consumed (can be non-zero even when FinalPreConsumedQuota is 0) needRefundSub := relayInfo.BillingSource == BillingSourceSubscription && relayInfo.SubscriptionId != 0 && relayInfo.SubscriptionPreConsumed > 0 needRefundToken := relayInfo.FinalPreConsumedQuota != 0 if !needRefundSub && !needRefundToken { return } logger.LogInfo(c, fmt.Sprintf("用户 %d 请求失败, 返还预扣费(token_quota=%s, subscription=%d)", relayInfo.UserId, logger.FormatQuota(relayInfo.FinalPreConsumedQuota), relayInfo.SubscriptionPreConsumed, )) gopool.Go(func() { relayInfoCopy := *relayInfo if relayInfoCopy.BillingSource == BillingSourceSubscription { if needRefundSub { if err := refundWithRetry(func() error { return model.RefundSubscriptionPreConsume(relayInfoCopy.RequestId) }); err != nil { common.SysLog("error refund subscription pre-consume: " + err.Error()) } } // refund token quota only if needRefundToken && !relayInfoCopy.IsPlayground { _ = model.IncreaseTokenQuota(relayInfoCopy.TokenId, relayInfoCopy.TokenKey, relayInfoCopy.FinalPreConsumedQuota) } return } // wallet refund uses existing path (user quota + token quota) if needRefundToken { err := PostConsumeQuota(&relayInfoCopy, -relayInfoCopy.FinalPreConsumedQuota, 0, false) if err != nil { common.SysLog("error return pre-consumed quota: " + err.Error()) } } }) } func refundWithRetry(fn func() error) error { if fn == nil { return nil } const maxAttempts = 3 var lastErr error for i := 0; i < maxAttempts; i++ { if err := fn(); err == nil { return nil } else { lastErr = err } if i < maxAttempts-1 { time.Sleep(time.Duration(200*(i+1)) * time.Millisecond) } } return lastErr } // PreConsumeQuota checks if the user has enough quota to pre-consume. // It returns the pre-consumed quota if successful, or an error if not. func PreConsumeQuota(c *gin.Context, preConsumedQuota int, relayInfo *relaycommon.RelayInfo) *types.NewAPIError { userQuota, err := model.GetUserQuota(relayInfo.UserId, false) if err != nil { return types.NewError(err, types.ErrorCodeQueryDataError, types.ErrOptionWithSkipRetry()) } if userQuota <= 0 { return types.NewErrorWithStatusCode(fmt.Errorf("用户额度不足, 剩余额度: %s", logger.FormatQuota(userQuota)), types.ErrorCodeInsufficientUserQuota, http.StatusForbidden, types.ErrOptionWithSkipRetry(), types.ErrOptionWithNoRecordErrorLog()) } if userQuota-preConsumedQuota < 0 { return types.NewErrorWithStatusCode(fmt.Errorf("预扣费额度失败, 用户剩余额度: %s, 需要预扣费额度: %s", logger.FormatQuota(userQuota), logger.FormatQuota(preConsumedQuota)), types.ErrorCodeInsufficientUserQuota, http.StatusForbidden, types.ErrOptionWithSkipRetry(), types.ErrOptionWithNoRecordErrorLog()) } trustQuota := common.GetTrustQuota() relayInfo.UserQuota = userQuota if userQuota > trustQuota { // 用户额度充足,判断令牌额度是否充足 if !relayInfo.TokenUnlimited { // 非无限令牌,判断令牌额度是否充足 tokenQuota := c.GetInt("token_quota") if tokenQuota > trustQuota { // 令牌额度充足,信任令牌 preConsumedQuota = 0 logger.LogInfo(c, fmt.Sprintf("用户 %d 剩余额度 %s 且令牌 %d 额度 %d 充足, 信任且不需要预扣费", relayInfo.UserId, logger.FormatQuota(userQuota), relayInfo.TokenId, tokenQuota)) } } else { // in this case, we do not pre-consume quota // because the user has enough quota preConsumedQuota = 0 logger.LogInfo(c, fmt.Sprintf("用户 %d 额度充足且为无限额度令牌, 信任且不需要预扣费", relayInfo.UserId)) } } if preConsumedQuota > 0 { err := PreConsumeTokenQuota(relayInfo, preConsumedQuota) if err != nil { return types.NewErrorWithStatusCode(err, types.ErrorCodePreConsumeTokenQuotaFailed, http.StatusForbidden, types.ErrOptionWithSkipRetry(), types.ErrOptionWithNoRecordErrorLog()) } err = model.DecreaseUserQuota(relayInfo.UserId, preConsumedQuota) if err != nil { return types.NewError(err, types.ErrorCodeUpdateDataError, types.ErrOptionWithSkipRetry()) } logger.LogInfo(c, fmt.Sprintf("用户 %d 预扣费 %s, 预扣费后剩余额度: %s", relayInfo.UserId, logger.FormatQuota(preConsumedQuota), logger.FormatQuota(userQuota-preConsumedQuota))) } relayInfo.FinalPreConsumedQuota = preConsumedQuota return nil }