ability.go 9.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341
  1. package model
  2. import (
  3. "errors"
  4. "fmt"
  5. "one-api/common"
  6. "one-api/logger"
  7. "strings"
  8. "sync"
  9. "github.com/samber/lo"
  10. "gorm.io/gorm"
  11. "gorm.io/gorm/clause"
  12. )
  13. type Ability struct {
  14. Group string `json:"group" gorm:"type:varchar(64);primaryKey;autoIncrement:false"`
  15. Model string `json:"model" gorm:"type:varchar(255);primaryKey;autoIncrement:false"`
  16. ChannelId int `json:"channel_id" gorm:"primaryKey;autoIncrement:false;index"`
  17. Enabled bool `json:"enabled"`
  18. Priority *int64 `json:"priority" gorm:"bigint;default:0;index"`
  19. Weight uint `json:"weight" gorm:"default:0;index"`
  20. Tag *string `json:"tag" gorm:"index"`
  21. }
  22. type AbilityWithChannel struct {
  23. Ability
  24. ChannelType int `json:"channel_type"`
  25. }
  26. func GetAllEnableAbilityWithChannels() ([]AbilityWithChannel, error) {
  27. var abilities []AbilityWithChannel
  28. err := DB.Table("abilities").
  29. Select("abilities.*, channels.type as channel_type").
  30. Joins("left join channels on abilities.channel_id = channels.id").
  31. Where("abilities.enabled = ?", true).
  32. Scan(&abilities).Error
  33. return abilities, err
  34. }
  35. func GetGroupEnabledModels(group string) []string {
  36. var models []string
  37. // Find distinct models
  38. DB.Table("abilities").Where(commonGroupCol+" = ? and enabled = ?", group, true).Distinct("model").Pluck("model", &models)
  39. return models
  40. }
  41. func GetEnabledModels() []string {
  42. var models []string
  43. // Find distinct models
  44. DB.Table("abilities").Where("enabled = ?", true).Distinct("model").Pluck("model", &models)
  45. return models
  46. }
  47. func GetAllEnableAbilities() []Ability {
  48. var abilities []Ability
  49. DB.Find(&abilities, "enabled = ?", true)
  50. return abilities
  51. }
  52. func getPriority(group string, model string, retry int) (int, error) {
  53. var priorities []int
  54. err := DB.Model(&Ability{}).
  55. Select("DISTINCT(priority)").
  56. Where(commonGroupCol+" = ? and model = ? and enabled = ?", group, model, true).
  57. Order("priority DESC"). // 按优先级降序排序
  58. Pluck("priority", &priorities).Error // Pluck用于将查询的结果直接扫描到一个切片中
  59. if err != nil {
  60. // 处理错误
  61. return 0, err
  62. }
  63. if len(priorities) == 0 {
  64. // 如果没有查询到优先级,则返回错误
  65. return 0, errors.New("数据库一致性被破坏")
  66. }
  67. // 确定要使用的优先级
  68. var priorityToUse int
  69. if retry >= len(priorities) {
  70. // 如果重试次数大于优先级数,则使用最小的优先级
  71. priorityToUse = priorities[len(priorities)-1]
  72. } else {
  73. priorityToUse = priorities[retry]
  74. }
  75. return priorityToUse, nil
  76. }
  77. func getChannelQuery(group string, model string, retry int) (*gorm.DB, error) {
  78. maxPrioritySubQuery := DB.Model(&Ability{}).Select("MAX(priority)").Where(commonGroupCol+" = ? and model = ? and enabled = ?", group, model, true)
  79. channelQuery := DB.Where(commonGroupCol+" = ? and model = ? and enabled = ? and priority = (?)", group, model, true, maxPrioritySubQuery)
  80. if retry != 0 {
  81. priority, err := getPriority(group, model, retry)
  82. if err != nil {
  83. return nil, err
  84. } else {
  85. channelQuery = DB.Where(commonGroupCol+" = ? and model = ? and enabled = ? and priority = ?", group, model, true, priority)
  86. }
  87. }
  88. return channelQuery, nil
  89. }
  90. func GetRandomSatisfiedChannel(group string, model string, retry int) (*Channel, error) {
  91. var abilities []Ability
  92. var err error = nil
  93. channelQuery, err := getChannelQuery(group, model, retry)
  94. if err != nil {
  95. return nil, err
  96. }
  97. if common.UsingSQLite || common.UsingPostgreSQL {
  98. err = channelQuery.Order("weight DESC").Find(&abilities).Error
  99. } else {
  100. err = channelQuery.Order("weight DESC").Find(&abilities).Error
  101. }
  102. if err != nil {
  103. return nil, err
  104. }
  105. channel := Channel{}
  106. if len(abilities) > 0 {
  107. // Randomly choose one
  108. weightSum := uint(0)
  109. for _, ability_ := range abilities {
  110. weightSum += ability_.Weight + 10
  111. }
  112. // Randomly choose one
  113. weight := common.GetRandomInt(int(weightSum))
  114. for _, ability_ := range abilities {
  115. weight -= int(ability_.Weight) + 10
  116. //log.Printf("weight: %d, ability weight: %d", weight, *ability_.Weight)
  117. if weight <= 0 {
  118. channel.Id = ability_.ChannelId
  119. break
  120. }
  121. }
  122. } else {
  123. return nil, nil
  124. }
  125. err = DB.First(&channel, "id = ?", channel.Id).Error
  126. return &channel, err
  127. }
  128. func (channel *Channel) AddAbilities(tx *gorm.DB) error {
  129. models_ := strings.Split(channel.Models, ",")
  130. groups_ := strings.Split(channel.Group, ",")
  131. abilitySet := make(map[string]struct{})
  132. abilities := make([]Ability, 0, len(models_))
  133. for _, model := range models_ {
  134. for _, group := range groups_ {
  135. key := group + "|" + model
  136. if _, exists := abilitySet[key]; exists {
  137. continue
  138. }
  139. abilitySet[key] = struct{}{}
  140. ability := Ability{
  141. Group: group,
  142. Model: model,
  143. ChannelId: channel.Id,
  144. Enabled: channel.Status == common.ChannelStatusEnabled,
  145. Priority: channel.Priority,
  146. Weight: uint(channel.GetWeight()),
  147. Tag: channel.Tag,
  148. }
  149. abilities = append(abilities, ability)
  150. }
  151. }
  152. if len(abilities) == 0 {
  153. return nil
  154. }
  155. // choose DB or provided tx
  156. useDB := DB
  157. if tx != nil {
  158. useDB = tx
  159. }
  160. for _, chunk := range lo.Chunk(abilities, 50) {
  161. err := useDB.Clauses(clause.OnConflict{DoNothing: true}).Create(&chunk).Error
  162. if err != nil {
  163. return err
  164. }
  165. }
  166. return nil
  167. }
  168. func (channel *Channel) DeleteAbilities() error {
  169. return DB.Where("channel_id = ?", channel.Id).Delete(&Ability{}).Error
  170. }
  171. // UpdateAbilities updates abilities of this channel.
  172. // Make sure the channel is completed before calling this function.
  173. func (channel *Channel) UpdateAbilities(tx *gorm.DB) error {
  174. isNewTx := false
  175. // 如果没有传入事务,创建新的事务
  176. if tx == nil {
  177. tx = DB.Begin()
  178. if tx.Error != nil {
  179. return tx.Error
  180. }
  181. isNewTx = true
  182. defer func() {
  183. if r := recover(); r != nil {
  184. tx.Rollback()
  185. }
  186. }()
  187. }
  188. // First delete all abilities of this channel
  189. err := tx.Where("channel_id = ?", channel.Id).Delete(&Ability{}).Error
  190. if err != nil {
  191. if isNewTx {
  192. tx.Rollback()
  193. }
  194. return err
  195. }
  196. // Then add new abilities
  197. models_ := strings.Split(channel.Models, ",")
  198. groups_ := strings.Split(channel.Group, ",")
  199. abilitySet := make(map[string]struct{})
  200. abilities := make([]Ability, 0, len(models_))
  201. for _, model := range models_ {
  202. for _, group := range groups_ {
  203. key := group + "|" + model
  204. if _, exists := abilitySet[key]; exists {
  205. continue
  206. }
  207. abilitySet[key] = struct{}{}
  208. ability := Ability{
  209. Group: group,
  210. Model: model,
  211. ChannelId: channel.Id,
  212. Enabled: channel.Status == common.ChannelStatusEnabled,
  213. Priority: channel.Priority,
  214. Weight: uint(channel.GetWeight()),
  215. Tag: channel.Tag,
  216. }
  217. abilities = append(abilities, ability)
  218. }
  219. }
  220. if len(abilities) > 0 {
  221. for _, chunk := range lo.Chunk(abilities, 50) {
  222. err = tx.Clauses(clause.OnConflict{DoNothing: true}).Create(&chunk).Error
  223. if err != nil {
  224. if isNewTx {
  225. tx.Rollback()
  226. }
  227. return err
  228. }
  229. }
  230. }
  231. // 如果是新创建的事务,需要提交
  232. if isNewTx {
  233. return tx.Commit().Error
  234. }
  235. return nil
  236. }
  237. func UpdateAbilityStatus(channelId int, status bool) error {
  238. return DB.Model(&Ability{}).Where("channel_id = ?", channelId).Select("enabled").Update("enabled", status).Error
  239. }
  240. func UpdateAbilityStatusByTag(tag string, status bool) error {
  241. return DB.Model(&Ability{}).Where("tag = ?", tag).Select("enabled").Update("enabled", status).Error
  242. }
  243. func UpdateAbilityByTag(tag string, newTag *string, priority *int64, weight *uint) error {
  244. ability := Ability{}
  245. if newTag != nil {
  246. ability.Tag = newTag
  247. }
  248. if priority != nil {
  249. ability.Priority = priority
  250. }
  251. if weight != nil {
  252. ability.Weight = *weight
  253. }
  254. return DB.Model(&Ability{}).Where("tag = ?", tag).Updates(ability).Error
  255. }
  256. var fixLock = sync.Mutex{}
  257. func FixAbility() (int, int, error) {
  258. lock := fixLock.TryLock()
  259. if !lock {
  260. return 0, 0, errors.New("已经有一个修复任务在运行中,请稍后再试")
  261. }
  262. defer fixLock.Unlock()
  263. // truncate abilities table
  264. if common.UsingSQLite {
  265. err := DB.Exec("DELETE FROM abilities").Error
  266. if err != nil {
  267. logger.SysError(fmt.Sprintf("Delete abilities failed: %s", err.Error()))
  268. return 0, 0, err
  269. }
  270. } else {
  271. err := DB.Exec("TRUNCATE TABLE abilities").Error
  272. if err != nil {
  273. logger.SysError(fmt.Sprintf("Truncate abilities failed: %s", err.Error()))
  274. return 0, 0, err
  275. }
  276. }
  277. var channels []*Channel
  278. // Find all channels
  279. err := DB.Model(&Channel{}).Find(&channels).Error
  280. if err != nil {
  281. return 0, 0, err
  282. }
  283. if len(channels) == 0 {
  284. return 0, 0, nil
  285. }
  286. successCount := 0
  287. failCount := 0
  288. for _, chunk := range lo.Chunk(channels, 50) {
  289. ids := lo.Map(chunk, func(c *Channel, _ int) int { return c.Id })
  290. // Delete all abilities of this channel
  291. err = DB.Where("channel_id IN ?", ids).Delete(&Ability{}).Error
  292. if err != nil {
  293. logger.SysError(fmt.Sprintf("Delete abilities failed: %s", err.Error()))
  294. failCount += len(chunk)
  295. continue
  296. }
  297. // Then add new abilities
  298. for _, channel := range chunk {
  299. err = channel.AddAbilities(nil)
  300. if err != nil {
  301. logger.SysError(fmt.Sprintf("Add abilities for channel %d failed: %s", channel.Id, err.Error()))
  302. failCount++
  303. } else {
  304. successCount++
  305. }
  306. }
  307. }
  308. InitChannelCache()
  309. return successCount, failCount, nil
  310. }