Просмотр исходного кода

feat: Initialize model settings and improve concurrency control in operation settings

CaIon 11 месяцев назад
Родитель
Сommit
aa34c3035a
3 измененных файлов с 65 добавлено и 41 удалено
  1. 4 0
      main.go
  2. 7 8
      setting/operation_setting/cache_ratio.go
  3. 54 33
      setting/operation_setting/model-ratio.go

+ 4 - 0
main.go

@@ -12,6 +12,7 @@ import (
 	"one-api/model"
 	"one-api/model"
 	"one-api/router"
 	"one-api/router"
 	"one-api/service"
 	"one-api/service"
+	"one-api/setting/operation_setting"
 	"os"
 	"os"
 	"strconv"
 	"strconv"
 
 
@@ -73,6 +74,9 @@ func main() {
 	constant.InitEnv()
 	constant.InitEnv()
 	// Initialize options
 	// Initialize options
 	model.InitOptionMap()
 	model.InitOptionMap()
+	// Initialize model settings
+	operation_setting.InitModelSettings()
+	
 	if common.RedisEnabled {
 	if common.RedisEnabled {
 		// for compatibility with old versions
 		// for compatibility with old versions
 		common.MemoryCacheEnabled = true
 		common.MemoryCacheEnabled = true

+ 7 - 8
setting/operation_setting/cache_ratio.go

@@ -56,17 +56,15 @@ var cacheRatioMapMutex sync.RWMutex
 
 
 // GetCacheRatioMap returns the cache ratio map
 // GetCacheRatioMap returns the cache ratio map
 func GetCacheRatioMap() map[string]float64 {
 func GetCacheRatioMap() map[string]float64 {
-	cacheRatioMapMutex.Lock()
-	defer cacheRatioMapMutex.Unlock()
-	if cacheRatioMap == nil {
-		cacheRatioMap = defaultCacheRatio
-	}
+	cacheRatioMapMutex.RLock()
+	defer cacheRatioMapMutex.RUnlock()
 	return cacheRatioMap
 	return cacheRatioMap
 }
 }
 
 
 // CacheRatio2JSONString converts the cache ratio map to a JSON string
 // CacheRatio2JSONString converts the cache ratio map to a JSON string
 func CacheRatio2JSONString() string {
 func CacheRatio2JSONString() string {
-	GetCacheRatioMap()
+	cacheRatioMapMutex.RLock()
+	defer cacheRatioMapMutex.RUnlock()
 	jsonBytes, err := json.Marshal(cacheRatioMap)
 	jsonBytes, err := json.Marshal(cacheRatioMap)
 	if err != nil {
 	if err != nil {
 		common.SysError("error marshalling cache ratio: " + err.Error())
 		common.SysError("error marshalling cache ratio: " + err.Error())
@@ -84,10 +82,11 @@ func UpdateCacheRatioByJSONString(jsonStr string) error {
 
 
 // GetCacheRatio returns the cache ratio for a model
 // GetCacheRatio returns the cache ratio for a model
 func GetCacheRatio(name string) (float64, bool) {
 func GetCacheRatio(name string) (float64, bool) {
-	GetCacheRatioMap()
+	cacheRatioMapMutex.RLock()
+	defer cacheRatioMapMutex.RUnlock()
 	ratio, ok := cacheRatioMap[name]
 	ratio, ok := cacheRatioMap[name]
 	if !ok {
 	if !ok {
-		return 1, false // Default to 0.5 if not found
+		return 1, false // Default to 1 if not found
 	}
 	}
 	return ratio, true
 	return ratio, true
 }
 }

+ 54 - 33
setting/operation_setting/model-ratio.go

@@ -245,17 +245,41 @@ var defaultCompletionRatio = map[string]float64{
 	"gpt-4-all":      2,
 	"gpt-4-all":      2,
 }
 }
 
 
-func GetModelPriceMap() map[string]float64 {
+// InitModelSettings initializes all model related settings maps
+func InitModelSettings() {
+	// Initialize modelPriceMap
 	modelPriceMapMutex.Lock()
 	modelPriceMapMutex.Lock()
-	defer modelPriceMapMutex.Unlock()
-	if modelPriceMap == nil {
-		modelPriceMap = defaultModelPrice
-	}
+	modelPriceMap = defaultModelPrice
+	modelPriceMapMutex.Unlock()
+
+	// Initialize modelRatioMap
+	modelRatioMapMutex.Lock()
+	modelRatioMap = defaultModelRatio
+	modelRatioMapMutex.Unlock()
+
+	// Initialize CompletionRatio
+	CompletionRatioMutex.Lock()
+	CompletionRatio = defaultCompletionRatio
+	CompletionRatioMutex.Unlock()
+
+	// Initialize cacheRatioMap
+	cacheRatioMapMutex.Lock()
+	cacheRatioMap = defaultCacheRatio
+	cacheRatioMapMutex.Unlock()
+
+	common.SysLog("model settings initialized")
+}
+
+func GetModelPriceMap() map[string]float64 {
+	modelPriceMapMutex.RLock()
+	defer modelPriceMapMutex.RUnlock()
 	return modelPriceMap
 	return modelPriceMap
 }
 }
 
 
 func ModelPrice2JSONString() string {
 func ModelPrice2JSONString() string {
-	GetModelPriceMap()
+	modelPriceMapMutex.RLock()
+	defer modelPriceMapMutex.RUnlock()
+
 	jsonBytes, err := json.Marshal(modelPriceMap)
 	jsonBytes, err := json.Marshal(modelPriceMap)
 	if err != nil {
 	if err != nil {
 		common.SysError("error marshalling model price: " + err.Error())
 		common.SysError("error marshalling model price: " + err.Error())
@@ -272,7 +296,9 @@ func UpdateModelPriceByJSONString(jsonStr string) error {
 
 
 // GetModelPrice 返回模型的价格,如果模型不存在则返回-1,false
 // GetModelPrice 返回模型的价格,如果模型不存在则返回-1,false
 func GetModelPrice(name string, printErr bool) (float64, bool) {
 func GetModelPrice(name string, printErr bool) (float64, bool) {
-	GetModelPriceMap()
+	modelPriceMapMutex.RLock()
+	defer modelPriceMapMutex.RUnlock()
+
 	if strings.HasPrefix(name, "gpt-4-gizmo") {
 	if strings.HasPrefix(name, "gpt-4-gizmo") {
 		name = "gpt-4-gizmo-*"
 		name = "gpt-4-gizmo-*"
 	}
 	}
@@ -289,24 +315,6 @@ func GetModelPrice(name string, printErr bool) (float64, bool) {
 	return price, true
 	return price, true
 }
 }
 
 
-func GetModelRatioMap() map[string]float64 {
-	modelRatioMapMutex.Lock()
-	defer modelRatioMapMutex.Unlock()
-	if modelRatioMap == nil {
-		modelRatioMap = defaultModelRatio
-	}
-	return modelRatioMap
-}
-
-func ModelRatio2JSONString() string {
-	GetModelRatioMap()
-	jsonBytes, err := json.Marshal(modelRatioMap)
-	if err != nil {
-		common.SysError("error marshalling model ratio: " + err.Error())
-	}
-	return string(jsonBytes)
-}
-
 func UpdateModelRatioByJSONString(jsonStr string) error {
 func UpdateModelRatioByJSONString(jsonStr string) error {
 	modelRatioMapMutex.Lock()
 	modelRatioMapMutex.Lock()
 	defer modelRatioMapMutex.Unlock()
 	defer modelRatioMapMutex.Unlock()
@@ -315,7 +323,9 @@ func UpdateModelRatioByJSONString(jsonStr string) error {
 }
 }
 
 
 func GetModelRatio(name string) (float64, bool) {
 func GetModelRatio(name string) (float64, bool) {
-	GetModelRatioMap()
+	modelRatioMapMutex.RLock()
+	defer modelRatioMapMutex.RUnlock()
+
 	if strings.HasPrefix(name, "gpt-4-gizmo") {
 	if strings.HasPrefix(name, "gpt-4-gizmo") {
 		name = "gpt-4-gizmo-*"
 		name = "gpt-4-gizmo-*"
 	}
 	}
@@ -339,16 +349,15 @@ func GetDefaultModelRatioMap() map[string]float64 {
 }
 }
 
 
 func GetCompletionRatioMap() map[string]float64 {
 func GetCompletionRatioMap() map[string]float64 {
-	CompletionRatioMutex.Lock()
-	defer CompletionRatioMutex.Unlock()
-	if CompletionRatio == nil {
-		CompletionRatio = defaultCompletionRatio
-	}
+	CompletionRatioMutex.RLock()
+	defer CompletionRatioMutex.RUnlock()
 	return CompletionRatio
 	return CompletionRatio
 }
 }
 
 
 func CompletionRatio2JSONString() string {
 func CompletionRatio2JSONString() string {
-	GetCompletionRatioMap()
+	CompletionRatioMutex.RLock()
+	defer CompletionRatioMutex.RUnlock()
+
 	jsonBytes, err := json.Marshal(CompletionRatio)
 	jsonBytes, err := json.Marshal(CompletionRatio)
 	if err != nil {
 	if err != nil {
 		common.SysError("error marshalling completion ratio: " + err.Error())
 		common.SysError("error marshalling completion ratio: " + err.Error())
@@ -364,7 +373,8 @@ func UpdateCompletionRatioByJSONString(jsonStr string) error {
 }
 }
 
 
 func GetCompletionRatio(name string) float64 {
 func GetCompletionRatio(name string) float64 {
-	GetCompletionRatioMap()
+	CompletionRatioMutex.RLock()
+	defer CompletionRatioMutex.RUnlock()
 
 
 	if strings.Contains(name, "/") {
 	if strings.Contains(name, "/") {
 		if ratio, ok := CompletionRatio[name]; ok {
 		if ratio, ok := CompletionRatio[name]; ok {
@@ -511,3 +521,14 @@ func GetAudioCompletionRatio(name string) float64 {
 	}
 	}
 	return 2
 	return 2
 }
 }
+
+func ModelRatio2JSONString() string {
+	modelRatioMapMutex.RLock()
+	defer modelRatioMapMutex.RUnlock()
+
+	jsonBytes, err := json.Marshal(modelRatioMap)
+	if err != nil {
+		common.SysError("error marshalling model ratio: " + err.Error())
+	}
+	return string(jsonBytes)
+}