Procházet zdrojové kódy

fix: 预扣额度使用 relay info 传递

Xyfacai před 5 měsíci
rodič
revize
b25ac0bfb6
2 změnil soubory, kde provedl 14 přidání a 14 odebrání
  1. 3 3
      controller/relay.go
  2. 11 11
      service/pre_consume_quota.go

+ 3 - 3
controller/relay.go

@@ -139,15 +139,15 @@ func Relay(c *gin.Context, relayFormat types.RelayFormat) {
 
 	// common.SetContextKey(c, constant.ContextKeyTokenCountMeta, meta)
 
-	preConsumedQuota, newAPIError := service.PreConsumeQuota(c, priceData.ShouldPreConsumedQuota, relayInfo)
+	newAPIError = service.PreConsumeQuota(c, priceData.ShouldPreConsumedQuota, relayInfo)
 	if newAPIError != nil {
 		return
 	}
 
 	defer func() {
 		// Only return quota if downstream failed and quota was actually pre-consumed
-		if newAPIError != nil && preConsumedQuota != 0 {
-			service.ReturnPreConsumedQuota(c, relayInfo, preConsumedQuota)
+		if newAPIError != nil && relayInfo.FinalPreConsumedQuota != 0 {
+			service.ReturnPreConsumedQuota(c, relayInfo)
 		}
 	}()
 

+ 11 - 11
service/pre_consume_quota.go

@@ -13,13 +13,13 @@ import (
 	"github.com/gin-gonic/gin"
 )
 
-func ReturnPreConsumedQuota(c *gin.Context, relayInfo *relaycommon.RelayInfo, preConsumedQuota int) {
-	if preConsumedQuota != 0 {
-		logger.LogInfo(c, fmt.Sprintf("用户 %d 请求失败, 返还预扣费额度 %s", relayInfo.UserId, logger.FormatQuota(preConsumedQuota)))
+func ReturnPreConsumedQuota(c *gin.Context, relayInfo *relaycommon.RelayInfo) {
+	if relayInfo.FinalPreConsumedQuota != 0 {
+		logger.LogInfo(c, fmt.Sprintf("用户 %d 请求失败, 返还预扣费额度 %s", relayInfo.UserId, logger.FormatQuota(relayInfo.FinalPreConsumedQuota)))
 		gopool.Go(func() {
 			relayInfoCopy := *relayInfo
 
-			err := PostConsumeQuota(&relayInfoCopy, -preConsumedQuota, 0, false)
+			err := PostConsumeQuota(&relayInfoCopy, -relayInfo.FinalPreConsumedQuota, 0, false)
 			if err != nil {
 				common.SysLog("error return pre-consumed quota: " + err.Error())
 			}
@@ -29,16 +29,16 @@ func ReturnPreConsumedQuota(c *gin.Context, relayInfo *relaycommon.RelayInfo, pr
 
 // PreConsumeQuota checks if the user has enough quota to pre-consume.
 // It returns the pre-consumed quota if successful, or an error if not.
-func PreConsumeQuota(c *gin.Context, preConsumedQuota int, relayInfo *relaycommon.RelayInfo) (int, *types.NewAPIError) {
+func PreConsumeQuota(c *gin.Context, preConsumedQuota int, relayInfo *relaycommon.RelayInfo) *types.NewAPIError {
 	userQuota, err := model.GetUserQuota(relayInfo.UserId, false)
 	if err != nil {
-		return 0, types.NewError(err, types.ErrorCodeQueryDataError, types.ErrOptionWithSkipRetry())
+		return types.NewError(err, types.ErrorCodeQueryDataError, types.ErrOptionWithSkipRetry())
 	}
 	if userQuota <= 0 {
-		return 0, types.NewErrorWithStatusCode(fmt.Errorf("用户额度不足, 剩余额度: %s", logger.FormatQuota(userQuota)), types.ErrorCodeInsufficientUserQuota, http.StatusForbidden, types.ErrOptionWithSkipRetry(), types.ErrOptionWithNoRecordErrorLog())
+		return types.NewErrorWithStatusCode(fmt.Errorf("用户额度不足, 剩余额度: %s", logger.FormatQuota(userQuota)), types.ErrorCodeInsufficientUserQuota, http.StatusForbidden, types.ErrOptionWithSkipRetry(), types.ErrOptionWithNoRecordErrorLog())
 	}
 	if userQuota-preConsumedQuota < 0 {
-		return 0, types.NewErrorWithStatusCode(fmt.Errorf("预扣费额度失败, 用户剩余额度: %s, 需要预扣费额度: %s", logger.FormatQuota(userQuota), logger.FormatQuota(preConsumedQuota)), types.ErrorCodeInsufficientUserQuota, http.StatusForbidden, types.ErrOptionWithSkipRetry(), types.ErrOptionWithNoRecordErrorLog())
+		return types.NewErrorWithStatusCode(fmt.Errorf("预扣费额度失败, 用户剩余额度: %s, 需要预扣费额度: %s", logger.FormatQuota(userQuota), logger.FormatQuota(preConsumedQuota)), types.ErrorCodeInsufficientUserQuota, http.StatusForbidden, types.ErrOptionWithSkipRetry(), types.ErrOptionWithNoRecordErrorLog())
 	}
 
 	trustQuota := common.GetTrustQuota()
@@ -65,14 +65,14 @@ func PreConsumeQuota(c *gin.Context, preConsumedQuota int, relayInfo *relaycommo
 	if preConsumedQuota > 0 {
 		err := PreConsumeTokenQuota(relayInfo, preConsumedQuota)
 		if err != nil {
-			return 0, types.NewErrorWithStatusCode(err, types.ErrorCodePreConsumeTokenQuotaFailed, http.StatusForbidden, types.ErrOptionWithSkipRetry(), types.ErrOptionWithNoRecordErrorLog())
+			return types.NewErrorWithStatusCode(err, types.ErrorCodePreConsumeTokenQuotaFailed, http.StatusForbidden, types.ErrOptionWithSkipRetry(), types.ErrOptionWithNoRecordErrorLog())
 		}
 		err = model.DecreaseUserQuota(relayInfo.UserId, preConsumedQuota)
 		if err != nil {
-			return 0, types.NewError(err, types.ErrorCodeUpdateDataError, types.ErrOptionWithSkipRetry())
+			return types.NewError(err, types.ErrorCodeUpdateDataError, types.ErrOptionWithSkipRetry())
 		}
 		logger.LogInfo(c, fmt.Sprintf("用户 %d 预扣费 %s, 预扣费后剩余额度: %s", relayInfo.UserId, logger.FormatQuota(preConsumedQuota), logger.FormatQuota(userQuota-preConsumedQuota)))
 	}
 	relayInfo.FinalPreConsumedQuota = preConsumedQuota
-	return preConsumedQuota, nil
+	return nil
 }