Просмотр исходного кода

feat: add endpoint type selection to channel testing functionality

CaIon 5 месяцев назад
Родитель
Сommit
6bc3e62fd5

+ 1 - 0
common/endpoint_defaults.go

@@ -23,6 +23,7 @@ var defaultEndpointInfoMap = map[constant.EndpointType]EndpointInfo{
 	constant.EndpointTypeGemini:          {Path: "/v1beta/models/{model}:generateContent", Method: "POST"},
 	constant.EndpointTypeJinaRerank:      {Path: "/rerank", Method: "POST"},
 	constant.EndpointTypeImageGeneration: {Path: "/v1/images/generations", Method: "POST"},
+	constant.EndpointTypeEmbeddings:      {Path: "/v1/embeddings", Method: "POST"},
 }
 
 // GetDefaultEndpointInfo 返回指定端点类型的默认信息以及是否存在

+ 1 - 0
constant/endpoint_type.go

@@ -9,6 +9,7 @@ const (
 	EndpointTypeGemini          EndpointType = "gemini"
 	EndpointTypeJinaRerank      EndpointType = "jina-rerank"
 	EndpointTypeImageGeneration EndpointType = "image-generation"
+	EndpointTypeEmbeddings      EndpointType = "embeddings"
 	//EndpointTypeMidjourney     EndpointType = "midjourney-proxy"
 	//EndpointTypeSuno           EndpointType = "suno-proxy"
 	//EndpointTypeKling          EndpointType = "kling"

+ 195 - 82
controller/channel-test.go

@@ -38,7 +38,7 @@ type testResult struct {
 	newAPIError *types.NewAPIError
 }
 
-func testChannel(channel *model.Channel, testModel string) testResult {
+func testChannel(channel *model.Channel, testModel string, endpointType string) testResult {
 	tik := time.Now()
 	if channel.Type == constant.ChannelTypeMidjourney {
 		return testResult{
@@ -81,18 +81,26 @@ func testChannel(channel *model.Channel, testModel string) testResult {
 
 	requestPath := "/v1/chat/completions"
 
-	// 先判断是否为 Embedding 模型
-	if strings.Contains(strings.ToLower(testModel), "embedding") ||
-		strings.HasPrefix(testModel, "m3e") || // m3e 系列模型
-		strings.Contains(testModel, "bge-") || // bge 系列模型
-		strings.Contains(testModel, "embed") ||
-		channel.Type == constant.ChannelTypeMokaAI { // 其他 embedding 模型
-		requestPath = "/v1/embeddings" // 修改请求路径
-	}
+	// 如果指定了端点类型,使用指定的端点类型
+	if endpointType != "" {
+		if endpointInfo, ok := common.GetDefaultEndpointInfo(constant.EndpointType(endpointType)); ok {
+			requestPath = endpointInfo.Path
+		}
+	} else {
+		// 如果没有指定端点类型,使用原有的自动检测逻辑
+		// 先判断是否为 Embedding 模型
+		if strings.Contains(strings.ToLower(testModel), "embedding") ||
+			strings.HasPrefix(testModel, "m3e") || // m3e 系列模型
+			strings.Contains(testModel, "bge-") || // bge 系列模型
+			strings.Contains(testModel, "embed") ||
+			channel.Type == constant.ChannelTypeMokaAI { // 其他 embedding 模型
+			requestPath = "/v1/embeddings" // 修改请求路径
+		}
 
-	// VolcEngine 图像生成模型
-	if channel.Type == constant.ChannelTypeVolcEngine && strings.Contains(testModel, "seedream") {
-		requestPath = "/v1/images/generations"
+		// VolcEngine 图像生成模型
+		if channel.Type == constant.ChannelTypeVolcEngine && strings.Contains(testModel, "seedream") {
+			requestPath = "/v1/images/generations"
+		}
 	}
 
 	c.Request = &http.Request{
@@ -114,21 +122,6 @@ func testChannel(channel *model.Channel, testModel string) testResult {
 		}
 	}
 
-	// 重新检查模型类型并更新请求路径
-	if strings.Contains(strings.ToLower(testModel), "embedding") ||
-		strings.HasPrefix(testModel, "m3e") ||
-		strings.Contains(testModel, "bge-") ||
-		strings.Contains(testModel, "embed") ||
-		channel.Type == constant.ChannelTypeMokaAI {
-		requestPath = "/v1/embeddings"
-		c.Request.URL.Path = requestPath
-	}
-
-	if channel.Type == constant.ChannelTypeVolcEngine && strings.Contains(testModel, "seedream") {
-		requestPath = "/v1/images/generations"
-		c.Request.URL.Path = requestPath
-	}
-
 	cache, err := model.GetUserCache(1)
 	if err != nil {
 		return testResult{
@@ -153,17 +146,54 @@ func testChannel(channel *model.Channel, testModel string) testResult {
 			newAPIError: newAPIError,
 		}
 	}
-	request := buildTestRequest(testModel)
 
-	// Determine relay format based on request path
-	relayFormat := types.RelayFormatOpenAI
-	if c.Request.URL.Path == "/v1/embeddings" {
-		relayFormat = types.RelayFormatEmbedding
-	}
-	if c.Request.URL.Path == "/v1/images/generations" {
-		relayFormat = types.RelayFormatOpenAIImage
+	// Determine relay format based on endpoint type or request path
+	var relayFormat types.RelayFormat
+	if endpointType != "" {
+		// 根据指定的端点类型设置 relayFormat
+		switch constant.EndpointType(endpointType) {
+		case constant.EndpointTypeOpenAI:
+			relayFormat = types.RelayFormatOpenAI
+		case constant.EndpointTypeOpenAIResponse:
+			relayFormat = types.RelayFormatOpenAIResponses
+		case constant.EndpointTypeAnthropic:
+			relayFormat = types.RelayFormatClaude
+		case constant.EndpointTypeGemini:
+			relayFormat = types.RelayFormatGemini
+		case constant.EndpointTypeJinaRerank:
+			relayFormat = types.RelayFormatRerank
+		case constant.EndpointTypeImageGeneration:
+			relayFormat = types.RelayFormatOpenAIImage
+		case constant.EndpointTypeEmbeddings:
+			relayFormat = types.RelayFormatEmbedding
+		default:
+			relayFormat = types.RelayFormatOpenAI
+		}
+	} else {
+		// 根据请求路径自动检测
+		relayFormat = types.RelayFormatOpenAI
+		if c.Request.URL.Path == "/v1/embeddings" {
+			relayFormat = types.RelayFormatEmbedding
+		}
+		if c.Request.URL.Path == "/v1/images/generations" {
+			relayFormat = types.RelayFormatOpenAIImage
+		}
+		if c.Request.URL.Path == "/v1/messages" {
+			relayFormat = types.RelayFormatClaude
+		}
+		if strings.Contains(c.Request.URL.Path, "/v1beta/models") {
+			relayFormat = types.RelayFormatGemini
+		}
+		if c.Request.URL.Path == "/v1/rerank" || c.Request.URL.Path == "/rerank" {
+			relayFormat = types.RelayFormatRerank
+		}
+		if c.Request.URL.Path == "/v1/responses" {
+			relayFormat = types.RelayFormatOpenAIResponses
+		}
 	}
 
+	request := buildTestRequest(testModel, endpointType)
+
 	info, err := relaycommon.GenRelayInfo(c, relayFormat, request, nil)
 
 	if err != nil {
@@ -186,7 +216,8 @@ func testChannel(channel *model.Channel, testModel string) testResult {
 	}
 
 	testModel = info.UpstreamModelName
-	request.Model = testModel
+	// 更新请求中的模型名称
+	request.SetModelName(testModel)
 
 	apiType, _ := common.ChannelType2APIType(channel.Type)
 	adaptor := relay.GetAdaptor(apiType)
@@ -216,33 +247,62 @@ func testChannel(channel *model.Channel, testModel string) testResult {
 
 	var convertedRequest any
 	// 根据 RelayMode 选择正确的转换函数
-	if info.RelayMode == relayconstant.RelayModeEmbeddings {
-		// 创建一个 EmbeddingRequest
-		embeddingRequest := dto.EmbeddingRequest{
-			Input: request.Input,
-			Model: request.Model,
-		}
-		// 调用专门用于 Embedding 的转换函数
-		convertedRequest, err = adaptor.ConvertEmbeddingRequest(c, info, embeddingRequest)
-	} else if info.RelayMode == relayconstant.RelayModeImagesGenerations {
-		// 创建一个 ImageRequest
-		prompt := "cat"
-		if request.Prompt != nil {
-			if promptStr, ok := request.Prompt.(string); ok && promptStr != "" {
-				prompt = promptStr
+	switch info.RelayMode {
+	case relayconstant.RelayModeEmbeddings:
+		// Embedding 请求 - request 已经是正确的类型
+		if embeddingReq, ok := request.(*dto.EmbeddingRequest); ok {
+			convertedRequest, err = adaptor.ConvertEmbeddingRequest(c, info, *embeddingReq)
+		} else {
+			return testResult{
+				context:     c,
+				localErr:    errors.New("invalid embedding request type"),
+				newAPIError: types.NewError(errors.New("invalid embedding request type"), types.ErrorCodeConvertRequestFailed),
 			}
 		}
-		imageRequest := dto.ImageRequest{
-			Prompt: prompt,
-			Model:  request.Model,
-			N:      uint(request.N),
-			Size:   request.Size,
+	case relayconstant.RelayModeImagesGenerations:
+		// 图像生成请求 - request 已经是正确的类型
+		if imageReq, ok := request.(*dto.ImageRequest); ok {
+			convertedRequest, err = adaptor.ConvertImageRequest(c, info, *imageReq)
+		} else {
+			return testResult{
+				context:     c,
+				localErr:    errors.New("invalid image request type"),
+				newAPIError: types.NewError(errors.New("invalid image request type"), types.ErrorCodeConvertRequestFailed),
+			}
+		}
+	case relayconstant.RelayModeRerank:
+		// Rerank 请求 - request 已经是正确的类型
+		if rerankReq, ok := request.(*dto.RerankRequest); ok {
+			convertedRequest, err = adaptor.ConvertRerankRequest(c, info.RelayMode, *rerankReq)
+		} else {
+			return testResult{
+				context:     c,
+				localErr:    errors.New("invalid rerank request type"),
+				newAPIError: types.NewError(errors.New("invalid rerank request type"), types.ErrorCodeConvertRequestFailed),
+			}
+		}
+	case relayconstant.RelayModeResponses:
+		// Response 请求 - request 已经是正确的类型
+		if responseReq, ok := request.(*dto.OpenAIResponsesRequest); ok {
+			convertedRequest, err = adaptor.ConvertOpenAIResponsesRequest(c, info, *responseReq)
+		} else {
+			return testResult{
+				context:     c,
+				localErr:    errors.New("invalid response request type"),
+				newAPIError: types.NewError(errors.New("invalid response request type"), types.ErrorCodeConvertRequestFailed),
+			}
+		}
+	default:
+		// Chat/Completion 等其他请求类型
+		if generalReq, ok := request.(*dto.GeneralOpenAIRequest); ok {
+			convertedRequest, err = adaptor.ConvertOpenAIRequest(c, info, generalReq)
+		} else {
+			return testResult{
+				context:     c,
+				localErr:    errors.New("invalid general request type"),
+				newAPIError: types.NewError(errors.New("invalid general request type"), types.ErrorCodeConvertRequestFailed),
+			}
 		}
-		// 调用专门用于图像生成的转换函数
-		convertedRequest, err = adaptor.ConvertImageRequest(c, info, imageRequest)
-	} else {
-		// 对其他所有请求类型(如 Chat),保持原有逻辑
-		convertedRequest, err = adaptor.ConvertOpenAIRequest(c, info, request)
 	}
 
 	if err != nil {
@@ -345,22 +405,82 @@ func testChannel(channel *model.Channel, testModel string) testResult {
 	}
 }
 
-func buildTestRequest(model string) *dto.GeneralOpenAIRequest {
-	testRequest := &dto.GeneralOpenAIRequest{
-		Model:  "", // this will be set later
-		Stream: false,
+func buildTestRequest(model string, endpointType string) dto.Request {
+	// 根据端点类型构建不同的测试请求
+	if endpointType != "" {
+		switch constant.EndpointType(endpointType) {
+		case constant.EndpointTypeEmbeddings:
+			// 返回 EmbeddingRequest
+			return &dto.EmbeddingRequest{
+				Model: model,
+				Input: []any{"hello world"},
+			}
+		case constant.EndpointTypeImageGeneration:
+			// 返回 ImageRequest
+			return &dto.ImageRequest{
+				Model:  model,
+				Prompt: "a cute cat",
+				N:      1,
+				Size:   "1024x1024",
+			}
+		case constant.EndpointTypeJinaRerank:
+			// 返回 RerankRequest
+			return &dto.RerankRequest{
+				Model:     model,
+				Query:     "What is Deep Learning?",
+				Documents: []any{"Deep Learning is a subset of machine learning.", "Machine learning is a field of artificial intelligence."},
+				TopN:      2,
+			}
+		case constant.EndpointTypeOpenAIResponse:
+			// 返回 OpenAIResponsesRequest
+			return &dto.OpenAIResponsesRequest{
+				Model: model,
+				Input: json.RawMessage("\"hi\""),
+			}
+		case constant.EndpointTypeAnthropic, constant.EndpointTypeGemini, constant.EndpointTypeOpenAI:
+			// 返回 GeneralOpenAIRequest
+			maxTokens := uint(10)
+			if constant.EndpointType(endpointType) == constant.EndpointTypeGemini {
+				maxTokens = 3000
+			}
+			return &dto.GeneralOpenAIRequest{
+				Model:  model,
+				Stream: false,
+				Messages: []dto.Message{
+					{
+						Role:    "user",
+						Content: "hi",
+					},
+				},
+				MaxTokens: maxTokens,
+			}
+		}
 	}
 
+	// 自动检测逻辑(保持原有行为)
 	// 先判断是否为 Embedding 模型
-	if strings.Contains(strings.ToLower(model), "embedding") || // 其他 embedding 模型
-		strings.HasPrefix(model, "m3e") || // m3e 系列模型
+	if strings.Contains(strings.ToLower(model), "embedding") ||
+		strings.HasPrefix(model, "m3e") ||
 		strings.Contains(model, "bge-") {
-		testRequest.Model = model
-		// Embedding 请求
-		testRequest.Input = []any{"hello world"} // 修改为any,因为dto/openai_request.go 的ParseInput方法无法处理[]string类型
-		return testRequest
+		// 返回 EmbeddingRequest
+		return &dto.EmbeddingRequest{
+			Model: model,
+			Input: []any{"hello world"},
+		}
 	}
-	// 并非Embedding 模型
+
+	// Chat/Completion 请求 - 返回 GeneralOpenAIRequest
+	testRequest := &dto.GeneralOpenAIRequest{
+		Model:  model,
+		Stream: false,
+		Messages: []dto.Message{
+			{
+				Role:    "user",
+				Content: "hi",
+			},
+		},
+	}
+
 	if strings.HasPrefix(model, "o") {
 		testRequest.MaxCompletionTokens = 10
 	} else if strings.Contains(model, "thinking") {
@@ -373,12 +493,6 @@ func buildTestRequest(model string) *dto.GeneralOpenAIRequest {
 		testRequest.MaxTokens = 10
 	}
 
-	testMessage := dto.Message{
-		Role:    "user",
-		Content: "hi",
-	}
-	testRequest.Model = model
-	testRequest.Messages = append(testRequest.Messages, testMessage)
 	return testRequest
 }
 
@@ -402,8 +516,9 @@ func TestChannel(c *gin.Context) {
 	//	}
 	//}()
 	testModel := c.Query("model")
+	endpointType := c.Query("endpoint_type")
 	tik := time.Now()
-	result := testChannel(channel, testModel)
+	result := testChannel(channel, testModel, endpointType)
 	if result.localErr != nil {
 		c.JSON(http.StatusOK, gin.H{
 			"success": false,
@@ -429,7 +544,6 @@ func TestChannel(c *gin.Context) {
 		"message": "",
 		"time":    consumedTime,
 	})
-	return
 }
 
 var testAllChannelsLock sync.Mutex
@@ -463,7 +577,7 @@ func testAllChannels(notify bool) error {
 		for _, channel := range channels {
 			isChannelEnabled := channel.Status == common.ChannelStatusEnabled
 			tik := time.Now()
-			result := testChannel(channel, "")
+			result := testChannel(channel, "", "")
 			tok := time.Now()
 			milliseconds := tok.Sub(tik).Milliseconds()
 
@@ -477,7 +591,7 @@ func testAllChannels(notify bool) error {
 			// 当错误检查通过,才检查响应时间
 			if common.AutomaticDisableChannelEnabled && !shouldBanChannel {
 				if milliseconds > disableThreshold {
-					err := errors.New(fmt.Sprintf("响应时间 %.2fs 超过阈值 %.2fs", float64(milliseconds)/1000.0, float64(disableThreshold)/1000.0))
+					err := fmt.Errorf("响应时间 %.2fs 超过阈值 %.2fs", float64(milliseconds)/1000.0, float64(disableThreshold)/1000.0)
 					newAPIError = types.NewOpenAIError(err, types.ErrorCodeChannelResponseTimeExceeded, http.StatusRequestTimeout)
 					shouldBanChannel = true
 				}
@@ -514,7 +628,6 @@ func TestAllChannels(c *gin.Context) {
 		"success": true,
 		"message": "",
 	})
-	return
 }
 
 var autoTestChannelsOnce sync.Once

+ 27 - 1
web/src/components/table/channels/modals/ModelTestModal.jsx

@@ -25,6 +25,7 @@ import {
   Table,
   Tag,
   Typography,
+  Select,
 } from '@douyinfe/semi-ui';
 import { IconSearch } from '@douyinfe/semi-icons';
 import { copy, showError, showInfo, showSuccess } from '../../../../helpers';
@@ -45,6 +46,8 @@ const ModelTestModal = ({
   testChannel,
   modelTablePage,
   setModelTablePage,
+  selectedEndpointType,
+  setSelectedEndpointType,
   allSelectingRef,
   isMobile,
   t,
@@ -59,6 +62,17 @@ const ModelTestModal = ({
         )
     : [];
 
+  const endpointTypeOptions = [
+    { value: '', label: t('自动检测') },
+    { value: 'openai', label: 'OpenAI (/v1/chat/completions)' },
+    { value: 'openai-response', label: 'OpenAI Response (/v1/responses)' },
+    { value: 'anthropic', label: 'Anthropic (/v1/messages)' },
+    { value: 'gemini', label: 'Gemini (/v1beta/models/{model}:generateContent)' },
+    { value: 'jina-rerank', label: 'Jina Rerank (/rerank)' },
+    { value: 'image-generation', label: t('图像生成') + ' (/v1/images/generations)' },
+    { value: 'embeddings', label: 'Embeddings (/v1/embeddings)' },
+  ];
+
   const handleCopySelected = () => {
     if (selectedModelKeys.length === 0) {
       showError(t('请先选择模型!'));
@@ -152,7 +166,7 @@ const ModelTestModal = ({
         return (
           <Button
             type='tertiary'
-            onClick={() => testChannel(currentTestChannel, record.model)}
+            onClick={() => testChannel(currentTestChannel, record.model, selectedEndpointType)}
             loading={isTesting}
             size='small'
           >
@@ -228,6 +242,18 @@ const ModelTestModal = ({
     >
       {hasChannel && (
         <div className='model-test-scroll'>
+          {/* 端点类型选择器 */}
+          <div className='flex items-center gap-2 w-full mb-2'>
+            <Typography.Text strong>{t('端点类型')}:</Typography.Text>
+            <Select
+              value={selectedEndpointType}
+              onChange={setSelectedEndpointType}
+              optionList={endpointTypeOptions}
+              className='!w-full'
+              placeholder={t('选择端点类型')}
+            />
+          </div>
+
           {/* 搜索与操作按钮 */}
           <div className='flex items-center justify-end gap-2 w-full mb-2'>
             <Input

+ 11 - 3
web/src/hooks/channels/useChannelsData.jsx

@@ -80,6 +80,7 @@ export const useChannelsData = () => {
   const [selectedModelKeys, setSelectedModelKeys] = useState([]);
   const [isBatchTesting, setIsBatchTesting] = useState(false);
   const [modelTablePage, setModelTablePage] = useState(1);
+  const [selectedEndpointType, setSelectedEndpointType] = useState('');
   
   // 使用 ref 来避免闭包问题,类似旧版实现
   const shouldStopBatchTestingRef = useRef(false);
@@ -691,7 +692,7 @@ export const useChannelsData = () => {
   };
 
   // Test channel - 单个模型测试,参考旧版实现
-  const testChannel = async (record, model) => {
+  const testChannel = async (record, model, endpointType = '') => {
     const testKey = `${record.id}-${model}`;
 
     // 检查是否应该停止批量测试
@@ -703,7 +704,11 @@ export const useChannelsData = () => {
     setTestingModels(prev => new Set([...prev, model]));
 
     try {
-      const res = await API.get(`/api/channel/test/${record.id}?model=${model}`);
+      let url = `/api/channel/test/${record.id}?model=${model}`;
+      if (endpointType) {
+        url += `&endpoint_type=${endpointType}`;
+      }
+      const res = await API.get(url);
 
       // 检查是否在请求期间被停止
       if (shouldStopBatchTestingRef.current && isBatchTesting) {
@@ -820,7 +825,7 @@ export const useChannelsData = () => {
           .replace('${total}', models.length)
         );
 
-        const batchPromises = batch.map(model => testChannel(currentTestChannel, model));
+        const batchPromises = batch.map(model => testChannel(currentTestChannel, model, selectedEndpointType));
         const batchResults = await Promise.allSettled(batchPromises);
         results.push(...batchResults);
 
@@ -902,6 +907,7 @@ export const useChannelsData = () => {
     setTestingModels(new Set());
     setSelectedModelKeys([]);
     setModelTablePage(1);
+    setSelectedEndpointType('');
     // 可选择性保留测试结果,这里不清空以便用户查看
   };
 
@@ -989,6 +995,8 @@ export const useChannelsData = () => {
     isBatchTesting,
     modelTablePage,
     setModelTablePage,
+    selectedEndpointType,
+    setSelectedEndpointType,
     allSelectingRef,
 
     // Multi-key management states