Przeglądaj źródła

feat: automatically disable channel when error occurred (#59)

JustSong 2 lat temu
rodzic
commit
a1f61384c5

+ 5 - 0
common/constants.go

@@ -52,6 +52,11 @@ var TurnstileSecretKey = ""
 
 
 var QuotaForNewUser = 100
 var QuotaForNewUser = 100
 
 
+var ChannelDisableThreshold = 5.0
+var AutomaticDisableChannelEnabled = false
+
+var RootUserEmail = ""
+
 const (
 const (
 	RoleGuestUser  = 0
 	RoleGuestUser  = 0
 	RoleCommonUser = 1
 	RoleCommonUser = 1

+ 20 - 11
controller/channel.go

@@ -263,6 +263,20 @@ func TestChannel(c *gin.Context) {
 var testAllChannelsLock sync.Mutex
 var testAllChannelsLock sync.Mutex
 var testAllChannelsRunning bool = false
 var testAllChannelsRunning bool = false
 
 
+// disable & notify
+func disableChannel(channelId int, channelName string, err error) {
+	if common.RootUserEmail == "" {
+		common.RootUserEmail = model.GetRootUserEmail()
+	}
+	model.UpdateChannelStatusById(channelId, common.ChannelStatusDisabled)
+	subject := fmt.Sprintf("通道「%s」(#%d)已被禁用", channelName, channelId)
+	content := fmt.Sprintf("通道「%s」(#%d)已被禁用,原因:%s", channelName, channelId, err.Error())
+	err = common.SendEmail(subject, common.RootUserEmail, content)
+	if err != nil {
+		common.SysError(fmt.Sprintf("发送邮件失败:%s", err.Error()))
+	}
+}
+
 func testAllChannels(c *gin.Context) error {
 func testAllChannels(c *gin.Context) error {
 	testAllChannelsLock.Lock()
 	testAllChannelsLock.Lock()
 	if testAllChannelsRunning {
 	if testAllChannelsRunning {
@@ -280,8 +294,10 @@ func testAllChannels(c *gin.Context) error {
 		return err
 		return err
 	}
 	}
 	testRequest := buildTestRequest(c)
 	testRequest := buildTestRequest(c)
-	var disableThreshold int64 = 5000 // TODO: make it configurable
-	email := model.GetRootUserEmail()
+	var disableThreshold = int64(common.ChannelDisableThreshold * 1000)
+	if disableThreshold == 0 {
+		disableThreshold = 10000000 // a impossible value
+	}
 	go func() {
 	go func() {
 		for _, channel := range channels {
 		for _, channel := range channels {
 			if channel.Status != common.ChannelStatusEnabled {
 			if channel.Status != common.ChannelStatusEnabled {
@@ -295,18 +311,11 @@ func testAllChannels(c *gin.Context) error {
 				if milliseconds > disableThreshold {
 				if milliseconds > disableThreshold {
 					err = errors.New(fmt.Sprintf("响应时间 %.2fs 超过阈值 %.2fs", float64(milliseconds)/1000.0, float64(disableThreshold)/1000.0))
 					err = errors.New(fmt.Sprintf("响应时间 %.2fs 超过阈值 %.2fs", float64(milliseconds)/1000.0, float64(disableThreshold)/1000.0))
 				}
 				}
-				// disable & notify
-				channel.UpdateStatus(common.ChannelStatusDisabled)
-				subject := fmt.Sprintf("通道「%s」(#%d)已被禁用", channel.Name, channel.Id)
-				content := fmt.Sprintf("通道「%s」(#%d)已被禁用,原因:%s", channel.Name, channel.Id, err.Error())
-				err = common.SendEmail(subject, email, content)
-				if err != nil {
-					common.SysError(fmt.Sprintf("发送邮件失败:%s", err.Error()))
-				}
+				disableChannel(channel.Id, channel.Name, err)
 			}
 			}
 			channel.UpdateResponseTime(milliseconds)
 			channel.UpdateResponseTime(milliseconds)
 		}
 		}
-		err := common.SendEmail("通道测试完成", email, "通道测试完成,如果没有收到禁用通知,说明所有通道都正常")
+		err := common.SendEmail("通道测试完成", common.RootUserEmail, "通道测试完成,如果没有收到禁用通知,说明所有通道都正常")
 		if err != nil {
 		if err != nil {
 			common.SysError(fmt.Sprintf("发送邮件失败:%s", err.Error()))
 			common.SysError(fmt.Sprintf("发送邮件失败:%s", err.Error()))
 		}
 		}

+ 10 - 0
controller/relay.go

@@ -4,6 +4,7 @@ import (
 	"bufio"
 	"bufio"
 	"bytes"
 	"bytes"
 	"encoding/json"
 	"encoding/json"
+	"errors"
 	"fmt"
 	"fmt"
 	"github.com/gin-gonic/gin"
 	"github.com/gin-gonic/gin"
 	"github.com/pkoukk/tiktoken-go"
 	"github.com/pkoukk/tiktoken-go"
@@ -74,6 +75,11 @@ func Relay(c *gin.Context) {
 				"type":    "one_api_error",
 				"type":    "one_api_error",
 			},
 			},
 		})
 		})
+		if common.AutomaticDisableChannelEnabled {
+			channelId := c.GetInt("channel_id")
+			channelName := c.GetString("channel_name")
+			disableChannel(channelId, channelName, err)
+		}
 	}
 	}
 }
 }
 
 
@@ -256,6 +262,10 @@ func relayHelper(c *gin.Context) error {
 			if err != nil {
 			if err != nil {
 				return err
 				return err
 			}
 			}
+			if textResponse.Error.Type != "" {
+				return errors.New(fmt.Sprintf("type %s, code %s, message %s",
+					textResponse.Error.Type, textResponse.Error.Code, textResponse.Error.Message))
+			}
 			// Reset response body
 			// Reset response body
 			resp.Body = io.NopCloser(bytes.NewBuffer(responseBody))
 			resp.Body = io.NopCloser(bytes.NewBuffer(responseBody))
 		}
 		}

+ 2 - 0
middleware/distributor.go

@@ -62,6 +62,8 @@ func Distribute() func(c *gin.Context) {
 			}
 			}
 		}
 		}
 		c.Set("channel", channel.Type)
 		c.Set("channel", channel.Type)
+		c.Set("channel_id", channel.Id)
+		c.Set("channel_name", channel.Name)
 		c.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", channel.Key))
 		c.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", channel.Key))
 		if channel.Type == common.ChannelTypeCustom || channel.Type == common.ChannelTypeAzure {
 		if channel.Type == common.ChannelTypeCustom || channel.Type == common.ChannelTypeAzure {
 			c.Set("base_url", channel.BaseURL)
 			c.Set("base_url", channel.BaseURL)

+ 7 - 7
model/channel.go

@@ -86,15 +86,15 @@ func (channel *Channel) UpdateResponseTime(responseTime int64) {
 	}
 	}
 }
 }
 
 
-func (channel *Channel) UpdateStatus(status int) {
-	err := DB.Model(channel).Update("status", status).Error
-	if err != nil {
-		common.SysError("failed to update response time: " + err.Error())
-	}
-}
-
 func (channel *Channel) Delete() error {
 func (channel *Channel) Delete() error {
 	var err error
 	var err error
 	err = DB.Delete(channel).Error
 	err = DB.Delete(channel).Error
 	return err
 	return err
 }
 }
+
+func UpdateChannelStatusById(id int, status int) {
+	err := DB.Model(&Channel{}).Where("id = ?", id).Update("status", status).Error
+	if err != nil {
+		common.SysError("failed to update channel status: " + err.Error())
+	}
+}

+ 6 - 0
model/option.go

@@ -32,6 +32,8 @@ func InitOptionMap() {
 	common.OptionMap["WeChatAuthEnabled"] = strconv.FormatBool(common.WeChatAuthEnabled)
 	common.OptionMap["WeChatAuthEnabled"] = strconv.FormatBool(common.WeChatAuthEnabled)
 	common.OptionMap["TurnstileCheckEnabled"] = strconv.FormatBool(common.TurnstileCheckEnabled)
 	common.OptionMap["TurnstileCheckEnabled"] = strconv.FormatBool(common.TurnstileCheckEnabled)
 	common.OptionMap["RegisterEnabled"] = strconv.FormatBool(common.RegisterEnabled)
 	common.OptionMap["RegisterEnabled"] = strconv.FormatBool(common.RegisterEnabled)
+	common.OptionMap["AutomaticDisableChannelEnabled"] = strconv.FormatBool(common.AutomaticDisableChannelEnabled)
+	common.OptionMap["ChannelDisableThreshold"] = strconv.FormatFloat(common.ChannelDisableThreshold, 'f', -1, 64)
 	common.OptionMap["SMTPServer"] = ""
 	common.OptionMap["SMTPServer"] = ""
 	common.OptionMap["SMTPFrom"] = ""
 	common.OptionMap["SMTPFrom"] = ""
 	common.OptionMap["SMTPPort"] = strconv.Itoa(common.SMTPPort)
 	common.OptionMap["SMTPPort"] = strconv.Itoa(common.SMTPPort)
@@ -114,6 +116,8 @@ func updateOptionMap(key string, value string) (err error) {
 			common.TurnstileCheckEnabled = boolValue
 			common.TurnstileCheckEnabled = boolValue
 		case "RegisterEnabled":
 		case "RegisterEnabled":
 			common.RegisterEnabled = boolValue
 			common.RegisterEnabled = boolValue
+		case "AutomaticDisableChannelEnabled":
+			common.AutomaticDisableChannelEnabled = boolValue
 		}
 		}
 	}
 	}
 	switch key {
 	switch key {
@@ -156,6 +160,8 @@ func updateOptionMap(key string, value string) (err error) {
 		err = common.UpdateModelRatioByJSONString(value)
 		err = common.UpdateModelRatioByJSONString(value)
 	case "TopUpLink":
 	case "TopUpLink":
 		common.TopUpLink = value
 		common.TopUpLink = value
+	case "ChannelDisableThreshold":
+		common.ChannelDisableThreshold, _ = strconv.ParseFloat(value, 64)
 	}
 	}
 	return err
 	return err
 }
 }

+ 28 - 1
web/src/components/SystemSetting.js

@@ -28,7 +28,9 @@ const SystemSetting = () => {
     RegisterEnabled: '',
     RegisterEnabled: '',
     QuotaForNewUser: 0,
     QuotaForNewUser: 0,
     ModelRatio: '',
     ModelRatio: '',
-    TopUpLink: ''
+    TopUpLink: '',
+    AutomaticDisableChannelEnabled: '',
+    ChannelDisableThreshold: 0,
   });
   });
   let originInputs = {};
   let originInputs = {};
   let [loading, setLoading] = useState(false);
   let [loading, setLoading] = useState(false);
@@ -62,6 +64,7 @@ const SystemSetting = () => {
       case 'WeChatAuthEnabled':
       case 'WeChatAuthEnabled':
       case 'TurnstileCheckEnabled':
       case 'TurnstileCheckEnabled':
       case 'RegisterEnabled':
       case 'RegisterEnabled':
+      case 'AutomaticDisableChannelEnabled':
         value = inputs[key] === 'true' ? 'false' : 'true';
         value = inputs[key] === 'true' ? 'false' : 'true';
         break;
         break;
       default:
       default:
@@ -298,6 +301,30 @@ const SystemSetting = () => {
           </Form.Group>
           </Form.Group>
           <Form.Button onClick={submitOperationConfig}>保存运营设置</Form.Button>
           <Form.Button onClick={submitOperationConfig}>保存运营设置</Form.Button>
           <Divider />
           <Divider />
+          <Header as='h3'>
+            监控设置
+          </Header>
+          <Form.Group widths={3}>
+            <Form.Input
+              label='最长回应时间'
+              name='ChannelDisableThreshold'
+              onChange={handleInputChange}
+              autoComplete='new-password'
+              value={inputs.ChannelDisableThreshold}
+              type='number'
+              min='0'
+              placeholder='单位秒,当运行通道全部测试时,超过此时间将自动禁用通道'
+            />
+          </Form.Group>
+          <Form.Group inline>
+            <Form.Checkbox
+              checked={inputs.AutomaticDisableChannelEnabled === 'true'}
+              label='失败时自动禁用通道'
+              name='AutomaticDisableChannelEnabled'
+              onChange={handleInputChange}
+            />
+          </Form.Group>
+          <Divider />
           <Header as='h3'>
           <Header as='h3'>
             配置 SMTP
             配置 SMTP
             <Header.Subheader>用以支持系统的邮件发送</Header.Subheader>
             <Header.Subheader>用以支持系统的邮件发送</Header.Subheader>