Przeglądaj źródła

fix: using whitelist when disabling channels (close #292)

JustSong 2 lat temu
rodzic
commit
0495b9a0d7
3 zmienionych plików z 28 dodań i 14 usunięć
  1. 14 13
      controller/channel-test.go
  2. 13 0
      controller/relay-utils.go
  3. 1 1
      controller/relay.go

+ 14 - 13
controller/channel-test.go

@@ -14,7 +14,7 @@ import (
 	"time"
 )
 
-func testChannel(channel *model.Channel, request ChatRequest) error {
+func testChannel(channel *model.Channel, request ChatRequest) (error, *OpenAIError) {
 	switch channel.Type {
 	case common.ChannelTypeAzure:
 		request.Model = "gpt-35-turbo"
@@ -33,11 +33,11 @@ func testChannel(channel *model.Channel, request ChatRequest) error {
 
 	jsonData, err := json.Marshal(request)
 	if err != nil {
-		return err
+		return err, nil
 	}
 	req, err := http.NewRequest("POST", requestURL, bytes.NewBuffer(jsonData))
 	if err != nil {
-		return err
+		return err, nil
 	}
 	if channel.Type == common.ChannelTypeAzure {
 		req.Header.Set("api-key", channel.Key)
@@ -48,18 +48,18 @@ func testChannel(channel *model.Channel, request ChatRequest) error {
 	client := &http.Client{}
 	resp, err := client.Do(req)
 	if err != nil {
-		return err
+		return err, nil
 	}
 	defer resp.Body.Close()
 	var response TextResponse
 	err = json.NewDecoder(resp.Body).Decode(&response)
 	if err != nil {
-		return err
+		return err, nil
 	}
 	if response.Usage.CompletionTokens == 0 {
-		return errors.New(fmt.Sprintf("type %s, code %v, message %s", response.Error.Type, response.Error.Code, response.Error.Message))
+		return errors.New(fmt.Sprintf("type %s, code %v, message %s", response.Error.Type, response.Error.Code, response.Error.Message)), &response.Error
 	}
-	return nil
+	return nil, nil
 }
 
 func buildTestRequest() *ChatRequest {
@@ -94,7 +94,7 @@ func TestChannel(c *gin.Context) {
 	}
 	testRequest := buildTestRequest()
 	tik := time.Now()
-	err = testChannel(channel, *testRequest)
+	err, _ = testChannel(channel, *testRequest)
 	tok := time.Now()
 	milliseconds := tok.Sub(tik).Milliseconds()
 	go channel.UpdateResponseTime(milliseconds)
@@ -158,13 +158,14 @@ func testAllChannels(notify bool) error {
 				continue
 			}
 			tik := time.Now()
-			err := testChannel(channel, *testRequest)
+			err, openaiErr := testChannel(channel, *testRequest)
 			tok := time.Now()
 			milliseconds := tok.Sub(tik).Milliseconds()
-			if err != nil || milliseconds > disableThreshold {
-				if milliseconds > disableThreshold {
-					err = errors.New(fmt.Sprintf("响应时间 %.2fs 超过阈值 %.2fs", float64(milliseconds)/1000.0, float64(disableThreshold)/1000.0))
-				}
+			if milliseconds > disableThreshold {
+				err = errors.New(fmt.Sprintf("响应时间 %.2fs 超过阈值 %.2fs", float64(milliseconds)/1000.0, float64(disableThreshold)/1000.0))
+				disableChannel(channel.Id, channel.Name, err.Error())
+			}
+			if shouldDisableChannel(openaiErr) {
 				disableChannel(channel.Id, channel.Name, err.Error())
 			}
 			channel.UpdateResponseTime(milliseconds)

+ 13 - 0
controller/relay-utils.go

@@ -91,3 +91,16 @@ func errorWrapper(err error, code string, statusCode int) *OpenAIErrorWithStatus
 		StatusCode:  statusCode,
 	}
 }
+
+func shouldDisableChannel(err *OpenAIError) bool {
+	if !common.AutomaticDisableChannelEnabled {
+		return false
+	}
+	if err == nil {
+		return false
+	}
+	if err.Type == "insufficient_quota" || err.Code == "invalid_api_key" || err.Code == "account_deactivated" {
+		return true
+	}
+	return false
+}

+ 1 - 1
controller/relay.go

@@ -171,7 +171,7 @@ func Relay(c *gin.Context) {
 		channelId := c.GetInt("channel_id")
 		common.SysError(fmt.Sprintf("relay error (channel #%d): %s", channelId, err.Message))
 		// https://platform.openai.com/docs/guides/error-codes/api-errors
-		if common.AutomaticDisableChannelEnabled && (err.Type == "insufficient_quota" || err.Code == "invalid_api_key" || err.Code == "account_deactivated") {
+		if shouldDisableChannel(&err.OpenAIError) {
 			channelId := c.GetInt("channel_id")
 			channelName := c.GetString("channel_name")
 			disableChannel(channelId, channelName, err.Message)