channel_upstream_update.go 30 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975
  1. package controller
  2. import (
  3. "fmt"
  4. "net/http"
  5. "slices"
  6. "strings"
  7. "sync"
  8. "sync/atomic"
  9. "time"
  10. "github.com/QuantumNous/new-api/common"
  11. "github.com/QuantumNous/new-api/constant"
  12. "github.com/QuantumNous/new-api/dto"
  13. "github.com/QuantumNous/new-api/model"
  14. "github.com/QuantumNous/new-api/relay/channel/gemini"
  15. "github.com/QuantumNous/new-api/relay/channel/ollama"
  16. "github.com/QuantumNous/new-api/service"
  17. "github.com/gin-gonic/gin"
  18. "github.com/samber/lo"
  19. )
  20. const (
  21. channelUpstreamModelUpdateTaskDefaultIntervalMinutes = 30
  22. channelUpstreamModelUpdateTaskBatchSize = 100
  23. channelUpstreamModelUpdateMinCheckIntervalSeconds = 300
  24. channelUpstreamModelUpdateNotifySuppressWindowSeconds = 86400
  25. channelUpstreamModelUpdateNotifyMaxChannelDetails = 8
  26. channelUpstreamModelUpdateNotifyMaxModelDetails = 12
  27. channelUpstreamModelUpdateNotifyMaxFailedChannelIDs = 10
  28. )
  29. var (
  30. channelUpstreamModelUpdateTaskOnce sync.Once
  31. channelUpstreamModelUpdateTaskRunning atomic.Bool
  32. channelUpstreamModelUpdateNotifyState = struct {
  33. sync.Mutex
  34. lastNotifiedAt int64
  35. lastChangedChannels int
  36. lastFailedChannels int
  37. }{}
  38. )
  39. type applyChannelUpstreamModelUpdatesRequest struct {
  40. ID int `json:"id"`
  41. AddModels []string `json:"add_models"`
  42. RemoveModels []string `json:"remove_models"`
  43. IgnoreModels []string `json:"ignore_models"`
  44. }
  45. type applyAllChannelUpstreamModelUpdatesResult struct {
  46. ChannelID int `json:"channel_id"`
  47. ChannelName string `json:"channel_name"`
  48. AddedModels []string `json:"added_models"`
  49. RemovedModels []string `json:"removed_models"`
  50. RemainingModels []string `json:"remaining_models"`
  51. RemainingRemoveModels []string `json:"remaining_remove_models"`
  52. }
  53. type detectChannelUpstreamModelUpdatesResult struct {
  54. ChannelID int `json:"channel_id"`
  55. ChannelName string `json:"channel_name"`
  56. AddModels []string `json:"add_models"`
  57. RemoveModels []string `json:"remove_models"`
  58. LastCheckTime int64 `json:"last_check_time"`
  59. AutoAddedModels int `json:"auto_added_models"`
  60. }
  61. type upstreamModelUpdateChannelSummary struct {
  62. ChannelName string
  63. AddCount int
  64. RemoveCount int
  65. }
  66. func normalizeModelNames(models []string) []string {
  67. return lo.Uniq(lo.FilterMap(models, func(model string, _ int) (string, bool) {
  68. trimmed := strings.TrimSpace(model)
  69. return trimmed, trimmed != ""
  70. }))
  71. }
  72. func mergeModelNames(base []string, appended []string) []string {
  73. merged := normalizeModelNames(base)
  74. seen := make(map[string]struct{}, len(merged))
  75. for _, model := range merged {
  76. seen[model] = struct{}{}
  77. }
  78. for _, model := range normalizeModelNames(appended) {
  79. if _, ok := seen[model]; ok {
  80. continue
  81. }
  82. seen[model] = struct{}{}
  83. merged = append(merged, model)
  84. }
  85. return merged
  86. }
  87. func subtractModelNames(base []string, removed []string) []string {
  88. removeSet := make(map[string]struct{}, len(removed))
  89. for _, model := range normalizeModelNames(removed) {
  90. removeSet[model] = struct{}{}
  91. }
  92. return lo.Filter(normalizeModelNames(base), func(model string, _ int) bool {
  93. _, ok := removeSet[model]
  94. return !ok
  95. })
  96. }
  97. func intersectModelNames(base []string, allowed []string) []string {
  98. allowedSet := make(map[string]struct{}, len(allowed))
  99. for _, model := range normalizeModelNames(allowed) {
  100. allowedSet[model] = struct{}{}
  101. }
  102. return lo.Filter(normalizeModelNames(base), func(model string, _ int) bool {
  103. _, ok := allowedSet[model]
  104. return ok
  105. })
  106. }
  107. func applySelectedModelChanges(originModels []string, addModels []string, removeModels []string) []string {
  108. // Add wins when the same model appears in both selected lists.
  109. normalizedAdd := normalizeModelNames(addModels)
  110. normalizedRemove := subtractModelNames(normalizeModelNames(removeModels), normalizedAdd)
  111. return subtractModelNames(mergeModelNames(originModels, normalizedAdd), normalizedRemove)
  112. }
  113. func normalizeChannelModelMapping(channel *model.Channel) map[string]string {
  114. if channel == nil || channel.ModelMapping == nil {
  115. return nil
  116. }
  117. rawMapping := strings.TrimSpace(*channel.ModelMapping)
  118. if rawMapping == "" || rawMapping == "{}" {
  119. return nil
  120. }
  121. parsed := make(map[string]string)
  122. if err := common.UnmarshalJsonStr(rawMapping, &parsed); err != nil {
  123. return nil
  124. }
  125. normalized := make(map[string]string, len(parsed))
  126. for source, target := range parsed {
  127. normalizedSource := strings.TrimSpace(source)
  128. normalizedTarget := strings.TrimSpace(target)
  129. if normalizedSource == "" || normalizedTarget == "" {
  130. continue
  131. }
  132. normalized[normalizedSource] = normalizedTarget
  133. }
  134. if len(normalized) == 0 {
  135. return nil
  136. }
  137. return normalized
  138. }
  139. func collectPendingUpstreamModelChangesFromModels(
  140. localModels []string,
  141. upstreamModels []string,
  142. ignoredModels []string,
  143. modelMapping map[string]string,
  144. ) (pendingAddModels []string, pendingRemoveModels []string) {
  145. localSet := make(map[string]struct{})
  146. localModels = normalizeModelNames(localModels)
  147. upstreamModels = normalizeModelNames(upstreamModels)
  148. for _, modelName := range localModels {
  149. localSet[modelName] = struct{}{}
  150. }
  151. upstreamSet := make(map[string]struct{}, len(upstreamModels))
  152. for _, modelName := range upstreamModels {
  153. upstreamSet[modelName] = struct{}{}
  154. }
  155. ignoredSet := make(map[string]struct{})
  156. for _, modelName := range normalizeModelNames(ignoredModels) {
  157. ignoredSet[modelName] = struct{}{}
  158. }
  159. redirectSourceSet := make(map[string]struct{}, len(modelMapping))
  160. redirectTargetSet := make(map[string]struct{}, len(modelMapping))
  161. for source, target := range modelMapping {
  162. redirectSourceSet[source] = struct{}{}
  163. redirectTargetSet[target] = struct{}{}
  164. }
  165. coveredUpstreamSet := make(map[string]struct{}, len(localSet)+len(redirectTargetSet))
  166. for modelName := range localSet {
  167. coveredUpstreamSet[modelName] = struct{}{}
  168. }
  169. for modelName := range redirectTargetSet {
  170. coveredUpstreamSet[modelName] = struct{}{}
  171. }
  172. pendingAdd := lo.Filter(upstreamModels, func(modelName string, _ int) bool {
  173. if _, ok := coveredUpstreamSet[modelName]; ok {
  174. return false
  175. }
  176. if _, ok := ignoredSet[modelName]; ok {
  177. return false
  178. }
  179. return true
  180. })
  181. pendingRemove := lo.Filter(localModels, func(modelName string, _ int) bool {
  182. // Redirect source models are virtual aliases and should not be removed
  183. // only because they are absent from upstream model list.
  184. if _, ok := redirectSourceSet[modelName]; ok {
  185. return false
  186. }
  187. _, ok := upstreamSet[modelName]
  188. return !ok
  189. })
  190. return normalizeModelNames(pendingAdd), normalizeModelNames(pendingRemove)
  191. }
  192. func collectPendingUpstreamModelChanges(channel *model.Channel, settings dto.ChannelOtherSettings) (pendingAddModels []string, pendingRemoveModels []string, err error) {
  193. upstreamModels, err := fetchChannelUpstreamModelIDs(channel)
  194. if err != nil {
  195. return nil, nil, err
  196. }
  197. pendingAddModels, pendingRemoveModels = collectPendingUpstreamModelChangesFromModels(
  198. channel.GetModels(),
  199. upstreamModels,
  200. settings.UpstreamModelUpdateIgnoredModels,
  201. normalizeChannelModelMapping(channel),
  202. )
  203. return pendingAddModels, pendingRemoveModels, nil
  204. }
  205. func getUpstreamModelUpdateMinCheckIntervalSeconds() int64 {
  206. interval := int64(common.GetEnvOrDefault(
  207. "CHANNEL_UPSTREAM_MODEL_UPDATE_MIN_CHECK_INTERVAL_SECONDS",
  208. channelUpstreamModelUpdateMinCheckIntervalSeconds,
  209. ))
  210. if interval < 0 {
  211. return channelUpstreamModelUpdateMinCheckIntervalSeconds
  212. }
  213. return interval
  214. }
  215. func fetchChannelUpstreamModelIDs(channel *model.Channel) ([]string, error) {
  216. baseURL := constant.ChannelBaseURLs[channel.Type]
  217. if channel.GetBaseURL() != "" {
  218. baseURL = channel.GetBaseURL()
  219. }
  220. if channel.Type == constant.ChannelTypeOllama {
  221. key := strings.TrimSpace(strings.Split(channel.Key, "\n")[0])
  222. models, err := ollama.FetchOllamaModels(baseURL, key)
  223. if err != nil {
  224. return nil, err
  225. }
  226. return normalizeModelNames(lo.Map(models, func(item ollama.OllamaModel, _ int) string {
  227. return item.Name
  228. })), nil
  229. }
  230. if channel.Type == constant.ChannelTypeGemini {
  231. key, _, apiErr := channel.GetNextEnabledKey()
  232. if apiErr != nil {
  233. return nil, fmt.Errorf("获取渠道密钥失败: %w", apiErr)
  234. }
  235. key = strings.TrimSpace(key)
  236. models, err := gemini.FetchGeminiModels(baseURL, key, channel.GetSetting().Proxy)
  237. if err != nil {
  238. return nil, err
  239. }
  240. return normalizeModelNames(models), nil
  241. }
  242. var url string
  243. switch channel.Type {
  244. case constant.ChannelTypeAli:
  245. url = fmt.Sprintf("%s/compatible-mode/v1/models", baseURL)
  246. case constant.ChannelTypeZhipu_v4:
  247. if plan, ok := constant.ChannelSpecialBases[baseURL]; ok && plan.OpenAIBaseURL != "" {
  248. url = fmt.Sprintf("%s/models", plan.OpenAIBaseURL)
  249. } else {
  250. url = fmt.Sprintf("%s/api/paas/v4/models", baseURL)
  251. }
  252. case constant.ChannelTypeVolcEngine:
  253. if plan, ok := constant.ChannelSpecialBases[baseURL]; ok && plan.OpenAIBaseURL != "" {
  254. url = fmt.Sprintf("%s/v1/models", plan.OpenAIBaseURL)
  255. } else {
  256. url = fmt.Sprintf("%s/v1/models", baseURL)
  257. }
  258. case constant.ChannelTypeMoonshot:
  259. if plan, ok := constant.ChannelSpecialBases[baseURL]; ok && plan.OpenAIBaseURL != "" {
  260. url = fmt.Sprintf("%s/models", plan.OpenAIBaseURL)
  261. } else {
  262. url = fmt.Sprintf("%s/v1/models", baseURL)
  263. }
  264. default:
  265. url = fmt.Sprintf("%s/v1/models", baseURL)
  266. }
  267. key, _, apiErr := channel.GetNextEnabledKey()
  268. if apiErr != nil {
  269. return nil, fmt.Errorf("获取渠道密钥失败: %w", apiErr)
  270. }
  271. key = strings.TrimSpace(key)
  272. headers, err := buildFetchModelsHeaders(channel, key)
  273. if err != nil {
  274. return nil, err
  275. }
  276. body, err := GetResponseBody(http.MethodGet, url, channel, headers)
  277. if err != nil {
  278. return nil, err
  279. }
  280. var result OpenAIModelsResponse
  281. if err := common.Unmarshal(body, &result); err != nil {
  282. return nil, err
  283. }
  284. ids := lo.Map(result.Data, func(item OpenAIModel, _ int) string {
  285. if channel.Type == constant.ChannelTypeGemini {
  286. return strings.TrimPrefix(item.ID, "models/")
  287. }
  288. return item.ID
  289. })
  290. return normalizeModelNames(ids), nil
  291. }
  292. func updateChannelUpstreamModelSettings(channel *model.Channel, settings dto.ChannelOtherSettings, updateModels bool) error {
  293. channel.SetOtherSettings(settings)
  294. updates := map[string]interface{}{
  295. "settings": channel.OtherSettings,
  296. }
  297. if updateModels {
  298. updates["models"] = channel.Models
  299. }
  300. return model.DB.Model(&model.Channel{}).Where("id = ?", channel.Id).Updates(updates).Error
  301. }
  302. func checkAndPersistChannelUpstreamModelUpdates(
  303. channel *model.Channel,
  304. settings *dto.ChannelOtherSettings,
  305. force bool,
  306. allowAutoApply bool,
  307. ) (modelsChanged bool, autoAdded int, err error) {
  308. now := common.GetTimestamp()
  309. if !force {
  310. minInterval := getUpstreamModelUpdateMinCheckIntervalSeconds()
  311. if settings.UpstreamModelUpdateLastCheckTime > 0 &&
  312. now-settings.UpstreamModelUpdateLastCheckTime < minInterval {
  313. return false, 0, nil
  314. }
  315. }
  316. pendingAddModels, pendingRemoveModels, fetchErr := collectPendingUpstreamModelChanges(channel, *settings)
  317. settings.UpstreamModelUpdateLastCheckTime = now
  318. if fetchErr != nil {
  319. if err = updateChannelUpstreamModelSettings(channel, *settings, false); err != nil {
  320. return false, 0, err
  321. }
  322. return false, 0, fetchErr
  323. }
  324. if allowAutoApply && settings.UpstreamModelUpdateAutoSyncEnabled && len(pendingAddModels) > 0 {
  325. originModels := normalizeModelNames(channel.GetModels())
  326. mergedModels := mergeModelNames(originModels, pendingAddModels)
  327. if len(mergedModels) > len(originModels) {
  328. channel.Models = strings.Join(mergedModels, ",")
  329. autoAdded = len(mergedModels) - len(originModels)
  330. modelsChanged = true
  331. }
  332. settings.UpstreamModelUpdateLastDetectedModels = []string{}
  333. } else {
  334. settings.UpstreamModelUpdateLastDetectedModels = pendingAddModels
  335. }
  336. settings.UpstreamModelUpdateLastRemovedModels = pendingRemoveModels
  337. if err = updateChannelUpstreamModelSettings(channel, *settings, modelsChanged); err != nil {
  338. return false, autoAdded, err
  339. }
  340. if modelsChanged {
  341. if err = channel.UpdateAbilities(nil); err != nil {
  342. return true, autoAdded, err
  343. }
  344. }
  345. return modelsChanged, autoAdded, nil
  346. }
  347. func refreshChannelRuntimeCache() {
  348. if common.MemoryCacheEnabled {
  349. func() {
  350. defer func() {
  351. if r := recover(); r != nil {
  352. common.SysLog(fmt.Sprintf("InitChannelCache panic: %v", r))
  353. }
  354. }()
  355. model.InitChannelCache()
  356. }()
  357. }
  358. service.ResetProxyClientCache()
  359. }
  360. func shouldSendUpstreamModelUpdateNotification(now int64, changedChannels int, failedChannels int) bool {
  361. if changedChannels <= 0 && failedChannels <= 0 {
  362. return true
  363. }
  364. channelUpstreamModelUpdateNotifyState.Lock()
  365. defer channelUpstreamModelUpdateNotifyState.Unlock()
  366. if channelUpstreamModelUpdateNotifyState.lastNotifiedAt > 0 &&
  367. now-channelUpstreamModelUpdateNotifyState.lastNotifiedAt < channelUpstreamModelUpdateNotifySuppressWindowSeconds &&
  368. channelUpstreamModelUpdateNotifyState.lastChangedChannels == changedChannels &&
  369. channelUpstreamModelUpdateNotifyState.lastFailedChannels == failedChannels {
  370. return false
  371. }
  372. channelUpstreamModelUpdateNotifyState.lastNotifiedAt = now
  373. channelUpstreamModelUpdateNotifyState.lastChangedChannels = changedChannels
  374. channelUpstreamModelUpdateNotifyState.lastFailedChannels = failedChannels
  375. return true
  376. }
  377. func buildUpstreamModelUpdateTaskNotificationContent(
  378. checkedChannels int,
  379. changedChannels int,
  380. detectedAddModels int,
  381. detectedRemoveModels int,
  382. autoAddedModels int,
  383. failedChannelIDs []int,
  384. channelSummaries []upstreamModelUpdateChannelSummary,
  385. addModelSamples []string,
  386. removeModelSamples []string,
  387. ) string {
  388. var builder strings.Builder
  389. failedChannels := len(failedChannelIDs)
  390. builder.WriteString(fmt.Sprintf(
  391. "上游模型巡检摘要:检测渠道 %d 个,发现变更 %d 个,新增 %d 个,删除 %d 个,自动同步新增 %d 个,失败 %d 个。",
  392. checkedChannels,
  393. changedChannels,
  394. detectedAddModels,
  395. detectedRemoveModels,
  396. autoAddedModels,
  397. failedChannels,
  398. ))
  399. if len(channelSummaries) > 0 {
  400. displayCount := min(len(channelSummaries), channelUpstreamModelUpdateNotifyMaxChannelDetails)
  401. builder.WriteString(fmt.Sprintf("\n\n变更渠道明细(展示 %d/%d):", displayCount, len(channelSummaries)))
  402. for _, summary := range channelSummaries[:displayCount] {
  403. builder.WriteString(fmt.Sprintf("\n- %s (+%d / -%d)", summary.ChannelName, summary.AddCount, summary.RemoveCount))
  404. }
  405. if len(channelSummaries) > displayCount {
  406. builder.WriteString(fmt.Sprintf("\n- 其余 %d 个渠道已省略", len(channelSummaries)-displayCount))
  407. }
  408. }
  409. normalizedAddModelSamples := normalizeModelNames(addModelSamples)
  410. if len(normalizedAddModelSamples) > 0 {
  411. displayCount := min(len(normalizedAddModelSamples), channelUpstreamModelUpdateNotifyMaxModelDetails)
  412. builder.WriteString(fmt.Sprintf("\n\n新增模型示例(展示 %d/%d):%s",
  413. displayCount,
  414. len(normalizedAddModelSamples),
  415. strings.Join(normalizedAddModelSamples[:displayCount], ", "),
  416. ))
  417. if len(normalizedAddModelSamples) > displayCount {
  418. builder.WriteString(fmt.Sprintf("(其余 %d 个已省略)", len(normalizedAddModelSamples)-displayCount))
  419. }
  420. }
  421. normalizedRemoveModelSamples := normalizeModelNames(removeModelSamples)
  422. if len(normalizedRemoveModelSamples) > 0 {
  423. displayCount := min(len(normalizedRemoveModelSamples), channelUpstreamModelUpdateNotifyMaxModelDetails)
  424. builder.WriteString(fmt.Sprintf("\n\n删除模型示例(展示 %d/%d):%s",
  425. displayCount,
  426. len(normalizedRemoveModelSamples),
  427. strings.Join(normalizedRemoveModelSamples[:displayCount], ", "),
  428. ))
  429. if len(normalizedRemoveModelSamples) > displayCount {
  430. builder.WriteString(fmt.Sprintf("(其余 %d 个已省略)", len(normalizedRemoveModelSamples)-displayCount))
  431. }
  432. }
  433. if failedChannels > 0 {
  434. displayCount := min(failedChannels, channelUpstreamModelUpdateNotifyMaxFailedChannelIDs)
  435. displayIDs := lo.Map(failedChannelIDs[:displayCount], func(channelID int, _ int) string {
  436. return fmt.Sprintf("%d", channelID)
  437. })
  438. builder.WriteString(fmt.Sprintf(
  439. "\n\n失败渠道 ID(展示 %d/%d):%s",
  440. displayCount,
  441. failedChannels,
  442. strings.Join(displayIDs, ", "),
  443. ))
  444. if failedChannels > displayCount {
  445. builder.WriteString(fmt.Sprintf("(其余 %d 个已省略)", failedChannels-displayCount))
  446. }
  447. }
  448. return builder.String()
  449. }
  450. func runChannelUpstreamModelUpdateTaskOnce() {
  451. if !channelUpstreamModelUpdateTaskRunning.CompareAndSwap(false, true) {
  452. return
  453. }
  454. defer channelUpstreamModelUpdateTaskRunning.Store(false)
  455. checkedChannels := 0
  456. failedChannels := 0
  457. failedChannelIDs := make([]int, 0)
  458. changedChannels := 0
  459. detectedAddModels := 0
  460. detectedRemoveModels := 0
  461. autoAddedModels := 0
  462. channelSummaries := make([]upstreamModelUpdateChannelSummary, 0)
  463. addModelSamples := make([]string, 0)
  464. removeModelSamples := make([]string, 0)
  465. refreshNeeded := false
  466. lastID := 0
  467. for {
  468. var channels []*model.Channel
  469. query := model.DB.
  470. Select("id", "name", "type", "key", "status", "base_url", "models", "settings", "setting", "other", "group", "priority", "weight", "tag", "channel_info", "header_override").
  471. Where("status = ?", common.ChannelStatusEnabled).
  472. Order("id asc").
  473. Limit(channelUpstreamModelUpdateTaskBatchSize)
  474. if lastID > 0 {
  475. query = query.Where("id > ?", lastID)
  476. }
  477. err := query.Find(&channels).Error
  478. if err != nil {
  479. common.SysLog(fmt.Sprintf("upstream model update task query failed: %v", err))
  480. break
  481. }
  482. if len(channels) == 0 {
  483. break
  484. }
  485. lastID = channels[len(channels)-1].Id
  486. for _, channel := range channels {
  487. if channel == nil {
  488. continue
  489. }
  490. settings := channel.GetOtherSettings()
  491. if !settings.UpstreamModelUpdateCheckEnabled {
  492. continue
  493. }
  494. checkedChannels++
  495. modelsChanged, autoAdded, err := checkAndPersistChannelUpstreamModelUpdates(channel, &settings, false, true)
  496. if err != nil {
  497. failedChannels++
  498. failedChannelIDs = append(failedChannelIDs, channel.Id)
  499. common.SysLog(fmt.Sprintf("upstream model update check failed: channel_id=%d channel_name=%s err=%v", channel.Id, channel.Name, err))
  500. continue
  501. }
  502. currentAddModels := normalizeModelNames(settings.UpstreamModelUpdateLastDetectedModels)
  503. currentRemoveModels := normalizeModelNames(settings.UpstreamModelUpdateLastRemovedModels)
  504. currentAddCount := len(currentAddModels) + autoAdded
  505. currentRemoveCount := len(currentRemoveModels)
  506. detectedAddModels += currentAddCount
  507. detectedRemoveModels += currentRemoveCount
  508. if currentAddCount > 0 || currentRemoveCount > 0 {
  509. changedChannels++
  510. channelSummaries = append(channelSummaries, upstreamModelUpdateChannelSummary{
  511. ChannelName: channel.Name,
  512. AddCount: currentAddCount,
  513. RemoveCount: currentRemoveCount,
  514. })
  515. }
  516. addModelSamples = mergeModelNames(addModelSamples, currentAddModels)
  517. removeModelSamples = mergeModelNames(removeModelSamples, currentRemoveModels)
  518. if modelsChanged {
  519. refreshNeeded = true
  520. }
  521. autoAddedModels += autoAdded
  522. if common.RequestInterval > 0 {
  523. time.Sleep(common.RequestInterval)
  524. }
  525. }
  526. if len(channels) < channelUpstreamModelUpdateTaskBatchSize {
  527. break
  528. }
  529. }
  530. if refreshNeeded {
  531. refreshChannelRuntimeCache()
  532. }
  533. if checkedChannels > 0 || common.DebugEnabled {
  534. common.SysLog(fmt.Sprintf(
  535. "upstream model update task done: checked_channels=%d changed_channels=%d detected_add_models=%d detected_remove_models=%d failed_channels=%d auto_added_models=%d",
  536. checkedChannels,
  537. changedChannels,
  538. detectedAddModels,
  539. detectedRemoveModels,
  540. failedChannels,
  541. autoAddedModels,
  542. ))
  543. }
  544. if changedChannels > 0 || failedChannels > 0 {
  545. now := common.GetTimestamp()
  546. if !shouldSendUpstreamModelUpdateNotification(now, changedChannels, failedChannels) {
  547. common.SysLog(fmt.Sprintf(
  548. "upstream model update notification skipped in 24h window: changed_channels=%d failed_channels=%d",
  549. changedChannels,
  550. failedChannels,
  551. ))
  552. return
  553. }
  554. service.NotifyUpstreamModelUpdateWatchers(
  555. "上游模型巡检通知",
  556. buildUpstreamModelUpdateTaskNotificationContent(
  557. checkedChannels,
  558. changedChannels,
  559. detectedAddModels,
  560. detectedRemoveModels,
  561. autoAddedModels,
  562. failedChannelIDs,
  563. channelSummaries,
  564. addModelSamples,
  565. removeModelSamples,
  566. ),
  567. )
  568. }
  569. }
  570. func StartChannelUpstreamModelUpdateTask() {
  571. channelUpstreamModelUpdateTaskOnce.Do(func() {
  572. if !common.IsMasterNode {
  573. return
  574. }
  575. if !common.GetEnvOrDefaultBool("CHANNEL_UPSTREAM_MODEL_UPDATE_TASK_ENABLED", true) {
  576. common.SysLog("upstream model update task disabled by CHANNEL_UPSTREAM_MODEL_UPDATE_TASK_ENABLED")
  577. return
  578. }
  579. intervalMinutes := common.GetEnvOrDefault(
  580. "CHANNEL_UPSTREAM_MODEL_UPDATE_TASK_INTERVAL_MINUTES",
  581. channelUpstreamModelUpdateTaskDefaultIntervalMinutes,
  582. )
  583. if intervalMinutes < 1 {
  584. intervalMinutes = channelUpstreamModelUpdateTaskDefaultIntervalMinutes
  585. }
  586. interval := time.Duration(intervalMinutes) * time.Minute
  587. go func() {
  588. common.SysLog(fmt.Sprintf("upstream model update task started: interval=%s", interval))
  589. runChannelUpstreamModelUpdateTaskOnce()
  590. ticker := time.NewTicker(interval)
  591. defer ticker.Stop()
  592. for range ticker.C {
  593. runChannelUpstreamModelUpdateTaskOnce()
  594. }
  595. }()
  596. })
  597. }
  598. func ApplyChannelUpstreamModelUpdates(c *gin.Context) {
  599. var req applyChannelUpstreamModelUpdatesRequest
  600. if err := c.ShouldBindJSON(&req); err != nil {
  601. common.ApiError(c, err)
  602. return
  603. }
  604. if req.ID <= 0 {
  605. c.JSON(http.StatusOK, gin.H{
  606. "success": false,
  607. "message": "invalid channel id",
  608. })
  609. return
  610. }
  611. channel, err := model.GetChannelById(req.ID, true)
  612. if err != nil {
  613. common.ApiError(c, err)
  614. return
  615. }
  616. beforeSettings := channel.GetOtherSettings()
  617. ignoredModels := intersectModelNames(req.IgnoreModels, beforeSettings.UpstreamModelUpdateLastDetectedModels)
  618. addedModels, removedModels, remainingModels, remainingRemoveModels, modelsChanged, err := applyChannelUpstreamModelUpdates(
  619. channel,
  620. req.AddModels,
  621. req.IgnoreModels,
  622. req.RemoveModels,
  623. )
  624. if err != nil {
  625. common.ApiError(c, err)
  626. return
  627. }
  628. if modelsChanged {
  629. refreshChannelRuntimeCache()
  630. }
  631. c.JSON(http.StatusOK, gin.H{
  632. "success": true,
  633. "message": "",
  634. "data": gin.H{
  635. "id": channel.Id,
  636. "added_models": addedModels,
  637. "removed_models": removedModels,
  638. "ignored_models": ignoredModels,
  639. "remaining_models": remainingModels,
  640. "remaining_remove_models": remainingRemoveModels,
  641. "models": channel.Models,
  642. "settings": channel.OtherSettings,
  643. },
  644. })
  645. }
  646. func DetectChannelUpstreamModelUpdates(c *gin.Context) {
  647. var req applyChannelUpstreamModelUpdatesRequest
  648. if err := c.ShouldBindJSON(&req); err != nil {
  649. common.ApiError(c, err)
  650. return
  651. }
  652. if req.ID <= 0 {
  653. c.JSON(http.StatusOK, gin.H{
  654. "success": false,
  655. "message": "invalid channel id",
  656. })
  657. return
  658. }
  659. channel, err := model.GetChannelById(req.ID, true)
  660. if err != nil {
  661. common.ApiError(c, err)
  662. return
  663. }
  664. settings := channel.GetOtherSettings()
  665. modelsChanged, autoAdded, err := checkAndPersistChannelUpstreamModelUpdates(channel, &settings, true, false)
  666. if err != nil {
  667. common.ApiError(c, err)
  668. return
  669. }
  670. if modelsChanged {
  671. refreshChannelRuntimeCache()
  672. }
  673. c.JSON(http.StatusOK, gin.H{
  674. "success": true,
  675. "message": "",
  676. "data": detectChannelUpstreamModelUpdatesResult{
  677. ChannelID: channel.Id,
  678. ChannelName: channel.Name,
  679. AddModels: normalizeModelNames(settings.UpstreamModelUpdateLastDetectedModels),
  680. RemoveModels: normalizeModelNames(settings.UpstreamModelUpdateLastRemovedModels),
  681. LastCheckTime: settings.UpstreamModelUpdateLastCheckTime,
  682. AutoAddedModels: autoAdded,
  683. },
  684. })
  685. }
  686. func applyChannelUpstreamModelUpdates(
  687. channel *model.Channel,
  688. addModelsInput []string,
  689. ignoreModelsInput []string,
  690. removeModelsInput []string,
  691. ) (
  692. addedModels []string,
  693. removedModels []string,
  694. remainingModels []string,
  695. remainingRemoveModels []string,
  696. modelsChanged bool,
  697. err error,
  698. ) {
  699. settings := channel.GetOtherSettings()
  700. pendingAddModels := normalizeModelNames(settings.UpstreamModelUpdateLastDetectedModels)
  701. pendingRemoveModels := normalizeModelNames(settings.UpstreamModelUpdateLastRemovedModels)
  702. addModels := intersectModelNames(addModelsInput, pendingAddModels)
  703. ignoreModels := intersectModelNames(ignoreModelsInput, pendingAddModels)
  704. removeModels := intersectModelNames(removeModelsInput, pendingRemoveModels)
  705. removeModels = subtractModelNames(removeModels, addModels)
  706. originModels := normalizeModelNames(channel.GetModels())
  707. nextModels := applySelectedModelChanges(originModels, addModels, removeModels)
  708. modelsChanged = !slices.Equal(originModels, nextModels)
  709. if modelsChanged {
  710. channel.Models = strings.Join(nextModels, ",")
  711. }
  712. settings.UpstreamModelUpdateIgnoredModels = mergeModelNames(settings.UpstreamModelUpdateIgnoredModels, ignoreModels)
  713. if len(addModels) > 0 {
  714. settings.UpstreamModelUpdateIgnoredModels = subtractModelNames(settings.UpstreamModelUpdateIgnoredModels, addModels)
  715. }
  716. remainingModels = subtractModelNames(pendingAddModels, append(addModels, ignoreModels...))
  717. remainingRemoveModels = subtractModelNames(pendingRemoveModels, removeModels)
  718. settings.UpstreamModelUpdateLastDetectedModels = remainingModels
  719. settings.UpstreamModelUpdateLastRemovedModels = remainingRemoveModels
  720. settings.UpstreamModelUpdateLastCheckTime = common.GetTimestamp()
  721. if err := updateChannelUpstreamModelSettings(channel, settings, modelsChanged); err != nil {
  722. return nil, nil, nil, nil, false, err
  723. }
  724. if modelsChanged {
  725. if err := channel.UpdateAbilities(nil); err != nil {
  726. return addModels, removeModels, remainingModels, remainingRemoveModels, true, err
  727. }
  728. }
  729. return addModels, removeModels, remainingModels, remainingRemoveModels, modelsChanged, nil
  730. }
  731. func collectPendingApplyUpstreamModelChanges(settings dto.ChannelOtherSettings) (pendingAddModels []string, pendingRemoveModels []string) {
  732. return normalizeModelNames(settings.UpstreamModelUpdateLastDetectedModels), normalizeModelNames(settings.UpstreamModelUpdateLastRemovedModels)
  733. }
  734. func findEnabledChannelsAfterID(lastID int, batchSize int) ([]*model.Channel, error) {
  735. var channels []*model.Channel
  736. query := model.DB.
  737. Select("id", "name", "type", "key", "status", "base_url", "models", "settings", "setting", "other", "group", "priority", "weight", "tag", "channel_info", "header_override").
  738. Where("status = ?", common.ChannelStatusEnabled).
  739. Order("id asc").
  740. Limit(batchSize)
  741. if lastID > 0 {
  742. query = query.Where("id > ?", lastID)
  743. }
  744. return channels, query.Find(&channels).Error
  745. }
  746. func ApplyAllChannelUpstreamModelUpdates(c *gin.Context) {
  747. results := make([]applyAllChannelUpstreamModelUpdatesResult, 0)
  748. failed := make([]int, 0)
  749. refreshNeeded := false
  750. addedModelCount := 0
  751. removedModelCount := 0
  752. lastID := 0
  753. for {
  754. channels, err := findEnabledChannelsAfterID(lastID, channelUpstreamModelUpdateTaskBatchSize)
  755. if err != nil {
  756. common.ApiError(c, err)
  757. return
  758. }
  759. if len(channels) == 0 {
  760. break
  761. }
  762. lastID = channels[len(channels)-1].Id
  763. for _, channel := range channels {
  764. if channel == nil {
  765. continue
  766. }
  767. settings := channel.GetOtherSettings()
  768. if !settings.UpstreamModelUpdateCheckEnabled {
  769. continue
  770. }
  771. pendingAddModels, pendingRemoveModels := collectPendingApplyUpstreamModelChanges(settings)
  772. if len(pendingAddModels) == 0 && len(pendingRemoveModels) == 0 {
  773. continue
  774. }
  775. addedModels, removedModels, remainingModels, remainingRemoveModels, modelsChanged, err := applyChannelUpstreamModelUpdates(
  776. channel,
  777. pendingAddModels,
  778. nil,
  779. pendingRemoveModels,
  780. )
  781. if err != nil {
  782. failed = append(failed, channel.Id)
  783. continue
  784. }
  785. if modelsChanged {
  786. refreshNeeded = true
  787. }
  788. addedModelCount += len(addedModels)
  789. removedModelCount += len(removedModels)
  790. results = append(results, applyAllChannelUpstreamModelUpdatesResult{
  791. ChannelID: channel.Id,
  792. ChannelName: channel.Name,
  793. AddedModels: addedModels,
  794. RemovedModels: removedModels,
  795. RemainingModels: remainingModels,
  796. RemainingRemoveModels: remainingRemoveModels,
  797. })
  798. }
  799. if len(channels) < channelUpstreamModelUpdateTaskBatchSize {
  800. break
  801. }
  802. }
  803. if refreshNeeded {
  804. refreshChannelRuntimeCache()
  805. }
  806. c.JSON(http.StatusOK, gin.H{
  807. "success": true,
  808. "message": "",
  809. "data": gin.H{
  810. "processed_channels": len(results),
  811. "added_models": addedModelCount,
  812. "removed_models": removedModelCount,
  813. "failed_channel_ids": failed,
  814. "results": results,
  815. },
  816. })
  817. }
  818. func DetectAllChannelUpstreamModelUpdates(c *gin.Context) {
  819. results := make([]detectChannelUpstreamModelUpdatesResult, 0)
  820. failed := make([]int, 0)
  821. detectedAddCount := 0
  822. detectedRemoveCount := 0
  823. refreshNeeded := false
  824. lastID := 0
  825. for {
  826. channels, err := findEnabledChannelsAfterID(lastID, channelUpstreamModelUpdateTaskBatchSize)
  827. if err != nil {
  828. common.ApiError(c, err)
  829. return
  830. }
  831. if len(channels) == 0 {
  832. break
  833. }
  834. lastID = channels[len(channels)-1].Id
  835. for _, channel := range channels {
  836. if channel == nil {
  837. continue
  838. }
  839. settings := channel.GetOtherSettings()
  840. if !settings.UpstreamModelUpdateCheckEnabled {
  841. continue
  842. }
  843. modelsChanged, autoAdded, err := checkAndPersistChannelUpstreamModelUpdates(channel, &settings, true, false)
  844. if err != nil {
  845. failed = append(failed, channel.Id)
  846. continue
  847. }
  848. if modelsChanged {
  849. refreshNeeded = true
  850. }
  851. addModels := normalizeModelNames(settings.UpstreamModelUpdateLastDetectedModels)
  852. removeModels := normalizeModelNames(settings.UpstreamModelUpdateLastRemovedModels)
  853. detectedAddCount += len(addModels)
  854. detectedRemoveCount += len(removeModels)
  855. results = append(results, detectChannelUpstreamModelUpdatesResult{
  856. ChannelID: channel.Id,
  857. ChannelName: channel.Name,
  858. AddModels: addModels,
  859. RemoveModels: removeModels,
  860. LastCheckTime: settings.UpstreamModelUpdateLastCheckTime,
  861. AutoAddedModels: autoAdded,
  862. })
  863. }
  864. if len(channels) < channelUpstreamModelUpdateTaskBatchSize {
  865. break
  866. }
  867. }
  868. if refreshNeeded {
  869. refreshChannelRuntimeCache()
  870. }
  871. c.JSON(http.StatusOK, gin.H{
  872. "success": true,
  873. "message": "",
  874. "data": gin.H{
  875. "processed_channels": len(results),
  876. "failed_channel_ids": failed,
  877. "detected_add_models": detectedAddCount,
  878. "detected_remove_models": detectedRemoveCount,
  879. "channel_detected_results": results,
  880. },
  881. })
  882. }