Jelajahi Sumber

fix: handel error response from server correctly (close #90)

JustSong 2 tahun lalu
induk
melakukan
ceb289cb4d
2 mengubah file dengan 61 tambahan dan 37 penghapusan
  1. 4 4
      controller/channel.go
  2. 57 33
      controller/relay.go

+ 4 - 4
controller/channel.go

@@ -265,14 +265,14 @@ var testAllChannelsLock sync.Mutex
 var testAllChannelsRunning bool = false
 
 // disable & notify
-func disableChannel(channelId int, channelName string, err error) {
+func disableChannel(channelId int, channelName string, reason string) {
 	if common.RootUserEmail == "" {
 		common.RootUserEmail = model.GetRootUserEmail()
 	}
 	model.UpdateChannelStatusById(channelId, common.ChannelStatusDisabled)
 	subject := fmt.Sprintf("通道「%s」(#%d)已被禁用", channelName, channelId)
-	content := fmt.Sprintf("通道「%s」(#%d)已被禁用,原因:%s", channelName, channelId, err.Error())
-	err = common.SendEmail(subject, common.RootUserEmail, content)
+	content := fmt.Sprintf("通道「%s」(#%d)已被禁用,原因:%s", channelName, channelId, reason)
+	err := common.SendEmail(subject, common.RootUserEmail, content)
 	if err != nil {
 		common.SysError(fmt.Sprintf("发送邮件失败:%s", err.Error()))
 	}
@@ -312,7 +312,7 @@ func testAllChannels(c *gin.Context) error {
 				if milliseconds > disableThreshold {
 					err = errors.New(fmt.Sprintf("响应时间 %.2fs 超过阈值 %.2fs", float64(milliseconds)/1000.0, float64(disableThreshold)/1000.0))
 				}
-				disableChannel(channel.Id, channel.Name, err)
+				disableChannel(channel.Id, channel.Name, err.Error())
 			}
 			channel.UpdateResponseTime(milliseconds)
 		}

+ 57 - 33
controller/relay.go

@@ -4,7 +4,6 @@ import (
 	"bufio"
 	"bytes"
 	"encoding/json"
-	"errors"
 	"fmt"
 	"github.com/gin-gonic/gin"
 	"github.com/pkoukk/tiktoken-go"
@@ -47,6 +46,11 @@ type OpenAIError struct {
 	Code    string `json:"code"`
 }
 
+type OpenAIErrorWithStatusCode struct {
+	OpenAIError
+	StatusCode int `json:"status_code"`
+}
+
 type TextResponse struct {
 	Usage `json:"usage"`
 	Error OpenAIError `json:"error"`
@@ -71,23 +75,33 @@ func countToken(text string) int {
 func Relay(c *gin.Context) {
 	err := relayHelper(c)
 	if err != nil {
-		c.JSON(http.StatusOK, gin.H{
-			"error": gin.H{
-				"message": err.Error(),
-				"type":    "one_api_error",
-			},
+		c.JSON(err.StatusCode, gin.H{
+			"error": err.OpenAIError,
 		})
 		channelId := c.GetInt("channel_id")
-		common.SysError(fmt.Sprintf("Relay error: %s, channel id: %d", err.Error(), channelId))
-		if common.AutomaticDisableChannelEnabled {
+		common.SysError(fmt.Sprintf("Relay error (channel #%d): %s", channelId, err.Message))
+		if err.Type != "invalid_request_error" && err.StatusCode != http.StatusTooManyRequests &&
+			common.AutomaticDisableChannelEnabled {
 			channelId := c.GetInt("channel_id")
 			channelName := c.GetString("channel_name")
-			disableChannel(channelId, channelName, err)
+			disableChannel(channelId, channelName, err.Message)
 		}
 	}
 }
 
-func relayHelper(c *gin.Context) error {
+func errorWrapper(err error, code string, statusCode int) *OpenAIErrorWithStatusCode {
+	openAIError := OpenAIError{
+		Message: err.Error(),
+		Type:    "one_api_error",
+		Code:    code,
+	}
+	return &OpenAIErrorWithStatusCode{
+		OpenAIError: openAIError,
+		StatusCode:  statusCode,
+	}
+}
+
+func relayHelper(c *gin.Context) *OpenAIErrorWithStatusCode {
 	channelType := c.GetInt("channel")
 	tokenId := c.GetInt("token_id")
 	consumeQuota := c.GetBool("consume_quota")
@@ -95,15 +109,15 @@ func relayHelper(c *gin.Context) error {
 	if consumeQuota || channelType == common.ChannelTypeAzure {
 		requestBody, err := io.ReadAll(c.Request.Body)
 		if err != nil {
-			return err
+			return errorWrapper(err, "read_request_body_failed", http.StatusBadRequest)
 		}
 		err = c.Request.Body.Close()
 		if err != nil {
-			return err
+			return errorWrapper(err, "close_request_body_failed", http.StatusBadRequest)
 		}
 		err = json.Unmarshal(requestBody, &textRequest)
 		if err != nil {
-			return err
+			return errorWrapper(err, "unmarshal_request_body_failed", http.StatusBadRequest)
 		}
 		// Reset request body
 		c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody))
@@ -146,12 +160,12 @@ func relayHelper(c *gin.Context) error {
 	if consumeQuota {
 		err := model.PreConsumeTokenQuota(tokenId, preConsumedQuota)
 		if err != nil {
-			return err
+			return errorWrapper(err, "pre_consume_token_quota_failed", http.StatusOK)
 		}
 	}
 	req, err := http.NewRequest(c.Request.Method, fullRequestURL, c.Request.Body)
 	if err != nil {
-		return err
+		return errorWrapper(err, "new_request_failed", http.StatusOK)
 	}
 	if channelType == common.ChannelTypeAzure {
 		key := c.Request.Header.Get("Authorization")
@@ -166,15 +180,15 @@ func relayHelper(c *gin.Context) error {
 	client := &http.Client{}
 	resp, err := client.Do(req)
 	if err != nil {
-		return err
+		return errorWrapper(err, "do_request_failed", http.StatusOK)
 	}
 	err = req.Body.Close()
 	if err != nil {
-		return err
+		return errorWrapper(err, "close_request_body_failed", http.StatusOK)
 	}
 	err = c.Request.Body.Close()
 	if err != nil {
-		return err
+		return errorWrapper(err, "close_request_body_failed", http.StatusOK)
 	}
 	var textResponse TextResponse
 	isStream := strings.HasPrefix(resp.Header.Get("Content-Type"), "text/event-stream")
@@ -259,50 +273,60 @@ func relayHelper(c *gin.Context) error {
 		})
 		err = resp.Body.Close()
 		if err != nil {
-			return err
+			return errorWrapper(err, "close_response_body_failed", http.StatusOK)
 		}
 		return nil
 	} else {
-		for k, v := range resp.Header {
-			c.Writer.Header().Set(k, v[0])
-		}
 		if consumeQuota {
 			responseBody, err := io.ReadAll(resp.Body)
 			if err != nil {
-				return err
+				return errorWrapper(err, "read_response_body_failed", http.StatusOK)
 			}
 			err = resp.Body.Close()
 			if err != nil {
-				return err
+				return errorWrapper(err, "close_response_body_failed", http.StatusOK)
 			}
 			err = json.Unmarshal(responseBody, &textResponse)
 			if err != nil {
-				return err
+				return errorWrapper(err, "unmarshal_response_body_failed", http.StatusOK)
 			}
 			if textResponse.Error.Type != "" {
-				return errors.New(fmt.Sprintf("type %s, code %s, message %s",
-					textResponse.Error.Type, textResponse.Error.Code, textResponse.Error.Message))
+				return &OpenAIErrorWithStatusCode{
+					OpenAIError: textResponse.Error,
+					StatusCode:  resp.StatusCode,
+				}
 			}
 			// Reset response body
 			resp.Body = io.NopCloser(bytes.NewBuffer(responseBody))
 		}
+		// We shouldn't set the header before we parse the response body, because the parse part may fail.
+		// And then we will have to send an error response, but in this case, the header has already been set.
+		// So the client will be confused by the response.
+		// For example, Postman will report error, and we cannot check the response at all.
+		for k, v := range resp.Header {
+			c.Writer.Header().Set(k, v[0])
+		}
+		c.Writer.WriteHeader(resp.StatusCode)
 		_, err = io.Copy(c.Writer, resp.Body)
 		if err != nil {
-			return err
+			return errorWrapper(err, "copy_response_body_failed", http.StatusOK)
 		}
 		err = resp.Body.Close()
 		if err != nil {
-			return err
+			return errorWrapper(err, "close_response_body_failed", http.StatusOK)
 		}
 		return nil
 	}
 }
 
 func RelayNotImplemented(c *gin.Context) {
+	err := OpenAIError{
+		Message: "API not implemented",
+		Type:    "one_api_error",
+		Param:   "",
+		Code:    "api_not_implemented",
+	}
 	c.JSON(http.StatusOK, gin.H{
-		"error": gin.H{
-			"message": "Not Implemented",
-			"type":    "one_api_error",
-		},
+		"error": err,
 	})
 }