Просмотр исходного кода

删除relay-text中的consumeQuota变量

CaIon 2 лет назад
Родитель
Сommit
53ce243736
2 измененных файлов с 55 добавлено и 60 удалено
  1. 21 23
      controller/relay-openai.go
  2. 34 37
      controller/relay-text.go

+ 21 - 23
controller/relay-openai.go

@@ -88,30 +88,28 @@ func openaiStreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*O
 	return nil, responseText
 }
 
-func openaiHandler(c *gin.Context, resp *http.Response, consumeQuota bool, promptTokens int, model string) (*OpenAIErrorWithStatusCode, *Usage) {
+func openaiHandler(c *gin.Context, resp *http.Response, promptTokens int, model string) (*OpenAIErrorWithStatusCode, *Usage) {
 	var textResponse TextResponse
-	if consumeQuota {
-		responseBody, err := io.ReadAll(resp.Body)
-		if err != nil {
-			return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
-		}
-		err = resp.Body.Close()
-		if err != nil {
-			return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
-		}
-		err = json.Unmarshal(responseBody, &textResponse)
-		if err != nil {
-			return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
-		}
-		if textResponse.Error.Type != "" {
-			return &OpenAIErrorWithStatusCode{
-				OpenAIError: textResponse.Error,
-				StatusCode:  resp.StatusCode,
-			}, nil
-		}
-		// Reset response body
-		resp.Body = io.NopCloser(bytes.NewBuffer(responseBody))
+	responseBody, err := io.ReadAll(resp.Body)
+	if err != nil {
+		return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
+	}
+	err = resp.Body.Close()
+	if err != nil {
+		return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
+	}
+	err = json.Unmarshal(responseBody, &textResponse)
+	if err != nil {
+		return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
+	}
+	if textResponse.Error.Type != "" {
+		return &OpenAIErrorWithStatusCode{
+			OpenAIError: textResponse.Error,
+			StatusCode:  resp.StatusCode,
+		}, nil
 	}
+	// Reset response body
+	resp.Body = io.NopCloser(bytes.NewBuffer(responseBody))
 	// We shouldn't set the header before we parse the response body, because the parse part may fail.
 	// And then we will have to send an error response, but in this case, the header has already been set.
 	// So the httpClient will be confused by the response.
@@ -120,7 +118,7 @@ func openaiHandler(c *gin.Context, resp *http.Response, consumeQuota bool, promp
 		c.Writer.Header().Set(k, v[0])
 	}
 	c.Writer.WriteHeader(resp.StatusCode)
-	_, err := io.Copy(c.Writer, resp.Body)
+	_, err = io.Copy(c.Writer, resp.Body)
 	if err != nil {
 		return errorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError), nil
 	}

+ 34 - 37
controller/relay-text.go

@@ -50,10 +50,9 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
 	channelId := c.GetInt("channel_id")
 	tokenId := c.GetInt("token_id")
 	userId := c.GetInt("id")
-	consumeQuota := c.GetBool("consume_quota")
 	group := c.GetString("group")
 	var textRequest GeneralOpenAIRequest
-	if consumeQuota || channelType == common.ChannelTypeAzure || channelType == common.ChannelTypePaLM {
+	if channelType == common.ChannelTypeAzure || channelType == common.ChannelTypePaLM {
 		err := common.UnmarshalBodyReusable(c, &textRequest)
 		if err != nil {
 			return errorWrapper(err, "bind_request_body_failed", http.StatusBadRequest)
@@ -236,7 +235,7 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
 		preConsumedQuota = 0
 		//common.LogInfo(c.Request.Context(), fmt.Sprintf("user %d has enough quota %d, trusted and no need to pre-consume", userId, userQuota))
 	}
-	if consumeQuota && preConsumedQuota > 0 {
+	if preConsumedQuota > 0 {
 		userQuota, err = model.PreConsumeTokenQuota(tokenId, preConsumedQuota)
 		if err != nil {
 			return errorWrapper(err, "pre_consume_token_quota_failed", http.StatusForbidden)
@@ -420,41 +419,39 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
 	defer func(ctx context.Context) {
 		// c.Writer.Flush()
 		go func() {
-			if consumeQuota {
-				quota := 0
-				completionRatio := common.GetCompletionRatio(textRequest.Model)
-				promptTokens = textResponse.Usage.PromptTokens
-				completionTokens = textResponse.Usage.CompletionTokens
+			quota := 0
+			completionRatio := common.GetCompletionRatio(textRequest.Model)
+			promptTokens = textResponse.Usage.PromptTokens
+			completionTokens = textResponse.Usage.CompletionTokens
 
-				quota = promptTokens + int(float64(completionTokens)*completionRatio)
-				quota = int(float64(quota) * ratio)
-				if ratio != 0 && quota <= 0 {
-					quota = 1
-				}
-				totalTokens := promptTokens + completionTokens
-				if totalTokens == 0 {
-					// in this case, must be some error happened
-					// we cannot just return, because we may have to return the pre-consumed quota
-					quota = 0
-				}
-				quotaDelta := quota - preConsumedQuota
-				err := model.PostConsumeTokenQuota(tokenId, userQuota, quotaDelta, preConsumedQuota, true)
-				if err != nil {
-					common.LogError(ctx, "error consuming token remain quota: "+err.Error())
-				}
-				err = model.CacheUpdateUserQuota(userId)
-				if err != nil {
-					common.LogError(ctx, "error update user quota cache: "+err.Error())
-				}
-				// record all the consume log even if quota is 0
-				logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio)
-				model.RecordConsumeLog(ctx, userId, channelId, promptTokens, completionTokens, textRequest.Model, tokenName, quota, logContent, tokenId)
-				model.UpdateUserUsedQuotaAndRequestCount(userId, quota)
-				model.UpdateChannelUsedQuota(channelId, quota)
-				//if quota != 0 {
-				//
-				//}
+			quota = promptTokens + int(float64(completionTokens)*completionRatio)
+			quota = int(float64(quota) * ratio)
+			if ratio != 0 && quota <= 0 {
+				quota = 1
+			}
+			totalTokens := promptTokens + completionTokens
+			if totalTokens == 0 {
+				// in this case, must be some error happened
+				// we cannot just return, because we may have to return the pre-consumed quota
+				quota = 0
+			}
+			quotaDelta := quota - preConsumedQuota
+			err := model.PostConsumeTokenQuota(tokenId, userQuota, quotaDelta, preConsumedQuota, true)
+			if err != nil {
+				common.LogError(ctx, "error consuming token remain quota: "+err.Error())
 			}
+			err = model.CacheUpdateUserQuota(userId)
+			if err != nil {
+				common.LogError(ctx, "error update user quota cache: "+err.Error())
+			}
+			// record all the consume log even if quota is 0
+			logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio)
+			model.RecordConsumeLog(ctx, userId, channelId, promptTokens, completionTokens, textRequest.Model, tokenName, quota, logContent, tokenId)
+			model.UpdateUserUsedQuotaAndRequestCount(userId, quota)
+			model.UpdateChannelUsedQuota(channelId, quota)
+			//if quota != 0 {
+			//
+			//}
 		}()
 	}(c.Request.Context())
 	switch apiType {
@@ -468,7 +465,7 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
 			textResponse.Usage.CompletionTokens = countTokenText(responseText, textRequest.Model)
 			return nil
 		} else {
-			err, usage := openaiHandler(c, resp, consumeQuota, promptTokens, textRequest.Model)
+			err, usage := openaiHandler(c, resp, promptTokens, textRequest.Model)
 			if err != nil {
 				return err
 			}