channel-test.go 21 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702
  1. package controller
  2. import (
  3. "bytes"
  4. "encoding/json"
  5. "errors"
  6. "fmt"
  7. "io"
  8. "math"
  9. "net/http"
  10. "net/http/httptest"
  11. "net/url"
  12. "strconv"
  13. "strings"
  14. "sync"
  15. "time"
  16. "github.com/QuantumNous/new-api/common"
  17. "github.com/QuantumNous/new-api/constant"
  18. "github.com/QuantumNous/new-api/dto"
  19. "github.com/QuantumNous/new-api/middleware"
  20. "github.com/QuantumNous/new-api/model"
  21. "github.com/QuantumNous/new-api/relay"
  22. relaycommon "github.com/QuantumNous/new-api/relay/common"
  23. relayconstant "github.com/QuantumNous/new-api/relay/constant"
  24. "github.com/QuantumNous/new-api/relay/helper"
  25. "github.com/QuantumNous/new-api/service"
  26. "github.com/QuantumNous/new-api/setting/operation_setting"
  27. "github.com/QuantumNous/new-api/types"
  28. "github.com/bytedance/gopkg/util/gopool"
  29. "github.com/samber/lo"
  30. "github.com/gin-gonic/gin"
  31. )
  32. type testResult struct {
  33. context *gin.Context
  34. localErr error
  35. newAPIError *types.NewAPIError
  36. }
  37. func testChannel(channel *model.Channel, testModel string, endpointType string) testResult {
  38. tik := time.Now()
  39. var unsupportedTestChannelTypes = []int{
  40. constant.ChannelTypeMidjourney,
  41. constant.ChannelTypeMidjourneyPlus,
  42. constant.ChannelTypeSunoAPI,
  43. constant.ChannelTypeKling,
  44. constant.ChannelTypeJimeng,
  45. constant.ChannelTypeDoubaoVideo,
  46. constant.ChannelTypeVidu,
  47. }
  48. if lo.Contains(unsupportedTestChannelTypes, channel.Type) {
  49. channelTypeName := constant.GetChannelTypeName(channel.Type)
  50. return testResult{
  51. localErr: fmt.Errorf("%s channel test is not supported", channelTypeName),
  52. }
  53. }
  54. w := httptest.NewRecorder()
  55. c, _ := gin.CreateTestContext(w)
  56. testModel = strings.TrimSpace(testModel)
  57. if testModel == "" {
  58. if channel.TestModel != nil && *channel.TestModel != "" {
  59. testModel = strings.TrimSpace(*channel.TestModel)
  60. } else {
  61. models := channel.GetModels()
  62. if len(models) > 0 {
  63. testModel = strings.TrimSpace(models[0])
  64. }
  65. if testModel == "" {
  66. testModel = "gpt-4o-mini"
  67. }
  68. }
  69. }
  70. requestPath := "/v1/chat/completions"
  71. // 如果指定了端点类型,使用指定的端点类型
  72. if endpointType != "" {
  73. if endpointInfo, ok := common.GetDefaultEndpointInfo(constant.EndpointType(endpointType)); ok {
  74. requestPath = endpointInfo.Path
  75. }
  76. } else {
  77. // 如果没有指定端点类型,使用原有的自动检测逻辑
  78. if strings.Contains(strings.ToLower(testModel), "rerank") {
  79. requestPath = "/v1/rerank"
  80. }
  81. // 先判断是否为 Embedding 模型
  82. if strings.Contains(strings.ToLower(testModel), "embedding") ||
  83. strings.HasPrefix(testModel, "m3e") || // m3e 系列模型
  84. strings.Contains(testModel, "bge-") || // bge 系列模型
  85. strings.Contains(testModel, "embed") ||
  86. channel.Type == constant.ChannelTypeMokaAI { // 其他 embedding 模型
  87. requestPath = "/v1/embeddings" // 修改请求路径
  88. }
  89. // VolcEngine 图像生成模型
  90. if channel.Type == constant.ChannelTypeVolcEngine && strings.Contains(testModel, "seedream") {
  91. requestPath = "/v1/images/generations"
  92. }
  93. // responses-only models
  94. if strings.Contains(strings.ToLower(testModel), "codex") {
  95. requestPath = "/v1/responses"
  96. }
  97. }
  98. c.Request = &http.Request{
  99. Method: "POST",
  100. URL: &url.URL{Path: requestPath}, // 使用动态路径
  101. Body: nil,
  102. Header: make(http.Header),
  103. }
  104. cache, err := model.GetUserCache(1)
  105. if err != nil {
  106. return testResult{
  107. localErr: err,
  108. newAPIError: nil,
  109. }
  110. }
  111. cache.WriteContext(c)
  112. //c.Request.Header.Set("Authorization", "Bearer "+channel.Key)
  113. c.Request.Header.Set("Content-Type", "application/json")
  114. c.Set("channel", channel.Type)
  115. c.Set("base_url", channel.GetBaseURL())
  116. group, _ := model.GetUserGroup(1, false)
  117. c.Set("group", group)
  118. newAPIError := middleware.SetupContextForSelectedChannel(c, channel, testModel)
  119. if newAPIError != nil {
  120. return testResult{
  121. context: c,
  122. localErr: newAPIError,
  123. newAPIError: newAPIError,
  124. }
  125. }
  126. // Determine relay format based on endpoint type or request path
  127. var relayFormat types.RelayFormat
  128. if endpointType != "" {
  129. // 根据指定的端点类型设置 relayFormat
  130. switch constant.EndpointType(endpointType) {
  131. case constant.EndpointTypeOpenAI:
  132. relayFormat = types.RelayFormatOpenAI
  133. case constant.EndpointTypeOpenAIResponse:
  134. relayFormat = types.RelayFormatOpenAIResponses
  135. case constant.EndpointTypeAnthropic:
  136. relayFormat = types.RelayFormatClaude
  137. case constant.EndpointTypeGemini:
  138. relayFormat = types.RelayFormatGemini
  139. case constant.EndpointTypeJinaRerank:
  140. relayFormat = types.RelayFormatRerank
  141. case constant.EndpointTypeImageGeneration:
  142. relayFormat = types.RelayFormatOpenAIImage
  143. case constant.EndpointTypeEmbeddings:
  144. relayFormat = types.RelayFormatEmbedding
  145. default:
  146. relayFormat = types.RelayFormatOpenAI
  147. }
  148. } else {
  149. // 根据请求路径自动检测
  150. relayFormat = types.RelayFormatOpenAI
  151. if c.Request.URL.Path == "/v1/embeddings" {
  152. relayFormat = types.RelayFormatEmbedding
  153. }
  154. if c.Request.URL.Path == "/v1/images/generations" {
  155. relayFormat = types.RelayFormatOpenAIImage
  156. }
  157. if c.Request.URL.Path == "/v1/messages" {
  158. relayFormat = types.RelayFormatClaude
  159. }
  160. if strings.Contains(c.Request.URL.Path, "/v1beta/models") {
  161. relayFormat = types.RelayFormatGemini
  162. }
  163. if c.Request.URL.Path == "/v1/rerank" || c.Request.URL.Path == "/rerank" {
  164. relayFormat = types.RelayFormatRerank
  165. }
  166. if c.Request.URL.Path == "/v1/responses" {
  167. relayFormat = types.RelayFormatOpenAIResponses
  168. }
  169. }
  170. request := buildTestRequest(testModel, endpointType, channel)
  171. info, err := relaycommon.GenRelayInfo(c, relayFormat, request, nil)
  172. if err != nil {
  173. return testResult{
  174. context: c,
  175. localErr: err,
  176. newAPIError: types.NewError(err, types.ErrorCodeGenRelayInfoFailed),
  177. }
  178. }
  179. info.IsChannelTest = true
  180. info.InitChannelMeta(c)
  181. err = helper.ModelMappedHelper(c, info, request)
  182. if err != nil {
  183. return testResult{
  184. context: c,
  185. localErr: err,
  186. newAPIError: types.NewError(err, types.ErrorCodeChannelModelMappedError),
  187. }
  188. }
  189. testModel = info.UpstreamModelName
  190. // 更新请求中的模型名称
  191. request.SetModelName(testModel)
  192. apiType, _ := common.ChannelType2APIType(channel.Type)
  193. adaptor := relay.GetAdaptor(apiType)
  194. if adaptor == nil {
  195. return testResult{
  196. context: c,
  197. localErr: fmt.Errorf("invalid api type: %d, adaptor is nil", apiType),
  198. newAPIError: types.NewError(fmt.Errorf("invalid api type: %d, adaptor is nil", apiType), types.ErrorCodeInvalidApiType),
  199. }
  200. }
  201. //// 创建一个用于日志的 info 副本,移除 ApiKey
  202. //logInfo := info
  203. //logInfo.ApiKey = ""
  204. common.SysLog(fmt.Sprintf("testing channel %d with model %s , info %+v ", channel.Id, testModel, info.ToString()))
  205. priceData, err := helper.ModelPriceHelper(c, info, 0, request.GetTokenCountMeta())
  206. if err != nil {
  207. return testResult{
  208. context: c,
  209. localErr: err,
  210. newAPIError: types.NewError(err, types.ErrorCodeModelPriceError),
  211. }
  212. }
  213. adaptor.Init(info)
  214. var convertedRequest any
  215. // 根据 RelayMode 选择正确的转换函数
  216. switch info.RelayMode {
  217. case relayconstant.RelayModeEmbeddings:
  218. // Embedding 请求 - request 已经是正确的类型
  219. if embeddingReq, ok := request.(*dto.EmbeddingRequest); ok {
  220. convertedRequest, err = adaptor.ConvertEmbeddingRequest(c, info, *embeddingReq)
  221. } else {
  222. return testResult{
  223. context: c,
  224. localErr: errors.New("invalid embedding request type"),
  225. newAPIError: types.NewError(errors.New("invalid embedding request type"), types.ErrorCodeConvertRequestFailed),
  226. }
  227. }
  228. case relayconstant.RelayModeImagesGenerations:
  229. // 图像生成请求 - request 已经是正确的类型
  230. if imageReq, ok := request.(*dto.ImageRequest); ok {
  231. convertedRequest, err = adaptor.ConvertImageRequest(c, info, *imageReq)
  232. } else {
  233. return testResult{
  234. context: c,
  235. localErr: errors.New("invalid image request type"),
  236. newAPIError: types.NewError(errors.New("invalid image request type"), types.ErrorCodeConvertRequestFailed),
  237. }
  238. }
  239. case relayconstant.RelayModeRerank:
  240. // Rerank 请求 - request 已经是正确的类型
  241. if rerankReq, ok := request.(*dto.RerankRequest); ok {
  242. convertedRequest, err = adaptor.ConvertRerankRequest(c, info.RelayMode, *rerankReq)
  243. } else {
  244. return testResult{
  245. context: c,
  246. localErr: errors.New("invalid rerank request type"),
  247. newAPIError: types.NewError(errors.New("invalid rerank request type"), types.ErrorCodeConvertRequestFailed),
  248. }
  249. }
  250. case relayconstant.RelayModeResponses:
  251. // Response 请求 - request 已经是正确的类型
  252. if responseReq, ok := request.(*dto.OpenAIResponsesRequest); ok {
  253. convertedRequest, err = adaptor.ConvertOpenAIResponsesRequest(c, info, *responseReq)
  254. } else {
  255. return testResult{
  256. context: c,
  257. localErr: errors.New("invalid response request type"),
  258. newAPIError: types.NewError(errors.New("invalid response request type"), types.ErrorCodeConvertRequestFailed),
  259. }
  260. }
  261. default:
  262. // Chat/Completion 等其他请求类型
  263. if generalReq, ok := request.(*dto.GeneralOpenAIRequest); ok {
  264. convertedRequest, err = adaptor.ConvertOpenAIRequest(c, info, generalReq)
  265. } else {
  266. return testResult{
  267. context: c,
  268. localErr: errors.New("invalid general request type"),
  269. newAPIError: types.NewError(errors.New("invalid general request type"), types.ErrorCodeConvertRequestFailed),
  270. }
  271. }
  272. }
  273. if err != nil {
  274. return testResult{
  275. context: c,
  276. localErr: err,
  277. newAPIError: types.NewError(err, types.ErrorCodeConvertRequestFailed),
  278. }
  279. }
  280. jsonData, err := json.Marshal(convertedRequest)
  281. if err != nil {
  282. return testResult{
  283. context: c,
  284. localErr: err,
  285. newAPIError: types.NewError(err, types.ErrorCodeJsonMarshalFailed),
  286. }
  287. }
  288. //jsonData, err = relaycommon.RemoveDisabledFields(jsonData, info.ChannelOtherSettings)
  289. //if err != nil {
  290. // return testResult{
  291. // context: c,
  292. // localErr: err,
  293. // newAPIError: types.NewError(err, types.ErrorCodeConvertRequestFailed),
  294. // }
  295. //}
  296. if len(info.ParamOverride) > 0 {
  297. jsonData, err = relaycommon.ApplyParamOverride(jsonData, info.ParamOverride, relaycommon.BuildParamOverrideContext(info))
  298. if err != nil {
  299. return testResult{
  300. context: c,
  301. localErr: err,
  302. newAPIError: types.NewError(err, types.ErrorCodeChannelParamOverrideInvalid),
  303. }
  304. }
  305. }
  306. requestBody := bytes.NewBuffer(jsonData)
  307. c.Request.Body = io.NopCloser(bytes.NewBuffer(jsonData))
  308. resp, err := adaptor.DoRequest(c, info, requestBody)
  309. if err != nil {
  310. return testResult{
  311. context: c,
  312. localErr: err,
  313. newAPIError: types.NewOpenAIError(err, types.ErrorCodeDoRequestFailed, http.StatusInternalServerError),
  314. }
  315. }
  316. var httpResp *http.Response
  317. if resp != nil {
  318. httpResp = resp.(*http.Response)
  319. if httpResp.StatusCode != http.StatusOK {
  320. err := service.RelayErrorHandler(c.Request.Context(), httpResp, true)
  321. common.SysError(fmt.Sprintf(
  322. "channel test bad response: channel_id=%d name=%s type=%d model=%s endpoint_type=%s status=%d err=%v",
  323. channel.Id,
  324. channel.Name,
  325. channel.Type,
  326. testModel,
  327. endpointType,
  328. httpResp.StatusCode,
  329. err,
  330. ))
  331. return testResult{
  332. context: c,
  333. localErr: err,
  334. newAPIError: types.NewOpenAIError(err, types.ErrorCodeBadResponse, http.StatusInternalServerError),
  335. }
  336. }
  337. }
  338. usageA, respErr := adaptor.DoResponse(c, httpResp, info)
  339. if respErr != nil {
  340. return testResult{
  341. context: c,
  342. localErr: respErr,
  343. newAPIError: respErr,
  344. }
  345. }
  346. if usageA == nil {
  347. return testResult{
  348. context: c,
  349. localErr: errors.New("usage is nil"),
  350. newAPIError: types.NewOpenAIError(errors.New("usage is nil"), types.ErrorCodeBadResponseBody, http.StatusInternalServerError),
  351. }
  352. }
  353. usage := usageA.(*dto.Usage)
  354. result := w.Result()
  355. respBody, err := io.ReadAll(result.Body)
  356. if err != nil {
  357. return testResult{
  358. context: c,
  359. localErr: err,
  360. newAPIError: types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError),
  361. }
  362. }
  363. info.SetEstimatePromptTokens(usage.PromptTokens)
  364. quota := 0
  365. if !priceData.UsePrice {
  366. quota = usage.PromptTokens + int(math.Round(float64(usage.CompletionTokens)*priceData.CompletionRatio))
  367. quota = int(math.Round(float64(quota) * priceData.ModelRatio))
  368. if priceData.ModelRatio != 0 && quota <= 0 {
  369. quota = 1
  370. }
  371. } else {
  372. quota = int(priceData.ModelPrice * common.QuotaPerUnit)
  373. }
  374. tok := time.Now()
  375. milliseconds := tok.Sub(tik).Milliseconds()
  376. consumedTime := float64(milliseconds) / 1000.0
  377. other := service.GenerateTextOtherInfo(c, info, priceData.ModelRatio, priceData.GroupRatioInfo.GroupRatio, priceData.CompletionRatio,
  378. usage.PromptTokensDetails.CachedTokens, priceData.CacheRatio, priceData.ModelPrice, priceData.GroupRatioInfo.GroupSpecialRatio)
  379. model.RecordConsumeLog(c, 1, model.RecordConsumeLogParams{
  380. ChannelId: channel.Id,
  381. PromptTokens: usage.PromptTokens,
  382. CompletionTokens: usage.CompletionTokens,
  383. ModelName: info.OriginModelName,
  384. TokenName: "模型测试",
  385. Quota: quota,
  386. Content: "模型测试",
  387. UseTimeSeconds: int(consumedTime),
  388. IsStream: info.IsStream,
  389. Group: info.UsingGroup,
  390. Other: other,
  391. })
  392. common.SysLog(fmt.Sprintf("testing channel #%d, response: \n%s", channel.Id, string(respBody)))
  393. return testResult{
  394. context: c,
  395. localErr: nil,
  396. newAPIError: nil,
  397. }
  398. }
  399. func buildTestRequest(model string, endpointType string, channel *model.Channel) dto.Request {
  400. // 根据端点类型构建不同的测试请求
  401. if endpointType != "" {
  402. switch constant.EndpointType(endpointType) {
  403. case constant.EndpointTypeEmbeddings:
  404. // 返回 EmbeddingRequest
  405. return &dto.EmbeddingRequest{
  406. Model: model,
  407. Input: []any{"hello world"},
  408. }
  409. case constant.EndpointTypeImageGeneration:
  410. // 返回 ImageRequest
  411. return &dto.ImageRequest{
  412. Model: model,
  413. Prompt: "a cute cat",
  414. N: 1,
  415. Size: "1024x1024",
  416. }
  417. case constant.EndpointTypeJinaRerank:
  418. // 返回 RerankRequest
  419. return &dto.RerankRequest{
  420. Model: model,
  421. Query: "What is Deep Learning?",
  422. Documents: []any{"Deep Learning is a subset of machine learning.", "Machine learning is a field of artificial intelligence."},
  423. TopN: 2,
  424. }
  425. case constant.EndpointTypeOpenAIResponse:
  426. // 返回 OpenAIResponsesRequest
  427. return &dto.OpenAIResponsesRequest{
  428. Model: model,
  429. Input: json.RawMessage(`[{"role":"user","content":"hi"}]`),
  430. }
  431. case constant.EndpointTypeAnthropic, constant.EndpointTypeGemini, constant.EndpointTypeOpenAI:
  432. // 返回 GeneralOpenAIRequest
  433. maxTokens := uint(16)
  434. if constant.EndpointType(endpointType) == constant.EndpointTypeGemini {
  435. maxTokens = 3000
  436. }
  437. return &dto.GeneralOpenAIRequest{
  438. Model: model,
  439. Stream: false,
  440. Messages: []dto.Message{
  441. {
  442. Role: "user",
  443. Content: "hi",
  444. },
  445. },
  446. MaxTokens: maxTokens,
  447. }
  448. }
  449. }
  450. // 自动检测逻辑(保持原有行为)
  451. if strings.Contains(strings.ToLower(model), "rerank") {
  452. return &dto.RerankRequest{
  453. Model: model,
  454. Query: "What is Deep Learning?",
  455. Documents: []any{"Deep Learning is a subset of machine learning.", "Machine learning is a field of artificial intelligence."},
  456. TopN: 2,
  457. }
  458. }
  459. // 先判断是否为 Embedding 模型
  460. if strings.Contains(strings.ToLower(model), "embedding") ||
  461. strings.HasPrefix(model, "m3e") ||
  462. strings.Contains(model, "bge-") {
  463. // 返回 EmbeddingRequest
  464. return &dto.EmbeddingRequest{
  465. Model: model,
  466. Input: []any{"hello world"},
  467. }
  468. }
  469. // Responses-only models (e.g. codex series)
  470. if strings.Contains(strings.ToLower(model), "codex") {
  471. return &dto.OpenAIResponsesRequest{
  472. Model: model,
  473. Input: json.RawMessage(`[{"role":"user","content":"hi"}]`),
  474. }
  475. }
  476. // Chat/Completion 请求 - 返回 GeneralOpenAIRequest
  477. testRequest := &dto.GeneralOpenAIRequest{
  478. Model: model,
  479. Stream: false,
  480. Messages: []dto.Message{
  481. {
  482. Role: "user",
  483. Content: "hi",
  484. },
  485. },
  486. }
  487. if strings.HasPrefix(model, "o") {
  488. testRequest.MaxCompletionTokens = 16
  489. } else if strings.Contains(model, "thinking") {
  490. if !strings.Contains(model, "claude") {
  491. testRequest.MaxTokens = 50
  492. }
  493. } else if strings.Contains(model, "gemini") {
  494. testRequest.MaxTokens = 3000
  495. } else {
  496. testRequest.MaxTokens = 16
  497. }
  498. return testRequest
  499. }
  500. func TestChannel(c *gin.Context) {
  501. channelId, err := strconv.Atoi(c.Param("id"))
  502. if err != nil {
  503. common.ApiError(c, err)
  504. return
  505. }
  506. channel, err := model.CacheGetChannel(channelId)
  507. if err != nil {
  508. channel, err = model.GetChannelById(channelId, true)
  509. if err != nil {
  510. common.ApiError(c, err)
  511. return
  512. }
  513. }
  514. //defer func() {
  515. // if channel.ChannelInfo.IsMultiKey {
  516. // go func() { _ = channel.SaveChannelInfo() }()
  517. // }
  518. //}()
  519. testModel := c.Query("model")
  520. endpointType := c.Query("endpoint_type")
  521. tik := time.Now()
  522. result := testChannel(channel, testModel, endpointType)
  523. if result.localErr != nil {
  524. c.JSON(http.StatusOK, gin.H{
  525. "success": false,
  526. "message": result.localErr.Error(),
  527. "time": 0.0,
  528. })
  529. return
  530. }
  531. tok := time.Now()
  532. milliseconds := tok.Sub(tik).Milliseconds()
  533. go channel.UpdateResponseTime(milliseconds)
  534. consumedTime := float64(milliseconds) / 1000.0
  535. if result.newAPIError != nil {
  536. c.JSON(http.StatusOK, gin.H{
  537. "success": false,
  538. "message": result.newAPIError.Error(),
  539. "time": consumedTime,
  540. })
  541. return
  542. }
  543. c.JSON(http.StatusOK, gin.H{
  544. "success": true,
  545. "message": "",
  546. "time": consumedTime,
  547. })
  548. }
  549. var testAllChannelsLock sync.Mutex
  550. var testAllChannelsRunning bool = false
  551. func testAllChannels(notify bool) error {
  552. testAllChannelsLock.Lock()
  553. if testAllChannelsRunning {
  554. testAllChannelsLock.Unlock()
  555. return errors.New("测试已在运行中")
  556. }
  557. testAllChannelsRunning = true
  558. testAllChannelsLock.Unlock()
  559. channels, getChannelErr := model.GetAllChannels(0, 0, true, false)
  560. if getChannelErr != nil {
  561. return getChannelErr
  562. }
  563. var disableThreshold = int64(common.ChannelDisableThreshold * 1000)
  564. if disableThreshold == 0 {
  565. disableThreshold = 10000000 // a impossible value
  566. }
  567. gopool.Go(func() {
  568. // 使用 defer 确保无论如何都会重置运行状态,防止死锁
  569. defer func() {
  570. testAllChannelsLock.Lock()
  571. testAllChannelsRunning = false
  572. testAllChannelsLock.Unlock()
  573. }()
  574. for _, channel := range channels {
  575. isChannelEnabled := channel.Status == common.ChannelStatusEnabled
  576. tik := time.Now()
  577. result := testChannel(channel, "", "")
  578. tok := time.Now()
  579. milliseconds := tok.Sub(tik).Milliseconds()
  580. shouldBanChannel := false
  581. newAPIError := result.newAPIError
  582. // request error disables the channel
  583. if newAPIError != nil {
  584. shouldBanChannel = service.ShouldDisableChannel(channel.Type, result.newAPIError)
  585. }
  586. // 当错误检查通过,才检查响应时间
  587. if common.AutomaticDisableChannelEnabled && !shouldBanChannel {
  588. if milliseconds > disableThreshold {
  589. err := fmt.Errorf("响应时间 %.2fs 超过阈值 %.2fs", float64(milliseconds)/1000.0, float64(disableThreshold)/1000.0)
  590. newAPIError = types.NewOpenAIError(err, types.ErrorCodeChannelResponseTimeExceeded, http.StatusRequestTimeout)
  591. shouldBanChannel = true
  592. }
  593. }
  594. // disable channel
  595. if isChannelEnabled && shouldBanChannel && channel.GetAutoBan() {
  596. processChannelError(result.context, *types.NewChannelError(channel.Id, channel.Type, channel.Name, channel.ChannelInfo.IsMultiKey, common.GetContextKeyString(result.context, constant.ContextKeyChannelKey), channel.GetAutoBan()), newAPIError)
  597. }
  598. // enable channel
  599. if !isChannelEnabled && service.ShouldEnableChannel(newAPIError, channel.Status) {
  600. service.EnableChannel(channel.Id, common.GetContextKeyString(result.context, constant.ContextKeyChannelKey), channel.Name)
  601. }
  602. channel.UpdateResponseTime(milliseconds)
  603. time.Sleep(common.RequestInterval)
  604. }
  605. if notify {
  606. service.NotifyRootUser(dto.NotifyTypeChannelTest, "通道测试完成", "所有通道测试已完成")
  607. }
  608. })
  609. return nil
  610. }
  611. func TestAllChannels(c *gin.Context) {
  612. err := testAllChannels(true)
  613. if err != nil {
  614. common.ApiError(c, err)
  615. return
  616. }
  617. c.JSON(http.StatusOK, gin.H{
  618. "success": true,
  619. "message": "",
  620. })
  621. }
  622. var autoTestChannelsOnce sync.Once
  623. func AutomaticallyTestChannels() {
  624. // 只在Master节点定时测试渠道
  625. if !common.IsMasterNode {
  626. return
  627. }
  628. autoTestChannelsOnce.Do(func() {
  629. for {
  630. if !operation_setting.GetMonitorSetting().AutoTestChannelEnabled {
  631. time.Sleep(1 * time.Minute)
  632. continue
  633. }
  634. for {
  635. frequency := operation_setting.GetMonitorSetting().AutoTestChannelMinutes
  636. time.Sleep(time.Duration(int(math.Round(frequency))) * time.Minute)
  637. common.SysLog(fmt.Sprintf("automatically test channels with interval %f minutes", frequency))
  638. common.SysLog("automatically testing all channels")
  639. _ = testAllChannels(false)
  640. common.SysLog("automatically channel test finished")
  641. if !operation_setting.GetMonitorSetting().AutoTestChannelEnabled {
  642. break
  643. }
  644. }
  645. }
  646. })
  647. }