Просмотр исходного кода

fix: 部分情况缺少返回预扣

(cherry picked from commit 96373455521a38095706bd81c57f9a18557d9c2e)
Xyfacai 1 год назад
Родитель
Сommit
f0907bf60a
5 измененных файлов с 33 добавлено и 18 удалено
  1. 4 4
      relay/channel/openai/relay-openai.go
  2. 7 3
      relay/relay-audio.go
  3. 7 5
      relay/relay-text.go
  4. 8 4
      relay/relay_rerank.go
  5. 7 2
      relay/websocket.go

+ 4 - 4
relay/channel/openai/relay-openai.go

@@ -391,7 +391,7 @@ func OpenaiRealtimeHandler(c *gin.Context, info *relaycommon.RelayInfo) (*dto.Op
 	localUsage := &dto.RealtimeUsage{}
 	localUsage := &dto.RealtimeUsage{}
 	sumUsage := &dto.RealtimeUsage{}
 	sumUsage := &dto.RealtimeUsage{}
 
 
-	go func() {
+	gopool.Go(func() {
 		for {
 		for {
 			select {
 			select {
 			case <-c.Done():
 			case <-c.Done():
@@ -444,9 +444,9 @@ func OpenaiRealtimeHandler(c *gin.Context, info *relaycommon.RelayInfo) (*dto.Op
 				}
 				}
 			}
 			}
 		}
 		}
-	}()
+	})
 
 
-	go func() {
+	gopool.Go(func() {
 		for {
 		for {
 			select {
 			select {
 			case <-c.Done():
 			case <-c.Done():
@@ -541,7 +541,7 @@ func OpenaiRealtimeHandler(c *gin.Context, info *relaycommon.RelayInfo) (*dto.Op
 				}
 				}
 			}
 			}
 		}
 		}
-	}()
+	})
 
 
 	select {
 	select {
 	case <-clientClosed:
 	case <-clientClosed:

+ 7 - 3
relay/relay-audio.go

@@ -46,7 +46,7 @@ func getAndValidAudioRequest(c *gin.Context, info *relaycommon.RelayInfo) (*dto.
 	return audioRequest, nil
 	return audioRequest, nil
 }
 }
 
 
-func AudioHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode {
+func AudioHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) {
 	relayInfo := relaycommon.GenRelayInfo(c)
 	relayInfo := relaycommon.GenRelayInfo(c)
 	audioRequest, err := getAndValidAudioRequest(c, relayInfo)
 	audioRequest, err := getAndValidAudioRequest(c, relayInfo)
 
 
@@ -92,6 +92,11 @@ func AudioHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode {
 			return service.OpenAIErrorWrapperLocal(err, "pre_consume_token_quota_failed", http.StatusForbidden)
 			return service.OpenAIErrorWrapperLocal(err, "pre_consume_token_quota_failed", http.StatusForbidden)
 		}
 		}
 	}
 	}
+	defer func() {
+		if openaiErr != nil {
+			returnPreConsumedQuota(c, relayInfo, userQuota, preConsumedQuota)
+		}
+	}()
 
 
 	// map model name
 	// map model name
 	modelMapping := c.GetString("model_mapping")
 	modelMapping := c.GetString("model_mapping")
@@ -128,8 +133,7 @@ func AudioHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode {
 	if resp != nil {
 	if resp != nil {
 		httpResp = resp.(*http.Response)
 		httpResp = resp.(*http.Response)
 		if httpResp.StatusCode != http.StatusOK {
 		if httpResp.StatusCode != http.StatusOK {
-			returnPreConsumedQuota(c, relayInfo, userQuota, preConsumedQuota)
-			openaiErr := service.RelayErrorHandler(httpResp)
+			openaiErr = service.RelayErrorHandler(httpResp)
 			// reset status code 重置状态码
 			// reset status code 重置状态码
 			service.ResetStatusCode(openaiErr, statusCodeMappingStr)
 			service.ResetStatusCode(openaiErr, statusCodeMappingStr)
 			return openaiErr
 			return openaiErr

+ 7 - 5
relay/relay-text.go

@@ -64,7 +64,7 @@ func getAndValidateTextRequest(c *gin.Context, relayInfo *relaycommon.RelayInfo)
 	return textRequest, nil
 	return textRequest, nil
 }
 }
 
 
-func TextHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode {
+func TextHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) {
 
 
 	relayInfo := relaycommon.GenRelayInfo(c)
 	relayInfo := relaycommon.GenRelayInfo(c)
 
 
@@ -131,7 +131,11 @@ func TextHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode {
 	if openaiErr != nil {
 	if openaiErr != nil {
 		return openaiErr
 		return openaiErr
 	}
 	}
-
+	defer func() {
+		if openaiErr != nil {
+			returnPreConsumedQuota(c, relayInfo, userQuota, preConsumedQuota)
+		}
+	}()
 	includeUsage := false
 	includeUsage := false
 	// 判断用户是否需要返回使用情况
 	// 判断用户是否需要返回使用情况
 	if textRequest.StreamOptions != nil && textRequest.StreamOptions.IncludeUsage {
 	if textRequest.StreamOptions != nil && textRequest.StreamOptions.IncludeUsage {
@@ -190,8 +194,7 @@ func TextHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode {
 		httpResp = resp.(*http.Response)
 		httpResp = resp.(*http.Response)
 		relayInfo.IsStream = relayInfo.IsStream || strings.HasPrefix(httpResp.Header.Get("Content-Type"), "text/event-stream")
 		relayInfo.IsStream = relayInfo.IsStream || strings.HasPrefix(httpResp.Header.Get("Content-Type"), "text/event-stream")
 		if httpResp.StatusCode != http.StatusOK {
 		if httpResp.StatusCode != http.StatusOK {
-			returnPreConsumedQuota(c, relayInfo, userQuota, preConsumedQuota)
-			openaiErr := service.RelayErrorHandler(httpResp)
+			openaiErr = service.RelayErrorHandler(httpResp)
 			// reset status code 重置状态码
 			// reset status code 重置状态码
 			service.ResetStatusCode(openaiErr, statusCodeMappingStr)
 			service.ResetStatusCode(openaiErr, statusCodeMappingStr)
 			return openaiErr
 			return openaiErr
@@ -200,7 +203,6 @@ func TextHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode {
 
 
 	usage, openaiErr := adaptor.DoResponse(c, httpResp, relayInfo)
 	usage, openaiErr := adaptor.DoResponse(c, httpResp, relayInfo)
 	if openaiErr != nil {
 	if openaiErr != nil {
-		returnPreConsumedQuota(c, relayInfo, userQuota, preConsumedQuota)
 		// reset status code 重置状态码
 		// reset status code 重置状态码
 		service.ResetStatusCode(openaiErr, statusCodeMappingStr)
 		service.ResetStatusCode(openaiErr, statusCodeMappingStr)
 		return openaiErr
 		return openaiErr

+ 8 - 4
relay/relay_rerank.go

@@ -23,7 +23,7 @@ func getRerankPromptToken(rerankRequest dto.RerankRequest) int {
 	return token
 	return token
 }
 }
 
 
-func RerankHelper(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode {
+func RerankHelper(c *gin.Context, relayMode int) (openaiErr *dto.OpenAIErrorWithStatusCode) {
 	relayInfo := relaycommon.GenRelayInfo(c)
 	relayInfo := relaycommon.GenRelayInfo(c)
 
 
 	var rerankRequest *dto.RerankRequest
 	var rerankRequest *dto.RerankRequest
@@ -79,6 +79,12 @@ func RerankHelper(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode
 	if openaiErr != nil {
 	if openaiErr != nil {
 		return openaiErr
 		return openaiErr
 	}
 	}
+	defer func() {
+		if openaiErr != nil {
+			returnPreConsumedQuota(c, relayInfo, userQuota, preConsumedQuota)
+		}
+	}()
+
 	adaptor := GetAdaptor(relayInfo.ApiType)
 	adaptor := GetAdaptor(relayInfo.ApiType)
 	if adaptor == nil {
 	if adaptor == nil {
 		return service.OpenAIErrorWrapperLocal(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), "invalid_api_type", http.StatusBadRequest)
 		return service.OpenAIErrorWrapperLocal(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), "invalid_api_type", http.StatusBadRequest)
@@ -104,8 +110,7 @@ func RerankHelper(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode
 	if resp != nil {
 	if resp != nil {
 		httpResp = resp.(*http.Response)
 		httpResp = resp.(*http.Response)
 		if httpResp.StatusCode != http.StatusOK {
 		if httpResp.StatusCode != http.StatusOK {
-			returnPreConsumedQuota(c, relayInfo, userQuota, preConsumedQuota)
-			openaiErr := service.RelayErrorHandler(httpResp)
+			openaiErr = service.RelayErrorHandler(httpResp)
 			// reset status code 重置状态码
 			// reset status code 重置状态码
 			service.ResetStatusCode(openaiErr, statusCodeMappingStr)
 			service.ResetStatusCode(openaiErr, statusCodeMappingStr)
 			return openaiErr
 			return openaiErr
@@ -114,7 +119,6 @@ func RerankHelper(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode
 
 
 	usage, openaiErr := adaptor.DoResponse(c, httpResp, relayInfo)
 	usage, openaiErr := adaptor.DoResponse(c, httpResp, relayInfo)
 	if openaiErr != nil {
 	if openaiErr != nil {
-		returnPreConsumedQuota(c, relayInfo, userQuota, preConsumedQuota)
 		// reset status code 重置状态码
 		// reset status code 重置状态码
 		service.ResetStatusCode(openaiErr, statusCodeMappingStr)
 		service.ResetStatusCode(openaiErr, statusCodeMappingStr)
 		return openaiErr
 		return openaiErr

+ 7 - 2
relay/websocket.go

@@ -30,7 +30,7 @@ import (
 //	return realtimeEvent, nil
 //	return realtimeEvent, nil
 //}
 //}
 
 
-func WssHelper(c *gin.Context, ws *websocket.Conn) *dto.OpenAIErrorWithStatusCode {
+func WssHelper(c *gin.Context, ws *websocket.Conn) (openaiErr *dto.OpenAIErrorWithStatusCode) {
 	relayInfo := relaycommon.GenRelayInfoWs(c, ws)
 	relayInfo := relaycommon.GenRelayInfoWs(c, ws)
 
 
 	// get & validate textRequest 获取并验证文本请求
 	// get & validate textRequest 获取并验证文本请求
@@ -96,6 +96,12 @@ func WssHelper(c *gin.Context, ws *websocket.Conn) *dto.OpenAIErrorWithStatusCod
 		return openaiErr
 		return openaiErr
 	}
 	}
 
 
+	defer func() {
+		if openaiErr != nil {
+			returnPreConsumedQuota(c, relayInfo, userQuota, preConsumedQuota)
+		}
+	}()
+
 	adaptor := GetAdaptor(relayInfo.ApiType)
 	adaptor := GetAdaptor(relayInfo.ApiType)
 	if adaptor == nil {
 	if adaptor == nil {
 		return service.OpenAIErrorWrapperLocal(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), "invalid_api_type", http.StatusBadRequest)
 		return service.OpenAIErrorWrapperLocal(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), "invalid_api_type", http.StatusBadRequest)
@@ -118,7 +124,6 @@ func WssHelper(c *gin.Context, ws *websocket.Conn) *dto.OpenAIErrorWithStatusCod
 
 
 	usage, openaiErr := adaptor.DoResponse(c, nil, relayInfo)
 	usage, openaiErr := adaptor.DoResponse(c, nil, relayInfo)
 	if openaiErr != nil {
 	if openaiErr != nil {
-		returnPreConsumedQuota(c, relayInfo, userQuota, preConsumedQuota)
 		// reset status code 重置状态码
 		// reset status code 重置状态码
 		service.ResetStatusCode(openaiErr, statusCodeMappingStr)
 		service.ResetStatusCode(openaiErr, statusCodeMappingStr)
 		return openaiErr
 		return openaiErr