Przeglądaj źródła

feat: channel test stream

Seefs 4 tygodni temu
rodzic
commit
23227e18f9

+ 129 - 17
controller/channel-test.go

@@ -31,6 +31,7 @@ import (
 
 	"github.com/bytedance/gopkg/util/gopool"
 	"github.com/samber/lo"
+	"github.com/tidwall/gjson"
 
 	"github.com/gin-gonic/gin"
 )
@@ -41,7 +42,7 @@ type testResult struct {
 	newAPIError *types.NewAPIError
 }
 
-func testChannel(channel *model.Channel, testModel string, endpointType string) testResult {
+func testChannel(channel *model.Channel, testModel string, endpointType string, isStream bool) testResult {
 	tik := time.Now()
 	var unsupportedTestChannelTypes = []int{
 		constant.ChannelTypeMidjourney,
@@ -200,7 +201,7 @@ func testChannel(channel *model.Channel, testModel string, endpointType string)
 		}
 	}
 
-	request := buildTestRequest(testModel, endpointType, channel)
+	request := buildTestRequest(testModel, endpointType, channel, isStream)
 
 	info, err := relaycommon.GenRelayInfo(c, relayFormat, request, nil)
 
@@ -418,16 +419,16 @@ func testChannel(channel *model.Channel, testModel string, endpointType string)
 			newAPIError: respErr,
 		}
 	}
-	if usageA == nil {
+	usage, usageErr := coerceTestUsage(usageA, isStream, info.GetEstimatePromptTokens())
+	if usageErr != nil {
 		return testResult{
 			context:     c,
-			localErr:    errors.New("usage is nil"),
-			newAPIError: types.NewOpenAIError(errors.New("usage is nil"), types.ErrorCodeBadResponseBody, http.StatusInternalServerError),
+			localErr:    usageErr,
+			newAPIError: types.NewOpenAIError(usageErr, types.ErrorCodeBadResponseBody, http.StatusInternalServerError),
 		}
 	}
-	usage := usageA.(*dto.Usage)
 	result := w.Result()
-	respBody, err := io.ReadAll(result.Body)
+	respBody, err := readTestResponseBody(result.Body, isStream)
 	if err != nil {
 		return testResult{
 			context:     c,
@@ -435,6 +436,13 @@ func testChannel(channel *model.Channel, testModel string, endpointType string)
 			newAPIError: types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError),
 		}
 	}
+	if bodyErr := detectErrorFromTestResponseBody(respBody); bodyErr != nil {
+		return testResult{
+			context:     c,
+			localErr:    bodyErr,
+			newAPIError: types.NewOpenAIError(bodyErr, types.ErrorCodeBadResponseBody, http.StatusInternalServerError),
+		}
+	}
 	info.SetEstimatePromptTokens(usage.PromptTokens)
 
 	quota := 0
@@ -473,7 +481,101 @@ func testChannel(channel *model.Channel, testModel string, endpointType string)
 	}
 }
 
-func buildTestRequest(model string, endpointType string, channel *model.Channel) dto.Request {
+func coerceTestUsage(usageAny any, isStream bool, estimatePromptTokens int) (*dto.Usage, error) {
+	switch u := usageAny.(type) {
+	case *dto.Usage:
+		return u, nil
+	case dto.Usage:
+		return &u, nil
+	case nil:
+		if !isStream {
+			return nil, errors.New("usage is nil")
+		}
+		usage := &dto.Usage{
+			PromptTokens: estimatePromptTokens,
+		}
+		usage.TotalTokens = usage.PromptTokens
+		return usage, nil
+	default:
+		if !isStream {
+			return nil, fmt.Errorf("invalid usage type: %T", usageAny)
+		}
+		usage := &dto.Usage{
+			PromptTokens: estimatePromptTokens,
+		}
+		usage.TotalTokens = usage.PromptTokens
+		return usage, nil
+	}
+}
+
+func readTestResponseBody(body io.ReadCloser, isStream bool) ([]byte, error) {
+	defer func() { _ = body.Close() }()
+	const maxStreamLogBytes = 8 << 10
+	if isStream {
+		return io.ReadAll(io.LimitReader(body, maxStreamLogBytes))
+	}
+	return io.ReadAll(body)
+}
+
+func detectErrorFromTestResponseBody(respBody []byte) error {
+	b := bytes.TrimSpace(respBody)
+	if len(b) == 0 {
+		return nil
+	}
+	if message := detectErrorMessageFromJSONBytes(b); message != "" {
+		return fmt.Errorf("upstream error: %s", message)
+	}
+
+	for _, line := range bytes.Split(b, []byte{'\n'}) {
+		line = bytes.TrimSpace(line)
+		if len(line) == 0 {
+			continue
+		}
+		if !bytes.HasPrefix(line, []byte("data:")) {
+			continue
+		}
+		payload := bytes.TrimSpace(bytes.TrimPrefix(line, []byte("data:")))
+		if len(payload) == 0 || bytes.Equal(payload, []byte("[DONE]")) {
+			continue
+		}
+		if message := detectErrorMessageFromJSONBytes(payload); message != "" {
+			return fmt.Errorf("upstream error: %s", message)
+		}
+	}
+
+	return nil
+}
+
+func detectErrorMessageFromJSONBytes(jsonBytes []byte) string {
+	if len(jsonBytes) == 0 {
+		return ""
+	}
+	if jsonBytes[0] != '{' && jsonBytes[0] != '[' {
+		return ""
+	}
+	errVal := gjson.GetBytes(jsonBytes, "error")
+	if !errVal.Exists() || errVal.Type == gjson.Null {
+		return ""
+	}
+
+	message := gjson.GetBytes(jsonBytes, "error.message").String()
+	if message == "" {
+		message = gjson.GetBytes(jsonBytes, "error.error.message").String()
+	}
+	if message == "" && errVal.Type == gjson.String {
+		message = errVal.String()
+	}
+	if message == "" {
+		message = errVal.Raw
+	}
+	message = strings.TrimSpace(message)
+	if message == "" {
+		return "upstream returned error payload"
+	}
+	return message
+}
+
+func buildTestRequest(model string, endpointType string, channel *model.Channel, isStream bool) dto.Request {
 	testResponsesInput := json.RawMessage(`[{"role":"user","content":"hi"}]`)
 
 	// 根据端点类型构建不同的测试请求
@@ -504,8 +606,9 @@ func buildTestRequest(model string, endpointType string, channel *model.Channel)
 		case constant.EndpointTypeOpenAIResponse:
 			// 返回 OpenAIResponsesRequest
 			return &dto.OpenAIResponsesRequest{
-				Model: model,
-				Input: json.RawMessage(`[{"role":"user","content":"hi"}]`),
+				Model:  model,
+				Input:  json.RawMessage(`[{"role":"user","content":"hi"}]`),
+				Stream: isStream,
 			}
 		case constant.EndpointTypeOpenAIResponseCompact:
 			// 返回 OpenAIResponsesCompactionRequest
@@ -519,9 +622,9 @@ func buildTestRequest(model string, endpointType string, channel *model.Channel)
 			if constant.EndpointType(endpointType) == constant.EndpointTypeGemini {
 				maxTokens = 3000
 			}
-			return &dto.GeneralOpenAIRequest{
+			req := &dto.GeneralOpenAIRequest{
 				Model:  model,
-				Stream: false,
+				Stream: isStream,
 				Messages: []dto.Message{
 					{
 						Role:    "user",
@@ -530,6 +633,10 @@ func buildTestRequest(model string, endpointType string, channel *model.Channel)
 				},
 				MaxTokens: maxTokens,
 			}
+			if isStream {
+				req.StreamOptions = &dto.StreamOptions{IncludeUsage: true}
+			}
+			return req
 		}
 	}
 
@@ -565,15 +672,16 @@ func buildTestRequest(model string, endpointType string, channel *model.Channel)
 	// Responses-only models (e.g. codex series)
 	if strings.Contains(strings.ToLower(model), "codex") {
 		return &dto.OpenAIResponsesRequest{
-			Model: model,
-			Input: json.RawMessage(`[{"role":"user","content":"hi"}]`),
+			Model:  model,
+			Input:  json.RawMessage(`[{"role":"user","content":"hi"}]`),
+			Stream: isStream,
 		}
 	}
 
 	// Chat/Completion 请求 - 返回 GeneralOpenAIRequest
 	testRequest := &dto.GeneralOpenAIRequest{
 		Model:  model,
-		Stream: false,
+		Stream: isStream,
 		Messages: []dto.Message{
 			{
 				Role:    "user",
@@ -581,6 +689,9 @@ func buildTestRequest(model string, endpointType string, channel *model.Channel)
 			},
 		},
 	}
+	if isStream {
+		testRequest.StreamOptions = &dto.StreamOptions{IncludeUsage: true}
+	}
 
 	if strings.HasPrefix(model, "o") {
 		testRequest.MaxCompletionTokens = 16
@@ -618,8 +729,9 @@ func TestChannel(c *gin.Context) {
 	//}()
 	testModel := c.Query("model")
 	endpointType := c.Query("endpoint_type")
+	isStream, _ := strconv.ParseBool(c.Query("stream"))
 	tik := time.Now()
-	result := testChannel(channel, testModel, endpointType)
+	result := testChannel(channel, testModel, endpointType, isStream)
 	if result.localErr != nil {
 		c.JSON(http.StatusOK, gin.H{
 			"success": false,
@@ -678,7 +790,7 @@ func testAllChannels(notify bool) error {
 		for _, channel := range channels {
 			isChannelEnabled := channel.Status == common.ChannelStatusEnabled
 			tik := time.Now()
-			result := testChannel(channel, "", "")
+			result := testChannel(channel, "", "", false)
 			tok := time.Now()
 			milliseconds := tok.Sub(tik).Milliseconds()
 

+ 60 - 21
web/src/components/table/channels/modals/ModelTestModal.jsx

@@ -26,8 +26,10 @@ import {
   Tag,
   Typography,
   Select,
+  Switch,
+  Banner,
 } from '@douyinfe/semi-ui';
-import { IconSearch } from '@douyinfe/semi-icons';
+import { IconSearch, IconInfoCircle } from '@douyinfe/semi-icons';
 import { copy, showError, showInfo, showSuccess } from '../../../../helpers';
 import { MODEL_TABLE_PAGE_SIZE } from '../../../../constants';
 
@@ -48,11 +50,25 @@ const ModelTestModal = ({
   setModelTablePage,
   selectedEndpointType,
   setSelectedEndpointType,
+  isStreamTest,
+  setIsStreamTest,
   allSelectingRef,
   isMobile,
   t,
 }) => {
   const hasChannel = Boolean(currentTestChannel);
+  const streamToggleDisabled = [
+    'embeddings',
+    'image-generation',
+    'jina-rerank',
+    'openai-response-compact',
+  ].includes(selectedEndpointType);
+
+  React.useEffect(() => {
+    if (streamToggleDisabled && isStreamTest) {
+      setIsStreamTest(false);
+    }
+  }, [streamToggleDisabled, isStreamTest, setIsStreamTest]);
 
   const filteredModels = hasChannel
     ? currentTestChannel.models
@@ -181,6 +197,7 @@ const ModelTestModal = ({
                 currentTestChannel,
                 record.model,
                 selectedEndpointType,
+                isStreamTest,
               )
             }
             loading={isTesting}
@@ -258,25 +275,46 @@ const ModelTestModal = ({
     >
       {hasChannel && (
         <div className='model-test-scroll'>
-          {/* 端点类型选择器 */}
-          <div className='flex items-center gap-2 w-full mb-2'>
-            <Typography.Text strong>{t('端点类型')}:</Typography.Text>
-            <Select
-              value={selectedEndpointType}
-              onChange={setSelectedEndpointType}
-              optionList={endpointTypeOptions}
-              className='!w-full'
-              placeholder={t('选择端点类型')}
-            />
+          {/* Endpoint toolbar */}
+          <div className='flex flex-col sm:flex-row sm:items-center gap-2 w-full mb-2'>
+            <div className='flex items-center gap-2 flex-1 min-w-0'>
+              <Typography.Text strong className='shrink-0'>
+                {t('端点类型')}:
+              </Typography.Text>
+              <Select
+                value={selectedEndpointType}
+                onChange={setSelectedEndpointType}
+                optionList={endpointTypeOptions}
+                className='!w-full min-w-0'
+                placeholder={t('选择端点类型')}
+              />
+            </div>
+            <div className='flex items-center justify-between sm:justify-end gap-2 shrink-0'>
+              <Typography.Text strong className='shrink-0'>
+                {t('流式')}:
+              </Typography.Text>
+              <Switch
+                checked={isStreamTest}
+                onChange={setIsStreamTest}
+                size='small'
+                disabled={streamToggleDisabled}
+                aria-label={t('流式')}
+              />
+            </div>
           </div>
-          <Typography.Text type='tertiary' size='small' className='block mb-2'>
-            {t(
+
+          <Banner
+            type='info'
+            closeIcon={null}
+            icon={<IconInfoCircle />}
+            className='!rounded-lg mb-2'
+            description={t(
               '说明:本页测试为非流式请求;若渠道仅支持流式返回,可能出现测试失败,请以实际使用为准。',
             )}
-          </Typography.Text>
+          />
 
           {/* 搜索与操作按钮 */}
-          <div className='flex items-center justify-end gap-2 w-full mb-2'>
+          <div className='flex flex-col sm:flex-row sm:items-center gap-2 w-full mb-2'>
             <Input
               placeholder={t('搜索模型...')}
               value={modelSearchKeyword}
@@ -284,16 +322,17 @@ const ModelTestModal = ({
                 setModelSearchKeyword(v);
                 setModelTablePage(1);
               }}
-              className='!w-full'
+              className='!w-full sm:!flex-1'
               prefix={<IconSearch />}
               showClear
             />
 
-            <Button onClick={handleCopySelected}>{t('复制已选')}</Button>
-
-            <Button type='tertiary' onClick={handleSelectSuccess}>
-              {t('选择成功')}
-            </Button>
+            <div className='flex items-center justify-end gap-2'>
+              <Button onClick={handleCopySelected}>{t('复制已选')}</Button>
+              <Button type='tertiary' onClick={handleSelectSuccess}>
+                {t('选择成功')}
+              </Button>
+            </div>
           </div>
 
           <Table

+ 19 - 2
web/src/hooks/channels/useChannelsData.jsx

@@ -87,6 +87,7 @@ export const useChannelsData = () => {
   const [isBatchTesting, setIsBatchTesting] = useState(false);
   const [modelTablePage, setModelTablePage] = useState(1);
   const [selectedEndpointType, setSelectedEndpointType] = useState('');
+  const [isStreamTest, setIsStreamTest] = useState(false);
   const [globalPassThroughEnabled, setGlobalPassThroughEnabled] =
     useState(false);
 
@@ -851,7 +852,12 @@ export const useChannelsData = () => {
   };
 
   // Test channel - 单个模型测试,参考旧版实现
-  const testChannel = async (record, model, endpointType = '') => {
+  const testChannel = async (
+    record,
+    model,
+    endpointType = '',
+    stream = false,
+  ) => {
     const testKey = `${record.id}-${model}`;
 
     // 检查是否应该停止批量测试
@@ -867,6 +873,9 @@ export const useChannelsData = () => {
       if (endpointType) {
         url += `&endpoint_type=${endpointType}`;
       }
+      if (stream) {
+        url += `&stream=true`;
+      }
       const res = await API.get(url);
 
       // 检查是否在请求期间被停止
@@ -995,7 +1004,12 @@ export const useChannelsData = () => {
         );
 
         const batchPromises = batch.map((model) =>
-          testChannel(currentTestChannel, model, selectedEndpointType),
+          testChannel(
+            currentTestChannel,
+            model,
+            selectedEndpointType,
+            isStreamTest,
+          ),
         );
         const batchResults = await Promise.allSettled(batchPromises);
         results.push(...batchResults);
@@ -1080,6 +1094,7 @@ export const useChannelsData = () => {
     setSelectedModelKeys([]);
     setModelTablePage(1);
     setSelectedEndpointType('');
+    setIsStreamTest(false);
     // 可选择性保留测试结果,这里不清空以便用户查看
   };
 
@@ -1170,6 +1185,8 @@ export const useChannelsData = () => {
     setModelTablePage,
     selectedEndpointType,
     setSelectedEndpointType,
+    isStreamTest,
+    setIsStreamTest,
     allSelectingRef,
 
     // Multi-key management states

+ 1 - 0
web/src/i18n/locales/en.json

@@ -1547,6 +1547,7 @@
     "流": "stream",
     "流式响应完成": "Streaming response completed",
     "流式输出": "Streaming Output",
+    "流式": "Streaming",
     "流量端口": "Traffic Port",
     "浅色": "Light",
     "浅色模式": "Light Mode",

+ 1 - 0
web/src/i18n/locales/fr.json

@@ -1557,6 +1557,7 @@
     "流": "Flux",
     "流式响应完成": "Flux terminé",
     "流式输出": "Sortie en flux",
+    "流式": "Streaming",
     "流量端口": "Traffic Port",
     "浅色": "Clair",
     "浅色模式": "Mode clair",

+ 1 - 0
web/src/i18n/locales/ja.json

@@ -1542,6 +1542,7 @@
     "流": "ストリーム",
     "流式响应完成": "ストリーム完了",
     "流式输出": "ストリーム出力",
+    "流式": "ストリーミング",
     "流量端口": "Traffic Port",
     "浅色": "ライト",
     "浅色模式": "ライトモード",

+ 1 - 0
web/src/i18n/locales/ru.json

@@ -1568,6 +1568,7 @@
     "流": "Поток",
     "流式响应完成": "Поток завершён",
     "流式输出": "Потоковый вывод",
+    "流式": "Стриминг",
     "流量端口": "Traffic Port",
     "浅色": "Светлая",
     "浅色模式": "Светлый режим",

+ 1 - 0
web/src/i18n/locales/vi.json

@@ -1596,6 +1596,7 @@
     "流": "luồng",
     "流式响应完成": "Luồng hoàn tất",
     "流式输出": "Đầu ra luồng",
+    "流式": "Streaming",
     "流量端口": "Traffic Port",
     "浅色": "Sáng",
     "浅色模式": "Chế độ sáng",

+ 1 - 0
web/src/i18n/locales/zh.json

@@ -1537,6 +1537,7 @@
     "流": "流",
     "流式响应完成": "流式响应完成",
     "流式输出": "流式输出",
+    "流式": "流式",
     "流量端口": "流量端口",
     "浅色": "浅色",
     "浅色模式": "浅色模式",