|
|
@@ -31,6 +31,7 @@ import (
|
|
|
|
|
|
"github.com/bytedance/gopkg/util/gopool"
|
|
|
"github.com/samber/lo"
|
|
|
+ "github.com/tidwall/gjson"
|
|
|
|
|
|
"github.com/gin-gonic/gin"
|
|
|
)
|
|
|
@@ -41,7 +42,21 @@ type testResult struct {
|
|
|
newAPIError *types.NewAPIError
|
|
|
}
|
|
|
|
|
|
-func testChannel(channel *model.Channel, testModel string, endpointType string) testResult {
|
|
|
+func normalizeChannelTestEndpoint(channel *model.Channel, modelName, endpointType string) string {
|
|
|
+ normalized := strings.TrimSpace(endpointType)
|
|
|
+ if normalized != "" {
|
|
|
+ return normalized
|
|
|
+ }
|
|
|
+ if strings.HasSuffix(modelName, ratio_setting.CompactModelSuffix) {
|
|
|
+ return string(constant.EndpointTypeOpenAIResponseCompact)
|
|
|
+ }
|
|
|
+ if channel != nil && channel.Type == constant.ChannelTypeCodex {
|
|
|
+ return string(constant.EndpointTypeOpenAIResponse)
|
|
|
+ }
|
|
|
+ return normalized
|
|
|
+}
|
|
|
+
|
|
|
+func testChannel(channel *model.Channel, testModel string, endpointType string, isStream bool) testResult {
|
|
|
tik := time.Now()
|
|
|
var unsupportedTestChannelTypes = []int{
|
|
|
constant.ChannelTypeMidjourney,
|
|
|
@@ -76,6 +91,8 @@ func testChannel(channel *model.Channel, testModel string, endpointType string)
|
|
|
}
|
|
|
}
|
|
|
|
|
|
+ endpointType = normalizeChannelTestEndpoint(channel, testModel, endpointType)
|
|
|
+
|
|
|
requestPath := "/v1/chat/completions"
|
|
|
|
|
|
// 如果指定了端点类型,使用指定的端点类型
|
|
|
@@ -200,7 +217,7 @@ func testChannel(channel *model.Channel, testModel string, endpointType string)
|
|
|
}
|
|
|
}
|
|
|
|
|
|
- request := buildTestRequest(testModel, endpointType, channel)
|
|
|
+ request := buildTestRequest(testModel, endpointType, channel, isStream)
|
|
|
|
|
|
info, err := relaycommon.GenRelayInfo(c, relayFormat, request, nil)
|
|
|
|
|
|
@@ -418,16 +435,16 @@ func testChannel(channel *model.Channel, testModel string, endpointType string)
|
|
|
newAPIError: respErr,
|
|
|
}
|
|
|
}
|
|
|
- if usageA == nil {
|
|
|
+ usage, usageErr := coerceTestUsage(usageA, isStream, info.GetEstimatePromptTokens())
|
|
|
+ if usageErr != nil {
|
|
|
return testResult{
|
|
|
context: c,
|
|
|
- localErr: errors.New("usage is nil"),
|
|
|
- newAPIError: types.NewOpenAIError(errors.New("usage is nil"), types.ErrorCodeBadResponseBody, http.StatusInternalServerError),
|
|
|
+ localErr: usageErr,
|
|
|
+ newAPIError: types.NewOpenAIError(usageErr, types.ErrorCodeBadResponseBody, http.StatusInternalServerError),
|
|
|
}
|
|
|
}
|
|
|
- usage := usageA.(*dto.Usage)
|
|
|
result := w.Result()
|
|
|
- respBody, err := io.ReadAll(result.Body)
|
|
|
+ respBody, err := readTestResponseBody(result.Body, isStream)
|
|
|
if err != nil {
|
|
|
return testResult{
|
|
|
context: c,
|
|
|
@@ -435,6 +452,13 @@ func testChannel(channel *model.Channel, testModel string, endpointType string)
|
|
|
newAPIError: types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError),
|
|
|
}
|
|
|
}
|
|
|
+ if bodyErr := detectErrorFromTestResponseBody(respBody); bodyErr != nil {
|
|
|
+ return testResult{
|
|
|
+ context: c,
|
|
|
+ localErr: bodyErr,
|
|
|
+ newAPIError: types.NewOpenAIError(bodyErr, types.ErrorCodeBadResponseBody, http.StatusInternalServerError),
|
|
|
+ }
|
|
|
+ }
|
|
|
info.SetEstimatePromptTokens(usage.PromptTokens)
|
|
|
|
|
|
quota := 0
|
|
|
@@ -473,7 +497,101 @@ func testChannel(channel *model.Channel, testModel string, endpointType string)
|
|
|
}
|
|
|
}
|
|
|
|
|
|
-func buildTestRequest(model string, endpointType string, channel *model.Channel) dto.Request {
|
|
|
+func coerceTestUsage(usageAny any, isStream bool, estimatePromptTokens int) (*dto.Usage, error) {
|
|
|
+ switch u := usageAny.(type) {
|
|
|
+ case *dto.Usage:
|
|
|
+ return u, nil
|
|
|
+ case dto.Usage:
|
|
|
+ return &u, nil
|
|
|
+ case nil:
|
|
|
+ if !isStream {
|
|
|
+ return nil, errors.New("usage is nil")
|
|
|
+ }
|
|
|
+ usage := &dto.Usage{
|
|
|
+ PromptTokens: estimatePromptTokens,
|
|
|
+ }
|
|
|
+ usage.TotalTokens = usage.PromptTokens
|
|
|
+ return usage, nil
|
|
|
+ default:
|
|
|
+ if !isStream {
|
|
|
+ return nil, fmt.Errorf("invalid usage type: %T", usageAny)
|
|
|
+ }
|
|
|
+ usage := &dto.Usage{
|
|
|
+ PromptTokens: estimatePromptTokens,
|
|
|
+ }
|
|
|
+ usage.TotalTokens = usage.PromptTokens
|
|
|
+ return usage, nil
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+func readTestResponseBody(body io.ReadCloser, isStream bool) ([]byte, error) {
|
|
|
+ defer func() { _ = body.Close() }()
|
|
|
+ const maxStreamLogBytes = 8 << 10
|
|
|
+ if isStream {
|
|
|
+ return io.ReadAll(io.LimitReader(body, maxStreamLogBytes))
|
|
|
+ }
|
|
|
+ return io.ReadAll(body)
|
|
|
+}
|
|
|
+
|
|
|
+func detectErrorFromTestResponseBody(respBody []byte) error {
|
|
|
+ b := bytes.TrimSpace(respBody)
|
|
|
+ if len(b) == 0 {
|
|
|
+ return nil
|
|
|
+ }
|
|
|
+ if message := detectErrorMessageFromJSONBytes(b); message != "" {
|
|
|
+ return fmt.Errorf("upstream error: %s", message)
|
|
|
+ }
|
|
|
+
|
|
|
+ for _, line := range bytes.Split(b, []byte{'\n'}) {
|
|
|
+ line = bytes.TrimSpace(line)
|
|
|
+ if len(line) == 0 {
|
|
|
+ continue
|
|
|
+ }
|
|
|
+ if !bytes.HasPrefix(line, []byte("data:")) {
|
|
|
+ continue
|
|
|
+ }
|
|
|
+ payload := bytes.TrimSpace(bytes.TrimPrefix(line, []byte("data:")))
|
|
|
+ if len(payload) == 0 || bytes.Equal(payload, []byte("[DONE]")) {
|
|
|
+ continue
|
|
|
+ }
|
|
|
+ if message := detectErrorMessageFromJSONBytes(payload); message != "" {
|
|
|
+ return fmt.Errorf("upstream error: %s", message)
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ return nil
|
|
|
+}
|
|
|
+
|
|
|
+func detectErrorMessageFromJSONBytes(jsonBytes []byte) string {
|
|
|
+ if len(jsonBytes) == 0 {
|
|
|
+ return ""
|
|
|
+ }
|
|
|
+ if jsonBytes[0] != '{' && jsonBytes[0] != '[' {
|
|
|
+ return ""
|
|
|
+ }
|
|
|
+ errVal := gjson.GetBytes(jsonBytes, "error")
|
|
|
+ if !errVal.Exists() || errVal.Type == gjson.Null {
|
|
|
+ return ""
|
|
|
+ }
|
|
|
+
|
|
|
+ message := gjson.GetBytes(jsonBytes, "error.message").String()
|
|
|
+ if message == "" {
|
|
|
+ message = gjson.GetBytes(jsonBytes, "error.error.message").String()
|
|
|
+ }
|
|
|
+ if message == "" && errVal.Type == gjson.String {
|
|
|
+ message = errVal.String()
|
|
|
+ }
|
|
|
+ if message == "" {
|
|
|
+ message = errVal.Raw
|
|
|
+ }
|
|
|
+ message = strings.TrimSpace(message)
|
|
|
+ if message == "" {
|
|
|
+ return "upstream returned error payload"
|
|
|
+ }
|
|
|
+ return message
|
|
|
+}
|
|
|
+
|
|
|
+func buildTestRequest(model string, endpointType string, channel *model.Channel, isStream bool) dto.Request {
|
|
|
testResponsesInput := json.RawMessage(`[{"role":"user","content":"hi"}]`)
|
|
|
|
|
|
// 根据端点类型构建不同的测试请求
|
|
|
@@ -504,8 +622,9 @@ func buildTestRequest(model string, endpointType string, channel *model.Channel)
|
|
|
case constant.EndpointTypeOpenAIResponse:
|
|
|
// 返回 OpenAIResponsesRequest
|
|
|
return &dto.OpenAIResponsesRequest{
|
|
|
- Model: model,
|
|
|
- Input: json.RawMessage(`[{"role":"user","content":"hi"}]`),
|
|
|
+ Model: model,
|
|
|
+ Input: json.RawMessage(`[{"role":"user","content":"hi"}]`),
|
|
|
+ Stream: isStream,
|
|
|
}
|
|
|
case constant.EndpointTypeOpenAIResponseCompact:
|
|
|
// 返回 OpenAIResponsesCompactionRequest
|
|
|
@@ -519,9 +638,9 @@ func buildTestRequest(model string, endpointType string, channel *model.Channel)
|
|
|
if constant.EndpointType(endpointType) == constant.EndpointTypeGemini {
|
|
|
maxTokens = 3000
|
|
|
}
|
|
|
- return &dto.GeneralOpenAIRequest{
|
|
|
+ req := &dto.GeneralOpenAIRequest{
|
|
|
Model: model,
|
|
|
- Stream: false,
|
|
|
+ Stream: isStream,
|
|
|
Messages: []dto.Message{
|
|
|
{
|
|
|
Role: "user",
|
|
|
@@ -530,6 +649,10 @@ func buildTestRequest(model string, endpointType string, channel *model.Channel)
|
|
|
},
|
|
|
MaxTokens: maxTokens,
|
|
|
}
|
|
|
+ if isStream {
|
|
|
+ req.StreamOptions = &dto.StreamOptions{IncludeUsage: true}
|
|
|
+ }
|
|
|
+ return req
|
|
|
}
|
|
|
}
|
|
|
|
|
|
@@ -565,15 +688,16 @@ func buildTestRequest(model string, endpointType string, channel *model.Channel)
|
|
|
// Responses-only models (e.g. codex series)
|
|
|
if strings.Contains(strings.ToLower(model), "codex") {
|
|
|
return &dto.OpenAIResponsesRequest{
|
|
|
- Model: model,
|
|
|
- Input: json.RawMessage(`[{"role":"user","content":"hi"}]`),
|
|
|
+ Model: model,
|
|
|
+ Input: json.RawMessage(`[{"role":"user","content":"hi"}]`),
|
|
|
+ Stream: isStream,
|
|
|
}
|
|
|
}
|
|
|
|
|
|
// Chat/Completion 请求 - 返回 GeneralOpenAIRequest
|
|
|
testRequest := &dto.GeneralOpenAIRequest{
|
|
|
Model: model,
|
|
|
- Stream: false,
|
|
|
+ Stream: isStream,
|
|
|
Messages: []dto.Message{
|
|
|
{
|
|
|
Role: "user",
|
|
|
@@ -581,6 +705,9 @@ func buildTestRequest(model string, endpointType string, channel *model.Channel)
|
|
|
},
|
|
|
},
|
|
|
}
|
|
|
+ if isStream {
|
|
|
+ testRequest.StreamOptions = &dto.StreamOptions{IncludeUsage: true}
|
|
|
+ }
|
|
|
|
|
|
if strings.HasPrefix(model, "o") {
|
|
|
testRequest.MaxCompletionTokens = 16
|
|
|
@@ -618,8 +745,9 @@ func TestChannel(c *gin.Context) {
|
|
|
//}()
|
|
|
testModel := c.Query("model")
|
|
|
endpointType := c.Query("endpoint_type")
|
|
|
+ isStream, _ := strconv.ParseBool(c.Query("stream"))
|
|
|
tik := time.Now()
|
|
|
- result := testChannel(channel, testModel, endpointType)
|
|
|
+ result := testChannel(channel, testModel, endpointType, isStream)
|
|
|
if result.localErr != nil {
|
|
|
c.JSON(http.StatusOK, gin.H{
|
|
|
"success": false,
|
|
|
@@ -678,7 +806,7 @@ func testAllChannels(notify bool) error {
|
|
|
for _, channel := range channels {
|
|
|
isChannelEnabled := channel.Status == common.ChannelStatusEnabled
|
|
|
tik := time.Now()
|
|
|
- result := testChannel(channel, "", "")
|
|
|
+ result := testChannel(channel, "", "", false)
|
|
|
tok := time.Now()
|
|
|
milliseconds := tok.Sub(tik).Milliseconds()
|
|
|
|