| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278 |
- package controller
- import (
- "encoding/json"
- "strconv"
- "strings"
- "one-api/common"
- "one-api/constant"
- "one-api/model"
- "github.com/gin-gonic/gin"
- )
- // GetAllModelsMeta 获取模型列表(分页)
- func GetAllModelsMeta(c *gin.Context) {
- pageInfo := common.GetPageQuery(c)
- modelsMeta, err := model.GetAllModels(pageInfo.GetStartIdx(), pageInfo.GetPageSize())
- if err != nil {
- common.ApiError(c, err)
- return
- }
- // 填充附加字段
- for _, m := range modelsMeta {
- fillModelExtra(m)
- }
- var total int64
- model.DB.Model(&model.Model{}).Count(&total)
- // 统计供应商计数(全部数据,不受分页影响)
- vendorCounts, _ := model.GetVendorModelCounts()
- pageInfo.SetTotal(int(total))
- pageInfo.SetItems(modelsMeta)
- common.ApiSuccess(c, gin.H{
- "items": modelsMeta,
- "total": total,
- "page": pageInfo.GetPage(),
- "page_size": pageInfo.GetPageSize(),
- "vendor_counts": vendorCounts,
- })
- }
- // SearchModelsMeta 搜索模型列表
- func SearchModelsMeta(c *gin.Context) {
- keyword := c.Query("keyword")
- vendor := c.Query("vendor")
- pageInfo := common.GetPageQuery(c)
- modelsMeta, total, err := model.SearchModels(keyword, vendor, pageInfo.GetStartIdx(), pageInfo.GetPageSize())
- if err != nil {
- common.ApiError(c, err)
- return
- }
- for _, m := range modelsMeta {
- fillModelExtra(m)
- }
- pageInfo.SetTotal(int(total))
- pageInfo.SetItems(modelsMeta)
- common.ApiSuccess(c, pageInfo)
- }
- // GetModelMeta 根据 ID 获取单条模型信息
- func GetModelMeta(c *gin.Context) {
- idStr := c.Param("id")
- id, err := strconv.Atoi(idStr)
- if err != nil {
- common.ApiError(c, err)
- return
- }
- var m model.Model
- if err := model.DB.First(&m, id).Error; err != nil {
- common.ApiError(c, err)
- return
- }
- fillModelExtra(&m)
- common.ApiSuccess(c, &m)
- }
- // CreateModelMeta 新建模型
- func CreateModelMeta(c *gin.Context) {
- var m model.Model
- if err := c.ShouldBindJSON(&m); err != nil {
- common.ApiError(c, err)
- return
- }
- if m.ModelName == "" {
- common.ApiErrorMsg(c, "模型名称不能为空")
- return
- }
- // 名称冲突检查
- if dup, err := model.IsModelNameDuplicated(0, m.ModelName); err != nil {
- common.ApiError(c, err)
- return
- } else if dup {
- common.ApiErrorMsg(c, "模型名称已存在")
- return
- }
- if err := m.Insert(); err != nil {
- common.ApiError(c, err)
- return
- }
- model.RefreshPricing()
- common.ApiSuccess(c, &m)
- }
- // UpdateModelMeta 更新模型
- func UpdateModelMeta(c *gin.Context) {
- statusOnly := c.Query("status_only") == "true"
- var m model.Model
- if err := c.ShouldBindJSON(&m); err != nil {
- common.ApiError(c, err)
- return
- }
- if m.Id == 0 {
- common.ApiErrorMsg(c, "缺少模型 ID")
- return
- }
- if statusOnly {
- // 只更新状态,防止误清空其他字段
- if err := model.DB.Model(&model.Model{}).Where("id = ?", m.Id).Update("status", m.Status).Error; err != nil {
- common.ApiError(c, err)
- return
- }
- } else {
- // 名称冲突检查
- if dup, err := model.IsModelNameDuplicated(m.Id, m.ModelName); err != nil {
- common.ApiError(c, err)
- return
- } else if dup {
- common.ApiErrorMsg(c, "模型名称已存在")
- return
- }
- if err := m.Update(); err != nil {
- common.ApiError(c, err)
- return
- }
- }
- model.RefreshPricing()
- common.ApiSuccess(c, &m)
- }
- // DeleteModelMeta 删除模型
- func DeleteModelMeta(c *gin.Context) {
- idStr := c.Param("id")
- id, err := strconv.Atoi(idStr)
- if err != nil {
- common.ApiError(c, err)
- return
- }
- if err := model.DB.Delete(&model.Model{}, id).Error; err != nil {
- common.ApiError(c, err)
- return
- }
- model.RefreshPricing()
- common.ApiSuccess(c, nil)
- }
- // 辅助函数:填充 Endpoints 和 BoundChannels 和 EnableGroups
- func fillModelExtra(m *model.Model) {
- // 若为精确匹配,保持原有逻辑
- if m.NameRule == model.NameRuleExact {
- if m.Endpoints == "" {
- eps := model.GetModelSupportEndpointTypes(m.ModelName)
- if b, err := json.Marshal(eps); err == nil {
- m.Endpoints = string(b)
- }
- }
- if channels, err := model.GetBoundChannels(m.ModelName); err == nil {
- m.BoundChannels = channels
- }
- m.EnableGroups = model.GetModelEnableGroups(m.ModelName)
- m.QuotaType = model.GetModelQuotaType(m.ModelName)
- return
- }
- // 非精确匹配:计算并集
- pricings := model.GetPricing()
- // 匹配到的模型名称集合
- matchedNames := make([]string, 0)
- // 端点去重集合
- endpointSet := make(map[constant.EndpointType]struct{})
- // 已绑定渠道去重集合
- channelSet := make(map[string]model.BoundChannel)
- // 分组去重集合
- groupSet := make(map[string]struct{})
- // 计费类型(若有任意模型为 1,则返回 1)
- quotaTypeSet := make(map[int]struct{})
- for _, p := range pricings {
- var matched bool
- switch m.NameRule {
- case model.NameRulePrefix:
- matched = strings.HasPrefix(p.ModelName, m.ModelName)
- case model.NameRuleSuffix:
- matched = strings.HasSuffix(p.ModelName, m.ModelName)
- case model.NameRuleContains:
- matched = strings.Contains(p.ModelName, m.ModelName)
- }
- if !matched {
- continue
- }
- // 记录匹配到的模型名称
- matchedNames = append(matchedNames, p.ModelName)
- // 收集端点
- for _, et := range p.SupportedEndpointTypes {
- endpointSet[et] = struct{}{}
- }
- // 收集分组
- for _, g := range p.EnableGroup {
- groupSet[g] = struct{}{}
- }
- // 收集计费类型
- quotaTypeSet[p.QuotaType] = struct{}{}
- // 收集渠道
- if channels, err := model.GetBoundChannels(p.ModelName); err == nil {
- for _, ch := range channels {
- key := ch.Name + "_" + strconv.Itoa(ch.Type)
- channelSet[key] = ch
- }
- }
- }
- // 序列化端点
- if len(endpointSet) > 0 && m.Endpoints == "" {
- eps := make([]constant.EndpointType, 0, len(endpointSet))
- for et := range endpointSet {
- eps = append(eps, et)
- }
- if b, err := json.Marshal(eps); err == nil {
- m.Endpoints = string(b)
- }
- }
- // 序列化渠道
- if len(channelSet) > 0 {
- channels := make([]model.BoundChannel, 0, len(channelSet))
- for _, ch := range channelSet {
- channels = append(channels, ch)
- }
- m.BoundChannels = channels
- }
- // 序列化分组
- if len(groupSet) > 0 {
- groups := make([]string, 0, len(groupSet))
- for g := range groupSet {
- groups = append(groups, g)
- }
- m.EnableGroups = groups
- }
- // 确定计费类型:仅当所有匹配模型计费类型一致时才返回该类型,否则返回 -1 表示未知/不确定
- if len(quotaTypeSet) == 1 {
- for k := range quotaTypeSet {
- m.QuotaType = k
- }
- } else {
- m.QuotaType = -1
- }
- // 设置匹配信息
- m.MatchedModels = matchedNames
- m.MatchedCount = len(matchedNames)
- }
|