|
|
@@ -3,6 +3,7 @@ package common
|
|
|
import (
|
|
|
"encoding/json"
|
|
|
"strings"
|
|
|
+ "sync"
|
|
|
)
|
|
|
|
|
|
// from songquanpeng/one-api
|
|
|
@@ -182,8 +183,14 @@ var defaultModelPrice = map[string]float64{
|
|
|
"swap_face": 0.05,
|
|
|
}
|
|
|
|
|
|
-var modelPrice map[string]float64 = nil
|
|
|
-var modelRatio map[string]float64 = nil
|
|
|
+var (
|
|
|
+ modelPriceMap = make(map[string]float64)
|
|
|
+ modelPriceMapMutex = sync.RWMutex{}
|
|
|
+)
|
|
|
+var (
|
|
|
+ modelRatioMap map[string]float64 = nil
|
|
|
+ modelRatioMapMutex = sync.RWMutex{}
|
|
|
+)
|
|
|
|
|
|
var CompletionRatio map[string]float64 = nil
|
|
|
var defaultCompletionRatio = map[string]float64{
|
|
|
@@ -191,11 +198,18 @@ var defaultCompletionRatio = map[string]float64{
|
|
|
"gpt-4-all": 2,
|
|
|
}
|
|
|
|
|
|
-func ModelPrice2JSONString() string {
|
|
|
- if modelPrice == nil {
|
|
|
- modelPrice = defaultModelPrice
|
|
|
+func GetModelPriceMap() map[string]float64 {
|
|
|
+ modelPriceMapMutex.Lock()
|
|
|
+ defer modelPriceMapMutex.Unlock()
|
|
|
+ if modelPriceMap == nil {
|
|
|
+ modelPriceMap = defaultModelPrice
|
|
|
}
|
|
|
- jsonBytes, err := json.Marshal(modelPrice)
|
|
|
+ return modelPriceMap
|
|
|
+}
|
|
|
+
|
|
|
+func ModelPrice2JSONString() string {
|
|
|
+ GetModelPriceMap()
|
|
|
+ jsonBytes, err := json.Marshal(modelPriceMap)
|
|
|
if err != nil {
|
|
|
SysError("error marshalling model price: " + err.Error())
|
|
|
}
|
|
|
@@ -203,19 +217,19 @@ func ModelPrice2JSONString() string {
|
|
|
}
|
|
|
|
|
|
func UpdateModelPriceByJSONString(jsonStr string) error {
|
|
|
- modelPrice = make(map[string]float64)
|
|
|
- return json.Unmarshal([]byte(jsonStr), &modelPrice)
|
|
|
+ modelPriceMapMutex.Lock()
|
|
|
+ defer modelPriceMapMutex.Unlock()
|
|
|
+ modelPriceMap = make(map[string]float64)
|
|
|
+ return json.Unmarshal([]byte(jsonStr), &modelPriceMap)
|
|
|
}
|
|
|
|
|
|
// GetModelPrice 返回模型的价格,如果模型不存在则返回-1,false
|
|
|
func GetModelPrice(name string, printErr bool) (float64, bool) {
|
|
|
- if modelPrice == nil {
|
|
|
- modelPrice = defaultModelPrice
|
|
|
- }
|
|
|
+ GetModelPriceMap()
|
|
|
if strings.HasPrefix(name, "gpt-4-gizmo") {
|
|
|
name = "gpt-4-gizmo-*"
|
|
|
}
|
|
|
- price, ok := modelPrice[name]
|
|
|
+ price, ok := modelPriceMap[name]
|
|
|
if !ok {
|
|
|
if printErr {
|
|
|
SysError("model price not found: " + name)
|
|
|
@@ -225,18 +239,18 @@ func GetModelPrice(name string, printErr bool) (float64, bool) {
|
|
|
return price, true
|
|
|
}
|
|
|
|
|
|
-func GetModelPriceMap() map[string]float64 {
|
|
|
- if modelPrice == nil {
|
|
|
- modelPrice = defaultModelPrice
|
|
|
+func GetModelRatioMap() map[string]float64 {
|
|
|
+ modelRatioMapMutex.Lock()
|
|
|
+ defer modelRatioMapMutex.Unlock()
|
|
|
+ if modelRatioMap == nil {
|
|
|
+ modelRatioMap = defaultModelRatio
|
|
|
}
|
|
|
- return modelPrice
|
|
|
+ return modelRatioMap
|
|
|
}
|
|
|
|
|
|
func ModelRatio2JSONString() string {
|
|
|
- if modelRatio == nil {
|
|
|
- modelRatio = defaultModelRatio
|
|
|
- }
|
|
|
- jsonBytes, err := json.Marshal(modelRatio)
|
|
|
+ GetModelRatioMap()
|
|
|
+ jsonBytes, err := json.Marshal(modelRatioMap)
|
|
|
if err != nil {
|
|
|
SysError("error marshalling model ratio: " + err.Error())
|
|
|
}
|
|
|
@@ -244,18 +258,18 @@ func ModelRatio2JSONString() string {
|
|
|
}
|
|
|
|
|
|
func UpdateModelRatioByJSONString(jsonStr string) error {
|
|
|
- modelRatio = make(map[string]float64)
|
|
|
- return json.Unmarshal([]byte(jsonStr), &modelRatio)
|
|
|
+ modelRatioMapMutex.Lock()
|
|
|
+ defer modelRatioMapMutex.Unlock()
|
|
|
+ modelRatioMap = make(map[string]float64)
|
|
|
+ return json.Unmarshal([]byte(jsonStr), &modelRatioMap)
|
|
|
}
|
|
|
|
|
|
func GetModelRatio(name string) float64 {
|
|
|
- if modelRatio == nil {
|
|
|
- modelRatio = defaultModelRatio
|
|
|
- }
|
|
|
+ GetModelRatioMap()
|
|
|
if strings.HasPrefix(name, "gpt-4-gizmo") {
|
|
|
name = "gpt-4-gizmo-*"
|
|
|
}
|
|
|
- ratio, ok := modelRatio[name]
|
|
|
+ ratio, ok := modelRatioMap[name]
|
|
|
if !ok {
|
|
|
SysError("model ratio not found: " + name)
|
|
|
return 30
|