|
|
@@ -1,31 +1,31 @@
|
|
|
package model
|
|
|
|
|
|
import (
|
|
|
- "encoding/json"
|
|
|
- "fmt"
|
|
|
- "strings"
|
|
|
+ "encoding/json"
|
|
|
+ "fmt"
|
|
|
+ "strings"
|
|
|
|
|
|
- "one-api/common"
|
|
|
- "one-api/constant"
|
|
|
- "one-api/setting/ratio_setting"
|
|
|
- "one-api/types"
|
|
|
- "sync"
|
|
|
- "time"
|
|
|
+ "one-api/common"
|
|
|
+ "one-api/constant"
|
|
|
+ "one-api/setting/ratio_setting"
|
|
|
+ "one-api/types"
|
|
|
+ "sync"
|
|
|
+ "time"
|
|
|
)
|
|
|
|
|
|
type Pricing struct {
|
|
|
- ModelName string `json:"model_name"`
|
|
|
- Description string `json:"description,omitempty"`
|
|
|
- Icon string `json:"icon,omitempty"`
|
|
|
- Tags string `json:"tags,omitempty"`
|
|
|
- VendorID int `json:"vendor_id,omitempty"`
|
|
|
- QuotaType int `json:"quota_type"`
|
|
|
- ModelRatio float64 `json:"model_ratio"`
|
|
|
- ModelPrice float64 `json:"model_price"`
|
|
|
- OwnerBy string `json:"owner_by"`
|
|
|
- CompletionRatio float64 `json:"completion_ratio"`
|
|
|
- EnableGroup []string `json:"enable_groups"`
|
|
|
- SupportedEndpointTypes []constant.EndpointType `json:"supported_endpoint_types"`
|
|
|
+ ModelName string `json:"model_name"`
|
|
|
+ Description string `json:"description,omitempty"`
|
|
|
+ Icon string `json:"icon,omitempty"`
|
|
|
+ Tags string `json:"tags,omitempty"`
|
|
|
+ VendorID int `json:"vendor_id,omitempty"`
|
|
|
+ QuotaType int `json:"quota_type"`
|
|
|
+ ModelRatio float64 `json:"model_ratio"`
|
|
|
+ ModelPrice float64 `json:"model_price"`
|
|
|
+ OwnerBy string `json:"owner_by"`
|
|
|
+ CompletionRatio float64 `json:"completion_ratio"`
|
|
|
+ EnableGroup []string `json:"enable_groups"`
|
|
|
+ SupportedEndpointTypes []constant.EndpointType `json:"supported_endpoint_types"`
|
|
|
}
|
|
|
|
|
|
type PricingVendor struct {
|
|
|
@@ -36,11 +36,11 @@ type PricingVendor struct {
|
|
|
}
|
|
|
|
|
|
var (
|
|
|
- pricingMap []Pricing
|
|
|
- vendorsList []PricingVendor
|
|
|
- supportedEndpointMap map[string]common.EndpointInfo
|
|
|
- lastGetPricingTime time.Time
|
|
|
- updatePricingLock sync.Mutex
|
|
|
+ pricingMap []Pricing
|
|
|
+ vendorsList []PricingVendor
|
|
|
+ supportedEndpointMap map[string]common.EndpointInfo
|
|
|
+ lastGetPricingTime time.Time
|
|
|
+ updatePricingLock sync.Mutex
|
|
|
|
|
|
// 缓存映射:模型名 -> 启用分组 / 计费类型
|
|
|
modelEnableGroups = make(map[string][]string)
|
|
|
@@ -122,19 +122,19 @@ func updatePricing() {
|
|
|
for _, m := range prefixList {
|
|
|
for _, pricingModel := range enableAbilities {
|
|
|
if strings.HasPrefix(pricingModel.Model, m.ModelName) {
|
|
|
- if _, exists := metaMap[pricingModel.Model]; !exists {
|
|
|
- metaMap[pricingModel.Model] = m
|
|
|
- }
|
|
|
- }
|
|
|
+ if _, exists := metaMap[pricingModel.Model]; !exists {
|
|
|
+ metaMap[pricingModel.Model] = m
|
|
|
+ }
|
|
|
+ }
|
|
|
}
|
|
|
}
|
|
|
for _, m := range suffixList {
|
|
|
for _, pricingModel := range enableAbilities {
|
|
|
if strings.HasSuffix(pricingModel.Model, m.ModelName) {
|
|
|
- if _, exists := metaMap[pricingModel.Model]; !exists {
|
|
|
- metaMap[pricingModel.Model] = m
|
|
|
- }
|
|
|
- }
|
|
|
+ if _, exists := metaMap[pricingModel.Model]; !exists {
|
|
|
+ metaMap[pricingModel.Model] = m
|
|
|
+ }
|
|
|
+ }
|
|
|
}
|
|
|
}
|
|
|
for _, m := range containsList {
|
|
|
@@ -180,34 +180,34 @@ func updatePricing() {
|
|
|
//这里使用切片而不是Set,因为一个模型可能支持多个端点类型,并且第一个端点是优先使用端点
|
|
|
modelSupportEndpointsStr := make(map[string][]string)
|
|
|
|
|
|
- // 先根据已有能力填充原生端点
|
|
|
- for _, ability := range enableAbilities {
|
|
|
- endpoints := modelSupportEndpointsStr[ability.Model]
|
|
|
- channelTypes := common.GetEndpointTypesByChannelType(ability.ChannelType, ability.Model)
|
|
|
- for _, channelType := range channelTypes {
|
|
|
- if !common.StringsContains(endpoints, string(channelType)) {
|
|
|
- endpoints = append(endpoints, string(channelType))
|
|
|
- }
|
|
|
- }
|
|
|
- modelSupportEndpointsStr[ability.Model] = endpoints
|
|
|
- }
|
|
|
+ // 先根据已有能力填充原生端点
|
|
|
+ for _, ability := range enableAbilities {
|
|
|
+ endpoints := modelSupportEndpointsStr[ability.Model]
|
|
|
+ channelTypes := common.GetEndpointTypesByChannelType(ability.ChannelType, ability.Model)
|
|
|
+ for _, channelType := range channelTypes {
|
|
|
+ if !common.StringsContains(endpoints, string(channelType)) {
|
|
|
+ endpoints = append(endpoints, string(channelType))
|
|
|
+ }
|
|
|
+ }
|
|
|
+ modelSupportEndpointsStr[ability.Model] = endpoints
|
|
|
+ }
|
|
|
|
|
|
- // 再补充模型自定义端点
|
|
|
- for modelName, meta := range metaMap {
|
|
|
- if strings.TrimSpace(meta.Endpoints) == "" {
|
|
|
- continue
|
|
|
- }
|
|
|
- var raw map[string]interface{}
|
|
|
- if err := json.Unmarshal([]byte(meta.Endpoints), &raw); err == nil {
|
|
|
- endpoints := modelSupportEndpointsStr[modelName]
|
|
|
- for k := range raw {
|
|
|
- if !common.StringsContains(endpoints, k) {
|
|
|
- endpoints = append(endpoints, k)
|
|
|
- }
|
|
|
- }
|
|
|
- modelSupportEndpointsStr[modelName] = endpoints
|
|
|
- }
|
|
|
- }
|
|
|
+ // 再补充模型自定义端点
|
|
|
+ for modelName, meta := range metaMap {
|
|
|
+ if strings.TrimSpace(meta.Endpoints) == "" {
|
|
|
+ continue
|
|
|
+ }
|
|
|
+ var raw map[string]interface{}
|
|
|
+ if err := json.Unmarshal([]byte(meta.Endpoints), &raw); err == nil {
|
|
|
+ endpoints := modelSupportEndpointsStr[modelName]
|
|
|
+ for k := range raw {
|
|
|
+ if !common.StringsContains(endpoints, k) {
|
|
|
+ endpoints = append(endpoints, k)
|
|
|
+ }
|
|
|
+ }
|
|
|
+ modelSupportEndpointsStr[modelName] = endpoints
|
|
|
+ }
|
|
|
+ }
|
|
|
|
|
|
modelSupportEndpointTypes = make(map[string][]constant.EndpointType)
|
|
|
for model, endpoints := range modelSupportEndpointsStr {
|
|
|
@@ -217,93 +217,93 @@ func updatePricing() {
|
|
|
supportedEndpoints = append(supportedEndpoints, endpointType)
|
|
|
}
|
|
|
modelSupportEndpointTypes[model] = supportedEndpoints
|
|
|
- }
|
|
|
+ }
|
|
|
|
|
|
- // 构建全局 supportedEndpointMap(默认 + 自定义覆盖)
|
|
|
- supportedEndpointMap = make(map[string]common.EndpointInfo)
|
|
|
- // 1. 默认端点
|
|
|
- for _, endpoints := range modelSupportEndpointTypes {
|
|
|
- for _, et := range endpoints {
|
|
|
- if info, ok := common.GetDefaultEndpointInfo(et); ok {
|
|
|
- if _, exists := supportedEndpointMap[string(et)]; !exists {
|
|
|
- supportedEndpointMap[string(et)] = info
|
|
|
- }
|
|
|
- }
|
|
|
- }
|
|
|
- }
|
|
|
- // 2. 自定义端点(models 表)覆盖默认
|
|
|
- for _, meta := range metaMap {
|
|
|
- if strings.TrimSpace(meta.Endpoints) == "" {
|
|
|
- continue
|
|
|
- }
|
|
|
- var raw map[string]interface{}
|
|
|
- if err := json.Unmarshal([]byte(meta.Endpoints), &raw); err == nil {
|
|
|
- for k, v := range raw {
|
|
|
- switch val := v.(type) {
|
|
|
- case string:
|
|
|
- supportedEndpointMap[k] = common.EndpointInfo{Path: val, Method: "POST"}
|
|
|
- case map[string]interface{}:
|
|
|
- ep := common.EndpointInfo{Method: "POST"}
|
|
|
- if p, ok := val["path"].(string); ok {
|
|
|
- ep.Path = p
|
|
|
- }
|
|
|
- if m, ok := val["method"].(string); ok {
|
|
|
- ep.Method = strings.ToUpper(m)
|
|
|
- }
|
|
|
- supportedEndpointMap[k] = ep
|
|
|
- default:
|
|
|
- // ignore unsupported types
|
|
|
- }
|
|
|
- }
|
|
|
- }
|
|
|
- }
|
|
|
+ // 构建全局 supportedEndpointMap(默认 + 自定义覆盖)
|
|
|
+ supportedEndpointMap = make(map[string]common.EndpointInfo)
|
|
|
+ // 1. 默认端点
|
|
|
+ for _, endpoints := range modelSupportEndpointTypes {
|
|
|
+ for _, et := range endpoints {
|
|
|
+ if info, ok := common.GetDefaultEndpointInfo(et); ok {
|
|
|
+ if _, exists := supportedEndpointMap[string(et)]; !exists {
|
|
|
+ supportedEndpointMap[string(et)] = info
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
+ // 2. 自定义端点(models 表)覆盖默认
|
|
|
+ for _, meta := range metaMap {
|
|
|
+ if strings.TrimSpace(meta.Endpoints) == "" {
|
|
|
+ continue
|
|
|
+ }
|
|
|
+ var raw map[string]interface{}
|
|
|
+ if err := json.Unmarshal([]byte(meta.Endpoints), &raw); err == nil {
|
|
|
+ for k, v := range raw {
|
|
|
+ switch val := v.(type) {
|
|
|
+ case string:
|
|
|
+ supportedEndpointMap[k] = common.EndpointInfo{Path: val, Method: "POST"}
|
|
|
+ case map[string]interface{}:
|
|
|
+ ep := common.EndpointInfo{Method: "POST"}
|
|
|
+ if p, ok := val["path"].(string); ok {
|
|
|
+ ep.Path = p
|
|
|
+ }
|
|
|
+ if m, ok := val["method"].(string); ok {
|
|
|
+ ep.Method = strings.ToUpper(m)
|
|
|
+ }
|
|
|
+ supportedEndpointMap[k] = ep
|
|
|
+ default:
|
|
|
+ // ignore unsupported types
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
|
|
|
- pricingMap = make([]Pricing, 0)
|
|
|
- for model, groups := range modelGroupsMap {
|
|
|
- pricing := Pricing{
|
|
|
- ModelName: model,
|
|
|
- EnableGroup: groups.Items(),
|
|
|
- SupportedEndpointTypes: modelSupportEndpointTypes[model],
|
|
|
- }
|
|
|
+ pricingMap = make([]Pricing, 0)
|
|
|
+ for model, groups := range modelGroupsMap {
|
|
|
+ pricing := Pricing{
|
|
|
+ ModelName: model,
|
|
|
+ EnableGroup: groups.Items(),
|
|
|
+ SupportedEndpointTypes: modelSupportEndpointTypes[model],
|
|
|
+ }
|
|
|
|
|
|
- // 补充模型元数据(描述、标签、供应商、状态)
|
|
|
- if meta, ok := metaMap[model]; ok {
|
|
|
- // 若模型被禁用(status!=1),则直接跳过,不返回给前端
|
|
|
- if meta.Status != 1 {
|
|
|
- continue
|
|
|
- }
|
|
|
- pricing.Description = meta.Description
|
|
|
- pricing.Icon = meta.Icon
|
|
|
- pricing.Tags = meta.Tags
|
|
|
- pricing.VendorID = meta.VendorID
|
|
|
- }
|
|
|
- modelPrice, findPrice := ratio_setting.GetModelPrice(model, false)
|
|
|
- if findPrice {
|
|
|
- pricing.ModelPrice = modelPrice
|
|
|
- pricing.QuotaType = 1
|
|
|
- } else {
|
|
|
- modelRatio, _, _ := ratio_setting.GetModelRatio(model)
|
|
|
- pricing.ModelRatio = modelRatio
|
|
|
- pricing.CompletionRatio = ratio_setting.GetCompletionRatio(model)
|
|
|
- pricing.QuotaType = 0
|
|
|
- }
|
|
|
- pricingMap = append(pricingMap, pricing)
|
|
|
- }
|
|
|
+ // 补充模型元数据(描述、标签、供应商、状态)
|
|
|
+ if meta, ok := metaMap[model]; ok {
|
|
|
+ // 若模型被禁用(status!=1),则直接跳过,不返回给前端
|
|
|
+ if meta.Status != 1 {
|
|
|
+ continue
|
|
|
+ }
|
|
|
+ pricing.Description = meta.Description
|
|
|
+ pricing.Icon = meta.Icon
|
|
|
+ pricing.Tags = meta.Tags
|
|
|
+ pricing.VendorID = meta.VendorID
|
|
|
+ }
|
|
|
+ modelPrice, findPrice := ratio_setting.GetModelPrice(model, false)
|
|
|
+ if findPrice {
|
|
|
+ pricing.ModelPrice = modelPrice
|
|
|
+ pricing.QuotaType = 1
|
|
|
+ } else {
|
|
|
+ modelRatio, _, _ := ratio_setting.GetModelRatio(model)
|
|
|
+ pricing.ModelRatio = modelRatio
|
|
|
+ pricing.CompletionRatio = ratio_setting.GetCompletionRatio(model)
|
|
|
+ pricing.QuotaType = 0
|
|
|
+ }
|
|
|
+ pricingMap = append(pricingMap, pricing)
|
|
|
+ }
|
|
|
|
|
|
- // 刷新缓存映射,供高并发快速查询
|
|
|
- modelEnableGroupsLock.Lock()
|
|
|
- modelEnableGroups = make(map[string][]string)
|
|
|
- modelQuotaTypeMap = make(map[string]int)
|
|
|
- for _, p := range pricingMap {
|
|
|
- modelEnableGroups[p.ModelName] = p.EnableGroup
|
|
|
- modelQuotaTypeMap[p.ModelName] = p.QuotaType
|
|
|
- }
|
|
|
- modelEnableGroupsLock.Unlock()
|
|
|
+ // 刷新缓存映射,供高并发快速查询
|
|
|
+ modelEnableGroupsLock.Lock()
|
|
|
+ modelEnableGroups = make(map[string][]string)
|
|
|
+ modelQuotaTypeMap = make(map[string]int)
|
|
|
+ for _, p := range pricingMap {
|
|
|
+ modelEnableGroups[p.ModelName] = p.EnableGroup
|
|
|
+ modelQuotaTypeMap[p.ModelName] = p.QuotaType
|
|
|
+ }
|
|
|
+ modelEnableGroupsLock.Unlock()
|
|
|
|
|
|
- lastGetPricingTime = time.Now()
|
|
|
+ lastGetPricingTime = time.Now()
|
|
|
}
|
|
|
|
|
|
// GetSupportedEndpointMap 返回全局端点到路径的映射
|
|
|
func GetSupportedEndpointMap() map[string]common.EndpointInfo {
|
|
|
- return supportedEndpointMap
|
|
|
+ return supportedEndpointMap
|
|
|
}
|