pricing.go 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378
  1. package model
  2. import (
  3. "encoding/json"
  4. "fmt"
  5. "strings"
  6. "sync"
  7. "time"
  8. "github.com/QuantumNous/new-api/common"
  9. "github.com/QuantumNous/new-api/constant"
  10. "github.com/QuantumNous/new-api/setting/billing_setting"
  11. "github.com/QuantumNous/new-api/setting/ratio_setting"
  12. "github.com/QuantumNous/new-api/types"
  13. )
  14. type Pricing struct {
  15. ModelName string `json:"model_name"`
  16. Description string `json:"description,omitempty"`
  17. Icon string `json:"icon,omitempty"`
  18. Tags string `json:"tags,omitempty"`
  19. VendorID int `json:"vendor_id,omitempty"`
  20. QuotaType int `json:"quota_type"`
  21. ModelRatio float64 `json:"model_ratio"`
  22. ModelPrice float64 `json:"model_price"`
  23. OwnerBy string `json:"owner_by"`
  24. CompletionRatio float64 `json:"completion_ratio"`
  25. CacheRatio *float64 `json:"cache_ratio,omitempty"`
  26. CreateCacheRatio *float64 `json:"create_cache_ratio,omitempty"`
  27. ImageRatio *float64 `json:"image_ratio,omitempty"`
  28. AudioRatio *float64 `json:"audio_ratio,omitempty"`
  29. AudioCompletionRatio *float64 `json:"audio_completion_ratio,omitempty"`
  30. EnableGroup []string `json:"enable_groups"`
  31. SupportedEndpointTypes []constant.EndpointType `json:"supported_endpoint_types"`
  32. BillingMode string `json:"billing_mode,omitempty"`
  33. BillingExpr string `json:"billing_expr,omitempty"`
  34. PricingVersion string `json:"pricing_version,omitempty"`
  35. }
  36. type PricingVendor struct {
  37. ID int `json:"id"`
  38. Name string `json:"name"`
  39. Description string `json:"description,omitempty"`
  40. Icon string `json:"icon,omitempty"`
  41. }
  42. var (
  43. pricingMap []Pricing
  44. vendorsList []PricingVendor
  45. supportedEndpointMap map[string]common.EndpointInfo
  46. lastGetPricingTime time.Time
  47. updatePricingLock sync.Mutex
  48. // 缓存映射:模型名 -> 启用分组 / 计费类型
  49. modelEnableGroups = make(map[string][]string)
  50. modelQuotaTypeMap = make(map[string]int)
  51. modelEnableGroupsLock = sync.RWMutex{}
  52. )
  53. var (
  54. modelSupportEndpointTypes = make(map[string][]constant.EndpointType)
  55. modelSupportEndpointsLock = sync.RWMutex{}
  56. )
  57. func GetPricing() []Pricing {
  58. if time.Since(lastGetPricingTime) > time.Minute*1 || len(pricingMap) == 0 {
  59. updatePricingLock.Lock()
  60. defer updatePricingLock.Unlock()
  61. // Double check after acquiring the lock
  62. if time.Since(lastGetPricingTime) > time.Minute*1 || len(pricingMap) == 0 {
  63. modelSupportEndpointsLock.Lock()
  64. defer modelSupportEndpointsLock.Unlock()
  65. updatePricing()
  66. }
  67. }
  68. return pricingMap
  69. }
  70. func InvalidatePricingCache() {
  71. updatePricingLock.Lock()
  72. defer updatePricingLock.Unlock()
  73. pricingMap = nil
  74. vendorsList = nil
  75. lastGetPricingTime = time.Time{}
  76. }
  77. func HasModelBillingConfig(modelName string) bool {
  78. if _, ok := ratio_setting.GetModelPrice(modelName, false); ok {
  79. return true
  80. }
  81. if _, ok, _ := ratio_setting.GetModelRatio(modelName); ok {
  82. return true
  83. }
  84. if billing_setting.GetBillingMode(modelName) != billing_setting.BillingModeTieredExpr {
  85. return false
  86. }
  87. expr, ok := billing_setting.GetBillingExpr(modelName)
  88. return ok && strings.TrimSpace(expr) != ""
  89. }
  90. // GetVendors 返回当前定价接口使用到的供应商信息
  91. func GetVendors() []PricingVendor {
  92. if time.Since(lastGetPricingTime) > time.Minute*1 || len(pricingMap) == 0 {
  93. // 保证先刷新一次
  94. GetPricing()
  95. }
  96. return vendorsList
  97. }
  98. func GetModelSupportEndpointTypes(model string) []constant.EndpointType {
  99. if model == "" {
  100. return make([]constant.EndpointType, 0)
  101. }
  102. modelSupportEndpointsLock.RLock()
  103. defer modelSupportEndpointsLock.RUnlock()
  104. if endpoints, ok := modelSupportEndpointTypes[model]; ok {
  105. return endpoints
  106. }
  107. return make([]constant.EndpointType, 0)
  108. }
  109. func updatePricing() {
  110. //modelRatios := common.GetModelRatios()
  111. enableAbilities, err := GetAllEnableAbilityWithChannels()
  112. if err != nil {
  113. common.SysLog(fmt.Sprintf("GetAllEnableAbilityWithChannels error: %v", err))
  114. return
  115. }
  116. // 预加载模型元数据与供应商一次,避免循环查询
  117. var allMeta []Model
  118. _ = DB.Find(&allMeta).Error
  119. metaMap := make(map[string]*Model)
  120. prefixList := make([]*Model, 0)
  121. suffixList := make([]*Model, 0)
  122. containsList := make([]*Model, 0)
  123. for i := range allMeta {
  124. m := &allMeta[i]
  125. if m.NameRule == NameRuleExact {
  126. metaMap[m.ModelName] = m
  127. } else {
  128. switch m.NameRule {
  129. case NameRulePrefix:
  130. prefixList = append(prefixList, m)
  131. case NameRuleSuffix:
  132. suffixList = append(suffixList, m)
  133. case NameRuleContains:
  134. containsList = append(containsList, m)
  135. }
  136. }
  137. }
  138. // 将非精确规则模型匹配到 metaMap
  139. for _, m := range prefixList {
  140. for _, pricingModel := range enableAbilities {
  141. if strings.HasPrefix(pricingModel.Model, m.ModelName) {
  142. if _, exists := metaMap[pricingModel.Model]; !exists {
  143. metaMap[pricingModel.Model] = m
  144. }
  145. }
  146. }
  147. }
  148. for _, m := range suffixList {
  149. for _, pricingModel := range enableAbilities {
  150. if strings.HasSuffix(pricingModel.Model, m.ModelName) {
  151. if _, exists := metaMap[pricingModel.Model]; !exists {
  152. metaMap[pricingModel.Model] = m
  153. }
  154. }
  155. }
  156. }
  157. for _, m := range containsList {
  158. for _, pricingModel := range enableAbilities {
  159. if strings.Contains(pricingModel.Model, m.ModelName) {
  160. if _, exists := metaMap[pricingModel.Model]; !exists {
  161. metaMap[pricingModel.Model] = m
  162. }
  163. }
  164. }
  165. }
  166. // 预加载供应商
  167. var vendors []Vendor
  168. _ = DB.Find(&vendors).Error
  169. vendorMap := make(map[int]*Vendor)
  170. for i := range vendors {
  171. vendorMap[vendors[i].Id] = &vendors[i]
  172. }
  173. // 初始化默认供应商映射
  174. initDefaultVendorMapping(metaMap, vendorMap, enableAbilities)
  175. // 构建对前端友好的供应商列表
  176. vendorsList = make([]PricingVendor, 0, len(vendorMap))
  177. for _, v := range vendorMap {
  178. vendorsList = append(vendorsList, PricingVendor{
  179. ID: v.Id,
  180. Name: v.Name,
  181. Description: v.Description,
  182. Icon: v.Icon,
  183. })
  184. }
  185. modelGroupsMap := make(map[string]*types.Set[string])
  186. for _, ability := range enableAbilities {
  187. groups, ok := modelGroupsMap[ability.Model]
  188. if !ok {
  189. groups = types.NewSet[string]()
  190. modelGroupsMap[ability.Model] = groups
  191. }
  192. groups.Add(ability.Group)
  193. }
  194. //这里使用切片而不是Set,因为一个模型可能支持多个端点类型,并且第一个端点是优先使用端点
  195. modelSupportEndpointsStr := make(map[string][]string)
  196. // 先根据已有能力填充原生端点
  197. for _, ability := range enableAbilities {
  198. endpoints := modelSupportEndpointsStr[ability.Model]
  199. channelTypes := common.GetEndpointTypesByChannelType(ability.ChannelType, ability.Model)
  200. for _, channelType := range channelTypes {
  201. if !common.StringsContains(endpoints, string(channelType)) {
  202. endpoints = append(endpoints, string(channelType))
  203. }
  204. }
  205. modelSupportEndpointsStr[ability.Model] = endpoints
  206. }
  207. // 再补充模型自定义端点:若配置有效则替换默认端点,不做合并
  208. for modelName, meta := range metaMap {
  209. if strings.TrimSpace(meta.Endpoints) == "" {
  210. continue
  211. }
  212. var raw map[string]interface{}
  213. if err := json.Unmarshal([]byte(meta.Endpoints), &raw); err == nil {
  214. endpoints := make([]string, 0, len(raw))
  215. for k, v := range raw {
  216. switch v.(type) {
  217. case string, map[string]interface{}:
  218. if !common.StringsContains(endpoints, k) {
  219. endpoints = append(endpoints, k)
  220. }
  221. }
  222. }
  223. if len(endpoints) > 0 {
  224. modelSupportEndpointsStr[modelName] = endpoints
  225. }
  226. }
  227. }
  228. modelSupportEndpointTypes = make(map[string][]constant.EndpointType)
  229. for model, endpoints := range modelSupportEndpointsStr {
  230. supportedEndpoints := make([]constant.EndpointType, 0)
  231. for _, endpointStr := range endpoints {
  232. endpointType := constant.EndpointType(endpointStr)
  233. supportedEndpoints = append(supportedEndpoints, endpointType)
  234. }
  235. modelSupportEndpointTypes[model] = supportedEndpoints
  236. }
  237. // 构建全局 supportedEndpointMap(默认 + 自定义覆盖)
  238. supportedEndpointMap = make(map[string]common.EndpointInfo)
  239. // 1. 默认端点
  240. for _, endpoints := range modelSupportEndpointTypes {
  241. for _, et := range endpoints {
  242. if info, ok := common.GetDefaultEndpointInfo(et); ok {
  243. if _, exists := supportedEndpointMap[string(et)]; !exists {
  244. supportedEndpointMap[string(et)] = info
  245. }
  246. }
  247. }
  248. }
  249. // 2. 自定义端点(models 表)覆盖默认
  250. for _, meta := range metaMap {
  251. if strings.TrimSpace(meta.Endpoints) == "" {
  252. continue
  253. }
  254. var raw map[string]interface{}
  255. if err := json.Unmarshal([]byte(meta.Endpoints), &raw); err == nil {
  256. for k, v := range raw {
  257. switch val := v.(type) {
  258. case string:
  259. supportedEndpointMap[k] = common.EndpointInfo{Path: val, Method: "POST"}
  260. case map[string]interface{}:
  261. ep := common.EndpointInfo{Method: "POST"}
  262. if p, ok := val["path"].(string); ok {
  263. ep.Path = p
  264. }
  265. if m, ok := val["method"].(string); ok {
  266. ep.Method = strings.ToUpper(m)
  267. }
  268. supportedEndpointMap[k] = ep
  269. default:
  270. // ignore unsupported types
  271. }
  272. }
  273. }
  274. }
  275. pricingMap = make([]Pricing, 0)
  276. for model, groups := range modelGroupsMap {
  277. pricing := Pricing{
  278. ModelName: model,
  279. EnableGroup: groups.Items(),
  280. SupportedEndpointTypes: modelSupportEndpointTypes[model],
  281. }
  282. // 补充模型元数据(描述、标签、供应商、状态)
  283. if meta, ok := metaMap[model]; ok {
  284. // 若模型被禁用(status!=1),则直接跳过,不返回给前端
  285. if meta.Status != 1 {
  286. continue
  287. }
  288. pricing.Description = meta.Description
  289. pricing.Icon = meta.Icon
  290. pricing.Tags = meta.Tags
  291. pricing.VendorID = meta.VendorID
  292. }
  293. modelPrice, findPrice := ratio_setting.GetModelPrice(model, false)
  294. if findPrice {
  295. pricing.ModelPrice = modelPrice
  296. pricing.QuotaType = 1
  297. } else {
  298. modelRatio, _, _ := ratio_setting.GetModelRatio(model)
  299. pricing.ModelRatio = modelRatio
  300. pricing.CompletionRatio = ratio_setting.GetCompletionRatio(model)
  301. pricing.QuotaType = 0
  302. }
  303. if cacheRatio, ok := ratio_setting.GetCacheRatio(model); ok {
  304. pricing.CacheRatio = &cacheRatio
  305. }
  306. if createCacheRatio, ok := ratio_setting.GetCreateCacheRatio(model); ok {
  307. pricing.CreateCacheRatio = &createCacheRatio
  308. }
  309. if imageRatio, ok := ratio_setting.GetImageRatio(model); ok {
  310. pricing.ImageRatio = &imageRatio
  311. }
  312. if ratio_setting.ContainsAudioRatio(model) {
  313. audioRatio := ratio_setting.GetAudioRatio(model)
  314. pricing.AudioRatio = &audioRatio
  315. }
  316. if ratio_setting.ContainsAudioCompletionRatio(model) {
  317. audioCompletionRatio := ratio_setting.GetAudioCompletionRatio(model)
  318. pricing.AudioCompletionRatio = &audioCompletionRatio
  319. }
  320. if billingMode := billing_setting.GetBillingMode(model); billingMode == "tiered_expr" {
  321. if expr, ok := billing_setting.GetBillingExpr(model); ok && strings.TrimSpace(expr) != "" {
  322. pricing.BillingMode = billingMode
  323. pricing.BillingExpr = expr
  324. }
  325. }
  326. pricingMap = append(pricingMap, pricing)
  327. }
  328. // 防止大更新后数据不通用
  329. if len(pricingMap) > 0 {
  330. pricingMap[0].PricingVersion = "5a90f2b86c08bd983a9a2e6d66c255f4eaef9c4bc934386d2b6ae84ef0ff1f1f"
  331. }
  332. // 刷新缓存映射,供高并发快速查询
  333. modelEnableGroupsLock.Lock()
  334. modelEnableGroups = make(map[string][]string)
  335. modelQuotaTypeMap = make(map[string]int)
  336. for _, p := range pricingMap {
  337. modelEnableGroups[p.ModelName] = p.EnableGroup
  338. modelQuotaTypeMap[p.ModelName] = p.QuotaType
  339. }
  340. modelEnableGroupsLock.Unlock()
  341. lastGetPricingTime = time.Now()
  342. }
  343. // GetSupportedEndpointMap 返回全局端点到路径的映射
  344. func GetSupportedEndpointMap() map[string]common.EndpointInfo {
  345. return supportedEndpointMap
  346. }