pricing.go 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307
  1. package model
  2. import (
  3. "encoding/json"
  4. "fmt"
  5. "strings"
  6. "one-api/common"
  7. "one-api/constant"
  8. "one-api/setting/ratio_setting"
  9. "one-api/types"
  10. "sync"
  11. "time"
  12. )
  13. type Pricing struct {
  14. ModelName string `json:"model_name"`
  15. Description string `json:"description,omitempty"`
  16. Tags string `json:"tags,omitempty"`
  17. VendorID int `json:"vendor_id,omitempty"`
  18. QuotaType int `json:"quota_type"`
  19. ModelRatio float64 `json:"model_ratio"`
  20. ModelPrice float64 `json:"model_price"`
  21. OwnerBy string `json:"owner_by"`
  22. CompletionRatio float64 `json:"completion_ratio"`
  23. EnableGroup []string `json:"enable_groups"`
  24. SupportedEndpointTypes []constant.EndpointType `json:"supported_endpoint_types"`
  25. }
  26. type PricingVendor struct {
  27. ID int `json:"id"`
  28. Name string `json:"name"`
  29. Description string `json:"description,omitempty"`
  30. Icon string `json:"icon,omitempty"`
  31. }
  32. var (
  33. pricingMap []Pricing
  34. vendorsList []PricingVendor
  35. supportedEndpointMap map[string]common.EndpointInfo
  36. lastGetPricingTime time.Time
  37. updatePricingLock sync.Mutex
  38. // 缓存映射:模型名 -> 启用分组 / 计费类型
  39. modelEnableGroups = make(map[string][]string)
  40. modelQuotaTypeMap = make(map[string]int)
  41. modelEnableGroupsLock = sync.RWMutex{}
  42. )
  43. var (
  44. modelSupportEndpointTypes = make(map[string][]constant.EndpointType)
  45. modelSupportEndpointsLock = sync.RWMutex{}
  46. )
  47. func GetPricing() []Pricing {
  48. if time.Since(lastGetPricingTime) > time.Minute*1 || len(pricingMap) == 0 {
  49. updatePricingLock.Lock()
  50. defer updatePricingLock.Unlock()
  51. // Double check after acquiring the lock
  52. if time.Since(lastGetPricingTime) > time.Minute*1 || len(pricingMap) == 0 {
  53. modelSupportEndpointsLock.Lock()
  54. defer modelSupportEndpointsLock.Unlock()
  55. updatePricing()
  56. }
  57. }
  58. return pricingMap
  59. }
  60. // GetVendors 返回当前定价接口使用到的供应商信息
  61. func GetVendors() []PricingVendor {
  62. if time.Since(lastGetPricingTime) > time.Minute*1 || len(pricingMap) == 0 {
  63. // 保证先刷新一次
  64. GetPricing()
  65. }
  66. return vendorsList
  67. }
  68. func GetModelSupportEndpointTypes(model string) []constant.EndpointType {
  69. if model == "" {
  70. return make([]constant.EndpointType, 0)
  71. }
  72. modelSupportEndpointsLock.RLock()
  73. defer modelSupportEndpointsLock.RUnlock()
  74. if endpoints, ok := modelSupportEndpointTypes[model]; ok {
  75. return endpoints
  76. }
  77. return make([]constant.EndpointType, 0)
  78. }
  79. func updatePricing() {
  80. //modelRatios := common.GetModelRatios()
  81. enableAbilities, err := GetAllEnableAbilityWithChannels()
  82. if err != nil {
  83. common.SysError(fmt.Sprintf("GetAllEnableAbilityWithChannels error: %v", err))
  84. return
  85. }
  86. // 预加载模型元数据与供应商一次,避免循环查询
  87. var allMeta []Model
  88. _ = DB.Find(&allMeta).Error
  89. metaMap := make(map[string]*Model)
  90. prefixList := make([]*Model, 0)
  91. suffixList := make([]*Model, 0)
  92. containsList := make([]*Model, 0)
  93. for i := range allMeta {
  94. m := &allMeta[i]
  95. if m.NameRule == NameRuleExact {
  96. metaMap[m.ModelName] = m
  97. } else {
  98. switch m.NameRule {
  99. case NameRulePrefix:
  100. prefixList = append(prefixList, m)
  101. case NameRuleSuffix:
  102. suffixList = append(suffixList, m)
  103. case NameRuleContains:
  104. containsList = append(containsList, m)
  105. }
  106. }
  107. }
  108. // 将非精确规则模型匹配到 metaMap
  109. for _, m := range prefixList {
  110. for _, pricingModel := range enableAbilities {
  111. if strings.HasPrefix(pricingModel.Model, m.ModelName) {
  112. if _, exists := metaMap[pricingModel.Model]; !exists {
  113. metaMap[pricingModel.Model] = m
  114. }
  115. }
  116. }
  117. }
  118. for _, m := range suffixList {
  119. for _, pricingModel := range enableAbilities {
  120. if strings.HasSuffix(pricingModel.Model, m.ModelName) {
  121. if _, exists := metaMap[pricingModel.Model]; !exists {
  122. metaMap[pricingModel.Model] = m
  123. }
  124. }
  125. }
  126. }
  127. for _, m := range containsList {
  128. for _, pricingModel := range enableAbilities {
  129. if strings.Contains(pricingModel.Model, m.ModelName) {
  130. if _, exists := metaMap[pricingModel.Model]; !exists {
  131. metaMap[pricingModel.Model] = m
  132. }
  133. }
  134. }
  135. }
  136. // 预加载供应商
  137. var vendors []Vendor
  138. _ = DB.Find(&vendors).Error
  139. vendorMap := make(map[int]*Vendor)
  140. for i := range vendors {
  141. vendorMap[vendors[i].Id] = &vendors[i]
  142. }
  143. // 构建对前端友好的供应商列表
  144. vendorsList = make([]PricingVendor, 0, len(vendors))
  145. for _, v := range vendors {
  146. vendorsList = append(vendorsList, PricingVendor{
  147. ID: v.Id,
  148. Name: v.Name,
  149. Description: v.Description,
  150. Icon: v.Icon,
  151. })
  152. }
  153. modelGroupsMap := make(map[string]*types.Set[string])
  154. for _, ability := range enableAbilities {
  155. groups, ok := modelGroupsMap[ability.Model]
  156. if !ok {
  157. groups = types.NewSet[string]()
  158. modelGroupsMap[ability.Model] = groups
  159. }
  160. groups.Add(ability.Group)
  161. }
  162. //这里使用切片而不是Set,因为一个模型可能支持多个端点类型,并且第一个端点是优先使用端点
  163. modelSupportEndpointsStr := make(map[string][]string)
  164. // 先根据已有能力填充原生端点
  165. for _, ability := range enableAbilities {
  166. endpoints := modelSupportEndpointsStr[ability.Model]
  167. channelTypes := common.GetEndpointTypesByChannelType(ability.ChannelType, ability.Model)
  168. for _, channelType := range channelTypes {
  169. if !common.StringsContains(endpoints, string(channelType)) {
  170. endpoints = append(endpoints, string(channelType))
  171. }
  172. }
  173. modelSupportEndpointsStr[ability.Model] = endpoints
  174. }
  175. // 再补充模型自定义端点
  176. for modelName, meta := range metaMap {
  177. if strings.TrimSpace(meta.Endpoints) == "" {
  178. continue
  179. }
  180. var raw map[string]interface{}
  181. if err := json.Unmarshal([]byte(meta.Endpoints), &raw); err == nil {
  182. endpoints := modelSupportEndpointsStr[modelName]
  183. for k := range raw {
  184. if !common.StringsContains(endpoints, k) {
  185. endpoints = append(endpoints, k)
  186. }
  187. }
  188. modelSupportEndpointsStr[modelName] = endpoints
  189. }
  190. }
  191. modelSupportEndpointTypes = make(map[string][]constant.EndpointType)
  192. for model, endpoints := range modelSupportEndpointsStr {
  193. supportedEndpoints := make([]constant.EndpointType, 0)
  194. for _, endpointStr := range endpoints {
  195. endpointType := constant.EndpointType(endpointStr)
  196. supportedEndpoints = append(supportedEndpoints, endpointType)
  197. }
  198. modelSupportEndpointTypes[model] = supportedEndpoints
  199. }
  200. // 构建全局 supportedEndpointMap(默认 + 自定义覆盖)
  201. supportedEndpointMap = make(map[string]common.EndpointInfo)
  202. // 1. 默认端点
  203. for _, endpoints := range modelSupportEndpointTypes {
  204. for _, et := range endpoints {
  205. if info, ok := common.GetDefaultEndpointInfo(et); ok {
  206. if _, exists := supportedEndpointMap[string(et)]; !exists {
  207. supportedEndpointMap[string(et)] = info
  208. }
  209. }
  210. }
  211. }
  212. // 2. 自定义端点(models 表)覆盖默认
  213. for _, meta := range metaMap {
  214. if strings.TrimSpace(meta.Endpoints) == "" {
  215. continue
  216. }
  217. var raw map[string]interface{}
  218. if err := json.Unmarshal([]byte(meta.Endpoints), &raw); err == nil {
  219. for k, v := range raw {
  220. switch val := v.(type) {
  221. case string:
  222. supportedEndpointMap[k] = common.EndpointInfo{Path: val, Method: "POST"}
  223. case map[string]interface{}:
  224. ep := common.EndpointInfo{Method: "POST"}
  225. if p, ok := val["path"].(string); ok {
  226. ep.Path = p
  227. }
  228. if m, ok := val["method"].(string); ok {
  229. ep.Method = strings.ToUpper(m)
  230. }
  231. supportedEndpointMap[k] = ep
  232. default:
  233. // ignore unsupported types
  234. }
  235. }
  236. }
  237. }
  238. pricingMap = make([]Pricing, 0)
  239. for model, groups := range modelGroupsMap {
  240. pricing := Pricing{
  241. ModelName: model,
  242. EnableGroup: groups.Items(),
  243. SupportedEndpointTypes: modelSupportEndpointTypes[model],
  244. }
  245. // 补充模型元数据(描述、标签、供应商、状态)
  246. if meta, ok := metaMap[model]; ok {
  247. // 若模型被禁用(status!=1),则直接跳过,不返回给前端
  248. if meta.Status != 1 {
  249. continue
  250. }
  251. pricing.Description = meta.Description
  252. pricing.Tags = meta.Tags
  253. pricing.VendorID = meta.VendorID
  254. }
  255. modelPrice, findPrice := ratio_setting.GetModelPrice(model, false)
  256. if findPrice {
  257. pricing.ModelPrice = modelPrice
  258. pricing.QuotaType = 1
  259. } else {
  260. modelRatio, _, _ := ratio_setting.GetModelRatio(model)
  261. pricing.ModelRatio = modelRatio
  262. pricing.CompletionRatio = ratio_setting.GetCompletionRatio(model)
  263. pricing.QuotaType = 0
  264. }
  265. pricingMap = append(pricingMap, pricing)
  266. }
  267. // 刷新缓存映射,供高并发快速查询
  268. modelEnableGroupsLock.Lock()
  269. modelEnableGroups = make(map[string][]string)
  270. modelQuotaTypeMap = make(map[string]int)
  271. for _, p := range pricingMap {
  272. modelEnableGroups[p.ModelName] = p.EnableGroup
  273. modelQuotaTypeMap[p.ModelName] = p.QuotaType
  274. }
  275. modelEnableGroupsLock.Unlock()
  276. lastGetPricingTime = time.Now()
  277. }
  278. // GetSupportedEndpointMap 返回全局端点到路径的映射
  279. func GetSupportedEndpointMap() map[string]common.EndpointInfo {
  280. return supportedEndpointMap
  281. }