pricing.go 7.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246
  1. package model
  2. import (
  3. "fmt"
  4. "strings"
  5. "one-api/common"
  6. "one-api/constant"
  7. "one-api/setting/ratio_setting"
  8. "one-api/types"
  9. "sync"
  10. "time"
  11. )
  12. type Pricing struct {
  13. ModelName string `json:"model_name"`
  14. Description string `json:"description,omitempty"`
  15. Tags string `json:"tags,omitempty"`
  16. VendorID int `json:"vendor_id,omitempty"`
  17. QuotaType int `json:"quota_type"`
  18. ModelRatio float64 `json:"model_ratio"`
  19. ModelPrice float64 `json:"model_price"`
  20. OwnerBy string `json:"owner_by"`
  21. CompletionRatio float64 `json:"completion_ratio"`
  22. EnableGroup []string `json:"enable_groups"`
  23. SupportedEndpointTypes []constant.EndpointType `json:"supported_endpoint_types"`
  24. }
  25. type PricingVendor struct {
  26. ID int `json:"id"`
  27. Name string `json:"name"`
  28. Description string `json:"description,omitempty"`
  29. Icon string `json:"icon,omitempty"`
  30. }
  31. var (
  32. pricingMap []Pricing
  33. vendorsList []PricingVendor
  34. lastGetPricingTime time.Time
  35. updatePricingLock sync.Mutex
  36. // 缓存映射:模型名 -> 启用分组 / 计费类型
  37. modelEnableGroups = make(map[string][]string)
  38. modelQuotaTypeMap = make(map[string]int)
  39. modelEnableGroupsLock = sync.RWMutex{}
  40. )
  41. var (
  42. modelSupportEndpointTypes = make(map[string][]constant.EndpointType)
  43. modelSupportEndpointsLock = sync.RWMutex{}
  44. )
  45. func GetPricing() []Pricing {
  46. if time.Since(lastGetPricingTime) > time.Minute*1 || len(pricingMap) == 0 {
  47. updatePricingLock.Lock()
  48. defer updatePricingLock.Unlock()
  49. // Double check after acquiring the lock
  50. if time.Since(lastGetPricingTime) > time.Minute*1 || len(pricingMap) == 0 {
  51. modelSupportEndpointsLock.Lock()
  52. defer modelSupportEndpointsLock.Unlock()
  53. updatePricing()
  54. }
  55. }
  56. return pricingMap
  57. }
  58. // GetVendors 返回当前定价接口使用到的供应商信息
  59. func GetVendors() []PricingVendor {
  60. if time.Since(lastGetPricingTime) > time.Minute*1 || len(pricingMap) == 0 {
  61. // 保证先刷新一次
  62. GetPricing()
  63. }
  64. return vendorsList
  65. }
  66. func GetModelSupportEndpointTypes(model string) []constant.EndpointType {
  67. if model == "" {
  68. return make([]constant.EndpointType, 0)
  69. }
  70. modelSupportEndpointsLock.RLock()
  71. defer modelSupportEndpointsLock.RUnlock()
  72. if endpoints, ok := modelSupportEndpointTypes[model]; ok {
  73. return endpoints
  74. }
  75. return make([]constant.EndpointType, 0)
  76. }
  77. func updatePricing() {
  78. //modelRatios := common.GetModelRatios()
  79. enableAbilities, err := GetAllEnableAbilityWithChannels()
  80. if err != nil {
  81. common.SysError(fmt.Sprintf("GetAllEnableAbilityWithChannels error: %v", err))
  82. return
  83. }
  84. // 预加载模型元数据与供应商一次,避免循环查询
  85. var allMeta []Model
  86. _ = DB.Find(&allMeta).Error
  87. metaMap := make(map[string]*Model)
  88. prefixList := make([]*Model, 0)
  89. suffixList := make([]*Model, 0)
  90. containsList := make([]*Model, 0)
  91. for i := range allMeta {
  92. m := &allMeta[i]
  93. if m.NameRule == NameRuleExact {
  94. metaMap[m.ModelName] = m
  95. } else {
  96. switch m.NameRule {
  97. case NameRulePrefix:
  98. prefixList = append(prefixList, m)
  99. case NameRuleSuffix:
  100. suffixList = append(suffixList, m)
  101. case NameRuleContains:
  102. containsList = append(containsList, m)
  103. }
  104. }
  105. }
  106. // 将非精确规则模型匹配到 metaMap
  107. for _, m := range prefixList {
  108. for _, pricingModel := range enableAbilities {
  109. if strings.HasPrefix(pricingModel.Model, m.ModelName) {
  110. if _, exists := metaMap[pricingModel.Model]; !exists {
  111. metaMap[pricingModel.Model] = m
  112. }
  113. }
  114. }
  115. }
  116. for _, m := range suffixList {
  117. for _, pricingModel := range enableAbilities {
  118. if strings.HasSuffix(pricingModel.Model, m.ModelName) {
  119. if _, exists := metaMap[pricingModel.Model]; !exists {
  120. metaMap[pricingModel.Model] = m
  121. }
  122. }
  123. }
  124. }
  125. for _, m := range containsList {
  126. for _, pricingModel := range enableAbilities {
  127. if strings.Contains(pricingModel.Model, m.ModelName) {
  128. if _, exists := metaMap[pricingModel.Model]; !exists {
  129. metaMap[pricingModel.Model] = m
  130. }
  131. }
  132. }
  133. }
  134. // 预加载供应商
  135. var vendors []Vendor
  136. _ = DB.Find(&vendors).Error
  137. vendorMap := make(map[int]*Vendor)
  138. for i := range vendors {
  139. vendorMap[vendors[i].Id] = &vendors[i]
  140. }
  141. // 构建对前端友好的供应商列表
  142. vendorsList = make([]PricingVendor, 0, len(vendors))
  143. for _, v := range vendors {
  144. vendorsList = append(vendorsList, PricingVendor{
  145. ID: v.Id,
  146. Name: v.Name,
  147. Description: v.Description,
  148. Icon: v.Icon,
  149. })
  150. }
  151. modelGroupsMap := make(map[string]*types.Set[string])
  152. for _, ability := range enableAbilities {
  153. groups, ok := modelGroupsMap[ability.Model]
  154. if !ok {
  155. groups = types.NewSet[string]()
  156. modelGroupsMap[ability.Model] = groups
  157. }
  158. groups.Add(ability.Group)
  159. }
  160. //这里使用切片而不是Set,因为一个模型可能支持多个端点类型,并且第一个端点是优先使用端点
  161. modelSupportEndpointsStr := make(map[string][]string)
  162. for _, ability := range enableAbilities {
  163. endpoints, ok := modelSupportEndpointsStr[ability.Model]
  164. if !ok {
  165. endpoints = make([]string, 0)
  166. modelSupportEndpointsStr[ability.Model] = endpoints
  167. }
  168. channelTypes := common.GetEndpointTypesByChannelType(ability.ChannelType, ability.Model)
  169. for _, channelType := range channelTypes {
  170. if !common.StringsContains(endpoints, string(channelType)) {
  171. endpoints = append(endpoints, string(channelType))
  172. }
  173. }
  174. modelSupportEndpointsStr[ability.Model] = endpoints
  175. }
  176. modelSupportEndpointTypes = make(map[string][]constant.EndpointType)
  177. for model, endpoints := range modelSupportEndpointsStr {
  178. supportedEndpoints := make([]constant.EndpointType, 0)
  179. for _, endpointStr := range endpoints {
  180. endpointType := constant.EndpointType(endpointStr)
  181. supportedEndpoints = append(supportedEndpoints, endpointType)
  182. }
  183. modelSupportEndpointTypes[model] = supportedEndpoints
  184. }
  185. pricingMap = make([]Pricing, 0)
  186. for model, groups := range modelGroupsMap {
  187. pricing := Pricing{
  188. ModelName: model,
  189. EnableGroup: groups.Items(),
  190. SupportedEndpointTypes: modelSupportEndpointTypes[model],
  191. }
  192. // 补充模型元数据(描述、标签、供应商、状态)
  193. if meta, ok := metaMap[model]; ok {
  194. // 若模型被禁用(status!=1),则直接跳过,不返回给前端
  195. if meta.Status != 1 {
  196. continue
  197. }
  198. pricing.Description = meta.Description
  199. pricing.Tags = meta.Tags
  200. pricing.VendorID = meta.VendorID
  201. }
  202. modelPrice, findPrice := ratio_setting.GetModelPrice(model, false)
  203. if findPrice {
  204. pricing.ModelPrice = modelPrice
  205. pricing.QuotaType = 1
  206. } else {
  207. modelRatio, _, _ := ratio_setting.GetModelRatio(model)
  208. pricing.ModelRatio = modelRatio
  209. pricing.CompletionRatio = ratio_setting.GetCompletionRatio(model)
  210. pricing.QuotaType = 0
  211. }
  212. pricingMap = append(pricingMap, pricing)
  213. }
  214. // 刷新缓存映射,供高并发快速查询
  215. modelEnableGroupsLock.Lock()
  216. modelEnableGroups = make(map[string][]string)
  217. modelQuotaTypeMap = make(map[string]int)
  218. for _, p := range pricingMap {
  219. modelEnableGroups[p.ModelName] = p.EnableGroup
  220. modelQuotaTypeMap[p.ModelName] = p.QuotaType
  221. }
  222. modelEnableGroupsLock.Unlock()
  223. lastGetPricingTime = time.Now()
  224. }