channel_upstream_update.go 30 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999
  1. package controller
  2. import (
  3. "fmt"
  4. "net/http"
  5. "regexp"
  6. "slices"
  7. "strings"
  8. "sync"
  9. "sync/atomic"
  10. "time"
  11. "github.com/QuantumNous/new-api/common"
  12. "github.com/QuantumNous/new-api/constant"
  13. "github.com/QuantumNous/new-api/dto"
  14. "github.com/QuantumNous/new-api/model"
  15. "github.com/QuantumNous/new-api/relay/channel/gemini"
  16. "github.com/QuantumNous/new-api/relay/channel/ollama"
  17. "github.com/QuantumNous/new-api/service"
  18. "github.com/gin-gonic/gin"
  19. "github.com/samber/lo"
  20. )
  21. const (
  22. channelUpstreamModelUpdateTaskDefaultIntervalMinutes = 30
  23. channelUpstreamModelUpdateTaskBatchSize = 100
  24. channelUpstreamModelUpdateMinCheckIntervalSeconds = 300
  25. channelUpstreamModelUpdateNotifySuppressWindowSeconds = 86400
  26. channelUpstreamModelUpdateNotifyMaxChannelDetails = 8
  27. channelUpstreamModelUpdateNotifyMaxModelDetails = 12
  28. channelUpstreamModelUpdateNotifyMaxFailedChannelIDs = 10
  29. )
  30. var channelUpstreamModelUpdateSelectFields = []string{
  31. "id",
  32. "name",
  33. "type",
  34. "key",
  35. "status",
  36. "base_url",
  37. "models",
  38. "model_mapping",
  39. "settings",
  40. "setting",
  41. "other",
  42. "group",
  43. "priority",
  44. "weight",
  45. "tag",
  46. "channel_info",
  47. "header_override",
  48. }
  49. var (
  50. channelUpstreamModelUpdateTaskOnce sync.Once
  51. channelUpstreamModelUpdateTaskRunning atomic.Bool
  52. channelUpstreamModelUpdateNotifyState = struct {
  53. sync.Mutex
  54. lastNotifiedAt int64
  55. lastChangedChannels int
  56. lastFailedChannels int
  57. }{}
  58. )
  59. type applyChannelUpstreamModelUpdatesRequest struct {
  60. ID int `json:"id"`
  61. AddModels []string `json:"add_models"`
  62. RemoveModels []string `json:"remove_models"`
  63. IgnoreModels []string `json:"ignore_models"`
  64. }
  65. type applyAllChannelUpstreamModelUpdatesResult struct {
  66. ChannelID int `json:"channel_id"`
  67. ChannelName string `json:"channel_name"`
  68. AddedModels []string `json:"added_models"`
  69. RemovedModels []string `json:"removed_models"`
  70. RemainingModels []string `json:"remaining_models"`
  71. RemainingRemoveModels []string `json:"remaining_remove_models"`
  72. }
  73. type detectChannelUpstreamModelUpdatesResult struct {
  74. ChannelID int `json:"channel_id"`
  75. ChannelName string `json:"channel_name"`
  76. AddModels []string `json:"add_models"`
  77. RemoveModels []string `json:"remove_models"`
  78. LastCheckTime int64 `json:"last_check_time"`
  79. AutoAddedModels int `json:"auto_added_models"`
  80. }
  81. type upstreamModelUpdateChannelSummary struct {
  82. ChannelName string
  83. AddCount int
  84. RemoveCount int
  85. }
  86. func normalizeModelNames(models []string) []string {
  87. return lo.Uniq(lo.FilterMap(models, func(model string, _ int) (string, bool) {
  88. trimmed := strings.TrimSpace(model)
  89. return trimmed, trimmed != ""
  90. }))
  91. }
  92. func mergeModelNames(base []string, appended []string) []string {
  93. merged := normalizeModelNames(base)
  94. seen := make(map[string]struct{}, len(merged))
  95. for _, model := range merged {
  96. seen[model] = struct{}{}
  97. }
  98. for _, model := range normalizeModelNames(appended) {
  99. if _, ok := seen[model]; ok {
  100. continue
  101. }
  102. seen[model] = struct{}{}
  103. merged = append(merged, model)
  104. }
  105. return merged
  106. }
  107. func subtractModelNames(base []string, removed []string) []string {
  108. removeSet := make(map[string]struct{}, len(removed))
  109. for _, model := range normalizeModelNames(removed) {
  110. removeSet[model] = struct{}{}
  111. }
  112. return lo.Filter(normalizeModelNames(base), func(model string, _ int) bool {
  113. _, ok := removeSet[model]
  114. return !ok
  115. })
  116. }
  117. func intersectModelNames(base []string, allowed []string) []string {
  118. allowedSet := make(map[string]struct{}, len(allowed))
  119. for _, model := range normalizeModelNames(allowed) {
  120. allowedSet[model] = struct{}{}
  121. }
  122. return lo.Filter(normalizeModelNames(base), func(model string, _ int) bool {
  123. _, ok := allowedSet[model]
  124. return ok
  125. })
  126. }
  127. func applySelectedModelChanges(originModels []string, addModels []string, removeModels []string) []string {
  128. // Add wins when the same model appears in both selected lists.
  129. normalizedAdd := normalizeModelNames(addModels)
  130. normalizedRemove := subtractModelNames(normalizeModelNames(removeModels), normalizedAdd)
  131. return subtractModelNames(mergeModelNames(originModels, normalizedAdd), normalizedRemove)
  132. }
  133. func normalizeChannelModelMapping(channel *model.Channel) map[string]string {
  134. if channel == nil || channel.ModelMapping == nil {
  135. return nil
  136. }
  137. rawMapping := strings.TrimSpace(*channel.ModelMapping)
  138. if rawMapping == "" || rawMapping == "{}" {
  139. return nil
  140. }
  141. parsed := make(map[string]string)
  142. if err := common.UnmarshalJsonStr(rawMapping, &parsed); err != nil {
  143. return nil
  144. }
  145. normalized := make(map[string]string, len(parsed))
  146. for source, target := range parsed {
  147. normalizedSource := strings.TrimSpace(source)
  148. normalizedTarget := strings.TrimSpace(target)
  149. if normalizedSource == "" || normalizedTarget == "" {
  150. continue
  151. }
  152. normalized[normalizedSource] = normalizedTarget
  153. }
  154. if len(normalized) == 0 {
  155. return nil
  156. }
  157. return normalized
  158. }
  159. func collectPendingUpstreamModelChangesFromModels(
  160. localModels []string,
  161. upstreamModels []string,
  162. ignoredModels []string,
  163. modelMapping map[string]string,
  164. ) (pendingAddModels []string, pendingRemoveModels []string) {
  165. localSet := make(map[string]struct{})
  166. localModels = normalizeModelNames(localModels)
  167. upstreamModels = normalizeModelNames(upstreamModels)
  168. for _, modelName := range localModels {
  169. localSet[modelName] = struct{}{}
  170. }
  171. upstreamSet := make(map[string]struct{}, len(upstreamModels))
  172. for _, modelName := range upstreamModels {
  173. upstreamSet[modelName] = struct{}{}
  174. }
  175. normalizedIgnoredModels := normalizeModelNames(ignoredModels)
  176. redirectSourceSet := make(map[string]struct{}, len(modelMapping))
  177. redirectTargetSet := make(map[string]struct{}, len(modelMapping))
  178. for source, target := range modelMapping {
  179. redirectSourceSet[source] = struct{}{}
  180. redirectTargetSet[target] = struct{}{}
  181. }
  182. coveredUpstreamSet := make(map[string]struct{}, len(localSet)+len(redirectTargetSet))
  183. for modelName := range localSet {
  184. coveredUpstreamSet[modelName] = struct{}{}
  185. }
  186. for modelName := range redirectTargetSet {
  187. coveredUpstreamSet[modelName] = struct{}{}
  188. }
  189. pendingAdd := lo.Filter(upstreamModels, func(modelName string, _ int) bool {
  190. if _, ok := coveredUpstreamSet[modelName]; ok {
  191. return false
  192. }
  193. if lo.ContainsBy(normalizedIgnoredModels, func(ignoredModel string) bool {
  194. if regexBody, ok := strings.CutPrefix(ignoredModel, "regex:"); ok {
  195. matched, err := regexp.MatchString(strings.TrimSpace(regexBody), modelName)
  196. return err == nil && matched
  197. }
  198. return ignoredModel == modelName
  199. }) {
  200. return false
  201. }
  202. return true
  203. })
  204. pendingRemove := lo.Filter(localModels, func(modelName string, _ int) bool {
  205. // Redirect source models are virtual aliases and should not be removed
  206. // only because they are absent from upstream model list.
  207. if _, ok := redirectSourceSet[modelName]; ok {
  208. return false
  209. }
  210. _, ok := upstreamSet[modelName]
  211. return !ok
  212. })
  213. return normalizeModelNames(pendingAdd), normalizeModelNames(pendingRemove)
  214. }
  215. func collectPendingUpstreamModelChanges(channel *model.Channel, settings dto.ChannelOtherSettings) (pendingAddModels []string, pendingRemoveModels []string, err error) {
  216. upstreamModels, err := fetchChannelUpstreamModelIDs(channel)
  217. if err != nil {
  218. return nil, nil, err
  219. }
  220. pendingAddModels, pendingRemoveModels = collectPendingUpstreamModelChangesFromModels(
  221. channel.GetModels(),
  222. upstreamModels,
  223. settings.UpstreamModelUpdateIgnoredModels,
  224. normalizeChannelModelMapping(channel),
  225. )
  226. return pendingAddModels, pendingRemoveModels, nil
  227. }
  228. func getUpstreamModelUpdateMinCheckIntervalSeconds() int64 {
  229. interval := int64(common.GetEnvOrDefault(
  230. "CHANNEL_UPSTREAM_MODEL_UPDATE_MIN_CHECK_INTERVAL_SECONDS",
  231. channelUpstreamModelUpdateMinCheckIntervalSeconds,
  232. ))
  233. if interval < 0 {
  234. return channelUpstreamModelUpdateMinCheckIntervalSeconds
  235. }
  236. return interval
  237. }
  238. func fetchChannelUpstreamModelIDs(channel *model.Channel) ([]string, error) {
  239. baseURL := constant.ChannelBaseURLs[channel.Type]
  240. if channel.GetBaseURL() != "" {
  241. baseURL = channel.GetBaseURL()
  242. }
  243. if channel.Type == constant.ChannelTypeOllama {
  244. key := strings.TrimSpace(strings.Split(channel.Key, "\n")[0])
  245. models, err := ollama.FetchOllamaModels(baseURL, key)
  246. if err != nil {
  247. return nil, err
  248. }
  249. return normalizeModelNames(lo.Map(models, func(item ollama.OllamaModel, _ int) string {
  250. return item.Name
  251. })), nil
  252. }
  253. if channel.Type == constant.ChannelTypeGemini {
  254. key, _, apiErr := channel.GetNextEnabledKey()
  255. if apiErr != nil {
  256. return nil, fmt.Errorf("获取渠道密钥失败: %w", apiErr)
  257. }
  258. key = strings.TrimSpace(key)
  259. models, err := gemini.FetchGeminiModels(baseURL, key, channel.GetSetting().Proxy)
  260. if err != nil {
  261. return nil, err
  262. }
  263. return normalizeModelNames(models), nil
  264. }
  265. var url string
  266. switch channel.Type {
  267. case constant.ChannelTypeAli:
  268. url = fmt.Sprintf("%s/compatible-mode/v1/models", baseURL)
  269. case constant.ChannelTypeZhipu_v4:
  270. if plan, ok := constant.ChannelSpecialBases[baseURL]; ok && plan.OpenAIBaseURL != "" {
  271. url = fmt.Sprintf("%s/models", plan.OpenAIBaseURL)
  272. } else {
  273. url = fmt.Sprintf("%s/api/paas/v4/models", baseURL)
  274. }
  275. case constant.ChannelTypeVolcEngine:
  276. if plan, ok := constant.ChannelSpecialBases[baseURL]; ok && plan.OpenAIBaseURL != "" {
  277. url = fmt.Sprintf("%s/v1/models", plan.OpenAIBaseURL)
  278. } else {
  279. url = fmt.Sprintf("%s/v1/models", baseURL)
  280. }
  281. case constant.ChannelTypeMoonshot:
  282. if plan, ok := constant.ChannelSpecialBases[baseURL]; ok && plan.OpenAIBaseURL != "" {
  283. url = fmt.Sprintf("%s/models", plan.OpenAIBaseURL)
  284. } else {
  285. url = fmt.Sprintf("%s/v1/models", baseURL)
  286. }
  287. default:
  288. url = fmt.Sprintf("%s/v1/models", baseURL)
  289. }
  290. key, _, apiErr := channel.GetNextEnabledKey()
  291. if apiErr != nil {
  292. return nil, fmt.Errorf("获取渠道密钥失败: %w", apiErr)
  293. }
  294. key = strings.TrimSpace(key)
  295. headers, err := buildFetchModelsHeaders(channel, key)
  296. if err != nil {
  297. return nil, err
  298. }
  299. body, err := GetResponseBody(http.MethodGet, url, channel, headers)
  300. if err != nil {
  301. return nil, err
  302. }
  303. var result OpenAIModelsResponse
  304. if err := common.Unmarshal(body, &result); err != nil {
  305. return nil, err
  306. }
  307. ids := lo.Map(result.Data, func(item OpenAIModel, _ int) string {
  308. if channel.Type == constant.ChannelTypeGemini {
  309. return strings.TrimPrefix(item.ID, "models/")
  310. }
  311. return item.ID
  312. })
  313. return normalizeModelNames(ids), nil
  314. }
  315. func updateChannelUpstreamModelSettings(channel *model.Channel, settings dto.ChannelOtherSettings, updateModels bool) error {
  316. channel.SetOtherSettings(settings)
  317. updates := map[string]interface{}{
  318. "settings": channel.OtherSettings,
  319. }
  320. if updateModels {
  321. updates["models"] = channel.Models
  322. }
  323. return model.DB.Model(&model.Channel{}).Where("id = ?", channel.Id).Updates(updates).Error
  324. }
  325. func checkAndPersistChannelUpstreamModelUpdates(
  326. channel *model.Channel,
  327. settings *dto.ChannelOtherSettings,
  328. force bool,
  329. allowAutoApply bool,
  330. ) (modelsChanged bool, autoAdded int, err error) {
  331. now := common.GetTimestamp()
  332. if !force {
  333. minInterval := getUpstreamModelUpdateMinCheckIntervalSeconds()
  334. if settings.UpstreamModelUpdateLastCheckTime > 0 &&
  335. now-settings.UpstreamModelUpdateLastCheckTime < minInterval {
  336. return false, 0, nil
  337. }
  338. }
  339. pendingAddModels, pendingRemoveModels, fetchErr := collectPendingUpstreamModelChanges(channel, *settings)
  340. settings.UpstreamModelUpdateLastCheckTime = now
  341. if fetchErr != nil {
  342. if err = updateChannelUpstreamModelSettings(channel, *settings, false); err != nil {
  343. return false, 0, err
  344. }
  345. return false, 0, fetchErr
  346. }
  347. if allowAutoApply && settings.UpstreamModelUpdateAutoSyncEnabled && len(pendingAddModels) > 0 {
  348. originModels := normalizeModelNames(channel.GetModels())
  349. mergedModels := mergeModelNames(originModels, pendingAddModels)
  350. if len(mergedModels) > len(originModels) {
  351. channel.Models = strings.Join(mergedModels, ",")
  352. autoAdded = len(mergedModels) - len(originModels)
  353. modelsChanged = true
  354. }
  355. settings.UpstreamModelUpdateLastDetectedModels = []string{}
  356. } else {
  357. settings.UpstreamModelUpdateLastDetectedModels = pendingAddModels
  358. }
  359. settings.UpstreamModelUpdateLastRemovedModels = pendingRemoveModels
  360. if err = updateChannelUpstreamModelSettings(channel, *settings, modelsChanged); err != nil {
  361. return false, autoAdded, err
  362. }
  363. if modelsChanged {
  364. if err = channel.UpdateAbilities(nil); err != nil {
  365. return true, autoAdded, err
  366. }
  367. }
  368. return modelsChanged, autoAdded, nil
  369. }
  370. func refreshChannelRuntimeCache() {
  371. if common.MemoryCacheEnabled {
  372. func() {
  373. defer func() {
  374. if r := recover(); r != nil {
  375. common.SysLog(fmt.Sprintf("InitChannelCache panic: %v", r))
  376. }
  377. }()
  378. model.InitChannelCache()
  379. }()
  380. }
  381. service.ResetProxyClientCache()
  382. }
  383. func shouldSendUpstreamModelUpdateNotification(now int64, changedChannels int, failedChannels int) bool {
  384. if changedChannels <= 0 && failedChannels <= 0 {
  385. return true
  386. }
  387. channelUpstreamModelUpdateNotifyState.Lock()
  388. defer channelUpstreamModelUpdateNotifyState.Unlock()
  389. if channelUpstreamModelUpdateNotifyState.lastNotifiedAt > 0 &&
  390. now-channelUpstreamModelUpdateNotifyState.lastNotifiedAt < channelUpstreamModelUpdateNotifySuppressWindowSeconds &&
  391. channelUpstreamModelUpdateNotifyState.lastChangedChannels == changedChannels &&
  392. channelUpstreamModelUpdateNotifyState.lastFailedChannels == failedChannels {
  393. return false
  394. }
  395. channelUpstreamModelUpdateNotifyState.lastNotifiedAt = now
  396. channelUpstreamModelUpdateNotifyState.lastChangedChannels = changedChannels
  397. channelUpstreamModelUpdateNotifyState.lastFailedChannels = failedChannels
  398. return true
  399. }
  400. func buildUpstreamModelUpdateTaskNotificationContent(
  401. checkedChannels int,
  402. changedChannels int,
  403. detectedAddModels int,
  404. detectedRemoveModels int,
  405. autoAddedModels int,
  406. failedChannelIDs []int,
  407. channelSummaries []upstreamModelUpdateChannelSummary,
  408. addModelSamples []string,
  409. removeModelSamples []string,
  410. ) string {
  411. var builder strings.Builder
  412. failedChannels := len(failedChannelIDs)
  413. builder.WriteString(fmt.Sprintf(
  414. "上游模型巡检摘要:检测渠道 %d 个,发现变更 %d 个,新增 %d 个,删除 %d 个,自动同步新增 %d 个,失败 %d 个。",
  415. checkedChannels,
  416. changedChannels,
  417. detectedAddModels,
  418. detectedRemoveModels,
  419. autoAddedModels,
  420. failedChannels,
  421. ))
  422. if len(channelSummaries) > 0 {
  423. displayCount := min(len(channelSummaries), channelUpstreamModelUpdateNotifyMaxChannelDetails)
  424. builder.WriteString(fmt.Sprintf("\n\n变更渠道明细(展示 %d/%d):", displayCount, len(channelSummaries)))
  425. for _, summary := range channelSummaries[:displayCount] {
  426. builder.WriteString(fmt.Sprintf("\n- %s (+%d / -%d)", summary.ChannelName, summary.AddCount, summary.RemoveCount))
  427. }
  428. if len(channelSummaries) > displayCount {
  429. builder.WriteString(fmt.Sprintf("\n- 其余 %d 个渠道已省略", len(channelSummaries)-displayCount))
  430. }
  431. }
  432. normalizedAddModelSamples := normalizeModelNames(addModelSamples)
  433. if len(normalizedAddModelSamples) > 0 {
  434. displayCount := min(len(normalizedAddModelSamples), channelUpstreamModelUpdateNotifyMaxModelDetails)
  435. builder.WriteString(fmt.Sprintf("\n\n新增模型示例(展示 %d/%d):%s",
  436. displayCount,
  437. len(normalizedAddModelSamples),
  438. strings.Join(normalizedAddModelSamples[:displayCount], ", "),
  439. ))
  440. if len(normalizedAddModelSamples) > displayCount {
  441. builder.WriteString(fmt.Sprintf("(其余 %d 个已省略)", len(normalizedAddModelSamples)-displayCount))
  442. }
  443. }
  444. normalizedRemoveModelSamples := normalizeModelNames(removeModelSamples)
  445. if len(normalizedRemoveModelSamples) > 0 {
  446. displayCount := min(len(normalizedRemoveModelSamples), channelUpstreamModelUpdateNotifyMaxModelDetails)
  447. builder.WriteString(fmt.Sprintf("\n\n删除模型示例(展示 %d/%d):%s",
  448. displayCount,
  449. len(normalizedRemoveModelSamples),
  450. strings.Join(normalizedRemoveModelSamples[:displayCount], ", "),
  451. ))
  452. if len(normalizedRemoveModelSamples) > displayCount {
  453. builder.WriteString(fmt.Sprintf("(其余 %d 个已省略)", len(normalizedRemoveModelSamples)-displayCount))
  454. }
  455. }
  456. if failedChannels > 0 {
  457. displayCount := min(failedChannels, channelUpstreamModelUpdateNotifyMaxFailedChannelIDs)
  458. displayIDs := lo.Map(failedChannelIDs[:displayCount], func(channelID int, _ int) string {
  459. return fmt.Sprintf("%d", channelID)
  460. })
  461. builder.WriteString(fmt.Sprintf(
  462. "\n\n失败渠道 ID(展示 %d/%d):%s",
  463. displayCount,
  464. failedChannels,
  465. strings.Join(displayIDs, ", "),
  466. ))
  467. if failedChannels > displayCount {
  468. builder.WriteString(fmt.Sprintf("(其余 %d 个已省略)", failedChannels-displayCount))
  469. }
  470. }
  471. return builder.String()
  472. }
  473. func runChannelUpstreamModelUpdateTaskOnce() {
  474. if !channelUpstreamModelUpdateTaskRunning.CompareAndSwap(false, true) {
  475. return
  476. }
  477. defer channelUpstreamModelUpdateTaskRunning.Store(false)
  478. checkedChannels := 0
  479. failedChannels := 0
  480. failedChannelIDs := make([]int, 0)
  481. changedChannels := 0
  482. detectedAddModels := 0
  483. detectedRemoveModels := 0
  484. autoAddedModels := 0
  485. channelSummaries := make([]upstreamModelUpdateChannelSummary, 0)
  486. addModelSamples := make([]string, 0)
  487. removeModelSamples := make([]string, 0)
  488. refreshNeeded := false
  489. lastID := 0
  490. for {
  491. var channels []*model.Channel
  492. query := model.DB.
  493. Select(channelUpstreamModelUpdateSelectFields).
  494. Where("status = ?", common.ChannelStatusEnabled).
  495. Order("id asc").
  496. Limit(channelUpstreamModelUpdateTaskBatchSize)
  497. if lastID > 0 {
  498. query = query.Where("id > ?", lastID)
  499. }
  500. err := query.Find(&channels).Error
  501. if err != nil {
  502. common.SysLog(fmt.Sprintf("upstream model update task query failed: %v", err))
  503. break
  504. }
  505. if len(channels) == 0 {
  506. break
  507. }
  508. lastID = channels[len(channels)-1].Id
  509. for _, channel := range channels {
  510. if channel == nil {
  511. continue
  512. }
  513. settings := channel.GetOtherSettings()
  514. if !settings.UpstreamModelUpdateCheckEnabled {
  515. continue
  516. }
  517. checkedChannels++
  518. modelsChanged, autoAdded, err := checkAndPersistChannelUpstreamModelUpdates(channel, &settings, false, true)
  519. if err != nil {
  520. failedChannels++
  521. failedChannelIDs = append(failedChannelIDs, channel.Id)
  522. common.SysLog(fmt.Sprintf("upstream model update check failed: channel_id=%d channel_name=%s err=%v", channel.Id, channel.Name, err))
  523. continue
  524. }
  525. currentAddModels := normalizeModelNames(settings.UpstreamModelUpdateLastDetectedModels)
  526. currentRemoveModels := normalizeModelNames(settings.UpstreamModelUpdateLastRemovedModels)
  527. currentAddCount := len(currentAddModels) + autoAdded
  528. currentRemoveCount := len(currentRemoveModels)
  529. detectedAddModels += currentAddCount
  530. detectedRemoveModels += currentRemoveCount
  531. if currentAddCount > 0 || currentRemoveCount > 0 {
  532. changedChannels++
  533. channelSummaries = append(channelSummaries, upstreamModelUpdateChannelSummary{
  534. ChannelName: channel.Name,
  535. AddCount: currentAddCount,
  536. RemoveCount: currentRemoveCount,
  537. })
  538. }
  539. addModelSamples = mergeModelNames(addModelSamples, currentAddModels)
  540. removeModelSamples = mergeModelNames(removeModelSamples, currentRemoveModels)
  541. if modelsChanged {
  542. refreshNeeded = true
  543. }
  544. autoAddedModels += autoAdded
  545. if common.RequestInterval > 0 {
  546. time.Sleep(common.RequestInterval)
  547. }
  548. }
  549. if len(channels) < channelUpstreamModelUpdateTaskBatchSize {
  550. break
  551. }
  552. }
  553. if refreshNeeded {
  554. refreshChannelRuntimeCache()
  555. }
  556. if checkedChannels > 0 || common.DebugEnabled {
  557. common.SysLog(fmt.Sprintf(
  558. "upstream model update task done: checked_channels=%d changed_channels=%d detected_add_models=%d detected_remove_models=%d failed_channels=%d auto_added_models=%d",
  559. checkedChannels,
  560. changedChannels,
  561. detectedAddModels,
  562. detectedRemoveModels,
  563. failedChannels,
  564. autoAddedModels,
  565. ))
  566. }
  567. if changedChannels > 0 || failedChannels > 0 {
  568. now := common.GetTimestamp()
  569. if !shouldSendUpstreamModelUpdateNotification(now, changedChannels, failedChannels) {
  570. common.SysLog(fmt.Sprintf(
  571. "upstream model update notification skipped in 24h window: changed_channels=%d failed_channels=%d",
  572. changedChannels,
  573. failedChannels,
  574. ))
  575. return
  576. }
  577. service.NotifyUpstreamModelUpdateWatchers(
  578. "上游模型巡检通知",
  579. buildUpstreamModelUpdateTaskNotificationContent(
  580. checkedChannels,
  581. changedChannels,
  582. detectedAddModels,
  583. detectedRemoveModels,
  584. autoAddedModels,
  585. failedChannelIDs,
  586. channelSummaries,
  587. addModelSamples,
  588. removeModelSamples,
  589. ),
  590. )
  591. }
  592. }
  593. func StartChannelUpstreamModelUpdateTask() {
  594. channelUpstreamModelUpdateTaskOnce.Do(func() {
  595. if !common.IsMasterNode {
  596. return
  597. }
  598. if !common.GetEnvOrDefaultBool("CHANNEL_UPSTREAM_MODEL_UPDATE_TASK_ENABLED", true) {
  599. common.SysLog("upstream model update task disabled by CHANNEL_UPSTREAM_MODEL_UPDATE_TASK_ENABLED")
  600. return
  601. }
  602. intervalMinutes := common.GetEnvOrDefault(
  603. "CHANNEL_UPSTREAM_MODEL_UPDATE_TASK_INTERVAL_MINUTES",
  604. channelUpstreamModelUpdateTaskDefaultIntervalMinutes,
  605. )
  606. if intervalMinutes < 1 {
  607. intervalMinutes = channelUpstreamModelUpdateTaskDefaultIntervalMinutes
  608. }
  609. interval := time.Duration(intervalMinutes) * time.Minute
  610. go func() {
  611. common.SysLog(fmt.Sprintf("upstream model update task started: interval=%s", interval))
  612. runChannelUpstreamModelUpdateTaskOnce()
  613. ticker := time.NewTicker(interval)
  614. defer ticker.Stop()
  615. for range ticker.C {
  616. runChannelUpstreamModelUpdateTaskOnce()
  617. }
  618. }()
  619. })
  620. }
  621. func ApplyChannelUpstreamModelUpdates(c *gin.Context) {
  622. var req applyChannelUpstreamModelUpdatesRequest
  623. if err := c.ShouldBindJSON(&req); err != nil {
  624. common.ApiError(c, err)
  625. return
  626. }
  627. if req.ID <= 0 {
  628. c.JSON(http.StatusOK, gin.H{
  629. "success": false,
  630. "message": "invalid channel id",
  631. })
  632. return
  633. }
  634. channel, err := model.GetChannelById(req.ID, true)
  635. if err != nil {
  636. common.ApiError(c, err)
  637. return
  638. }
  639. beforeSettings := channel.GetOtherSettings()
  640. ignoredModels := intersectModelNames(req.IgnoreModels, beforeSettings.UpstreamModelUpdateLastDetectedModels)
  641. addedModels, removedModels, remainingModels, remainingRemoveModels, modelsChanged, err := applyChannelUpstreamModelUpdates(
  642. channel,
  643. req.AddModels,
  644. req.IgnoreModels,
  645. req.RemoveModels,
  646. )
  647. if err != nil {
  648. common.ApiError(c, err)
  649. return
  650. }
  651. if modelsChanged {
  652. refreshChannelRuntimeCache()
  653. }
  654. c.JSON(http.StatusOK, gin.H{
  655. "success": true,
  656. "message": "",
  657. "data": gin.H{
  658. "id": channel.Id,
  659. "added_models": addedModels,
  660. "removed_models": removedModels,
  661. "ignored_models": ignoredModels,
  662. "remaining_models": remainingModels,
  663. "remaining_remove_models": remainingRemoveModels,
  664. "models": channel.Models,
  665. "settings": channel.OtherSettings,
  666. },
  667. })
  668. }
  669. func DetectChannelUpstreamModelUpdates(c *gin.Context) {
  670. var req applyChannelUpstreamModelUpdatesRequest
  671. if err := c.ShouldBindJSON(&req); err != nil {
  672. common.ApiError(c, err)
  673. return
  674. }
  675. if req.ID <= 0 {
  676. c.JSON(http.StatusOK, gin.H{
  677. "success": false,
  678. "message": "invalid channel id",
  679. })
  680. return
  681. }
  682. channel, err := model.GetChannelById(req.ID, true)
  683. if err != nil {
  684. common.ApiError(c, err)
  685. return
  686. }
  687. settings := channel.GetOtherSettings()
  688. modelsChanged, autoAdded, err := checkAndPersistChannelUpstreamModelUpdates(channel, &settings, true, false)
  689. if err != nil {
  690. common.ApiError(c, err)
  691. return
  692. }
  693. if modelsChanged {
  694. refreshChannelRuntimeCache()
  695. }
  696. c.JSON(http.StatusOK, gin.H{
  697. "success": true,
  698. "message": "",
  699. "data": detectChannelUpstreamModelUpdatesResult{
  700. ChannelID: channel.Id,
  701. ChannelName: channel.Name,
  702. AddModels: normalizeModelNames(settings.UpstreamModelUpdateLastDetectedModels),
  703. RemoveModels: normalizeModelNames(settings.UpstreamModelUpdateLastRemovedModels),
  704. LastCheckTime: settings.UpstreamModelUpdateLastCheckTime,
  705. AutoAddedModels: autoAdded,
  706. },
  707. })
  708. }
  709. func applyChannelUpstreamModelUpdates(
  710. channel *model.Channel,
  711. addModelsInput []string,
  712. ignoreModelsInput []string,
  713. removeModelsInput []string,
  714. ) (
  715. addedModels []string,
  716. removedModels []string,
  717. remainingModels []string,
  718. remainingRemoveModels []string,
  719. modelsChanged bool,
  720. err error,
  721. ) {
  722. settings := channel.GetOtherSettings()
  723. pendingAddModels := normalizeModelNames(settings.UpstreamModelUpdateLastDetectedModels)
  724. pendingRemoveModels := normalizeModelNames(settings.UpstreamModelUpdateLastRemovedModels)
  725. addModels := intersectModelNames(addModelsInput, pendingAddModels)
  726. ignoreModels := intersectModelNames(ignoreModelsInput, pendingAddModels)
  727. removeModels := intersectModelNames(removeModelsInput, pendingRemoveModels)
  728. removeModels = subtractModelNames(removeModels, addModels)
  729. originModels := normalizeModelNames(channel.GetModels())
  730. nextModels := applySelectedModelChanges(originModels, addModels, removeModels)
  731. modelsChanged = !slices.Equal(originModels, nextModels)
  732. if modelsChanged {
  733. channel.Models = strings.Join(nextModels, ",")
  734. }
  735. settings.UpstreamModelUpdateIgnoredModels = mergeModelNames(settings.UpstreamModelUpdateIgnoredModels, ignoreModels)
  736. if len(addModels) > 0 {
  737. settings.UpstreamModelUpdateIgnoredModels = subtractModelNames(settings.UpstreamModelUpdateIgnoredModels, addModels)
  738. }
  739. remainingModels = subtractModelNames(pendingAddModels, append(addModels, ignoreModels...))
  740. remainingRemoveModels = subtractModelNames(pendingRemoveModels, removeModels)
  741. settings.UpstreamModelUpdateLastDetectedModels = remainingModels
  742. settings.UpstreamModelUpdateLastRemovedModels = remainingRemoveModels
  743. settings.UpstreamModelUpdateLastCheckTime = common.GetTimestamp()
  744. if err := updateChannelUpstreamModelSettings(channel, settings, modelsChanged); err != nil {
  745. return nil, nil, nil, nil, false, err
  746. }
  747. if modelsChanged {
  748. if err := channel.UpdateAbilities(nil); err != nil {
  749. return addModels, removeModels, remainingModels, remainingRemoveModels, true, err
  750. }
  751. }
  752. return addModels, removeModels, remainingModels, remainingRemoveModels, modelsChanged, nil
  753. }
  754. func collectPendingApplyUpstreamModelChanges(settings dto.ChannelOtherSettings) (pendingAddModels []string, pendingRemoveModels []string) {
  755. return normalizeModelNames(settings.UpstreamModelUpdateLastDetectedModels), normalizeModelNames(settings.UpstreamModelUpdateLastRemovedModels)
  756. }
  757. func findEnabledChannelsAfterID(lastID int, batchSize int) ([]*model.Channel, error) {
  758. var channels []*model.Channel
  759. query := model.DB.
  760. Select(channelUpstreamModelUpdateSelectFields).
  761. Where("status = ?", common.ChannelStatusEnabled).
  762. Order("id asc").
  763. Limit(batchSize)
  764. if lastID > 0 {
  765. query = query.Where("id > ?", lastID)
  766. }
  767. return channels, query.Find(&channels).Error
  768. }
  769. func ApplyAllChannelUpstreamModelUpdates(c *gin.Context) {
  770. results := make([]applyAllChannelUpstreamModelUpdatesResult, 0)
  771. failed := make([]int, 0)
  772. refreshNeeded := false
  773. addedModelCount := 0
  774. removedModelCount := 0
  775. lastID := 0
  776. for {
  777. channels, err := findEnabledChannelsAfterID(lastID, channelUpstreamModelUpdateTaskBatchSize)
  778. if err != nil {
  779. common.ApiError(c, err)
  780. return
  781. }
  782. if len(channels) == 0 {
  783. break
  784. }
  785. lastID = channels[len(channels)-1].Id
  786. for _, channel := range channels {
  787. if channel == nil {
  788. continue
  789. }
  790. settings := channel.GetOtherSettings()
  791. if !settings.UpstreamModelUpdateCheckEnabled {
  792. continue
  793. }
  794. pendingAddModels, pendingRemoveModels := collectPendingApplyUpstreamModelChanges(settings)
  795. if len(pendingAddModels) == 0 && len(pendingRemoveModels) == 0 {
  796. continue
  797. }
  798. addedModels, removedModels, remainingModels, remainingRemoveModels, modelsChanged, err := applyChannelUpstreamModelUpdates(
  799. channel,
  800. pendingAddModels,
  801. nil,
  802. pendingRemoveModels,
  803. )
  804. if err != nil {
  805. failed = append(failed, channel.Id)
  806. continue
  807. }
  808. if modelsChanged {
  809. refreshNeeded = true
  810. }
  811. addedModelCount += len(addedModels)
  812. removedModelCount += len(removedModels)
  813. results = append(results, applyAllChannelUpstreamModelUpdatesResult{
  814. ChannelID: channel.Id,
  815. ChannelName: channel.Name,
  816. AddedModels: addedModels,
  817. RemovedModels: removedModels,
  818. RemainingModels: remainingModels,
  819. RemainingRemoveModels: remainingRemoveModels,
  820. })
  821. }
  822. if len(channels) < channelUpstreamModelUpdateTaskBatchSize {
  823. break
  824. }
  825. }
  826. if refreshNeeded {
  827. refreshChannelRuntimeCache()
  828. }
  829. c.JSON(http.StatusOK, gin.H{
  830. "success": true,
  831. "message": "",
  832. "data": gin.H{
  833. "processed_channels": len(results),
  834. "added_models": addedModelCount,
  835. "removed_models": removedModelCount,
  836. "failed_channel_ids": failed,
  837. "results": results,
  838. },
  839. })
  840. }
  841. func DetectAllChannelUpstreamModelUpdates(c *gin.Context) {
  842. results := make([]detectChannelUpstreamModelUpdatesResult, 0)
  843. failed := make([]int, 0)
  844. detectedAddCount := 0
  845. detectedRemoveCount := 0
  846. refreshNeeded := false
  847. lastID := 0
  848. for {
  849. channels, err := findEnabledChannelsAfterID(lastID, channelUpstreamModelUpdateTaskBatchSize)
  850. if err != nil {
  851. common.ApiError(c, err)
  852. return
  853. }
  854. if len(channels) == 0 {
  855. break
  856. }
  857. lastID = channels[len(channels)-1].Id
  858. for _, channel := range channels {
  859. if channel == nil {
  860. continue
  861. }
  862. settings := channel.GetOtherSettings()
  863. if !settings.UpstreamModelUpdateCheckEnabled {
  864. continue
  865. }
  866. modelsChanged, autoAdded, err := checkAndPersistChannelUpstreamModelUpdates(channel, &settings, true, false)
  867. if err != nil {
  868. failed = append(failed, channel.Id)
  869. continue
  870. }
  871. if modelsChanged {
  872. refreshNeeded = true
  873. }
  874. addModels := normalizeModelNames(settings.UpstreamModelUpdateLastDetectedModels)
  875. removeModels := normalizeModelNames(settings.UpstreamModelUpdateLastRemovedModels)
  876. detectedAddCount += len(addModels)
  877. detectedRemoveCount += len(removeModels)
  878. results = append(results, detectChannelUpstreamModelUpdatesResult{
  879. ChannelID: channel.Id,
  880. ChannelName: channel.Name,
  881. AddModels: addModels,
  882. RemoveModels: removeModels,
  883. LastCheckTime: settings.UpstreamModelUpdateLastCheckTime,
  884. AutoAddedModels: autoAdded,
  885. })
  886. }
  887. if len(channels) < channelUpstreamModelUpdateTaskBatchSize {
  888. break
  889. }
  890. }
  891. if refreshNeeded {
  892. refreshChannelRuntimeCache()
  893. }
  894. c.JSON(http.StatusOK, gin.H{
  895. "success": true,
  896. "message": "",
  897. "data": gin.H{
  898. "processed_channels": len(results),
  899. "failed_channel_ids": failed,
  900. "detected_add_models": detectedAddCount,
  901. "detected_remove_models": detectedRemoveCount,
  902. "channel_detected_results": results,
  903. },
  904. })
  905. }