model_meta.go 6.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278
  1. package controller
  2. import (
  3. "encoding/json"
  4. "strconv"
  5. "strings"
  6. "one-api/common"
  7. "one-api/constant"
  8. "one-api/model"
  9. "github.com/gin-gonic/gin"
  10. )
  11. // GetAllModelsMeta 获取模型列表(分页)
  12. func GetAllModelsMeta(c *gin.Context) {
  13. pageInfo := common.GetPageQuery(c)
  14. modelsMeta, err := model.GetAllModels(pageInfo.GetStartIdx(), pageInfo.GetPageSize())
  15. if err != nil {
  16. common.ApiError(c, err)
  17. return
  18. }
  19. // 填充附加字段
  20. for _, m := range modelsMeta {
  21. fillModelExtra(m)
  22. }
  23. var total int64
  24. model.DB.Model(&model.Model{}).Count(&total)
  25. // 统计供应商计数(全部数据,不受分页影响)
  26. vendorCounts, _ := model.GetVendorModelCounts()
  27. pageInfo.SetTotal(int(total))
  28. pageInfo.SetItems(modelsMeta)
  29. common.ApiSuccess(c, gin.H{
  30. "items": modelsMeta,
  31. "total": total,
  32. "page": pageInfo.GetPage(),
  33. "page_size": pageInfo.GetPageSize(),
  34. "vendor_counts": vendorCounts,
  35. })
  36. }
  37. // SearchModelsMeta 搜索模型列表
  38. func SearchModelsMeta(c *gin.Context) {
  39. keyword := c.Query("keyword")
  40. vendor := c.Query("vendor")
  41. pageInfo := common.GetPageQuery(c)
  42. modelsMeta, total, err := model.SearchModels(keyword, vendor, pageInfo.GetStartIdx(), pageInfo.GetPageSize())
  43. if err != nil {
  44. common.ApiError(c, err)
  45. return
  46. }
  47. for _, m := range modelsMeta {
  48. fillModelExtra(m)
  49. }
  50. pageInfo.SetTotal(int(total))
  51. pageInfo.SetItems(modelsMeta)
  52. common.ApiSuccess(c, pageInfo)
  53. }
  54. // GetModelMeta 根据 ID 获取单条模型信息
  55. func GetModelMeta(c *gin.Context) {
  56. idStr := c.Param("id")
  57. id, err := strconv.Atoi(idStr)
  58. if err != nil {
  59. common.ApiError(c, err)
  60. return
  61. }
  62. var m model.Model
  63. if err := model.DB.First(&m, id).Error; err != nil {
  64. common.ApiError(c, err)
  65. return
  66. }
  67. fillModelExtra(&m)
  68. common.ApiSuccess(c, &m)
  69. }
  70. // CreateModelMeta 新建模型
  71. func CreateModelMeta(c *gin.Context) {
  72. var m model.Model
  73. if err := c.ShouldBindJSON(&m); err != nil {
  74. common.ApiError(c, err)
  75. return
  76. }
  77. if m.ModelName == "" {
  78. common.ApiErrorMsg(c, "模型名称不能为空")
  79. return
  80. }
  81. // 名称冲突检查
  82. if dup, err := model.IsModelNameDuplicated(0, m.ModelName); err != nil {
  83. common.ApiError(c, err)
  84. return
  85. } else if dup {
  86. common.ApiErrorMsg(c, "模型名称已存在")
  87. return
  88. }
  89. if err := m.Insert(); err != nil {
  90. common.ApiError(c, err)
  91. return
  92. }
  93. model.RefreshPricing()
  94. common.ApiSuccess(c, &m)
  95. }
  96. // UpdateModelMeta 更新模型
  97. func UpdateModelMeta(c *gin.Context) {
  98. statusOnly := c.Query("status_only") == "true"
  99. var m model.Model
  100. if err := c.ShouldBindJSON(&m); err != nil {
  101. common.ApiError(c, err)
  102. return
  103. }
  104. if m.Id == 0 {
  105. common.ApiErrorMsg(c, "缺少模型 ID")
  106. return
  107. }
  108. if statusOnly {
  109. // 只更新状态,防止误清空其他字段
  110. if err := model.DB.Model(&model.Model{}).Where("id = ?", m.Id).Update("status", m.Status).Error; err != nil {
  111. common.ApiError(c, err)
  112. return
  113. }
  114. } else {
  115. // 名称冲突检查
  116. if dup, err := model.IsModelNameDuplicated(m.Id, m.ModelName); err != nil {
  117. common.ApiError(c, err)
  118. return
  119. } else if dup {
  120. common.ApiErrorMsg(c, "模型名称已存在")
  121. return
  122. }
  123. if err := m.Update(); err != nil {
  124. common.ApiError(c, err)
  125. return
  126. }
  127. }
  128. model.RefreshPricing()
  129. common.ApiSuccess(c, &m)
  130. }
  131. // DeleteModelMeta 删除模型
  132. func DeleteModelMeta(c *gin.Context) {
  133. idStr := c.Param("id")
  134. id, err := strconv.Atoi(idStr)
  135. if err != nil {
  136. common.ApiError(c, err)
  137. return
  138. }
  139. if err := model.DB.Delete(&model.Model{}, id).Error; err != nil {
  140. common.ApiError(c, err)
  141. return
  142. }
  143. model.RefreshPricing()
  144. common.ApiSuccess(c, nil)
  145. }
  146. // 辅助函数:填充 Endpoints 和 BoundChannels 和 EnableGroups
  147. func fillModelExtra(m *model.Model) {
  148. // 若为精确匹配,保持原有逻辑
  149. if m.NameRule == model.NameRuleExact {
  150. if m.Endpoints == "" {
  151. eps := model.GetModelSupportEndpointTypes(m.ModelName)
  152. if b, err := json.Marshal(eps); err == nil {
  153. m.Endpoints = string(b)
  154. }
  155. }
  156. if channels, err := model.GetBoundChannels(m.ModelName); err == nil {
  157. m.BoundChannels = channels
  158. }
  159. m.EnableGroups = model.GetModelEnableGroups(m.ModelName)
  160. m.QuotaType = model.GetModelQuotaType(m.ModelName)
  161. return
  162. }
  163. // 非精确匹配:计算并集
  164. pricings := model.GetPricing()
  165. // 匹配到的模型名称集合
  166. matchedNames := make([]string, 0)
  167. // 端点去重集合
  168. endpointSet := make(map[constant.EndpointType]struct{})
  169. // 已绑定渠道去重集合
  170. channelSet := make(map[string]model.BoundChannel)
  171. // 分组去重集合
  172. groupSet := make(map[string]struct{})
  173. // 计费类型(若有任意模型为 1,则返回 1)
  174. quotaTypeSet := make(map[int]struct{})
  175. for _, p := range pricings {
  176. var matched bool
  177. switch m.NameRule {
  178. case model.NameRulePrefix:
  179. matched = strings.HasPrefix(p.ModelName, m.ModelName)
  180. case model.NameRuleSuffix:
  181. matched = strings.HasSuffix(p.ModelName, m.ModelName)
  182. case model.NameRuleContains:
  183. matched = strings.Contains(p.ModelName, m.ModelName)
  184. }
  185. if !matched {
  186. continue
  187. }
  188. // 记录匹配到的模型名称
  189. matchedNames = append(matchedNames, p.ModelName)
  190. // 收集端点
  191. for _, et := range p.SupportedEndpointTypes {
  192. endpointSet[et] = struct{}{}
  193. }
  194. // 收集分组
  195. for _, g := range p.EnableGroup {
  196. groupSet[g] = struct{}{}
  197. }
  198. // 收集计费类型
  199. quotaTypeSet[p.QuotaType] = struct{}{}
  200. // 收集渠道
  201. if channels, err := model.GetBoundChannels(p.ModelName); err == nil {
  202. for _, ch := range channels {
  203. key := ch.Name + "_" + strconv.Itoa(ch.Type)
  204. channelSet[key] = ch
  205. }
  206. }
  207. }
  208. // 序列化端点
  209. if len(endpointSet) > 0 && m.Endpoints == "" {
  210. eps := make([]constant.EndpointType, 0, len(endpointSet))
  211. for et := range endpointSet {
  212. eps = append(eps, et)
  213. }
  214. if b, err := json.Marshal(eps); err == nil {
  215. m.Endpoints = string(b)
  216. }
  217. }
  218. // 序列化渠道
  219. if len(channelSet) > 0 {
  220. channels := make([]model.BoundChannel, 0, len(channelSet))
  221. for _, ch := range channelSet {
  222. channels = append(channels, ch)
  223. }
  224. m.BoundChannels = channels
  225. }
  226. // 序列化分组
  227. if len(groupSet) > 0 {
  228. groups := make([]string, 0, len(groupSet))
  229. for g := range groupSet {
  230. groups = append(groups, g)
  231. }
  232. m.EnableGroups = groups
  233. }
  234. // 确定计费类型:仅当所有匹配模型计费类型一致时才返回该类型,否则返回 -1 表示未知/不确定
  235. if len(quotaTypeSet) == 1 {
  236. for k := range quotaTypeSet {
  237. m.QuotaType = k
  238. }
  239. } else {
  240. m.QuotaType = -1
  241. }
  242. // 设置匹配信息
  243. m.MatchedModels = matchedNames
  244. m.MatchedCount = len(matchedNames)
  245. }