Преглед изворни кода

fix: panic when get model ratio (close #392)

CalciumIon пре 1 година
родитељ
комит
b16e6bf423
1 измењених фајлова са 40 додато и 26 уклоњено
  1. 40 26
      common/model-ratio.go

+ 40 - 26
common/model-ratio.go

@@ -3,6 +3,7 @@ package common
 import (
 	"encoding/json"
 	"strings"
+	"sync"
 )
 
 // from songquanpeng/one-api
@@ -182,8 +183,14 @@ var defaultModelPrice = map[string]float64{
 	"swap_face":         0.05,
 }
 
-var modelPrice map[string]float64 = nil
-var modelRatio map[string]float64 = nil
+var (
+	modelPriceMap      = make(map[string]float64)
+	modelPriceMapMutex = sync.RWMutex{}
+)
+var (
+	modelRatioMap      map[string]float64 = nil
+	modelRatioMapMutex                    = sync.RWMutex{}
+)
 
 var CompletionRatio map[string]float64 = nil
 var defaultCompletionRatio = map[string]float64{
@@ -191,11 +198,18 @@ var defaultCompletionRatio = map[string]float64{
 	"gpt-4-all":     2,
 }
 
-func ModelPrice2JSONString() string {
-	if modelPrice == nil {
-		modelPrice = defaultModelPrice
+func GetModelPriceMap() map[string]float64 {
+	modelPriceMapMutex.Lock()
+	defer modelPriceMapMutex.Unlock()
+	if modelPriceMap == nil {
+		modelPriceMap = defaultModelPrice
 	}
-	jsonBytes, err := json.Marshal(modelPrice)
+	return modelPriceMap
+}
+
+func ModelPrice2JSONString() string {
+	GetModelPriceMap()
+	jsonBytes, err := json.Marshal(modelPriceMap)
 	if err != nil {
 		SysError("error marshalling model price: " + err.Error())
 	}
@@ -203,19 +217,19 @@ func ModelPrice2JSONString() string {
 }
 
 func UpdateModelPriceByJSONString(jsonStr string) error {
-	modelPrice = make(map[string]float64)
-	return json.Unmarshal([]byte(jsonStr), &modelPrice)
+	modelPriceMapMutex.Lock()
+	defer modelPriceMapMutex.Unlock()
+	modelPriceMap = make(map[string]float64)
+	return json.Unmarshal([]byte(jsonStr), &modelPriceMap)
 }
 
 // GetModelPrice 返回模型的价格,如果模型不存在则返回-1,false
 func GetModelPrice(name string, printErr bool) (float64, bool) {
-	if modelPrice == nil {
-		modelPrice = defaultModelPrice
-	}
+	GetModelPriceMap()
 	if strings.HasPrefix(name, "gpt-4-gizmo") {
 		name = "gpt-4-gizmo-*"
 	}
-	price, ok := modelPrice[name]
+	price, ok := modelPriceMap[name]
 	if !ok {
 		if printErr {
 			SysError("model price not found: " + name)
@@ -225,18 +239,18 @@ func GetModelPrice(name string, printErr bool) (float64, bool) {
 	return price, true
 }
 
-func GetModelPriceMap() map[string]float64 {
-	if modelPrice == nil {
-		modelPrice = defaultModelPrice
+func GetModelRatioMap() map[string]float64 {
+	modelRatioMapMutex.Lock()
+	defer modelRatioMapMutex.Unlock()
+	if modelRatioMap == nil {
+		modelRatioMap = defaultModelRatio
 	}
-	return modelPrice
+	return modelRatioMap
 }
 
 func ModelRatio2JSONString() string {
-	if modelRatio == nil {
-		modelRatio = defaultModelRatio
-	}
-	jsonBytes, err := json.Marshal(modelRatio)
+	GetModelRatioMap()
+	jsonBytes, err := json.Marshal(modelRatioMap)
 	if err != nil {
 		SysError("error marshalling model ratio: " + err.Error())
 	}
@@ -244,18 +258,18 @@ func ModelRatio2JSONString() string {
 }
 
 func UpdateModelRatioByJSONString(jsonStr string) error {
-	modelRatio = make(map[string]float64)
-	return json.Unmarshal([]byte(jsonStr), &modelRatio)
+	modelRatioMapMutex.Lock()
+	defer modelRatioMapMutex.Unlock()
+	modelRatioMap = make(map[string]float64)
+	return json.Unmarshal([]byte(jsonStr), &modelRatioMap)
 }
 
 func GetModelRatio(name string) float64 {
-	if modelRatio == nil {
-		modelRatio = defaultModelRatio
-	}
+	GetModelRatioMap()
 	if strings.HasPrefix(name, "gpt-4-gizmo") {
 		name = "gpt-4-gizmo-*"
 	}
-	ratio, ok := modelRatio[name]
+	ratio, ok := modelRatioMap[name]
 	if !ok {
 		SysError("model ratio not found: " + name)
 		return 30