瀏覽代碼

Merge pull request #2926 from seefs001/fix/status_code_mapping

fix: support numeric status code mapping in ResetStatusCode
Calcium-Ion 3 周之前
父節點
當前提交
f77381cc75

+ 3 - 0
controller/channel-test.go

@@ -804,6 +804,9 @@ func testAllChannels(notify bool) error {
 		}()
 
 		for _, channel := range channels {
+			if channel.Status == common.ChannelStatusManuallyDisabled {
+				continue
+			}
 			isChannelEnabled := channel.Status == common.ChannelStatusEnabled
 			tik := time.Now()
 			result := testChannel(channel, "", "", false)

+ 31 - 26
relay/channel/api_request.go

@@ -171,35 +171,37 @@ func processHeaderOverride(info *common.RelayInfo, c *gin.Context) (map[string]s
 
 	passAll := false
 	var passthroughRegex []*regexp.Regexp
-	for k := range info.HeadersOverride {
-		key := strings.TrimSpace(k)
-		if key == "" {
-			continue
-		}
-		if key == headerPassthroughAllKey {
-			passAll = true
-			continue
-		}
+	if !info.IsChannelTest {
+		for k := range info.HeadersOverride {
+			key := strings.TrimSpace(k)
+			if key == "" {
+				continue
+			}
+			if key == headerPassthroughAllKey {
+				passAll = true
+				continue
+			}
 
-		lower := strings.ToLower(key)
-		var pattern string
-		switch {
-		case strings.HasPrefix(lower, headerPassthroughRegexPrefix):
-			pattern = strings.TrimSpace(key[len(headerPassthroughRegexPrefix):])
-		case strings.HasPrefix(lower, headerPassthroughRegexPrefixV2):
-			pattern = strings.TrimSpace(key[len(headerPassthroughRegexPrefixV2):])
-		default:
-			continue
-		}
+			lower := strings.ToLower(key)
+			var pattern string
+			switch {
+			case strings.HasPrefix(lower, headerPassthroughRegexPrefix):
+				pattern = strings.TrimSpace(key[len(headerPassthroughRegexPrefix):])
+			case strings.HasPrefix(lower, headerPassthroughRegexPrefixV2):
+				pattern = strings.TrimSpace(key[len(headerPassthroughRegexPrefixV2):])
+			default:
+				continue
+			}
 
-		if pattern == "" {
-			return nil, types.NewError(fmt.Errorf("header passthrough regex pattern is empty: %q", k), types.ErrorCodeChannelHeaderOverrideInvalid)
-		}
-		compiled, err := getHeaderPassthroughRegex(pattern)
-		if err != nil {
-			return nil, types.NewError(err, types.ErrorCodeChannelHeaderOverrideInvalid)
+			if pattern == "" {
+				return nil, types.NewError(fmt.Errorf("header passthrough regex pattern is empty: %q", k), types.ErrorCodeChannelHeaderOverrideInvalid)
+			}
+			compiled, err := getHeaderPassthroughRegex(pattern)
+			if err != nil {
+				return nil, types.NewError(err, types.ErrorCodeChannelHeaderOverrideInvalid)
+			}
+			passthroughRegex = append(passthroughRegex, compiled)
 		}
-		passthroughRegex = append(passthroughRegex, compiled)
 	}
 
 	if passAll || len(passthroughRegex) > 0 {
@@ -243,6 +245,9 @@ func processHeaderOverride(info *common.RelayInfo, c *gin.Context) (map[string]s
 		if !ok {
 			return nil, types.NewError(nil, types.ErrorCodeChannelHeaderOverrideInvalid)
 		}
+		if info.IsChannelTest && strings.HasPrefix(strings.TrimSpace(str), clientHeaderPlaceholderPrefix) {
+			continue
+		}
 
 		value, include, err := applyHeaderOverridePlaceholders(str, c, info.ApiKey)
 		if err != nil {

+ 81 - 0
relay/channel/api_request_test.go

@@ -0,0 +1,81 @@
+package channel
+
+import (
+	"net/http"
+	"net/http/httptest"
+	"testing"
+
+	relaycommon "github.com/QuantumNous/new-api/relay/common"
+	"github.com/gin-gonic/gin"
+	"github.com/stretchr/testify/require"
+)
+
+func TestProcessHeaderOverride_ChannelTestSkipsPassthroughRules(t *testing.T) {
+	t.Parallel()
+
+	gin.SetMode(gin.TestMode)
+	recorder := httptest.NewRecorder()
+	ctx, _ := gin.CreateTestContext(recorder)
+	ctx.Request = httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
+	ctx.Request.Header.Set("X-Trace-Id", "trace-123")
+
+	info := &relaycommon.RelayInfo{
+		IsChannelTest: true,
+		ChannelMeta: &relaycommon.ChannelMeta{
+			HeadersOverride: map[string]any{
+				"*": "",
+			},
+		},
+	}
+
+	headers, err := processHeaderOverride(info, ctx)
+	require.NoError(t, err)
+	require.Empty(t, headers)
+}
+
+func TestProcessHeaderOverride_ChannelTestSkipsClientHeaderPlaceholder(t *testing.T) {
+	t.Parallel()
+
+	gin.SetMode(gin.TestMode)
+	recorder := httptest.NewRecorder()
+	ctx, _ := gin.CreateTestContext(recorder)
+	ctx.Request = httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
+	ctx.Request.Header.Set("X-Trace-Id", "trace-123")
+
+	info := &relaycommon.RelayInfo{
+		IsChannelTest: true,
+		ChannelMeta: &relaycommon.ChannelMeta{
+			HeadersOverride: map[string]any{
+				"X-Upstream-Trace": "{client_header:X-Trace-Id}",
+			},
+		},
+	}
+
+	headers, err := processHeaderOverride(info, ctx)
+	require.NoError(t, err)
+	_, ok := headers["X-Upstream-Trace"]
+	require.False(t, ok)
+}
+
+func TestProcessHeaderOverride_NonTestKeepsClientHeaderPlaceholder(t *testing.T) {
+	t.Parallel()
+
+	gin.SetMode(gin.TestMode)
+	recorder := httptest.NewRecorder()
+	ctx, _ := gin.CreateTestContext(recorder)
+	ctx.Request = httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
+	ctx.Request.Header.Set("X-Trace-Id", "trace-123")
+
+	info := &relaycommon.RelayInfo{
+		IsChannelTest: false,
+		ChannelMeta: &relaycommon.ChannelMeta{
+			HeadersOverride: map[string]any{
+				"X-Upstream-Trace": "{client_header:X-Trace-Id}",
+			},
+		},
+	}
+
+	headers, err := processHeaderOverride(info, ctx)
+	require.NoError(t, err)
+	require.Equal(t, "trace-123", headers["X-Upstream-Trace"])
+}

+ 40 - 3
service/error.go

@@ -2,9 +2,11 @@ package service
 
 import (
 	"context"
+	"encoding/json"
 	"errors"
 	"fmt"
 	"io"
+	"math"
 	"net/http"
 	"strconv"
 	"strings"
@@ -127,10 +129,13 @@ func RelayErrorHandler(ctx context.Context, resp *http.Response, showBodyWhenFai
 }
 
 func ResetStatusCode(newApiErr *types.NewAPIError, statusCodeMappingStr string) {
+	if newApiErr == nil {
+		return
+	}
 	if statusCodeMappingStr == "" || statusCodeMappingStr == "{}" {
 		return
 	}
-	statusCodeMapping := make(map[string]string)
+	statusCodeMapping := make(map[string]any)
 	err := common.Unmarshal([]byte(statusCodeMappingStr), &statusCodeMapping)
 	if err != nil {
 		return
@@ -139,12 +144,44 @@ func ResetStatusCode(newApiErr *types.NewAPIError, statusCodeMappingStr string)
 		return
 	}
 	codeStr := strconv.Itoa(newApiErr.StatusCode)
-	if _, ok := statusCodeMapping[codeStr]; ok {
-		intCode, _ := strconv.Atoi(statusCodeMapping[codeStr])
+	if value, ok := statusCodeMapping[codeStr]; ok {
+		intCode, ok := parseStatusCodeMappingValue(value)
+		if !ok {
+			return
+		}
 		newApiErr.StatusCode = intCode
 	}
 }
 
+func parseStatusCodeMappingValue(value any) (int, bool) {
+	switch v := value.(type) {
+	case string:
+		if v == "" {
+			return 0, false
+		}
+		statusCode, err := strconv.Atoi(v)
+		if err != nil {
+			return 0, false
+		}
+		return statusCode, true
+	case float64:
+		if v != math.Trunc(v) {
+			return 0, false
+		}
+		return int(v), true
+	case int:
+		return v, true
+	case json.Number:
+		statusCode, err := strconv.Atoi(v.String())
+		if err != nil {
+			return 0, false
+		}
+		return statusCode, true
+	default:
+		return 0, false
+	}
+}
+
 func TaskErrorWrapperLocal(err error, code string, statusCode int) *dto.TaskError {
 	openaiErr := TaskErrorWrapper(err, code, statusCode)
 	openaiErr.LocalError = true

+ 57 - 0
service/error_test.go

@@ -0,0 +1,57 @@
+package service
+
+import (
+	"testing"
+
+	"github.com/QuantumNous/new-api/types"
+	"github.com/stretchr/testify/require"
+)
+
+func TestResetStatusCode(t *testing.T) {
+	t.Parallel()
+
+	testCases := []struct {
+		name             string
+		statusCode       int
+		statusCodeConfig string
+		expectedCode     int
+	}{
+		{
+			name:             "map string value",
+			statusCode:       429,
+			statusCodeConfig: `{"429":"503"}`,
+			expectedCode:     503,
+		},
+		{
+			name:             "map int value",
+			statusCode:       429,
+			statusCodeConfig: `{"429":503}`,
+			expectedCode:     503,
+		},
+		{
+			name:             "skip invalid string value",
+			statusCode:       429,
+			statusCodeConfig: `{"429":"bad-code"}`,
+			expectedCode:     429,
+		},
+		{
+			name:             "skip status code 200",
+			statusCode:       200,
+			statusCodeConfig: `{"200":503}`,
+			expectedCode:     200,
+		},
+	}
+
+	for _, tc := range testCases {
+		tc := tc
+		t.Run(tc.name, func(t *testing.T) {
+			t.Parallel()
+
+			newAPIError := &types.NewAPIError{
+				StatusCode: tc.statusCode,
+			}
+			ResetStatusCode(newAPIError, tc.statusCodeConfig)
+			require.Equal(t, tc.expectedCode, newAPIError.StatusCode)
+		})
+	}
+}

+ 2 - 2
web/src/components/table/channels/ChannelsActions.jsx

@@ -99,14 +99,14 @@ const ChannelsActions = ({
                     onClick={() => {
                       Modal.confirm({
                         title: t('确定?'),
-                        content: t('确定要测试所有道吗?'),
+                        content: t('确定要测试所有未手动禁用渠道吗?'),
                         onOk: () => testAllChannels(),
                         size: 'small',
                         centered: true,
                       });
                     }}
                   >
-                    {t('测试所有道')}
+                    {t('测试所有未手动禁用渠道')}
                   </Button>
                 </Dropdown.Item>
                 <Dropdown.Item>

+ 1 - 1
web/src/components/table/tokens/TokensFilters.jsx

@@ -47,7 +47,7 @@ const TokensFilters = ({
         setFormApi(api);
         formApiRef.current = api;
       }}
-      onSubmit={searchTokens}
+      onSubmit={() => searchTokens(1)}
       allowEmpty={true}
       autoComplete='off'
       layout='horizontal'

+ 5 - 1
web/src/hooks/tokens/useTokensData.jsx

@@ -191,6 +191,10 @@ export const useTokensData = (openFluentNotification) => {
 
   // Search tokens function
   const searchTokens = async (page = 1, size = pageSize) => {
+    const normalizedPage = Number.isInteger(page) && page > 0 ? page : 1;
+    const normalizedSize =
+      Number.isInteger(size) && size > 0 ? size : pageSize;
+
     const { searchKeyword, searchToken } = getFormValues();
     if (searchKeyword === '' && searchToken === '') {
       setSearchMode(false);
@@ -199,7 +203,7 @@ export const useTokensData = (openFluentNotification) => {
     }
     setSearching(true);
     const res = await API.get(
-      `/api/token/search?keyword=${encodeURIComponent(searchKeyword)}&token=${encodeURIComponent(searchToken)}&p=${page}&size=${size}`,
+      `/api/token/search?keyword=${encodeURIComponent(searchKeyword)}&token=${encodeURIComponent(searchToken)}&p=${normalizedPage}&size=${normalizedSize}`,
     );
     const { success, message, data } = res.data;
     if (success) {

+ 2 - 0
web/src/i18n/locales/en.json

@@ -1563,6 +1563,7 @@
     "测试失败:": "Test failed: ",
     "测试所有渠道的最长响应时间": "Maximum response time for testing all channels",
     "测试所有通道": "Test all channels",
+    "测试所有未手动禁用渠道": "Test all channels except manually disabled ones",
     "测试模式": "Test Mode",
     "测试连接": "Test Connection",
     "测速": "Speed Test",
@@ -1745,6 +1746,7 @@
     "确定要提升此用户吗?": "Are you sure you want to promote this user?",
     "确定要更新所有已启用通道余额吗?": "Are you sure you want to update the balance of all enabled channels?",
     "确定要测试所有通道吗?": "Are you sure you want to test all channels?",
+    "确定要测试所有未手动禁用渠道吗?": "Are you sure you want to test all channels except manually disabled ones?",
     "确定要禁用所有的密钥吗?": "Are you sure you want to disable all keys?",
     "确定要禁用此用户吗?": "Are you sure you want to disable this user?",
     "确定要降级此用户吗?": "Are you sure you want to demote this user?",

+ 2 - 0
web/src/i18n/locales/fr.json

@@ -1573,6 +1573,7 @@
     "测试失败:": "Test failed: ",
     "测试所有渠道的最长响应时间": "Temps de réponse maximal pour tester tous les canaux",
     "测试所有通道": "Tester tous les canaux",
+    "测试所有未手动禁用渠道": "Tester tous les canaux sauf ceux désactivés manuellement",
     "测试模式": "Mode test",
     "测试连接": "Test Connection",
     "测速": "Test de vitesse",
@@ -1757,6 +1758,7 @@
     "确定要提升此用户吗?": "Êtes-vous sûr de vouloir promouvoir cet utilisateur ?",
     "确定要更新所有已启用通道余额吗?": "Êtes-vous sûr de vouloir mettre à jour le solde de tous les canaux activés ?",
     "确定要测试所有通道吗?": "Êtes-vous sûr de vouloir tester tous les canaux ?",
+    "确定要测试所有未手动禁用渠道吗?": "Êtes-vous sûr de vouloir tester tous les canaux sauf ceux désactivés manuellement ?",
     "确定要禁用所有的密钥吗?": "Êtes-vous sûr de vouloir désactiver toutes les clés ?",
     "确定要禁用此用户吗?": "Êtes-vous sûr de vouloir désactiver cet utilisateur ?",
     "确定要降级此用户吗?": "Êtes-vous sûr de vouloir rétrograder cet utilisateur ?",

+ 2 - 0
web/src/i18n/locales/ja.json

@@ -1558,6 +1558,7 @@
     "测试失败:": "Test failed: ",
     "测试所有渠道的最长响应时间": "すべてのチャネルテストの最大応答時間",
     "测试所有通道": "すべてのチャネルをテスト",
+    "测试所有未手动禁用渠道": "手動で無効化されたものを除くすべてのチャネルをテスト",
     "测试模式": "Test Mode",
     "测试连接": "Test Connection",
     "测速": "スピードテスト",
@@ -1740,6 +1741,7 @@
     "确定要提升此用户吗?": "このユーザーを昇格させてもよろしいですか?",
     "确定要更新所有已启用通道余额吗?": "有効なすべてのチャネルのクォータを更新してもよろしいですか?",
     "确定要测试所有通道吗?": "すべてのチャネルをテストしてもよろしいですか?",
+    "确定要测试所有未手动禁用渠道吗?": "手動で無効化されたチャネルを除くすべてのチャネルをテストしてもよろしいですか?",
     "确定要禁用所有的密钥吗?": "すべてのAPIキーを無効にしてもよろしいですか?",
     "确定要禁用此用户吗?": "このユーザーを無効にしてもよろしいですか?",
     "确定要降级此用户吗?": "このユーザーを降格させてもよろしいですか?",

+ 2 - 0
web/src/i18n/locales/ru.json

@@ -1584,6 +1584,7 @@
     "测试失败:": "Test failed: ",
     "测试所有渠道的最长响应时间": "Максимальное время отклика для тестирования всех каналов",
     "测试所有通道": "Тестировать все каналы",
+    "测试所有未手动禁用渠道": "Тестировать все каналы, кроме отключенных вручную",
     "测试模式": "Тестовый режим",
     "测试连接": "Test Connection",
     "测速": "Измерение скорости",
@@ -1770,6 +1771,7 @@
     "确定要提升此用户吗?": "Подтвердить повышение этого пользователя?",
     "确定要更新所有已启用通道余额吗?": "Подтвердить обновление баланса всех включенных каналов?",
     "确定要测试所有通道吗?": "Подтвердить тестирование всех каналов?",
+    "确定要测试所有未手动禁用渠道吗?": "Вы уверены, что хотите протестировать все каналы, кроме отключенных вручную?",
     "确定要禁用所有的密钥吗?": "Подтвердить отключение всех ключей?",
     "确定要禁用此用户吗?": "Подтвердить отключение этого пользователя?",
     "确定要降级此用户吗?": "Подтвердить понижение этого пользователя?",

+ 2 - 0
web/src/i18n/locales/vi.json

@@ -1620,6 +1620,7 @@
     "测试成功,耗时 ": "Kiểm tra thành công, mất ",
     "测试所有渠道的最长响应时间": "Thời gian phản hồi tối đa để kiểm tra tất cả các kênh",
     "测试所有通道": "Kiểm tra tất cả các kênh",
+    "测试所有未手动禁用渠道": "Kiểm tra tất cả các kênh ngoại trừ các kênh bị vô hiệu hóa thủ công",
     "测试模型": "Mô hình kiểm tra",
     "测试模型耗时": "Thời gian kiểm tra mô hình",
     "测试模式": "Chế độ kiểm tra",
@@ -1971,6 +1972,7 @@
     "确定要提升此用户吗?": "Bạn có chắc chắn muốn thăng cấp người dùng này không?",
     "确定要更新所有已启用通道余额吗?": "Bạn có chắc chắn muốn cập nhật số dư của tất cả các kênh đã bật không?",
     "确定要测试所有通道吗?": "Bạn có chắc chắn muốn kiểm tra tất cả các kênh không?",
+    "确定要测试所有未手动禁用渠道吗?": "Bạn có chắc chắn muốn kiểm tra tất cả các kênh ngoại trừ các kênh bị vô hiệu hóa thủ công không?",
     "确定要禁用所有的密钥吗?": "Bạn có chắc chắn muốn vô hiệu hóa tất cả các khóa không?",
     "确定要禁用此用户吗?": "Bạn có chắc chắn muốn vô hiệu hóa người dùng này không?",
     "确定要降级此用户吗?": "Bạn có chắc chắn muốn hạ cấp người dùng này không?",

+ 2 - 0
web/src/i18n/locales/zh-CN.json

@@ -1553,6 +1553,7 @@
     "测试失败:": "测试失败:",
     "测试所有渠道的最长响应时间": "测试所有渠道的最长响应时间",
     "测试所有通道": "测试所有通道",
+    "测试所有未手动禁用渠道": "测试所有未手动禁用渠道",
     "测试模式": "测试模式",
     "测试连接": "测试连接",
     "测速": "测速",
@@ -1733,6 +1734,7 @@
     "确定要提升此用户吗?": "确定要提升此用户吗?",
     "确定要更新所有已启用通道余额吗?": "确定要更新所有已启用通道余额吗?",
     "确定要测试所有通道吗?": "确定要测试所有通道吗?",
+    "确定要测试所有未手动禁用渠道吗?": "确定要测试所有未手动禁用渠道吗?",
     "确定要禁用所有的密钥吗?": "确定要禁用所有的密钥吗?",
     "确定要禁用此用户吗?": "确定要禁用此用户吗?",
     "确定要降级此用户吗?": "确定要降级此用户吗?",

+ 2 - 0
web/src/i18n/locales/zh-TW.json

@@ -1553,6 +1553,7 @@
     "测试失败:": "測試失敗:",
     "测试所有渠道的最长响应时间": "測試所有管道的最長響應時間",
     "测试所有通道": "測試所有通道",
+    "测试所有未手动禁用渠道": "測試所有未手動停用通道",
     "测试模式": "測試模式",
     "测试连接": "測試連接",
     "测速": "測速",
@@ -1733,6 +1734,7 @@
     "确定要提升此用户吗?": "確定要提升此使用者嗎?",
     "确定要更新所有已启用通道余额吗?": "確定要更新所有已啟用通道餘額嗎?",
     "确定要测试所有通道吗?": "確定要測試所有通道嗎?",
+    "确定要测试所有未手动禁用渠道吗?": "確定要測試所有未手動停用通道嗎?",
     "确定要禁用所有的密钥吗?": "確定要禁用所有的密鑰嗎?",
     "确定要禁用此用户吗?": "確定要禁用此使用者嗎?",
     "确定要降级此用户吗?": "確定要降級此使用者嗎?",