channel.go 31 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060
  1. package model
  2. import (
  3. "database/sql/driver"
  4. "encoding/json"
  5. "errors"
  6. "fmt"
  7. "math/rand"
  8. "strings"
  9. "sync"
  10. "github.com/QuantumNous/new-api/common"
  11. "github.com/QuantumNous/new-api/constant"
  12. "github.com/QuantumNous/new-api/dto"
  13. "github.com/QuantumNous/new-api/types"
  14. "github.com/samber/lo"
  15. "gorm.io/gorm"
  16. "gorm.io/gorm/clause"
  17. )
  18. type Channel struct {
  19. Id int `json:"id"`
  20. Type int `json:"type" gorm:"default:0"`
  21. Key string `json:"key" gorm:"not null"`
  22. OpenAIOrganization *string `json:"openai_organization"`
  23. TestModel *string `json:"test_model"`
  24. Status int `json:"status" gorm:"default:1"`
  25. Name string `json:"name" gorm:"index"`
  26. Weight *uint `json:"weight" gorm:"default:0"`
  27. CreatedTime int64 `json:"created_time" gorm:"bigint"`
  28. TestTime int64 `json:"test_time" gorm:"bigint"`
  29. ResponseTime int `json:"response_time"` // in milliseconds
  30. BaseURL *string `json:"base_url" gorm:"column:base_url;default:''"`
  31. Other string `json:"other"`
  32. Balance float64 `json:"balance"` // in USD
  33. BalanceUpdatedTime int64 `json:"balance_updated_time" gorm:"bigint"`
  34. Models string `json:"models"`
  35. Group string `json:"group" gorm:"type:varchar(64);default:'default'"`
  36. UsedQuota int64 `json:"used_quota" gorm:"bigint;default:0"`
  37. ModelMapping *string `json:"model_mapping" gorm:"type:text"`
  38. //MaxInputTokens *int `json:"max_input_tokens" gorm:"default:0"`
  39. StatusCodeMapping *string `json:"status_code_mapping" gorm:"type:varchar(1024);default:''"`
  40. Priority *int64 `json:"priority" gorm:"bigint;default:0"`
  41. AutoBan *int `json:"auto_ban" gorm:"default:1"`
  42. OtherInfo string `json:"other_info"`
  43. Tag *string `json:"tag" gorm:"index"`
  44. Setting *string `json:"setting" gorm:"type:text"` // 渠道额外设置
  45. ParamOverride *string `json:"param_override" gorm:"type:text"`
  46. HeaderOverride *string `json:"header_override" gorm:"type:text"`
  47. Remark *string `json:"remark" gorm:"type:varchar(255)" validate:"max=255"`
  48. // add after v0.8.5
  49. ChannelInfo ChannelInfo `json:"channel_info" gorm:"type:json"`
  50. OtherSettings string `json:"settings" gorm:"column:settings"` // 其他设置,存储azure版本等不需要检索的信息,详见dto.ChannelOtherSettings
  51. // cache info
  52. Keys []string `json:"-" gorm:"-"`
  53. }
  54. type ChannelInfo struct {
  55. IsMultiKey bool `json:"is_multi_key"` // 是否多Key模式
  56. MultiKeySize int `json:"multi_key_size"` // 多Key模式下的Key数量
  57. MultiKeyStatusList map[int]int `json:"multi_key_status_list"` // key状态列表,key index -> status
  58. MultiKeyDisabledReason map[int]string `json:"multi_key_disabled_reason,omitempty"` // key禁用原因列表,key index -> reason
  59. MultiKeyDisabledTime map[int]int64 `json:"multi_key_disabled_time,omitempty"` // key禁用时间列表,key index -> time
  60. MultiKeyPollingIndex int `json:"multi_key_polling_index"` // 多Key模式下轮询的key索引
  61. MultiKeyMode constant.MultiKeyMode `json:"multi_key_mode"`
  62. }
  63. type ChannelSortOptions struct {
  64. SortBy string
  65. SortOrder string
  66. IDSort bool
  67. }
  68. var channelSortColumns = map[string]string{
  69. "id": "id",
  70. "name": "name",
  71. "priority": "priority",
  72. "balance": "balance",
  73. "response_time": "response_time",
  74. "test_time": "test_time",
  75. }
  76. func NewChannelSortOptions(sortBy string, sortOrder string, idSort bool) ChannelSortOptions {
  77. normalizedSortBy := strings.ToLower(strings.TrimSpace(sortBy))
  78. normalizedSortOrder := strings.ToLower(strings.TrimSpace(sortOrder))
  79. if _, ok := channelSortColumns[normalizedSortBy]; !ok {
  80. normalizedSortBy = ""
  81. normalizedSortOrder = ""
  82. } else if normalizedSortOrder != "asc" {
  83. normalizedSortOrder = "desc"
  84. }
  85. return ChannelSortOptions{
  86. SortBy: normalizedSortBy,
  87. SortOrder: normalizedSortOrder,
  88. IDSort: idSort,
  89. }
  90. }
  91. func (options ChannelSortOptions) Apply(query *gorm.DB) *gorm.DB {
  92. if columnName, ok := channelSortColumns[options.SortBy]; ok {
  93. return query.Order(clause.OrderByColumn{
  94. Column: clause.Column{Name: columnName},
  95. Desc: options.SortOrder != "asc",
  96. })
  97. }
  98. if options.IDSort {
  99. return query.Order(clause.OrderByColumn{
  100. Column: clause.Column{Name: "id"},
  101. Desc: true,
  102. })
  103. }
  104. return query.Order(clause.OrderByColumn{
  105. Column: clause.Column{Name: "priority"},
  106. Desc: true,
  107. })
  108. }
  109. func resolveChannelSortOptions(idSort bool, sortOptions []ChannelSortOptions) ChannelSortOptions {
  110. if len(sortOptions) == 0 {
  111. return NewChannelSortOptions("", "", idSort)
  112. }
  113. options := sortOptions[0]
  114. options.IDSort = options.IDSort || idSort
  115. return options
  116. }
  117. // Value implements driver.Valuer interface
  118. func (c ChannelInfo) Value() (driver.Value, error) {
  119. return common.Marshal(&c)
  120. }
  121. // Scan implements sql.Scanner interface
  122. func (c *ChannelInfo) Scan(value interface{}) error {
  123. bytesValue, _ := value.([]byte)
  124. return common.Unmarshal(bytesValue, c)
  125. }
  126. func (channel *Channel) GetKeys() []string {
  127. if channel.Key == "" {
  128. return []string{}
  129. }
  130. if len(channel.Keys) > 0 {
  131. return channel.Keys
  132. }
  133. trimmed := strings.TrimSpace(channel.Key)
  134. // If the key starts with '[', try to parse it as a JSON array (e.g., for Vertex AI scenarios)
  135. if strings.HasPrefix(trimmed, "[") {
  136. var arr []json.RawMessage
  137. if err := common.Unmarshal([]byte(trimmed), &arr); err == nil {
  138. res := make([]string, len(arr))
  139. for i, v := range arr {
  140. res[i] = string(v)
  141. }
  142. return res
  143. }
  144. }
  145. // Otherwise, fall back to splitting by newline
  146. keys := strings.Split(strings.Trim(channel.Key, "\n"), "\n")
  147. return keys
  148. }
  149. func (channel *Channel) GetNextEnabledKey() (string, int, *types.NewAPIError) {
  150. // If not in multi-key mode, return the original key string directly.
  151. if !channel.ChannelInfo.IsMultiKey {
  152. return channel.Key, 0, nil
  153. }
  154. // Obtain all keys (split by \n)
  155. keys := channel.GetKeys()
  156. if len(keys) == 0 {
  157. // No keys available, return error, should disable the channel
  158. return "", 0, types.NewError(errors.New("no keys available"), types.ErrorCodeChannelNoAvailableKey)
  159. }
  160. lock := GetChannelPollingLock(channel.Id)
  161. lock.Lock()
  162. defer lock.Unlock()
  163. statusList := channel.ChannelInfo.MultiKeyStatusList
  164. // helper to get key status, default to enabled when missing
  165. getStatus := func(idx int) int {
  166. if statusList == nil {
  167. return common.ChannelStatusEnabled
  168. }
  169. if status, ok := statusList[idx]; ok {
  170. return status
  171. }
  172. return common.ChannelStatusEnabled
  173. }
  174. // Collect indexes of enabled keys
  175. enabledIdx := make([]int, 0, len(keys))
  176. for i := range keys {
  177. if getStatus(i) == common.ChannelStatusEnabled {
  178. enabledIdx = append(enabledIdx, i)
  179. }
  180. }
  181. // If no specific status list or none enabled, return an explicit error so caller can
  182. // properly handle a channel with no available keys (e.g. mark channel disabled).
  183. // Returning the first key here caused requests to keep using an already-disabled key.
  184. if len(enabledIdx) == 0 {
  185. return "", 0, types.NewError(errors.New("no enabled keys"), types.ErrorCodeChannelNoAvailableKey)
  186. }
  187. switch channel.ChannelInfo.MultiKeyMode {
  188. case constant.MultiKeyModeRandom:
  189. // Randomly pick one enabled key
  190. selectedIdx := enabledIdx[rand.Intn(len(enabledIdx))]
  191. return keys[selectedIdx], selectedIdx, nil
  192. case constant.MultiKeyModePolling:
  193. // Use channel-specific lock to ensure thread-safe polling
  194. channelInfo, err := CacheGetChannelInfo(channel.Id)
  195. if err != nil {
  196. return "", 0, types.NewError(err, types.ErrorCodeGetChannelFailed, types.ErrOptionWithSkipRetry())
  197. }
  198. //println("before polling index:", channel.ChannelInfo.MultiKeyPollingIndex)
  199. defer func() {
  200. if common.DebugEnabled {
  201. println(fmt.Sprintf("channel %d polling index: %d", channel.Id, channel.ChannelInfo.MultiKeyPollingIndex))
  202. }
  203. if !common.MemoryCacheEnabled {
  204. _ = channel.SaveChannelInfo()
  205. } else {
  206. // CacheUpdateChannel(channel)
  207. }
  208. }()
  209. // Start from the saved polling index and look for the next enabled key
  210. start := channelInfo.MultiKeyPollingIndex
  211. if start < 0 || start >= len(keys) {
  212. start = 0
  213. }
  214. for i := 0; i < len(keys); i++ {
  215. idx := (start + i) % len(keys)
  216. if getStatus(idx) == common.ChannelStatusEnabled {
  217. // update polling index for next call (point to the next position)
  218. channel.ChannelInfo.MultiKeyPollingIndex = (idx + 1) % len(keys)
  219. return keys[idx], idx, nil
  220. }
  221. }
  222. // Fallback – should not happen, but return first enabled key
  223. return keys[enabledIdx[0]], enabledIdx[0], nil
  224. default:
  225. // Unknown mode, default to first enabled key (or original key string)
  226. return keys[enabledIdx[0]], enabledIdx[0], nil
  227. }
  228. }
  229. func (channel *Channel) SaveChannelInfo() error {
  230. return DB.Model(channel).Update("channel_info", channel.ChannelInfo).Error
  231. }
  232. func (channel *Channel) GetModels() []string {
  233. if channel.Models == "" {
  234. return []string{}
  235. }
  236. return strings.Split(strings.Trim(channel.Models, ","), ",")
  237. }
  238. func (channel *Channel) GetGroups() []string {
  239. if channel.Group == "" {
  240. return []string{}
  241. }
  242. groups := strings.Split(strings.Trim(channel.Group, ","), ",")
  243. for i, group := range groups {
  244. groups[i] = strings.TrimSpace(group)
  245. }
  246. return groups
  247. }
  248. func (channel *Channel) GetOtherInfo() map[string]interface{} {
  249. otherInfo := make(map[string]interface{})
  250. if channel.OtherInfo != "" {
  251. err := common.Unmarshal([]byte(channel.OtherInfo), &otherInfo)
  252. if err != nil {
  253. common.SysLog(fmt.Sprintf("failed to unmarshal other info: channel_id=%d, tag=%s, name=%s, error=%v", channel.Id, channel.GetTag(), channel.Name, err))
  254. }
  255. }
  256. return otherInfo
  257. }
  258. func (channel *Channel) SetOtherInfo(otherInfo map[string]interface{}) {
  259. otherInfoBytes, err := json.Marshal(otherInfo)
  260. if err != nil {
  261. common.SysLog(fmt.Sprintf("failed to marshal other info: channel_id=%d, tag=%s, name=%s, error=%v", channel.Id, channel.GetTag(), channel.Name, err))
  262. return
  263. }
  264. channel.OtherInfo = string(otherInfoBytes)
  265. }
  266. func (channel *Channel) GetTag() string {
  267. if channel.Tag == nil {
  268. return ""
  269. }
  270. return *channel.Tag
  271. }
  272. func (channel *Channel) SetTag(tag string) {
  273. channel.Tag = &tag
  274. }
  275. func (channel *Channel) GetAutoBan() bool {
  276. if channel.AutoBan == nil {
  277. return false
  278. }
  279. return *channel.AutoBan == 1
  280. }
  281. func (channel *Channel) Save() error {
  282. return DB.Save(channel).Error
  283. }
  284. func (channel *Channel) SaveWithoutKey() error {
  285. if channel.Id == 0 {
  286. return errors.New("channel ID is 0")
  287. }
  288. return DB.Omit("key").Save(channel).Error
  289. }
  290. func GetAllChannels(startIdx int, num int, selectAll bool, idSort bool, sortOptions ...ChannelSortOptions) ([]*Channel, error) {
  291. var channels []*Channel
  292. var err error
  293. order := resolveChannelSortOptions(idSort, sortOptions)
  294. if selectAll {
  295. err = order.Apply(DB).Find(&channels).Error
  296. } else {
  297. err = order.Apply(DB).Limit(num).Offset(startIdx).Omit("key").Find(&channels).Error
  298. }
  299. return channels, err
  300. }
  301. func GetChannelsByTag(tag string, idSort bool, selectAll bool, sortOptions ...ChannelSortOptions) ([]*Channel, error) {
  302. var channels []*Channel
  303. order := resolveChannelSortOptions(idSort, sortOptions)
  304. query := order.Apply(DB.Where("tag = ?", tag))
  305. if !selectAll {
  306. query = query.Omit("key")
  307. }
  308. err := query.Find(&channels).Error
  309. return channels, err
  310. }
  311. func SearchChannels(keyword string, group string, model string, idSort bool, sortOptions ...ChannelSortOptions) ([]*Channel, error) {
  312. var channels []*Channel
  313. modelsCol := "`models`"
  314. // 如果是 PostgreSQL,使用双引号
  315. if common.UsingPostgreSQL {
  316. modelsCol = `"models"`
  317. }
  318. baseURLCol := "`base_url`"
  319. // 如果是 PostgreSQL,使用双引号
  320. if common.UsingPostgreSQL {
  321. baseURLCol = `"base_url"`
  322. }
  323. order := resolveChannelSortOptions(idSort, sortOptions)
  324. // 构造基础查询
  325. baseQuery := DB.Model(&Channel{}).Omit("key")
  326. // 构造WHERE子句
  327. var whereClause string
  328. var args []interface{}
  329. if group != "" && group != "null" {
  330. var groupCondition string
  331. if common.UsingMySQL {
  332. groupCondition = `CONCAT(',', ` + commonGroupCol + `, ',') LIKE ?`
  333. } else {
  334. // sqlite, PostgreSQL
  335. groupCondition = `(',' || ` + commonGroupCol + ` || ',') LIKE ?`
  336. }
  337. whereClause = "(id = ? OR name LIKE ? OR " + commonKeyCol + " = ? OR " + baseURLCol + " LIKE ?) AND " + modelsCol + ` LIKE ? AND ` + groupCondition
  338. args = append(args, common.String2Int(keyword), "%"+keyword+"%", keyword, "%"+keyword+"%", "%"+model+"%", "%,"+group+",%")
  339. } else {
  340. whereClause = "(id = ? OR name LIKE ? OR " + commonKeyCol + " = ? OR " + baseURLCol + " LIKE ?) AND " + modelsCol + " LIKE ?"
  341. args = append(args, common.String2Int(keyword), "%"+keyword+"%", keyword, "%"+keyword+"%", "%"+model+"%")
  342. }
  343. // 执行查询
  344. err := order.Apply(baseQuery.Where(whereClause, args...)).Find(&channels).Error
  345. if err != nil {
  346. return nil, err
  347. }
  348. return channels, nil
  349. }
  350. func GetChannelById(id int, selectAll bool) (*Channel, error) {
  351. channel := &Channel{Id: id}
  352. var err error = nil
  353. if selectAll {
  354. err = DB.First(channel, "id = ?", id).Error
  355. } else {
  356. err = DB.Omit("key").First(channel, "id = ?", id).Error
  357. }
  358. if err != nil {
  359. return nil, err
  360. }
  361. if channel == nil {
  362. return nil, errors.New("channel not found")
  363. }
  364. return channel, nil
  365. }
  366. func BatchInsertChannels(channels []Channel) error {
  367. if len(channels) == 0 {
  368. return nil
  369. }
  370. tx := DB.Begin()
  371. if tx.Error != nil {
  372. return tx.Error
  373. }
  374. defer func() {
  375. if r := recover(); r != nil {
  376. tx.Rollback()
  377. }
  378. }()
  379. for _, chunk := range lo.Chunk(channels, 50) {
  380. if err := tx.Create(&chunk).Error; err != nil {
  381. tx.Rollback()
  382. return err
  383. }
  384. for _, channel_ := range chunk {
  385. if err := channel_.AddAbilities(tx); err != nil {
  386. tx.Rollback()
  387. return err
  388. }
  389. }
  390. }
  391. return tx.Commit().Error
  392. }
  393. func BatchDeleteChannels(ids []int) error {
  394. if len(ids) == 0 {
  395. return nil
  396. }
  397. // 使用事务 分批删除channel表和abilities表
  398. tx := DB.Begin()
  399. if tx.Error != nil {
  400. return tx.Error
  401. }
  402. for _, chunk := range lo.Chunk(ids, 200) {
  403. if err := tx.Where("id in (?)", chunk).Delete(&Channel{}).Error; err != nil {
  404. tx.Rollback()
  405. return err
  406. }
  407. if err := tx.Where("channel_id in (?)", chunk).Delete(&Ability{}).Error; err != nil {
  408. tx.Rollback()
  409. return err
  410. }
  411. }
  412. return tx.Commit().Error
  413. }
  414. func (channel *Channel) GetPriority() int64 {
  415. if channel.Priority == nil {
  416. return 0
  417. }
  418. return *channel.Priority
  419. }
  420. func (channel *Channel) GetWeight() int {
  421. if channel.Weight == nil {
  422. return 0
  423. }
  424. return int(*channel.Weight)
  425. }
  426. func (channel *Channel) GetBaseURL() string {
  427. if channel.BaseURL == nil {
  428. return ""
  429. }
  430. url := *channel.BaseURL
  431. if url == "" {
  432. url = constant.ChannelBaseURLs[channel.Type]
  433. }
  434. return url
  435. }
  436. func (channel *Channel) GetModelMapping() string {
  437. if channel.ModelMapping == nil {
  438. return ""
  439. }
  440. return *channel.ModelMapping
  441. }
  442. func (channel *Channel) GetStatusCodeMapping() string {
  443. if channel.StatusCodeMapping == nil {
  444. return ""
  445. }
  446. return *channel.StatusCodeMapping
  447. }
  448. func (channel *Channel) Insert() error {
  449. var err error
  450. err = DB.Create(channel).Error
  451. if err != nil {
  452. return err
  453. }
  454. err = channel.AddAbilities(nil)
  455. return err
  456. }
  457. func (channel *Channel) Update() error {
  458. // If this is a multi-key channel, recalculate MultiKeySize based on the current key list to avoid inconsistency after editing keys
  459. if channel.ChannelInfo.IsMultiKey {
  460. var keyStr string
  461. if channel.Key != "" {
  462. keyStr = channel.Key
  463. } else {
  464. // If key is not provided, read the existing key from the database
  465. if existing, err := GetChannelById(channel.Id, true); err == nil {
  466. keyStr = existing.Key
  467. }
  468. }
  469. // Parse the key list (supports newline separation or JSON array)
  470. keys := []string{}
  471. if keyStr != "" {
  472. trimmed := strings.TrimSpace(keyStr)
  473. if strings.HasPrefix(trimmed, "[") {
  474. var arr []json.RawMessage
  475. if err := common.Unmarshal([]byte(trimmed), &arr); err == nil {
  476. keys = make([]string, len(arr))
  477. for i, v := range arr {
  478. keys[i] = string(v)
  479. }
  480. }
  481. }
  482. if len(keys) == 0 { // fallback to newline split
  483. keys = strings.Split(strings.Trim(keyStr, "\n"), "\n")
  484. }
  485. }
  486. channel.ChannelInfo.MultiKeySize = len(keys)
  487. // Clean up status data that exceeds the new key count to prevent index out of range
  488. if channel.ChannelInfo.MultiKeyStatusList != nil {
  489. for idx := range channel.ChannelInfo.MultiKeyStatusList {
  490. if idx >= channel.ChannelInfo.MultiKeySize {
  491. delete(channel.ChannelInfo.MultiKeyStatusList, idx)
  492. }
  493. }
  494. }
  495. }
  496. var err error
  497. err = DB.Model(channel).Updates(channel).Error
  498. if err != nil {
  499. return err
  500. }
  501. DB.Model(channel).First(channel, "id = ?", channel.Id)
  502. err = channel.UpdateAbilities(nil)
  503. return err
  504. }
  505. func (channel *Channel) UpdateResponseTime(responseTime int64) {
  506. err := DB.Model(channel).Select("response_time", "test_time").Updates(Channel{
  507. TestTime: common.GetTimestamp(),
  508. ResponseTime: int(responseTime),
  509. }).Error
  510. if err != nil {
  511. common.SysLog(fmt.Sprintf("failed to update response time: channel_id=%d, error=%v", channel.Id, err))
  512. }
  513. }
  514. func (channel *Channel) UpdateBalance(balance float64) {
  515. err := DB.Model(channel).Select("balance_updated_time", "balance").Updates(Channel{
  516. BalanceUpdatedTime: common.GetTimestamp(),
  517. Balance: balance,
  518. }).Error
  519. if err != nil {
  520. common.SysLog(fmt.Sprintf("failed to update balance: channel_id=%d, error=%v", channel.Id, err))
  521. }
  522. }
  523. func (channel *Channel) Delete() error {
  524. var err error
  525. err = DB.Delete(channel).Error
  526. if err != nil {
  527. return err
  528. }
  529. err = channel.DeleteAbilities()
  530. return err
  531. }
  532. var channelStatusLock sync.Mutex
  533. // channelPollingLocks stores locks for each channel.id to ensure thread-safe polling
  534. var channelPollingLocks sync.Map
  535. // GetChannelPollingLock returns or creates a mutex for the given channel ID
  536. func GetChannelPollingLock(channelId int) *sync.Mutex {
  537. if lock, exists := channelPollingLocks.Load(channelId); exists {
  538. return lock.(*sync.Mutex)
  539. }
  540. // Create new lock for this channel
  541. newLock := &sync.Mutex{}
  542. actual, _ := channelPollingLocks.LoadOrStore(channelId, newLock)
  543. return actual.(*sync.Mutex)
  544. }
  545. // CleanupChannelPollingLocks removes locks for channels that no longer exist
  546. // This is optional and can be called periodically to prevent memory leaks
  547. func CleanupChannelPollingLocks() {
  548. var activeChannelIds []int
  549. DB.Model(&Channel{}).Pluck("id", &activeChannelIds)
  550. activeChannelSet := make(map[int]bool)
  551. for _, id := range activeChannelIds {
  552. activeChannelSet[id] = true
  553. }
  554. channelPollingLocks.Range(func(key, value interface{}) bool {
  555. channelId := key.(int)
  556. if !activeChannelSet[channelId] {
  557. channelPollingLocks.Delete(channelId)
  558. }
  559. return true
  560. })
  561. }
  562. func handlerMultiKeyUpdate(channel *Channel, usingKey string, status int, reason string) {
  563. keys := channel.GetKeys()
  564. if len(keys) == 0 {
  565. channel.Status = status
  566. } else {
  567. var keyIndex int
  568. for i, key := range keys {
  569. if key == usingKey {
  570. keyIndex = i
  571. break
  572. }
  573. }
  574. if channel.ChannelInfo.MultiKeyStatusList == nil {
  575. channel.ChannelInfo.MultiKeyStatusList = make(map[int]int)
  576. }
  577. if status == common.ChannelStatusEnabled {
  578. delete(channel.ChannelInfo.MultiKeyStatusList, keyIndex)
  579. } else {
  580. channel.ChannelInfo.MultiKeyStatusList[keyIndex] = status
  581. if channel.ChannelInfo.MultiKeyDisabledReason == nil {
  582. channel.ChannelInfo.MultiKeyDisabledReason = make(map[int]string)
  583. }
  584. if channel.ChannelInfo.MultiKeyDisabledTime == nil {
  585. channel.ChannelInfo.MultiKeyDisabledTime = make(map[int]int64)
  586. }
  587. channel.ChannelInfo.MultiKeyDisabledReason[keyIndex] = reason
  588. channel.ChannelInfo.MultiKeyDisabledTime[keyIndex] = common.GetTimestamp()
  589. }
  590. if len(channel.ChannelInfo.MultiKeyStatusList) >= channel.ChannelInfo.MultiKeySize {
  591. channel.Status = common.ChannelStatusAutoDisabled
  592. info := channel.GetOtherInfo()
  593. info["status_reason"] = "All keys are disabled"
  594. info["status_time"] = common.GetTimestamp()
  595. channel.SetOtherInfo(info)
  596. }
  597. }
  598. }
  599. func UpdateChannelStatus(channelId int, usingKey string, status int, reason string) bool {
  600. if common.MemoryCacheEnabled {
  601. channelStatusLock.Lock()
  602. defer channelStatusLock.Unlock()
  603. channelCache, _ := CacheGetChannel(channelId)
  604. if channelCache == nil {
  605. return false
  606. }
  607. if channelCache.ChannelInfo.IsMultiKey {
  608. // Use per-channel lock to prevent concurrent map read/write with GetNextEnabledKey
  609. pollingLock := GetChannelPollingLock(channelId)
  610. pollingLock.Lock()
  611. // 如果是多Key模式,更新缓存中的状态
  612. handlerMultiKeyUpdate(channelCache, usingKey, status, reason)
  613. pollingLock.Unlock()
  614. //CacheUpdateChannel(channelCache)
  615. //return true
  616. } else {
  617. // 如果缓存渠道存在,且状态已是目标状态,直接返回
  618. if channelCache.Status == status {
  619. return false
  620. }
  621. CacheUpdateChannelStatus(channelId, status)
  622. }
  623. }
  624. shouldUpdateAbilities := false
  625. defer func() {
  626. if shouldUpdateAbilities {
  627. err := UpdateAbilityStatus(channelId, status == common.ChannelStatusEnabled)
  628. if err != nil {
  629. common.SysLog(fmt.Sprintf("failed to update ability status: channel_id=%d, error=%v", channelId, err))
  630. }
  631. }
  632. }()
  633. channel, err := GetChannelById(channelId, true)
  634. if err != nil {
  635. return false
  636. } else {
  637. if channel.Status == status {
  638. return false
  639. }
  640. if channel.ChannelInfo.IsMultiKey {
  641. beforeStatus := channel.Status
  642. // Protect map writes with the same per-channel lock used by readers
  643. pollingLock := GetChannelPollingLock(channelId)
  644. pollingLock.Lock()
  645. handlerMultiKeyUpdate(channel, usingKey, status, reason)
  646. pollingLock.Unlock()
  647. if beforeStatus != channel.Status {
  648. shouldUpdateAbilities = true
  649. }
  650. } else {
  651. info := channel.GetOtherInfo()
  652. info["status_reason"] = reason
  653. info["status_time"] = common.GetTimestamp()
  654. channel.SetOtherInfo(info)
  655. channel.Status = status
  656. shouldUpdateAbilities = true
  657. }
  658. err = channel.SaveWithoutKey()
  659. if err != nil {
  660. common.SysLog(fmt.Sprintf("failed to update channel status: channel_id=%d, status=%d, error=%v", channel.Id, status, err))
  661. return false
  662. }
  663. }
  664. return true
  665. }
  666. func EnableChannelByTag(tag string) error {
  667. err := DB.Model(&Channel{}).Where("tag = ?", tag).Update("status", common.ChannelStatusEnabled).Error
  668. if err != nil {
  669. return err
  670. }
  671. err = UpdateAbilityStatusByTag(tag, true)
  672. return err
  673. }
  674. func DisableChannelByTag(tag string) error {
  675. err := DB.Model(&Channel{}).Where("tag = ?", tag).Update("status", common.ChannelStatusManuallyDisabled).Error
  676. if err != nil {
  677. return err
  678. }
  679. err = UpdateAbilityStatusByTag(tag, false)
  680. return err
  681. }
  682. func EditChannelByTag(tag string, newTag *string, modelMapping *string, models *string, group *string, priority *int64, weight *uint, paramOverride *string, headerOverride *string) error {
  683. updateData := Channel{}
  684. shouldReCreateAbilities := false
  685. updatedTag := tag
  686. // 如果 newTag 不为空且不等于 tag,则更新 tag
  687. if newTag != nil && *newTag != tag {
  688. updateData.Tag = newTag
  689. updatedTag = *newTag
  690. }
  691. if modelMapping != nil && *modelMapping != "" {
  692. updateData.ModelMapping = modelMapping
  693. }
  694. if models != nil && *models != "" {
  695. shouldReCreateAbilities = true
  696. updateData.Models = *models
  697. }
  698. if group != nil && *group != "" {
  699. shouldReCreateAbilities = true
  700. updateData.Group = *group
  701. }
  702. if priority != nil {
  703. updateData.Priority = priority
  704. }
  705. if weight != nil {
  706. updateData.Weight = weight
  707. }
  708. if paramOverride != nil {
  709. updateData.ParamOverride = paramOverride
  710. }
  711. if headerOverride != nil {
  712. updateData.HeaderOverride = headerOverride
  713. }
  714. err := DB.Model(&Channel{}).Where("tag = ?", tag).Updates(updateData).Error
  715. if err != nil {
  716. return err
  717. }
  718. if shouldReCreateAbilities {
  719. channels, err := GetChannelsByTag(updatedTag, false, false)
  720. if err == nil {
  721. for _, channel := range channels {
  722. err = channel.UpdateAbilities(nil)
  723. if err != nil {
  724. common.SysLog(fmt.Sprintf("failed to update abilities: channel_id=%d, tag=%s, error=%v", channel.Id, channel.GetTag(), err))
  725. }
  726. }
  727. }
  728. } else {
  729. err := UpdateAbilityByTag(tag, newTag, priority, weight)
  730. if err != nil {
  731. return err
  732. }
  733. }
  734. return nil
  735. }
  736. func UpdateChannelUsedQuota(id int, quota int) {
  737. if common.BatchUpdateEnabled {
  738. addNewRecord(BatchUpdateTypeChannelUsedQuota, id, quota)
  739. return
  740. }
  741. updateChannelUsedQuota(id, quota)
  742. }
  743. func updateChannelUsedQuota(id int, quota int) {
  744. err := DB.Model(&Channel{}).Where("id = ?", id).Update("used_quota", gorm.Expr("used_quota + ?", quota)).Error
  745. if err != nil {
  746. common.SysLog(fmt.Sprintf("failed to update channel used quota: channel_id=%d, delta_quota=%d, error=%v", id, quota, err))
  747. }
  748. }
  749. func DeleteChannelByStatus(status int64) (int64, error) {
  750. result := DB.Where("status = ?", status).Delete(&Channel{})
  751. return result.RowsAffected, result.Error
  752. }
  753. func DeleteDisabledChannel() (int64, error) {
  754. result := DB.Where("status = ? or status = ?", common.ChannelStatusAutoDisabled, common.ChannelStatusManuallyDisabled).Delete(&Channel{})
  755. return result.RowsAffected, result.Error
  756. }
  757. func GetPaginatedTags(offset int, limit int) ([]*string, error) {
  758. var tags []*string
  759. err := DB.Model(&Channel{}).Select("DISTINCT tag").Where("tag != ''").Offset(offset).Limit(limit).Find(&tags).Error
  760. return tags, err
  761. }
  762. func SearchTags(keyword string, group string, model string, idSort bool) ([]*string, error) {
  763. var tags []*string
  764. modelsCol := "`models`"
  765. // 如果是 PostgreSQL,使用双引号
  766. if common.UsingPostgreSQL {
  767. modelsCol = `"models"`
  768. }
  769. baseURLCol := "`base_url`"
  770. // 如果是 PostgreSQL,使用双引号
  771. if common.UsingPostgreSQL {
  772. baseURLCol = `"base_url"`
  773. }
  774. order := "priority desc"
  775. if idSort {
  776. order = "id desc"
  777. }
  778. // 构造基础查询
  779. baseQuery := DB.Model(&Channel{}).Omit("key")
  780. // 构造WHERE子句
  781. var whereClause string
  782. var args []interface{}
  783. if group != "" && group != "null" {
  784. var groupCondition string
  785. if common.UsingMySQL {
  786. groupCondition = `CONCAT(',', ` + commonGroupCol + `, ',') LIKE ?`
  787. } else {
  788. // sqlite, PostgreSQL
  789. groupCondition = `(',' || ` + commonGroupCol + ` || ',') LIKE ?`
  790. }
  791. whereClause = "(id = ? OR name LIKE ? OR " + commonKeyCol + " = ? OR " + baseURLCol + " LIKE ?) AND " + modelsCol + ` LIKE ? AND ` + groupCondition
  792. args = append(args, common.String2Int(keyword), "%"+keyword+"%", keyword, "%"+keyword+"%", "%"+model+"%", "%,"+group+",%")
  793. } else {
  794. whereClause = "(id = ? OR name LIKE ? OR " + commonKeyCol + " = ? OR " + baseURLCol + " LIKE ?) AND " + modelsCol + " LIKE ?"
  795. args = append(args, common.String2Int(keyword), "%"+keyword+"%", keyword, "%"+keyword+"%", "%"+model+"%")
  796. }
  797. subQuery := baseQuery.Where(whereClause, args...).
  798. Select("tag").
  799. Where("tag != ''").
  800. Order(order)
  801. err := DB.Table("(?) as sub", subQuery).
  802. Select("DISTINCT tag").
  803. Find(&tags).Error
  804. if err != nil {
  805. return nil, err
  806. }
  807. return tags, nil
  808. }
  809. func (channel *Channel) ValidateSettings() error {
  810. channelParams := &dto.ChannelSettings{}
  811. if channel.Setting != nil && *channel.Setting != "" {
  812. err := common.Unmarshal([]byte(*channel.Setting), channelParams)
  813. if err != nil {
  814. return err
  815. }
  816. }
  817. return nil
  818. }
  819. func (channel *Channel) GetSetting() dto.ChannelSettings {
  820. setting := dto.ChannelSettings{}
  821. if channel.Setting != nil && *channel.Setting != "" {
  822. err := common.Unmarshal([]byte(*channel.Setting), &setting)
  823. if err != nil {
  824. common.SysLog(fmt.Sprintf("failed to unmarshal setting: channel_id=%d, error=%v", channel.Id, err))
  825. channel.Setting = nil // 清空设置以避免后续错误
  826. _ = channel.Save() // 保存修改
  827. }
  828. }
  829. return setting
  830. }
  831. func (channel *Channel) SetSetting(setting dto.ChannelSettings) {
  832. settingBytes, err := common.Marshal(setting)
  833. if err != nil {
  834. common.SysLog(fmt.Sprintf("failed to marshal setting: channel_id=%d, error=%v", channel.Id, err))
  835. return
  836. }
  837. channel.Setting = common.GetPointer[string](string(settingBytes))
  838. }
  839. func (channel *Channel) GetOtherSettings() dto.ChannelOtherSettings {
  840. setting := dto.ChannelOtherSettings{}
  841. if channel.OtherSettings != "" {
  842. err := common.UnmarshalJsonStr(channel.OtherSettings, &setting)
  843. if err != nil {
  844. common.SysLog(fmt.Sprintf("failed to unmarshal setting: channel_id=%d, error=%v", channel.Id, err))
  845. channel.OtherSettings = "{}" // 清空设置以避免后续错误
  846. _ = channel.Save() // 保存修改
  847. }
  848. }
  849. return setting
  850. }
  851. func (channel *Channel) SetOtherSettings(setting dto.ChannelOtherSettings) {
  852. settingBytes, err := common.Marshal(setting)
  853. if err != nil {
  854. common.SysLog(fmt.Sprintf("failed to marshal setting: channel_id=%d, error=%v", channel.Id, err))
  855. return
  856. }
  857. channel.OtherSettings = string(settingBytes)
  858. }
  859. func (channel *Channel) GetParamOverride() map[string]interface{} {
  860. paramOverride := make(map[string]interface{})
  861. if channel.ParamOverride != nil && *channel.ParamOverride != "" {
  862. err := common.Unmarshal([]byte(*channel.ParamOverride), &paramOverride)
  863. if err != nil {
  864. common.SysLog(fmt.Sprintf("failed to unmarshal param override: channel_id=%d, error=%v", channel.Id, err))
  865. }
  866. }
  867. return paramOverride
  868. }
  869. func (channel *Channel) GetHeaderOverride() map[string]interface{} {
  870. headerOverride := make(map[string]interface{})
  871. if channel.HeaderOverride != nil && *channel.HeaderOverride != "" {
  872. err := common.Unmarshal([]byte(*channel.HeaderOverride), &headerOverride)
  873. if err != nil {
  874. common.SysLog(fmt.Sprintf("failed to unmarshal header override: channel_id=%d, error=%v", channel.Id, err))
  875. }
  876. }
  877. return headerOverride
  878. }
  879. func GetChannelsByIds(ids []int) ([]*Channel, error) {
  880. var channels []*Channel
  881. err := DB.Where("id in (?)", ids).Find(&channels).Error
  882. return channels, err
  883. }
  884. func BatchSetChannelTag(ids []int, tag *string) error {
  885. // 开启事务
  886. tx := DB.Begin()
  887. if tx.Error != nil {
  888. return tx.Error
  889. }
  890. // 更新标签
  891. err := tx.Model(&Channel{}).Where("id in (?)", ids).Update("tag", tag).Error
  892. if err != nil {
  893. tx.Rollback()
  894. return err
  895. }
  896. // update ability status
  897. channels, err := GetChannelsByIds(ids)
  898. if err != nil {
  899. tx.Rollback()
  900. return err
  901. }
  902. for _, channel := range channels {
  903. err = channel.UpdateAbilities(tx)
  904. if err != nil {
  905. tx.Rollback()
  906. return err
  907. }
  908. }
  909. // 提交事务
  910. return tx.Commit().Error
  911. }
  912. // CountAllChannels returns total channels in DB
  913. func CountAllChannels() (int64, error) {
  914. var total int64
  915. err := DB.Model(&Channel{}).Count(&total).Error
  916. return total, err
  917. }
  918. // CountAllTags returns number of non-empty distinct tags
  919. func CountAllTags() (int64, error) {
  920. var total int64
  921. err := DB.Model(&Channel{}).Where("tag is not null AND tag != ''").Distinct("tag").Count(&total).Error
  922. return total, err
  923. }
  924. // Get channels of specified type with pagination
  925. func GetChannelsByType(startIdx int, num int, idSort bool, channelType int) ([]*Channel, error) {
  926. var channels []*Channel
  927. order := "priority desc"
  928. if idSort {
  929. order = "id desc"
  930. }
  931. err := DB.Where("type = ?", channelType).Order(order).Limit(num).Offset(startIdx).Omit("key").Find(&channels).Error
  932. return channels, err
  933. }
  934. // Count channels of specific type
  935. func CountChannelsByType(channelType int) (int64, error) {
  936. var count int64
  937. err := DB.Model(&Channel{}).Where("type = ?", channelType).Count(&count).Error
  938. return count, err
  939. }
  940. // Return map[type]count for all channels
  941. func CountChannelsGroupByType() (map[int64]int64, error) {
  942. type result struct {
  943. Type int64 `gorm:"column:type"`
  944. Count int64 `gorm:"column:count"`
  945. }
  946. var results []result
  947. err := DB.Model(&Channel{}).Select("type, count(*) as count").Group("type").Find(&results).Error
  948. if err != nil {
  949. return nil, err
  950. }
  951. counts := make(map[int64]int64)
  952. for _, r := range results {
  953. counts[r.Type] = r.Count
  954. }
  955. return counts, nil
  956. }