| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165 |
- package model
- import (
- "errors"
- "fmt"
- "one-api/common"
- "strings"
- )
- type Ability struct {
- Group string `json:"group" gorm:"type:varchar(64);primaryKey;autoIncrement:false"`
- Model string `json:"model" gorm:"type:varchar(64);primaryKey;autoIncrement:false"`
- ChannelId int `json:"channel_id" gorm:"primaryKey;autoIncrement:false;index"`
- Enabled bool `json:"enabled"`
- Priority *int64 `json:"priority" gorm:"bigint;default:0;index"`
- Weight uint `json:"weight" gorm:"default:0;index"`
- }
- func GetGroupModels(group string) []string {
- var models []string
- // Find distinct models
- groupCol := "`group`"
- if common.UsingPostgreSQL {
- groupCol = `"group"`
- }
- DB.Table("abilities").Where(groupCol+" = ? and enabled = ?", group, true).Distinct("model").Pluck("model", &models)
- return models
- }
- func GetRandomSatisfiedChannel(group string, model string) (*Channel, error) {
- var abilities []Ability
- groupCol := "`group`"
- trueVal := "1"
- if common.UsingPostgreSQL {
- groupCol = `"group"`
- trueVal = "true"
- }
- var err error = nil
- maxPrioritySubQuery := DB.Model(&Ability{}).Select("MAX(priority)").Where(groupCol+" = ? and model = ? and enabled = "+trueVal, group, model)
- channelQuery := DB.Where(groupCol+" = ? and model = ? and enabled = "+trueVal+" and priority = (?)", group, model, maxPrioritySubQuery)
- if common.UsingSQLite || common.UsingPostgreSQL {
- err = channelQuery.Order("weight DESC").Find(&abilities).Error
- } else {
- err = channelQuery.Order("weight DESC").Find(&abilities).Error
- }
- if err != nil {
- return nil, err
- }
- channel := Channel{}
- if len(abilities) > 0 {
- // Randomly choose one
- weightSum := uint(0)
- for _, ability_ := range abilities {
- weightSum += ability_.Weight + 10
- }
- // Randomly choose one
- weight := common.GetRandomInt(int(weightSum))
- for _, ability_ := range abilities {
- weight -= int(ability_.Weight)
- //log.Printf("weight: %d, ability weight: %d", weight, *ability_.Weight)
- if weight <= 0 {
- channel.Id = ability_.ChannelId
- break
- }
- }
- } else {
- return nil, errors.New("channel not found")
- }
- err = DB.First(&channel, "id = ?", channel.Id).Error
- return &channel, err
- }
- func (channel *Channel) AddAbilities() error {
- models_ := strings.Split(channel.Models, ",")
- groups_ := strings.Split(channel.Group, ",")
- abilities := make([]Ability, 0, len(models_))
- for _, model := range models_ {
- for _, group := range groups_ {
- ability := Ability{
- Group: group,
- Model: model,
- ChannelId: channel.Id,
- Enabled: channel.Status == common.ChannelStatusEnabled,
- Priority: channel.Priority,
- Weight: uint(channel.GetWeight()),
- }
- abilities = append(abilities, ability)
- }
- }
- return DB.Create(&abilities).Error
- }
- func (channel *Channel) DeleteAbilities() error {
- return DB.Where("channel_id = ?", channel.Id).Delete(&Ability{}).Error
- }
- // UpdateAbilities updates abilities of this channel.
- // Make sure the channel is completed before calling this function.
- func (channel *Channel) UpdateAbilities() error {
- // A quick and dirty way to update abilities
- // First delete all abilities of this channel
- err := channel.DeleteAbilities()
- if err != nil {
- return err
- }
- // Then add new abilities
- err = channel.AddAbilities()
- if err != nil {
- return err
- }
- return nil
- }
- func UpdateAbilityStatus(channelId int, status bool) error {
- return DB.Model(&Ability{}).Where("channel_id = ?", channelId).Select("enabled").Update("enabled", status).Error
- }
- func FixAbility() (int, error) {
- var channelIds []int
- count := 0
- // Find all channel ids from channel table
- err := DB.Model(&Channel{}).Pluck("id", &channelIds).Error
- if err != nil {
- common.SysError(fmt.Sprintf("Get channel ids from channel table failed: %s", err.Error()))
- return 0, err
- }
- // Delete abilities of channels that are not in channel table
- err = DB.Where("channel_id NOT IN (?)", channelIds).Delete(&Ability{}).Error
- if err != nil {
- common.SysError(fmt.Sprintf("Delete abilities of channels that are not in channel table failed: %s", err.Error()))
- return 0, err
- }
- common.SysLog(fmt.Sprintf("Delete abilities of channels that are not in channel table successfully, ids: %v", channelIds))
- count += len(channelIds)
- // Use channelIds to find channel not in abilities table
- var abilityChannelIds []int
- err = DB.Model(&Ability{}).Pluck("channel_id", &abilityChannelIds).Error
- if err != nil {
- common.SysError(fmt.Sprintf("Get channel ids from abilities table failed: %s", err.Error()))
- return 0, err
- }
- var channels []Channel
- if len(abilityChannelIds) == 0 {
- err = DB.Find(&channels).Error
- } else {
- err = DB.Where("id NOT IN (?)", abilityChannelIds).Find(&channels).Error
- }
- if err != nil {
- return 0, err
- }
- for _, channel := range channels {
- err := channel.UpdateAbilities()
- if err != nil {
- common.SysError(fmt.Sprintf("Update abilities of channel %d failed: %s", channel.Id, err.Error()))
- } else {
- common.SysLog(fmt.Sprintf("Update abilities of channel %d successfully", channel.Id))
- count++
- }
- }
- InitChannelCache()
- return count, nil
- }
|