|
|
@@ -30,22 +30,43 @@ import (
|
|
|
"github.com/gin-gonic/gin"
|
|
|
)
|
|
|
|
|
|
-func testChannel(channel *model.Channel, testModel string) (err error, newAPIError *types.NewAPIError) {
|
|
|
+type testResult struct {
|
|
|
+ context *gin.Context
|
|
|
+ localErr error
|
|
|
+ newAPIError *types.NewAPIError
|
|
|
+}
|
|
|
+
|
|
|
+func testChannel(channel *model.Channel, testModel string) testResult {
|
|
|
tik := time.Now()
|
|
|
if channel.Type == constant.ChannelTypeMidjourney {
|
|
|
- return errors.New("midjourney channel test is not supported"), nil
|
|
|
+ return testResult{
|
|
|
+ localErr: errors.New("midjourney channel test is not supported"),
|
|
|
+ newAPIError: nil,
|
|
|
+ }
|
|
|
}
|
|
|
if channel.Type == constant.ChannelTypeMidjourneyPlus {
|
|
|
- return errors.New("midjourney plus channel test is not supported"), nil
|
|
|
+ return testResult{
|
|
|
+ localErr: errors.New("midjourney plus channel test is not supported"),
|
|
|
+ newAPIError: nil,
|
|
|
+ }
|
|
|
}
|
|
|
if channel.Type == constant.ChannelTypeSunoAPI {
|
|
|
- return errors.New("suno channel test is not supported"), nil
|
|
|
+ return testResult{
|
|
|
+ localErr: errors.New("suno channel test is not supported"),
|
|
|
+ newAPIError: nil,
|
|
|
+ }
|
|
|
}
|
|
|
if channel.Type == constant.ChannelTypeKling {
|
|
|
- return errors.New("kling channel test is not supported"), nil
|
|
|
+ return testResult{
|
|
|
+ localErr: errors.New("kling channel test is not supported"),
|
|
|
+ newAPIError: nil,
|
|
|
+ }
|
|
|
}
|
|
|
if channel.Type == constant.ChannelTypeJimeng {
|
|
|
- return errors.New("jimeng channel test is not supported"), nil
|
|
|
+ return testResult{
|
|
|
+ localErr: errors.New("jimeng channel test is not supported"),
|
|
|
+ newAPIError: nil,
|
|
|
+ }
|
|
|
}
|
|
|
w := httptest.NewRecorder()
|
|
|
c, _ := gin.CreateTestContext(w)
|
|
|
@@ -82,7 +103,10 @@ func testChannel(channel *model.Channel, testModel string) (err error, newAPIErr
|
|
|
|
|
|
cache, err := model.GetUserCache(1)
|
|
|
if err != nil {
|
|
|
- return err, nil
|
|
|
+ return testResult{
|
|
|
+ localErr: err,
|
|
|
+ newAPIError: nil,
|
|
|
+ }
|
|
|
}
|
|
|
cache.WriteContext(c)
|
|
|
|
|
|
@@ -93,20 +117,35 @@ func testChannel(channel *model.Channel, testModel string) (err error, newAPIErr
|
|
|
group, _ := model.GetUserGroup(1, false)
|
|
|
c.Set("group", group)
|
|
|
|
|
|
- middleware.SetupContextForSelectedChannel(c, channel, testModel)
|
|
|
+ newAPIError := middleware.SetupContextForSelectedChannel(c, channel, testModel)
|
|
|
+ if newAPIError != nil {
|
|
|
+ return testResult{
|
|
|
+ context: c,
|
|
|
+ localErr: newAPIError,
|
|
|
+ newAPIError: newAPIError,
|
|
|
+ }
|
|
|
+ }
|
|
|
|
|
|
info := relaycommon.GenRelayInfo(c)
|
|
|
|
|
|
err = helper.ModelMappedHelper(c, info, nil)
|
|
|
if err != nil {
|
|
|
- return err, types.NewError(err, types.ErrorCodeChannelModelMappedError)
|
|
|
+ return testResult{
|
|
|
+ context: c,
|
|
|
+ localErr: err,
|
|
|
+ newAPIError: types.NewError(err, types.ErrorCodeChannelModelMappedError),
|
|
|
+ }
|
|
|
}
|
|
|
testModel = info.UpstreamModelName
|
|
|
|
|
|
apiType, _ := common.ChannelType2APIType(channel.Type)
|
|
|
adaptor := relay.GetAdaptor(apiType)
|
|
|
if adaptor == nil {
|
|
|
- return fmt.Errorf("invalid api type: %d, adaptor is nil", apiType), types.NewError(fmt.Errorf("invalid api type: %d, adaptor is nil", apiType), types.ErrorCodeInvalidApiType)
|
|
|
+ return testResult{
|
|
|
+ context: c,
|
|
|
+ localErr: fmt.Errorf("invalid api type: %d, adaptor is nil", apiType),
|
|
|
+ newAPIError: types.NewError(fmt.Errorf("invalid api type: %d, adaptor is nil", apiType), types.ErrorCodeInvalidApiType),
|
|
|
+ }
|
|
|
}
|
|
|
|
|
|
request := buildTestRequest(testModel)
|
|
|
@@ -117,45 +156,77 @@ func testChannel(channel *model.Channel, testModel string) (err error, newAPIErr
|
|
|
|
|
|
priceData, err := helper.ModelPriceHelper(c, info, 0, int(request.MaxTokens))
|
|
|
if err != nil {
|
|
|
- return err, types.NewError(err, types.ErrorCodeModelPriceError)
|
|
|
+ return testResult{
|
|
|
+ context: c,
|
|
|
+ localErr: err,
|
|
|
+ newAPIError: types.NewError(err, types.ErrorCodeModelPriceError),
|
|
|
+ }
|
|
|
}
|
|
|
|
|
|
adaptor.Init(info)
|
|
|
|
|
|
convertedRequest, err := adaptor.ConvertOpenAIRequest(c, info, request)
|
|
|
if err != nil {
|
|
|
- return err, types.NewError(err, types.ErrorCodeConvertRequestFailed)
|
|
|
+ return testResult{
|
|
|
+ context: c,
|
|
|
+ localErr: err,
|
|
|
+ newAPIError: types.NewError(err, types.ErrorCodeConvertRequestFailed),
|
|
|
+ }
|
|
|
}
|
|
|
jsonData, err := json.Marshal(convertedRequest)
|
|
|
if err != nil {
|
|
|
- return err, types.NewError(err, types.ErrorCodeJsonMarshalFailed)
|
|
|
+ return testResult{
|
|
|
+ context: c,
|
|
|
+ localErr: err,
|
|
|
+ newAPIError: types.NewError(err, types.ErrorCodeJsonMarshalFailed),
|
|
|
+ }
|
|
|
}
|
|
|
requestBody := bytes.NewBuffer(jsonData)
|
|
|
c.Request.Body = io.NopCloser(requestBody)
|
|
|
resp, err := adaptor.DoRequest(c, info, requestBody)
|
|
|
if err != nil {
|
|
|
- return err, types.NewError(err, types.ErrorCodeDoRequestFailed)
|
|
|
+ return testResult{
|
|
|
+ context: c,
|
|
|
+ localErr: err,
|
|
|
+ newAPIError: types.NewError(err, types.ErrorCodeDoRequestFailed),
|
|
|
+ }
|
|
|
}
|
|
|
var httpResp *http.Response
|
|
|
if resp != nil {
|
|
|
httpResp = resp.(*http.Response)
|
|
|
if httpResp.StatusCode != http.StatusOK {
|
|
|
err := service.RelayErrorHandler(httpResp, true)
|
|
|
- return err, types.NewError(err, types.ErrorCodeBadResponse)
|
|
|
+ return testResult{
|
|
|
+ context: c,
|
|
|
+ localErr: err,
|
|
|
+ newAPIError: types.NewError(err, types.ErrorCodeBadResponse),
|
|
|
+ }
|
|
|
}
|
|
|
}
|
|
|
usageA, respErr := adaptor.DoResponse(c, httpResp, info)
|
|
|
if respErr != nil {
|
|
|
- return respErr, respErr
|
|
|
+ return testResult{
|
|
|
+ context: c,
|
|
|
+ localErr: respErr,
|
|
|
+ newAPIError: respErr,
|
|
|
+ }
|
|
|
}
|
|
|
if usageA == nil {
|
|
|
- return errors.New("usage is nil"), types.NewError(errors.New("usage is nil"), types.ErrorCodeBadResponseBody)
|
|
|
+ return testResult{
|
|
|
+ context: c,
|
|
|
+ localErr: errors.New("usage is nil"),
|
|
|
+ newAPIError: types.NewError(errors.New("usage is nil"), types.ErrorCodeBadResponseBody),
|
|
|
+ }
|
|
|
}
|
|
|
usage := usageA.(*dto.Usage)
|
|
|
result := w.Result()
|
|
|
respBody, err := io.ReadAll(result.Body)
|
|
|
if err != nil {
|
|
|
- return err, types.NewError(err, types.ErrorCodeReadResponseBodyFailed)
|
|
|
+ return testResult{
|
|
|
+ context: c,
|
|
|
+ localErr: err,
|
|
|
+ newAPIError: types.NewError(err, types.ErrorCodeReadResponseBodyFailed),
|
|
|
+ }
|
|
|
}
|
|
|
info.PromptTokens = usage.PromptTokens
|
|
|
|
|
|
@@ -188,7 +259,11 @@ func testChannel(channel *model.Channel, testModel string) (err error, newAPIErr
|
|
|
Other: other,
|
|
|
})
|
|
|
common.SysLog(fmt.Sprintf("testing channel #%d, response: \n%s", channel.Id, string(respBody)))
|
|
|
- return nil, nil
|
|
|
+ return testResult{
|
|
|
+ context: c,
|
|
|
+ localErr: nil,
|
|
|
+ newAPIError: nil,
|
|
|
+ }
|
|
|
}
|
|
|
|
|
|
func buildTestRequest(model string) *dto.GeneralOpenAIRequest {
|
|
|
@@ -247,15 +322,23 @@ func TestChannel(c *gin.Context) {
|
|
|
}
|
|
|
testModel := c.Query("model")
|
|
|
tik := time.Now()
|
|
|
- _, newAPIError := testChannel(channel, testModel)
|
|
|
+ result := testChannel(channel, testModel)
|
|
|
+ if result.localErr != nil {
|
|
|
+ c.JSON(http.StatusOK, gin.H{
|
|
|
+ "success": false,
|
|
|
+ "message": result.localErr.Error(),
|
|
|
+ "time": 0.0,
|
|
|
+ })
|
|
|
+ return
|
|
|
+ }
|
|
|
tok := time.Now()
|
|
|
milliseconds := tok.Sub(tik).Milliseconds()
|
|
|
go channel.UpdateResponseTime(milliseconds)
|
|
|
consumedTime := float64(milliseconds) / 1000.0
|
|
|
- if newAPIError != nil {
|
|
|
+ if result.newAPIError != nil {
|
|
|
c.JSON(http.StatusOK, gin.H{
|
|
|
"success": false,
|
|
|
- "message": newAPIError.Error(),
|
|
|
+ "message": result.newAPIError.Error(),
|
|
|
"time": consumedTime,
|
|
|
})
|
|
|
return
|
|
|
@@ -280,9 +363,9 @@ func testAllChannels(notify bool) error {
|
|
|
}
|
|
|
testAllChannelsRunning = true
|
|
|
testAllChannelsLock.Unlock()
|
|
|
- channels, err := model.GetAllChannels(0, 0, true, false)
|
|
|
- if err != nil {
|
|
|
- return err
|
|
|
+ channels, getChannelErr := model.GetAllChannels(0, 0, true, false)
|
|
|
+ if getChannelErr != nil {
|
|
|
+ return getChannelErr
|
|
|
}
|
|
|
var disableThreshold = int64(common.ChannelDisableThreshold * 1000)
|
|
|
if disableThreshold == 0 {
|
|
|
@@ -299,30 +382,34 @@ func testAllChannels(notify bool) error {
|
|
|
for _, channel := range channels {
|
|
|
isChannelEnabled := channel.Status == common.ChannelStatusEnabled
|
|
|
tik := time.Now()
|
|
|
- err, newAPIError := testChannel(channel, "")
|
|
|
+ result := testChannel(channel, "")
|
|
|
tok := time.Now()
|
|
|
milliseconds := tok.Sub(tik).Milliseconds()
|
|
|
|
|
|
shouldBanChannel := false
|
|
|
-
|
|
|
+ newAPIError := result.newAPIError
|
|
|
// request error disables the channel
|
|
|
- if err != nil {
|
|
|
- shouldBanChannel = service.ShouldDisableChannel(channel.Type, newAPIError)
|
|
|
+ if newAPIError != nil {
|
|
|
+ shouldBanChannel = service.ShouldDisableChannel(channel.Type, result.newAPIError)
|
|
|
}
|
|
|
|
|
|
- if milliseconds > disableThreshold {
|
|
|
- err = errors.New(fmt.Sprintf("响应时间 %.2fs 超过阈值 %.2fs", float64(milliseconds)/1000.0, float64(disableThreshold)/1000.0))
|
|
|
- shouldBanChannel = true
|
|
|
+ // 当错误检查通过,才检查响应时间
|
|
|
+ if common.AutomaticDisableChannelEnabled && !shouldBanChannel {
|
|
|
+ if milliseconds > disableThreshold {
|
|
|
+ err := errors.New(fmt.Sprintf("响应时间 %.2fs 超过阈值 %.2fs", float64(milliseconds)/1000.0, float64(disableThreshold)/1000.0))
|
|
|
+ newAPIError = types.NewError(err, types.ErrorCodeChannelResponseTimeExceeded)
|
|
|
+ shouldBanChannel = true
|
|
|
+ }
|
|
|
}
|
|
|
|
|
|
// disable channel
|
|
|
if isChannelEnabled && shouldBanChannel && channel.GetAutoBan() {
|
|
|
- service.DisableChannel(channel.Id, channel.Name, err.Error())
|
|
|
+ go processChannelError(result.context, *types.NewChannelError(channel.Id, channel.Type, channel.Name, channel.ChannelInfo.IsMultiKey, common.GetContextKeyString(result.context, constant.ContextKeyChannelKey), channel.GetAutoBan()), newAPIError)
|
|
|
}
|
|
|
|
|
|
// enable channel
|
|
|
- if !isChannelEnabled && service.ShouldEnableChannel(err, newAPIError, channel.Status) {
|
|
|
- service.EnableChannel(channel.Id, channel.Name)
|
|
|
+ if !isChannelEnabled && service.ShouldEnableChannel(newAPIError, channel.Status) {
|
|
|
+ service.EnableChannel(channel.Id, common.GetContextKeyString(result.context, constant.ContextKeyChannelKey), channel.Name)
|
|
|
}
|
|
|
|
|
|
channel.UpdateResponseTime(milliseconds)
|