model_meta.go 7.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226
  1. package model
  2. import (
  3. "one-api/common"
  4. "strconv"
  5. "strings"
  6. "gorm.io/gorm"
  7. )
  8. // Model 用于存储模型的元数据,例如描述、标签等
  9. // ModelName 字段具有唯一性约束,确保每个模型只会出现一次
  10. // Tags 字段使用逗号分隔的字符串保存标签集合,后期可根据需要扩展为 JSON 类型
  11. // Status: 1 表示启用,0 表示禁用,保留以便后续功能扩展
  12. // CreatedTime 和 UpdatedTime 使用 Unix 时间戳(秒)保存方便跨数据库移植
  13. // DeletedAt 采用 GORM 的软删除特性,便于后续数据恢复
  14. //
  15. // 该表设计遵循第三范式(3NF):
  16. // 1. 每一列都与主键(Id 或 ModelName)直接相关
  17. // 2. 不存在部分依赖(ModelName 是唯一键)
  18. // 3. 不存在传递依赖(描述、标签等都依赖于 ModelName,而非依赖于其他非主键列)
  19. // 这样既保证了数据一致性,也方便后期扩展
  20. // 模型名称匹配规则
  21. const (
  22. NameRuleExact = iota // 0 精确匹配
  23. NameRulePrefix // 1 前缀匹配
  24. NameRuleContains // 2 包含匹配
  25. NameRuleSuffix // 3 后缀匹配
  26. )
  27. type BoundChannel struct {
  28. Name string `json:"name"`
  29. Type int `json:"type"`
  30. }
  31. type Model struct {
  32. Id int `json:"id"`
  33. ModelName string `json:"model_name" gorm:"size:128;not null;uniqueIndex:uk_model_name,priority:1"`
  34. Description string `json:"description,omitempty" gorm:"type:text"`
  35. Icon string `json:"icon,omitempty" gorm:"type:varchar(128)"`
  36. Tags string `json:"tags,omitempty" gorm:"type:varchar(255)"`
  37. VendorID int `json:"vendor_id,omitempty" gorm:"index"`
  38. Endpoints string `json:"endpoints,omitempty" gorm:"type:text"`
  39. Status int `json:"status" gorm:"default:1"`
  40. CreatedTime int64 `json:"created_time" gorm:"bigint"`
  41. UpdatedTime int64 `json:"updated_time" gorm:"bigint"`
  42. DeletedAt gorm.DeletedAt `json:"-" gorm:"index;uniqueIndex:uk_model_name,priority:2"`
  43. BoundChannels []BoundChannel `json:"bound_channels,omitempty" gorm:"-"`
  44. EnableGroups []string `json:"enable_groups,omitempty" gorm:"-"`
  45. QuotaType int `json:"quota_type" gorm:"-"`
  46. NameRule int `json:"name_rule" gorm:"default:0"`
  47. MatchedModels []string `json:"matched_models,omitempty" gorm:"-"`
  48. MatchedCount int `json:"matched_count,omitempty" gorm:"-"`
  49. }
  50. // Insert 创建新的模型元数据记录
  51. func (mi *Model) Insert() error {
  52. now := common.GetTimestamp()
  53. mi.CreatedTime = now
  54. mi.UpdatedTime = now
  55. return DB.Create(mi).Error
  56. }
  57. // IsModelNameDuplicated 检查模型名称是否重复(排除自身 ID)
  58. func IsModelNameDuplicated(id int, name string) (bool, error) {
  59. if name == "" {
  60. return false, nil
  61. }
  62. var cnt int64
  63. err := DB.Model(&Model{}).Where("model_name = ? AND id <> ?", name, id).Count(&cnt).Error
  64. return cnt > 0, err
  65. }
  66. // Update 更新现有模型记录
  67. func (mi *Model) Update() error {
  68. mi.UpdatedTime = common.GetTimestamp()
  69. // 使用 Session 配置并选择所有字段,允许零值(如空字符串)也能被更新
  70. return DB.Session(&gorm.Session{AllowGlobalUpdate: false, FullSaveAssociations: false}).
  71. Model(&Model{}).
  72. Where("id = ?", mi.Id).
  73. Omit("created_time").
  74. Select("*").
  75. Updates(mi).Error
  76. }
  77. // Delete 软删除模型记录
  78. func (mi *Model) Delete() error {
  79. return DB.Delete(mi).Error
  80. }
  81. // GetModelByName 根据模型名称查询元数据
  82. func GetModelByName(name string) (*Model, error) {
  83. var mi Model
  84. err := DB.Where("model_name = ?", name).First(&mi).Error
  85. if err != nil {
  86. return nil, err
  87. }
  88. return &mi, nil
  89. }
  90. // GetVendorModelCounts 统计每个供应商下模型数量(不受分页影响)
  91. func GetVendorModelCounts() (map[int64]int64, error) {
  92. var stats []struct {
  93. VendorID int64
  94. Count int64
  95. }
  96. if err := DB.Model(&Model{}).
  97. Select("vendor_id as vendor_id, count(*) as count").
  98. Group("vendor_id").
  99. Scan(&stats).Error; err != nil {
  100. return nil, err
  101. }
  102. m := make(map[int64]int64, len(stats))
  103. for _, s := range stats {
  104. m[s.VendorID] = s.Count
  105. }
  106. return m, nil
  107. }
  108. // GetAllModels 分页获取所有模型元数据
  109. func GetAllModels(offset int, limit int) ([]*Model, error) {
  110. var models []*Model
  111. err := DB.Offset(offset).Limit(limit).Find(&models).Error
  112. return models, err
  113. }
  114. // GetBoundChannels 查询支持该模型的渠道(名称+类型)
  115. func GetBoundChannels(modelName string) ([]BoundChannel, error) {
  116. var channels []BoundChannel
  117. err := DB.Table("channels").
  118. Select("channels.name, channels.type").
  119. Joins("join abilities on abilities.channel_id = channels.id").
  120. Where("abilities.model = ? AND abilities.enabled = ?", modelName, true).
  121. Group("channels.id").
  122. Scan(&channels).Error
  123. return channels, err
  124. }
  125. // GetBoundChannelsForModels 批量查询多模型的绑定渠道并去重返回
  126. func GetBoundChannelsForModels(modelNames []string) ([]BoundChannel, error) {
  127. if len(modelNames) == 0 {
  128. return make([]BoundChannel, 0), nil
  129. }
  130. var channels []BoundChannel
  131. err := DB.Table("channels").
  132. Select("channels.name, channels.type").
  133. Joins("join abilities on abilities.channel_id = channels.id").
  134. Where("abilities.model IN ? AND abilities.enabled = ?", modelNames, true).
  135. Group("channels.id").
  136. Scan(&channels).Error
  137. return channels, err
  138. }
  139. // FindModelByNameWithRule 根据模型名称和匹配规则查找模型元数据,优先级:精确 > 前缀 > 后缀 > 包含
  140. func FindModelByNameWithRule(name string) (*Model, error) {
  141. // 1. 精确匹配
  142. if m, err := GetModelByName(name); err == nil {
  143. return m, nil
  144. }
  145. // 2. 规则匹配
  146. var models []*Model
  147. if err := DB.Where("name_rule <> ?", NameRuleExact).Find(&models).Error; err != nil {
  148. return nil, err
  149. }
  150. var prefixMatch, suffixMatch, containsMatch *Model
  151. for _, m := range models {
  152. switch m.NameRule {
  153. case NameRulePrefix:
  154. if strings.HasPrefix(name, m.ModelName) {
  155. if prefixMatch == nil || len(m.ModelName) > len(prefixMatch.ModelName) {
  156. prefixMatch = m
  157. }
  158. }
  159. case NameRuleSuffix:
  160. if strings.HasSuffix(name, m.ModelName) {
  161. if suffixMatch == nil || len(m.ModelName) > len(suffixMatch.ModelName) {
  162. suffixMatch = m
  163. }
  164. }
  165. case NameRuleContains:
  166. if strings.Contains(name, m.ModelName) {
  167. if containsMatch == nil || len(m.ModelName) > len(containsMatch.ModelName) {
  168. containsMatch = m
  169. }
  170. }
  171. }
  172. }
  173. if prefixMatch != nil {
  174. return prefixMatch, nil
  175. }
  176. if suffixMatch != nil {
  177. return suffixMatch, nil
  178. }
  179. if containsMatch != nil {
  180. return containsMatch, nil
  181. }
  182. return nil, gorm.ErrRecordNotFound
  183. }
  184. // SearchModels 根据关键词和供应商搜索模型,支持分页
  185. func SearchModels(keyword string, vendor string, offset int, limit int) ([]*Model, int64, error) {
  186. var models []*Model
  187. db := DB.Model(&Model{})
  188. if keyword != "" {
  189. like := "%" + keyword + "%"
  190. db = db.Where("model_name LIKE ? OR description LIKE ? OR tags LIKE ?", like, like, like)
  191. }
  192. if vendor != "" {
  193. // 如果是数字,按供应商 ID 精确匹配;否则按名称模糊匹配
  194. if vid, err := strconv.Atoi(vendor); err == nil {
  195. db = db.Where("models.vendor_id = ?", vid)
  196. } else {
  197. db = db.Joins("JOIN vendors ON vendors.id = models.vendor_id").Where("vendors.name LIKE ?", "%"+vendor+"%")
  198. }
  199. }
  200. var total int64
  201. err := db.Count(&total).Error
  202. if err != nil {
  203. return nil, 0, err
  204. }
  205. err = db.Offset(offset).Limit(limit).Order("models.id DESC").Find(&models).Error
  206. return models, total, err
  207. }