Sfoglia il codice sorgente

Merge pull request #2647 from seefs001/feature/status-code-auto-disable

feat: status code auto-disable configuration
Seefs 1 mese fa
parent
commit
41da848c56

+ 10 - 0
controller/option.go

@@ -10,6 +10,7 @@ import (
 	"github.com/QuantumNous/new-api/model"
 	"github.com/QuantumNous/new-api/setting"
 	"github.com/QuantumNous/new-api/setting/console_setting"
+	"github.com/QuantumNous/new-api/setting/operation_setting"
 	"github.com/QuantumNous/new-api/setting/ratio_setting"
 	"github.com/QuantumNous/new-api/setting/system_setting"
 
@@ -177,6 +178,15 @@ func UpdateOption(c *gin.Context) {
 			})
 			return
 		}
+	case "AutomaticDisableStatusCodes":
+		_, err = operation_setting.ParseHTTPStatusCodeRanges(option.Value.(string))
+		if err != nil {
+			c.JSON(http.StatusOK, gin.H{
+				"success": false,
+				"message": err.Error(),
+			})
+			return
+		}
 	case "console_setting.api_info":
 		err = console_setting.ValidateConsoleSettings(option.Value.(string), "ApiInfo")
 		if err != nil {

+ 2 - 2
controller/relay.go

@@ -348,7 +348,7 @@ func processChannelError(c *gin.Context, channelError types.ChannelError, err *t
 	// do not use context to get channel info, there may be inconsistent channel info when processing asynchronously
 	if service.ShouldDisableChannel(channelError.ChannelType, err) && channelError.AutoBan {
 		gopool.Go(func() {
-			service.DisableChannel(channelError, err.Error())
+			service.DisableChannel(channelError, err.ErrorWithStatusCode())
 		})
 	}
 
@@ -378,7 +378,7 @@ func processChannelError(c *gin.Context, channelError types.ChannelError, err *t
 			adminInfo["multi_key_index"] = common.GetContextKeyInt(c, constant.ContextKeyChannelMultiKeyIndex)
 		}
 		other["admin_info"] = adminInfo
-		model.RecordErrorLog(c, userId, channelId, modelName, tokenName, err.MaskSensitiveError(), tokenId, 0, false, userGroup, other)
+		model.RecordErrorLog(c, userId, channelId, modelName, tokenName, err.MaskSensitiveErrorWithStatusCode(), tokenId, 0, false, userGroup, other)
 	}
 
 }

+ 3 - 0
model/option.go

@@ -143,6 +143,7 @@ func InitOptionMap() {
 	common.OptionMap["SensitiveWords"] = setting.SensitiveWordsToString()
 	common.OptionMap["StreamCacheQueueLength"] = strconv.Itoa(setting.StreamCacheQueueLength)
 	common.OptionMap["AutomaticDisableKeywords"] = operation_setting.AutomaticDisableKeywordsToString()
+	common.OptionMap["AutomaticDisableStatusCodes"] = operation_setting.AutomaticDisableStatusCodesToString()
 	common.OptionMap["ExposeRatioEnabled"] = strconv.FormatBool(ratio_setting.IsExposeRatioEnabled())
 
 	// 自动添加所有注册的模型配置
@@ -444,6 +445,8 @@ func updateOptionMap(key string, value string) (err error) {
 		setting.SensitiveWordsFromString(value)
 	case "AutomaticDisableKeywords":
 		operation_setting.AutomaticDisableKeywordsFromString(value)
+	case "AutomaticDisableStatusCodes":
+		err = operation_setting.AutomaticDisableStatusCodesFromString(value)
 	case "StreamCacheQueueLength":
 		setting.StreamCacheQueueLength, _ = strconv.Atoi(value)
 	case "PayMethods":

+ 4 - 1
service/channel.go

@@ -57,9 +57,12 @@ func ShouldDisableChannel(channelType int, err *types.NewAPIError) bool {
 	if types.IsSkipRetryError(err) {
 		return false
 	}
-	if err.StatusCode == http.StatusUnauthorized {
+	if operation_setting.ShouldDisableByStatusCode(err.StatusCode) {
 		return true
 	}
+	//if err.StatusCode == http.StatusUnauthorized {
+	//	return true
+	//}
 	if err.StatusCode == http.StatusForbidden {
 		switch channelType {
 		case constant.ChannelTypeGemini:

+ 147 - 0
setting/operation_setting/status_code_ranges.go

@@ -0,0 +1,147 @@
+package operation_setting
+
+import (
+	"fmt"
+	"sort"
+	"strconv"
+	"strings"
+)
+
+type StatusCodeRange struct {
+	Start int
+	End   int
+}
+
+var AutomaticDisableStatusCodeRanges = []StatusCodeRange{{Start: 401, End: 401}}
+
+func AutomaticDisableStatusCodesToString() string {
+	if len(AutomaticDisableStatusCodeRanges) == 0 {
+		return ""
+	}
+	parts := make([]string, 0, len(AutomaticDisableStatusCodeRanges))
+	for _, r := range AutomaticDisableStatusCodeRanges {
+		if r.Start == r.End {
+			parts = append(parts, strconv.Itoa(r.Start))
+			continue
+		}
+		parts = append(parts, fmt.Sprintf("%d-%d", r.Start, r.End))
+	}
+	return strings.Join(parts, ",")
+}
+
+func AutomaticDisableStatusCodesFromString(s string) error {
+	ranges, err := ParseHTTPStatusCodeRanges(s)
+	if err != nil {
+		return err
+	}
+	AutomaticDisableStatusCodeRanges = ranges
+	return nil
+}
+
+func ShouldDisableByStatusCode(code int) bool {
+	if code < 100 || code > 599 {
+		return false
+	}
+	for _, r := range AutomaticDisableStatusCodeRanges {
+		if code < r.Start {
+			return false
+		}
+		if code <= r.End {
+			return true
+		}
+	}
+	return false
+}
+
+func ParseHTTPStatusCodeRanges(input string) ([]StatusCodeRange, error) {
+	input = strings.TrimSpace(input)
+	if input == "" {
+		return nil, nil
+	}
+
+	input = strings.NewReplacer(",", ",").Replace(input)
+	segments := strings.Split(input, ",")
+
+	var ranges []StatusCodeRange
+	var invalid []string
+
+	for _, seg := range segments {
+		seg = strings.TrimSpace(seg)
+		if seg == "" {
+			continue
+		}
+		r, err := parseHTTPStatusCodeToken(seg)
+		if err != nil {
+			invalid = append(invalid, seg)
+			continue
+		}
+		ranges = append(ranges, r)
+	}
+
+	if len(invalid) > 0 {
+		return nil, fmt.Errorf("invalid http status code rules: %s", strings.Join(invalid, ", "))
+	}
+	if len(ranges) == 0 {
+		return nil, nil
+	}
+
+	sort.Slice(ranges, func(i, j int) bool {
+		if ranges[i].Start == ranges[j].Start {
+			return ranges[i].End < ranges[j].End
+		}
+		return ranges[i].Start < ranges[j].Start
+	})
+
+	merged := []StatusCodeRange{ranges[0]}
+	for _, r := range ranges[1:] {
+		last := &merged[len(merged)-1]
+		if r.Start <= last.End+1 {
+			if r.End > last.End {
+				last.End = r.End
+			}
+			continue
+		}
+		merged = append(merged, r)
+	}
+
+	return merged, nil
+}
+
+func parseHTTPStatusCodeToken(token string) (StatusCodeRange, error) {
+	token = strings.TrimSpace(token)
+	token = strings.ReplaceAll(token, " ", "")
+	if token == "" {
+		return StatusCodeRange{}, fmt.Errorf("empty token")
+	}
+
+	if strings.Contains(token, "-") {
+		parts := strings.Split(token, "-")
+		if len(parts) != 2 || parts[0] == "" || parts[1] == "" {
+			return StatusCodeRange{}, fmt.Errorf("invalid range token: %s", token)
+		}
+		start, err := strconv.Atoi(parts[0])
+		if err != nil {
+			return StatusCodeRange{}, fmt.Errorf("invalid range start: %s", token)
+		}
+		end, err := strconv.Atoi(parts[1])
+		if err != nil {
+			return StatusCodeRange{}, fmt.Errorf("invalid range end: %s", token)
+		}
+		if start > end {
+			return StatusCodeRange{}, fmt.Errorf("range start > end: %s", token)
+		}
+		if start < 100 || end > 599 {
+			return StatusCodeRange{}, fmt.Errorf("range out of bounds: %s", token)
+		}
+		return StatusCodeRange{Start: start, End: end}, nil
+	}
+
+	code, err := strconv.Atoi(token)
+	if err != nil {
+		return StatusCodeRange{}, fmt.Errorf("invalid status code: %s", token)
+	}
+	if code < 100 || code > 599 {
+		return StatusCodeRange{}, fmt.Errorf("status code out of bounds: %s", token)
+	}
+	return StatusCodeRange{Start: code, End: code}, nil
+}

+ 52 - 0
setting/operation_setting/status_code_ranges_test.go

@@ -0,0 +1,52 @@
+package operation_setting
+
+import (
+	"testing"
+
+	"github.com/stretchr/testify/require"
+)
+
+func TestParseHTTPStatusCodeRanges_CommaSeparated(t *testing.T) {
+	ranges, err := ParseHTTPStatusCodeRanges("401,403,500-599")
+	require.NoError(t, err)
+	require.Equal(t, []StatusCodeRange{
+		{Start: 401, End: 401},
+		{Start: 403, End: 403},
+		{Start: 500, End: 599},
+	}, ranges)
+}
+
+func TestParseHTTPStatusCodeRanges_MergeAndNormalize(t *testing.T) {
+	ranges, err := ParseHTTPStatusCodeRanges("500-505,504,401,403,402")
+	require.NoError(t, err)
+	require.Equal(t, []StatusCodeRange{
+		{Start: 401, End: 403},
+		{Start: 500, End: 505},
+	}, ranges)
+}
+
+func TestParseHTTPStatusCodeRanges_Invalid(t *testing.T) {
+	_, err := ParseHTTPStatusCodeRanges("99,600,foo,500-400,500-")
+	require.Error(t, err)
+}
+
+func TestParseHTTPStatusCodeRanges_NoComma_IsInvalid(t *testing.T) {
+	_, err := ParseHTTPStatusCodeRanges("401 403")
+	require.Error(t, err)
+}
+
+func TestShouldDisableByStatusCode(t *testing.T) {
+	orig := AutomaticDisableStatusCodeRanges
+	t.Cleanup(func() { AutomaticDisableStatusCodeRanges = orig })
+
+	AutomaticDisableStatusCodeRanges = []StatusCodeRange{
+		{Start: 401, End: 403},
+		{Start: 500, End: 599},
+	}
+
+	require.True(t, ShouldDisableByStatusCode(401))
+	require.True(t, ShouldDisableByStatusCode(403))
+	require.False(t, ShouldDisableByStatusCode(404))
+	require.True(t, ShouldDisableByStatusCode(500))
+	require.False(t, ShouldDisableByStatusCode(200))
+}

+ 28 - 0
types/error.go

@@ -130,6 +130,20 @@ func (e *NewAPIError) Error() string {
 	return e.Err.Error()
 }
 
+func (e *NewAPIError) ErrorWithStatusCode() string {
+	if e == nil {
+		return ""
+	}
+	msg := e.Error()
+	if e.StatusCode == 0 {
+		return msg
+	}
+	if msg == "" {
+		return fmt.Sprintf("status_code=%d", e.StatusCode)
+	}
+	return fmt.Sprintf("status_code=%d, %s", e.StatusCode, msg)
+}
+
 func (e *NewAPIError) MaskSensitiveError() string {
 	if e == nil {
 		return ""
@@ -144,6 +158,20 @@ func (e *NewAPIError) MaskSensitiveError() string {
 	return common.MaskSensitiveInfo(errStr)
 }
 
+func (e *NewAPIError) MaskSensitiveErrorWithStatusCode() string {
+	if e == nil {
+		return ""
+	}
+	msg := e.MaskSensitiveError()
+	if e.StatusCode == 0 {
+		return msg
+	}
+	if msg == "" {
+		return fmt.Sprintf("status_code=%d", e.StatusCode)
+	}
+	return fmt.Sprintf("status_code=%d, %s", e.StatusCode, msg)
+}
+
 func (e *NewAPIError) SetMessage(message string) {
 	e.Err = errors.New(message)
 }

+ 1 - 0
web/src/components/settings/OperationSetting.jsx

@@ -70,6 +70,7 @@ const OperationSetting = () => {
     AutomaticDisableChannelEnabled: false,
     AutomaticEnableChannelEnabled: false,
     AutomaticDisableKeywords: '',
+    AutomaticDisableStatusCodes: '401',
     'monitor_setting.auto_test_channel_enabled': false,
     'monitor_setting.auto_test_channel_minutes': 10 /* 签到设置 */,
     'checkin_setting.enabled': false,

+ 1 - 0
web/src/helpers/index.js

@@ -29,3 +29,4 @@ export * from './token';
 export * from './boolean';
 export * from './dashboard';
 export * from './passkey';
+export * from './statusCodeRules';

+ 96 - 0
web/src/helpers/statusCodeRules.js

@@ -0,0 +1,96 @@
+export function parseHttpStatusCodeRules(input) {
+  const raw = (input ?? '').toString().trim();
+  if (raw.length === 0) {
+    return {
+      ok: true,
+      ranges: [],
+      tokens: [],
+      normalized: '',
+      invalidTokens: [],
+    };
+  }
+
+  const sanitized = raw.replace(/[,]/g, ',');
+  const segments = sanitized.split(/[,]/g);
+
+  const ranges = [];
+  const invalidTokens = [];
+
+  for (const segment of segments) {
+    const trimmed = segment.trim();
+    if (!trimmed) continue;
+    const parsed = parseToken(trimmed);
+    if (!parsed) invalidTokens.push(trimmed);
+    else ranges.push(parsed);
+  }
+
+  if (invalidTokens.length > 0) {
+    return {
+      ok: false,
+      ranges: [],
+      tokens: [],
+      normalized: raw,
+      invalidTokens,
+    };
+  }
+
+  const merged = mergeRanges(ranges);
+  const tokens = merged.map((r) => (r.start === r.end ? `${r.start}` : `${r.start}-${r.end}`));
+  const normalized = tokens.join(',');
+
+  return {
+    ok: true,
+    ranges: merged,
+    tokens,
+    normalized,
+    invalidTokens: [],
+  };
+}
+
+function parseToken(token) {
+  const cleaned = (token ?? '').toString().trim().replaceAll(' ', '');
+  if (!cleaned) return null;
+
+  if (cleaned.includes('-')) {
+    const parts = cleaned.split('-');
+    if (parts.length !== 2) return null;
+    const [a, b] = parts;
+    if (!isNumber(a) || !isNumber(b)) return null;
+    const start = Number.parseInt(a, 10);
+    const end = Number.parseInt(b, 10);
+    if (!Number.isFinite(start) || !Number.isFinite(end)) return null;
+    if (start > end) return null;
+    if (start < 100 || end > 599) return null;
+    return { start, end };
+  }
+
+  if (!isNumber(cleaned)) return null;
+  const code = Number.parseInt(cleaned, 10);
+  if (!Number.isFinite(code)) return null;
+  if (code < 100 || code > 599) return null;
+  return { start: code, end: code };
+}
+
+function isNumber(s) {
+  return typeof s === 'string' && /^\d+$/.test(s);
+}
+
+function mergeRanges(ranges) {
+  if (!Array.isArray(ranges) || ranges.length === 0) return [];
+
+  const sorted = [...ranges].sort((a, b) => (a.start !== b.start ? a.start - b.start : a.end - b.end));
+  const merged = [sorted[0]];
+
+  for (let i = 1; i < sorted.length; i += 1) {
+    const current = sorted[i];
+    const last = merged[merged.length - 1];
+
+    if (current.start <= last.end + 1) {
+      last.end = Math.max(last.end, current.end);
+      continue;
+    }
+    merged.push({ ...current });
+  }
+
+  return merged;
+}

+ 4 - 0
web/src/i18n/locales/en.json

@@ -1923,6 +1923,10 @@
     "自动测试所有通道间隔时间": "Auto test interval for all channels",
     "自动禁用": "Auto disabled",
     "自动禁用关键词": "Automatic disable keywords",
+    "自动禁用状态码": "Auto-disable status codes",
+    "自动禁用状态码格式不正确": "Invalid auto-disable status code format",
+    "支持填写单个状态码或范围(含首尾),使用逗号分隔": "Supports single status codes or inclusive ranges; separate with commas",
+    "例如:401, 403, 429, 500-599": "e.g. 401,403,429,500-599",
     "自动选择": "Auto Select",
     "自定义充值数量选项": "Custom Recharge Amount Options",
     "自定义充值数量选项不是合法的 JSON 数组": "Custom recharge amount options is not a valid JSON array",

+ 4 - 0
web/src/i18n/locales/zh.json

@@ -1909,6 +1909,10 @@
     "自动测试所有通道间隔时间": "自动测试所有通道间隔时间",
     "自动禁用": "自动禁用",
     "自动禁用关键词": "自动禁用关键词",
+    "自动禁用状态码": "自动禁用状态码",
+    "自动禁用状态码格式不正确": "自动禁用状态码格式不正确",
+    "支持填写单个状态码或范围(含首尾),使用逗号分隔": "支持填写单个状态码或范围(含首尾),使用逗号分隔",
+    "例如:401, 403, 429, 500-599": "例如:401,403,429,500-599",
     "自动选择": "自动选择",
     "自定义充值数量选项": "自定义充值数量选项",
     "自定义充值数量选项不是合法的 JSON 数组": "自定义充值数量选项不是合法的 JSON 数组",

+ 67 - 2
web/src/pages/Setting/Operation/SettingsMonitoring.jsx

@@ -18,19 +18,29 @@ For commercial licensing, please contact support@quantumnous.com
 */
 
 import React, { useEffect, useState, useRef } from 'react';
-import { Button, Col, Form, Row, Spin } from '@douyinfe/semi-ui';
+import {
+  Button,
+  Col,
+  Form,
+  Row,
+  Spin,
+  Tag,
+  Typography,
+} from '@douyinfe/semi-ui';
 import {
   compareObjects,
   API,
   showError,
   showSuccess,
   showWarning,
+  parseHttpStatusCodeRules,
   verifyJSON,
 } from '../../../helpers';
 import { useTranslation } from 'react-i18next';
 
 export default function SettingsMonitoring(props) {
   const { t } = useTranslation();
+  const { Text } = Typography;
   const [loading, setLoading] = useState(false);
   const [inputs, setInputs] = useState({
     ChannelDisableThreshold: '',
@@ -38,21 +48,37 @@ export default function SettingsMonitoring(props) {
     AutomaticDisableChannelEnabled: false,
     AutomaticEnableChannelEnabled: false,
     AutomaticDisableKeywords: '',
+    AutomaticDisableStatusCodes: '401',
     'monitor_setting.auto_test_channel_enabled': false,
     'monitor_setting.auto_test_channel_minutes': 10,
   });
   const refForm = useRef();
   const [inputsRow, setInputsRow] = useState(inputs);
+  const parsedAutoDisableStatusCodes = parseHttpStatusCodeRules(
+    inputs.AutomaticDisableStatusCodes || '',
+  );
 
   function onSubmit() {
     const updateArray = compareObjects(inputs, inputsRow);
     if (!updateArray.length) return showWarning(t('你似乎并没有修改什么'));
+    if (!parsedAutoDisableStatusCodes.ok) {
+      const details =
+        parsedAutoDisableStatusCodes.invalidTokens &&
+        parsedAutoDisableStatusCodes.invalidTokens.length > 0
+          ? `: ${parsedAutoDisableStatusCodes.invalidTokens.join(', ')}`
+          : '';
+      return showError(`${t('自动禁用状态码格式不正确')}${details}`);
+    }
     const requestQueue = updateArray.map((item) => {
       let value = '';
       if (typeof inputs[item.key] === 'boolean') {
         value = String(inputs[item.key]);
       } else {
-        value = inputs[item.key];
+        if (item.key === 'AutomaticDisableStatusCodes') {
+          value = parsedAutoDisableStatusCodes.normalized;
+        } else {
+          value = inputs[item.key];
+        }
       }
       return API.put('/api/option/', {
         key: item.key,
@@ -207,6 +233,45 @@ export default function SettingsMonitoring(props) {
             </Row>
             <Row gutter={16}>
               <Col xs={24} sm={16}>
+                <Form.Input
+                  label={t('自动禁用状态码')}
+                  placeholder={t('例如:401, 403, 429, 500-599')}
+                  extraText={t(
+                    '支持填写单个状态码或范围(含首尾),使用逗号分隔',
+                  )}
+                  field={'AutomaticDisableStatusCodes'}
+                  onChange={(value) =>
+                    setInputs({ ...inputs, AutomaticDisableStatusCodes: value })
+                  }
+                />
+                {parsedAutoDisableStatusCodes.ok &&
+                  parsedAutoDisableStatusCodes.tokens.length > 0 && (
+                    <div
+                      style={{
+                        display: 'flex',
+                        flexWrap: 'wrap',
+                        gap: 8,
+                        marginTop: 8,
+                      }}
+                    >
+                      {parsedAutoDisableStatusCodes.tokens.map((token) => (
+                        <Tag key={token} size='small'>
+                          {token}
+                        </Tag>
+                      ))}
+                    </div>
+                  )}
+                {!parsedAutoDisableStatusCodes.ok && (
+                  <Text type='danger' style={{ display: 'block', marginTop: 8 }}>
+                    {t('自动禁用状态码格式不正确')}
+                    {parsedAutoDisableStatusCodes.invalidTokens &&
+                    parsedAutoDisableStatusCodes.invalidTokens.length > 0
+                      ? `: ${parsedAutoDisableStatusCodes.invalidTokens.join(
+                          ', ',
+                        )}`
+                      : ''}
+                  </Text>
+                )}
                 <Form.TextArea
                   label={t('自动禁用关键词')}
                   placeholder={t('一行一个,不区分大小写')}