Преглед изворни кода

fix: fix quota not consuming

JustSong пре 2 година
родитељ
комит
8afdc56b11
3 измењених фајлова са 57 додато и 8 уклоњено
  1. 9 1
      controller/relay.go
  2. 1 1
      middleware/auth.go
  3. 47 6
      model/token.go

+ 9 - 1
controller/relay.go

@@ -128,6 +128,13 @@ func relayHelper(c *gin.Context) error {
 		model_ = strings.TrimSuffix(model_, "-0314")
 		model_ = strings.TrimSuffix(model_, "-0314")
 		fullRequestURL = fmt.Sprintf("%s/openai/deployments/%s/%s", baseURL, model_, task)
 		fullRequestURL = fmt.Sprintf("%s/openai/deployments/%s/%s", baseURL, model_, task)
 	}
 	}
+	preConsumedQuota := 500 // TODO: make this configurable, take ratio into account
+	if consumeQuota {
+		err := model.PreConsumeTokenQuota(tokenId, preConsumedQuota)
+		if err != nil {
+			return err
+		}
+	}
 	req, err := http.NewRequest(c.Request.Method, fullRequestURL, c.Request.Body)
 	req, err := http.NewRequest(c.Request.Method, fullRequestURL, c.Request.Body)
 	if err != nil {
 	if err != nil {
 		return err
 		return err
@@ -179,7 +186,8 @@ func relayHelper(c *gin.Context) error {
 			}
 			}
 			ratio := common.GetModelRatio(textRequest.Model)
 			ratio := common.GetModelRatio(textRequest.Model)
 			quota = int(float64(quota) * ratio)
 			quota = int(float64(quota) * ratio)
-			err := model.DecreaseTokenQuota(tokenId, quota)
+			quotaDelta := quota - preConsumedQuota
+			err := model.PostConsumeTokenQuota(tokenId, quotaDelta)
 			if err != nil {
 			if err != nil {
 				common.SysError("Error consuming token remain quota: " + err.Error())
 				common.SysError("Error consuming token remain quota: " + err.Error())
 			}
 			}

+ 1 - 1
middleware/auth.go

@@ -111,7 +111,7 @@ func TokenAuth() func(c *gin.Context) {
 		c.Set("id", token.UserId)
 		c.Set("id", token.UserId)
 		c.Set("token_id", token.Id)
 		c.Set("token_id", token.Id)
 		requestURL := c.Request.URL.String()
 		requestURL := c.Request.URL.String()
-		consumeQuota := !token.UnlimitedQuota
+		consumeQuota := true
 		if strings.HasPrefix(requestURL, "/v1/models") {
 		if strings.HasPrefix(requestURL, "/v1/models") {
 			consumeQuota = false
 			consumeQuota = false
 		}
 		}

+ 47 - 6
model/token.go

@@ -130,7 +130,23 @@ func DeleteTokenById(id int, userId int) (err error) {
 	return token.Delete()
 	return token.Delete()
 }
 }
 
 
-func DecreaseTokenQuota(tokenId int, quota int) (err error) {
+func IncreaseTokenQuota(id int, quota int) (err error) {
+	if quota < 0 {
+		return errors.New("quota 不能为负数!")
+	}
+	err = DB.Model(&Token{}).Where("id = ?", id).Update("remain_quota", gorm.Expr("remain_quota + ?", quota)).Error
+	return err
+}
+
+func DecreaseTokenQuota(id int, quota int) (err error) {
+	if quota < 0 {
+		return errors.New("quota 不能为负数!")
+	}
+	err = DB.Model(&Token{}).Where("id = ?", id).Update("remain_quota", gorm.Expr("remain_quota - ?", quota)).Error
+	return err
+}
+
+func PreConsumeTokenQuota(tokenId int, quota int) (err error) {
 	if quota < 0 {
 	if quota < 0 {
 		return errors.New("quota 不能为负数!")
 		return errors.New("quota 不能为负数!")
 	}
 	}
@@ -138,7 +154,7 @@ func DecreaseTokenQuota(tokenId int, quota int) (err error) {
 	if err != nil {
 	if err != nil {
 		return err
 		return err
 	}
 	}
-	if token.RemainQuota < quota {
+	if !token.UnlimitedQuota && token.RemainQuota < quota {
 		return errors.New("令牌额度不足")
 		return errors.New("令牌额度不足")
 	}
 	}
 	userQuota, err := GetUserQuota(token.UserId)
 	userQuota, err := GetUserQuota(token.UserId)
@@ -163,17 +179,42 @@ func DecreaseTokenQuota(tokenId int, quota int) (err error) {
 			if email != "" {
 			if email != "" {
 				topUpLink := fmt.Sprintf("%s/topup", common.ServerAddress)
 				topUpLink := fmt.Sprintf("%s/topup", common.ServerAddress)
 				err = common.SendEmail(prompt, email,
 				err = common.SendEmail(prompt, email,
-					fmt.Sprintf("%s,剩余额度为 %d,为了不影响您的使用,请及时充值。<br/>充值链接:<a href='%s'>%s</a>", prompt, userQuota-quota, topUpLink, topUpLink))
+					fmt.Sprintf("%s,当前剩余额度为 %d,为了不影响您的使用,请及时充值。<br/>充值链接:<a href='%s'>%s</a>", prompt, userQuota, topUpLink, topUpLink))
 				if err != nil {
 				if err != nil {
 					common.SysError("发送邮件失败:" + err.Error())
 					common.SysError("发送邮件失败:" + err.Error())
 				}
 				}
 			}
 			}
 		}()
 		}()
 	}
 	}
-	err = DB.Model(&Token{}).Where("id = ?", tokenId).Update("remain_quota", gorm.Expr("remain_quota - ?", quota)).Error
-	if err != nil {
-		return err
+	if !token.UnlimitedQuota {
+		err = DecreaseTokenQuota(tokenId, quota)
+		if err != nil {
+			return err
+		}
 	}
 	}
 	err = DecreaseUserQuota(token.UserId, quota)
 	err = DecreaseUserQuota(token.UserId, quota)
 	return err
 	return err
 }
 }
+
+func PostConsumeTokenQuota(tokenId int, quota int) (err error) {
+	token, err := GetTokenById(tokenId)
+	if quota > 0 {
+		err = DecreaseUserQuota(token.UserId, quota)
+	} else {
+		err = IncreaseUserQuota(token.UserId, -quota)
+	}
+	if err != nil {
+		return err
+	}
+	if !token.UnlimitedQuota {
+		if quota > 0 {
+			err = DecreaseTokenQuota(tokenId, quota)
+		} else {
+			err = IncreaseTokenQuota(tokenId, -quota)
+		}
+		if err != nil {
+			return err
+		}
+	}
+	return nil
+}