CaIon 2 лет назад
Родитель
Сommit
9b2e5c2978
1 измененных файлов с 32 добавлено и 39 удалено
  1. 32 39
      relay/relay-image.go

+ 32 - 39
relay/relay-image.go

@@ -24,16 +24,13 @@ func RelayImageHelper(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusC
 	channelType := c.GetInt("channel")
 	channelId := c.GetInt("channel_id")
 	userId := c.GetInt("id")
-	consumeQuota := c.GetBool("consume_quota")
 	group := c.GetString("group")
 	startTime := time.Now()
 
 	var imageRequest dto.ImageRequest
-	if consumeQuota {
-		err := common.UnmarshalBodyReusable(c, &imageRequest)
-		if err != nil {
-			return service.OpenAIErrorWrapper(err, "bind_request_body_failed", http.StatusBadRequest)
-		}
+	err := common.UnmarshalBodyReusable(c, &imageRequest)
+	if err != nil {
+		return service.OpenAIErrorWrapper(err, "bind_request_body_failed", http.StatusBadRequest)
 	}
 
 	if imageRequest.Model == "" {
@@ -136,7 +133,7 @@ func RelayImageHelper(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusC
 
 	quota := int(ratio*sizeRatio*qualityRatio*1000) * imageRequest.N
 
-	if consumeQuota && userQuota-quota < 0 {
+	if userQuota-quota < 0 {
 		return service.OpenAIErrorWrapper(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden)
 	}
 
@@ -176,47 +173,43 @@ func RelayImageHelper(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusC
 	var textResponse dto.ImageResponse
 	defer func(ctx context.Context) {
 		useTimeSeconds := time.Now().Unix() - startTime.Unix()
-		if consumeQuota {
-			if resp.StatusCode != http.StatusOK {
-				return
-			}
-			err := model.PostConsumeTokenQuota(tokenId, userQuota, quota, 0, true)
-			if err != nil {
-				common.SysError("error consuming token remain quota: " + err.Error())
-			}
-			err = model.CacheUpdateUserQuota(userId)
-			if err != nil {
-				common.SysError("error update user quota cache: " + err.Error())
-			}
-			if quota != 0 {
-				tokenName := c.GetString("token_name")
-				logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio)
-				model.RecordConsumeLog(ctx, userId, channelId, 0, 0, imageRequest.Model, tokenName, quota, logContent, tokenId, userQuota, int(useTimeSeconds), false)
-				model.UpdateUserUsedQuotaAndRequestCount(userId, quota)
-				channelId := c.GetInt("channel_id")
-				model.UpdateChannelUsedQuota(channelId, quota)
-			}
+		if resp.StatusCode != http.StatusOK {
+			return
 		}
-	}(c.Request.Context())
-
-	if consumeQuota {
-		responseBody, err := io.ReadAll(resp.Body)
-
+		err := model.PostConsumeTokenQuota(tokenId, userQuota, quota, 0, true)
 		if err != nil {
-			return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError)
+			common.SysError("error consuming token remain quota: " + err.Error())
 		}
-		err = resp.Body.Close()
+		err = model.CacheUpdateUserQuota(userId)
 		if err != nil {
-			return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError)
+			common.SysError("error update user quota cache: " + err.Error())
 		}
-		err = json.Unmarshal(responseBody, &textResponse)
-		if err != nil {
-			return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError)
+		if quota != 0 {
+			tokenName := c.GetString("token_name")
+			logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio)
+			model.RecordConsumeLog(ctx, userId, channelId, 0, 0, imageRequest.Model, tokenName, quota, logContent, tokenId, userQuota, int(useTimeSeconds), false)
+			model.UpdateUserUsedQuotaAndRequestCount(userId, quota)
+			channelId := c.GetInt("channel_id")
+			model.UpdateChannelUsedQuota(channelId, quota)
 		}
+	}(c.Request.Context())
+
+	responseBody, err := io.ReadAll(resp.Body)
 
-		resp.Body = io.NopCloser(bytes.NewBuffer(responseBody))
+	if err != nil {
+		return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError)
+	}
+	err = resp.Body.Close()
+	if err != nil {
+		return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError)
+	}
+	err = json.Unmarshal(responseBody, &textResponse)
+	if err != nil {
+		return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError)
 	}
 
+	resp.Body = io.NopCloser(bytes.NewBuffer(responseBody))
+
 	for k, v := range resp.Header {
 		c.Writer.Header().Set(k, v[0])
 	}