pricing.go 9.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324
  1. package model
  2. import (
  3. "encoding/json"
  4. "fmt"
  5. "strings"
  6. "sync"
  7. "time"
  8. "github.com/QuantumNous/new-api/common"
  9. "github.com/QuantumNous/new-api/constant"
  10. "github.com/QuantumNous/new-api/setting/ratio_setting"
  11. "github.com/QuantumNous/new-api/types"
  12. )
  13. type Pricing struct {
  14. ModelName string `json:"model_name"`
  15. Description string `json:"description,omitempty"`
  16. Icon string `json:"icon,omitempty"`
  17. Tags string `json:"tags,omitempty"`
  18. VendorID int `json:"vendor_id,omitempty"`
  19. QuotaType int `json:"quota_type"`
  20. ModelRatio float64 `json:"model_ratio"`
  21. ModelPrice float64 `json:"model_price"`
  22. OwnerBy string `json:"owner_by"`
  23. CompletionRatio float64 `json:"completion_ratio"`
  24. EnableGroup []string `json:"enable_groups"`
  25. SupportedEndpointTypes []constant.EndpointType `json:"supported_endpoint_types"`
  26. PricingVersion string `json:"pricing_version,omitempty"`
  27. }
  28. type PricingVendor struct {
  29. ID int `json:"id"`
  30. Name string `json:"name"`
  31. Description string `json:"description,omitempty"`
  32. Icon string `json:"icon,omitempty"`
  33. }
  34. var (
  35. pricingMap []Pricing
  36. vendorsList []PricingVendor
  37. supportedEndpointMap map[string]common.EndpointInfo
  38. lastGetPricingTime time.Time
  39. updatePricingLock sync.Mutex
  40. // 缓存映射:模型名 -> 启用分组 / 计费类型
  41. modelEnableGroups = make(map[string][]string)
  42. modelQuotaTypeMap = make(map[string]int)
  43. modelEnableGroupsLock = sync.RWMutex{}
  44. )
  45. var (
  46. modelSupportEndpointTypes = make(map[string][]constant.EndpointType)
  47. modelSupportEndpointsLock = sync.RWMutex{}
  48. )
  49. func GetPricing() []Pricing {
  50. if time.Since(lastGetPricingTime) > time.Minute*1 || len(pricingMap) == 0 {
  51. updatePricingLock.Lock()
  52. defer updatePricingLock.Unlock()
  53. // Double check after acquiring the lock
  54. if time.Since(lastGetPricingTime) > time.Minute*1 || len(pricingMap) == 0 {
  55. modelSupportEndpointsLock.Lock()
  56. defer modelSupportEndpointsLock.Unlock()
  57. updatePricing()
  58. }
  59. }
  60. return pricingMap
  61. }
  62. // GetVendors 返回当前定价接口使用到的供应商信息
  63. func GetVendors() []PricingVendor {
  64. if time.Since(lastGetPricingTime) > time.Minute*1 || len(pricingMap) == 0 {
  65. // 保证先刷新一次
  66. GetPricing()
  67. }
  68. return vendorsList
  69. }
  70. func GetModelSupportEndpointTypes(model string) []constant.EndpointType {
  71. if model == "" {
  72. return make([]constant.EndpointType, 0)
  73. }
  74. modelSupportEndpointsLock.RLock()
  75. defer modelSupportEndpointsLock.RUnlock()
  76. if endpoints, ok := modelSupportEndpointTypes[model]; ok {
  77. return endpoints
  78. }
  79. return make([]constant.EndpointType, 0)
  80. }
  81. func updatePricing() {
  82. //modelRatios := common.GetModelRatios()
  83. enableAbilities, err := GetAllEnableAbilityWithChannels()
  84. if err != nil {
  85. common.SysLog(fmt.Sprintf("GetAllEnableAbilityWithChannels error: %v", err))
  86. return
  87. }
  88. // 预加载模型元数据与供应商一次,避免循环查询
  89. var allMeta []Model
  90. _ = DB.Find(&allMeta).Error
  91. metaMap := make(map[string]*Model)
  92. prefixList := make([]*Model, 0)
  93. suffixList := make([]*Model, 0)
  94. containsList := make([]*Model, 0)
  95. for i := range allMeta {
  96. m := &allMeta[i]
  97. if m.NameRule == NameRuleExact {
  98. metaMap[m.ModelName] = m
  99. } else {
  100. switch m.NameRule {
  101. case NameRulePrefix:
  102. prefixList = append(prefixList, m)
  103. case NameRuleSuffix:
  104. suffixList = append(suffixList, m)
  105. case NameRuleContains:
  106. containsList = append(containsList, m)
  107. }
  108. }
  109. }
  110. // 将非精确规则模型匹配到 metaMap
  111. for _, m := range prefixList {
  112. for _, pricingModel := range enableAbilities {
  113. if strings.HasPrefix(pricingModel.Model, m.ModelName) {
  114. if _, exists := metaMap[pricingModel.Model]; !exists {
  115. metaMap[pricingModel.Model] = m
  116. }
  117. }
  118. }
  119. }
  120. for _, m := range suffixList {
  121. for _, pricingModel := range enableAbilities {
  122. if strings.HasSuffix(pricingModel.Model, m.ModelName) {
  123. if _, exists := metaMap[pricingModel.Model]; !exists {
  124. metaMap[pricingModel.Model] = m
  125. }
  126. }
  127. }
  128. }
  129. for _, m := range containsList {
  130. for _, pricingModel := range enableAbilities {
  131. if strings.Contains(pricingModel.Model, m.ModelName) {
  132. if _, exists := metaMap[pricingModel.Model]; !exists {
  133. metaMap[pricingModel.Model] = m
  134. }
  135. }
  136. }
  137. }
  138. // 预加载供应商
  139. var vendors []Vendor
  140. _ = DB.Find(&vendors).Error
  141. vendorMap := make(map[int]*Vendor)
  142. for i := range vendors {
  143. vendorMap[vendors[i].Id] = &vendors[i]
  144. }
  145. // 初始化默认供应商映射
  146. initDefaultVendorMapping(metaMap, vendorMap, enableAbilities)
  147. // 构建对前端友好的供应商列表
  148. vendorsList = make([]PricingVendor, 0, len(vendorMap))
  149. for _, v := range vendorMap {
  150. vendorsList = append(vendorsList, PricingVendor{
  151. ID: v.Id,
  152. Name: v.Name,
  153. Description: v.Description,
  154. Icon: v.Icon,
  155. })
  156. }
  157. modelGroupsMap := make(map[string]*types.Set[string])
  158. for _, ability := range enableAbilities {
  159. groups, ok := modelGroupsMap[ability.Model]
  160. if !ok {
  161. groups = types.NewSet[string]()
  162. modelGroupsMap[ability.Model] = groups
  163. }
  164. groups.Add(ability.Group)
  165. }
  166. //这里使用切片而不是Set,因为一个模型可能支持多个端点类型,并且第一个端点是优先使用端点
  167. modelSupportEndpointsStr := make(map[string][]string)
  168. // 先根据已有能力填充原生端点
  169. for _, ability := range enableAbilities {
  170. endpoints := modelSupportEndpointsStr[ability.Model]
  171. channelTypes := common.GetEndpointTypesByChannelType(ability.ChannelType, ability.Model)
  172. for _, channelType := range channelTypes {
  173. if !common.StringsContains(endpoints, string(channelType)) {
  174. endpoints = append(endpoints, string(channelType))
  175. }
  176. }
  177. modelSupportEndpointsStr[ability.Model] = endpoints
  178. }
  179. // 再补充模型自定义端点:若配置有效则替换默认端点,不做合并
  180. for modelName, meta := range metaMap {
  181. if strings.TrimSpace(meta.Endpoints) == "" {
  182. continue
  183. }
  184. var raw map[string]interface{}
  185. if err := json.Unmarshal([]byte(meta.Endpoints), &raw); err == nil {
  186. endpoints := make([]string, 0, len(raw))
  187. for k, v := range raw {
  188. switch v.(type) {
  189. case string, map[string]interface{}:
  190. if !common.StringsContains(endpoints, k) {
  191. endpoints = append(endpoints, k)
  192. }
  193. }
  194. }
  195. if len(endpoints) > 0 {
  196. modelSupportEndpointsStr[modelName] = endpoints
  197. }
  198. }
  199. }
  200. modelSupportEndpointTypes = make(map[string][]constant.EndpointType)
  201. for model, endpoints := range modelSupportEndpointsStr {
  202. supportedEndpoints := make([]constant.EndpointType, 0)
  203. for _, endpointStr := range endpoints {
  204. endpointType := constant.EndpointType(endpointStr)
  205. supportedEndpoints = append(supportedEndpoints, endpointType)
  206. }
  207. modelSupportEndpointTypes[model] = supportedEndpoints
  208. }
  209. // 构建全局 supportedEndpointMap(默认 + 自定义覆盖)
  210. supportedEndpointMap = make(map[string]common.EndpointInfo)
  211. // 1. 默认端点
  212. for _, endpoints := range modelSupportEndpointTypes {
  213. for _, et := range endpoints {
  214. if info, ok := common.GetDefaultEndpointInfo(et); ok {
  215. if _, exists := supportedEndpointMap[string(et)]; !exists {
  216. supportedEndpointMap[string(et)] = info
  217. }
  218. }
  219. }
  220. }
  221. // 2. 自定义端点(models 表)覆盖默认
  222. for _, meta := range metaMap {
  223. if strings.TrimSpace(meta.Endpoints) == "" {
  224. continue
  225. }
  226. var raw map[string]interface{}
  227. if err := json.Unmarshal([]byte(meta.Endpoints), &raw); err == nil {
  228. for k, v := range raw {
  229. switch val := v.(type) {
  230. case string:
  231. supportedEndpointMap[k] = common.EndpointInfo{Path: val, Method: "POST"}
  232. case map[string]interface{}:
  233. ep := common.EndpointInfo{Method: "POST"}
  234. if p, ok := val["path"].(string); ok {
  235. ep.Path = p
  236. }
  237. if m, ok := val["method"].(string); ok {
  238. ep.Method = strings.ToUpper(m)
  239. }
  240. supportedEndpointMap[k] = ep
  241. default:
  242. // ignore unsupported types
  243. }
  244. }
  245. }
  246. }
  247. pricingMap = make([]Pricing, 0)
  248. for model, groups := range modelGroupsMap {
  249. pricing := Pricing{
  250. ModelName: model,
  251. EnableGroup: groups.Items(),
  252. SupportedEndpointTypes: modelSupportEndpointTypes[model],
  253. }
  254. // 补充模型元数据(描述、标签、供应商、状态)
  255. if meta, ok := metaMap[model]; ok {
  256. // 若模型被禁用(status!=1),则直接跳过,不返回给前端
  257. if meta.Status != 1 {
  258. continue
  259. }
  260. pricing.Description = meta.Description
  261. pricing.Icon = meta.Icon
  262. pricing.Tags = meta.Tags
  263. pricing.VendorID = meta.VendorID
  264. }
  265. modelPrice, findPrice := ratio_setting.GetModelPrice(model, false)
  266. if findPrice {
  267. pricing.ModelPrice = modelPrice
  268. pricing.QuotaType = 1
  269. } else {
  270. modelRatio, _, _ := ratio_setting.GetModelRatio(model)
  271. pricing.ModelRatio = modelRatio
  272. pricing.CompletionRatio = ratio_setting.GetCompletionRatio(model)
  273. pricing.QuotaType = 0
  274. }
  275. pricingMap = append(pricingMap, pricing)
  276. }
  277. // 防止大更新后数据不通用
  278. if len(pricingMap) > 0 {
  279. pricingMap[0].PricingVersion = "82c4a357505fff6fee8462c3f7ec8a645bb95532669cb73b2cabee6a416ec24f"
  280. }
  281. // 刷新缓存映射,供高并发快速查询
  282. modelEnableGroupsLock.Lock()
  283. modelEnableGroups = make(map[string][]string)
  284. modelQuotaTypeMap = make(map[string]int)
  285. for _, p := range pricingMap {
  286. modelEnableGroups[p.ModelName] = p.EnableGroup
  287. modelQuotaTypeMap[p.ModelName] = p.QuotaType
  288. }
  289. modelEnableGroupsLock.Unlock()
  290. lastGetPricingTime = time.Now()
  291. }
  292. // GetSupportedEndpointMap 返回全局端点到路径的映射
  293. func GetSupportedEndpointMap() map[string]common.EndpointInfo {
  294. return supportedEndpointMap
  295. }