Просмотр исходного кода

refactor: migrate group ratio and user usable groups logic to new setting package

- Replaced references to common.GroupRatio and common.UserUsableGroups with corresponding functions from the new setting package across multiple controllers and services.
- Introduced new setting functions for managing group ratios and user usable groups, enhancing code organization and maintainability.
- Updated related functions to ensure consistent behavior with the new setting package integration.
CalciumIon 1 год назад
Родитель
Сommit
4fc1fe318e

+ 4 - 4
controller/group.go

@@ -3,13 +3,13 @@ package controller
 import (
 	"github.com/gin-gonic/gin"
 	"net/http"
-	"one-api/common"
 	"one-api/model"
+	"one-api/setting"
 )
 
 func GetGroups(c *gin.Context) {
 	groupNames := make([]string, 0)
-	for groupName, _ := range common.GroupRatio {
+	for groupName, _ := range setting.GetGroupRatioCopy() {
 		groupNames = append(groupNames, groupName)
 	}
 	c.JSON(http.StatusOK, gin.H{
@@ -24,9 +24,9 @@ func GetUserGroups(c *gin.Context) {
 	userGroup := ""
 	userId := c.GetInt("id")
 	userGroup, _ = model.CacheGetUserGroup(userId)
-	for groupName, _ := range common.GroupRatio {
+	for groupName, _ := range setting.GetGroupRatioCopy() {
 		// UserUsableGroups contains the groups that the user can use
-		userUsableGroups := common.GetUserUsableGroups(userGroup)
+		userUsableGroups := setting.GetUserUsableGroups(userGroup)
 		if _, ok := userUsableGroups[groupName]; ok {
 			usableGroups[groupName] = userUsableGroups[groupName]
 		}

+ 2 - 1
controller/option.go

@@ -5,6 +5,7 @@ import (
 	"net/http"
 	"one-api/common"
 	"one-api/model"
+	"one-api/setting"
 	"strings"
 
 	"github.com/gin-gonic/gin"
@@ -83,7 +84,7 @@ func UpdateOption(c *gin.Context) {
 			return
 		}
 	case "GroupRatio":
-		err = common.CheckGroupRatio(option.Value)
+		err = setting.CheckGroupRatio(option.Value)
 		if err != nil {
 			c.JSON(http.StatusOK, gin.H{
 				"success": false,

+ 4 - 3
controller/pricing.go

@@ -4,6 +4,7 @@ import (
 	"github.com/gin-gonic/gin"
 	"one-api/common"
 	"one-api/model"
+	"one-api/setting"
 )
 
 func GetPricing(c *gin.Context) {
@@ -11,7 +12,7 @@ func GetPricing(c *gin.Context) {
 	userId, exists := c.Get("id")
 	usableGroup := map[string]string{}
 	groupRatio := map[string]float64{}
-	for s, f := range common.GroupRatio {
+	for s, f := range setting.GetGroupRatioCopy() {
 		groupRatio[s] = f
 	}
 	var group string
@@ -22,9 +23,9 @@ func GetPricing(c *gin.Context) {
 		}
 	}
 
-	usableGroup = common.GetUserUsableGroups(group)
+	usableGroup = setting.GetUserUsableGroups(group)
 	// check groupRatio contains usableGroup
-	for group := range common.GroupRatio {
+	for group := range setting.GetGroupRatioCopy() {
 		if _, ok := usableGroup[group]; !ok {
 			delete(groupRatio, group)
 		}

+ 2 - 1
controller/relay.go

@@ -17,6 +17,7 @@ import (
 	"one-api/relay/constant"
 	relayconstant "one-api/relay/constant"
 	"one-api/service"
+	"one-api/setting"
 	"strings"
 )
 
@@ -83,7 +84,7 @@ func Playground(c *gin.Context) {
 	if group == "" {
 		group = userGroup
 	} else {
-		if !common.GroupInUserUsableGroups(group) && group != userGroup {
+		if !setting.GroupInUserUsableGroups(group) && group != userGroup {
 			openaiErr = service.OpenAIErrorWrapperLocal(errors.New("无权访问该分组"), "group_not_allowed", http.StatusForbidden)
 			return
 		}

+ 3 - 2
middleware/distributor.go

@@ -10,6 +10,7 @@ import (
 	"one-api/model"
 	relayconstant "one-api/relay/constant"
 	"one-api/service"
+	"one-api/setting"
 	"strconv"
 	"strings"
 	"time"
@@ -43,12 +44,12 @@ func Distribute() func(c *gin.Context) {
 		tokenGroup := c.GetString("token_group")
 		if tokenGroup != "" {
 			// check common.UserUsableGroups[userGroup]
-			if _, ok := common.GetUserUsableGroups(userGroup)[tokenGroup]; !ok {
+			if _, ok := setting.GetUserUsableGroups(userGroup)[tokenGroup]; !ok {
 				abortWithOpenAiMessage(c, http.StatusForbidden, fmt.Sprintf("令牌分组 %s 已被禁用", tokenGroup))
 				return
 			}
 			// check group in common.GroupRatio
-			if _, ok := common.GroupRatio[tokenGroup]; !ok {
+			if !setting.ContainsGroupRatio(tokenGroup) {
 				abortWithOpenAiMessage(c, http.StatusForbidden, fmt.Sprintf("分组 %s 已被弃用", tokenGroup))
 				return
 			}

+ 4 - 4
model/option.go

@@ -87,8 +87,8 @@ func InitOptionMap() {
 	common.OptionMap["PreConsumedQuota"] = strconv.Itoa(common.PreConsumedQuota)
 	common.OptionMap["ModelRatio"] = common.ModelRatio2JSONString()
 	common.OptionMap["ModelPrice"] = common.ModelPrice2JSONString()
-	common.OptionMap["GroupRatio"] = common.GroupRatio2JSONString()
-	common.OptionMap["UserUsableGroups"] = common.UserUsableGroups2JSONString()
+	common.OptionMap["GroupRatio"] = setting.GroupRatio2JSONString()
+	common.OptionMap["UserUsableGroups"] = setting.UserUsableGroups2JSONString()
 	common.OptionMap["CompletionRatio"] = common.CompletionRatio2JSONString()
 	common.OptionMap["TopUpLink"] = common.TopUpLink
 	common.OptionMap["ChatLink"] = common.ChatLink
@@ -313,9 +313,9 @@ func updateOptionMap(key string, value string) (err error) {
 	case "ModelRatio":
 		err = common.UpdateModelRatioByJSONString(value)
 	case "GroupRatio":
-		err = common.UpdateGroupRatioByJSONString(value)
+		err = setting.UpdateGroupRatioByJSONString(value)
 	case "UserUsableGroups":
-		err = common.UpdateUserUsableGroupsByJSONString(value)
+		err = setting.UpdateUserUsableGroupsByJSONString(value)
 	case "CompletionRatio":
 		err = common.UpdateCompletionRatioByJSONString(value)
 	case "ModelPrice":

+ 1 - 1
relay/relay-audio.go

@@ -74,7 +74,7 @@ func AudioHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) {
 	}
 
 	modelRatio := common.GetModelRatio(audioRequest.Model)
-	groupRatio := common.GetGroupRatio(relayInfo.Group)
+	groupRatio := setting.GetGroupRatio(relayInfo.Group)
 	ratio := modelRatio * groupRatio
 	preConsumedQuota := int(float64(preConsumedTokens) * ratio)
 	userQuota, err := model.CacheGetUserQuota(relayInfo.UserId)

+ 1 - 1
relay/relay-image.go

@@ -99,7 +99,7 @@ func ImageHelper(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode {
 		modelPrice = 0.0025 * modelRatio
 	}
 
-	groupRatio := common.GetGroupRatio(relayInfo.Group)
+	groupRatio := setting.GetGroupRatio(relayInfo.Group)
 	userQuota, err := model.CacheGetUserQuota(relayInfo.UserId)
 
 	sizeRatio := 1.0

+ 2 - 2
relay/relay-mj.go

@@ -168,7 +168,7 @@ func RelaySwapFace(c *gin.Context) *dto.MidjourneyResponse {
 			modelPrice = defaultPrice
 		}
 	}
-	groupRatio := common.GetGroupRatio(group)
+	groupRatio := setting.GetGroupRatio(group)
 	ratio := modelPrice * groupRatio
 	userQuota, err := model.CacheGetUserQuota(userId)
 	if err != nil {
@@ -474,7 +474,7 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons
 			modelPrice = defaultPrice
 		}
 	}
-	groupRatio := common.GetGroupRatio(group)
+	groupRatio := setting.GetGroupRatio(group)
 	ratio := modelPrice * groupRatio
 	userQuota, err := model.CacheGetUserQuota(userId)
 	if err != nil {

+ 1 - 1
relay/relay-text.go

@@ -94,7 +94,7 @@ func TextHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) {
 	}
 	relayInfo.UpstreamModelName = textRequest.Model
 	modelPrice, getModelPriceSuccess := common.GetModelPrice(textRequest.Model, false)
-	groupRatio := common.GetGroupRatio(relayInfo.Group)
+	groupRatio := setting.GetGroupRatio(relayInfo.Group)
 
 	var preConsumedQuota int
 	var ratio float64

+ 2 - 1
relay/relay_rerank.go

@@ -10,6 +10,7 @@ import (
 	"one-api/dto"
 	relaycommon "one-api/relay/common"
 	"one-api/service"
+	"one-api/setting"
 )
 
 func getRerankPromptToken(rerankRequest dto.RerankRequest) int {
@@ -57,7 +58,7 @@ func RerankHelper(c *gin.Context, relayMode int) (openaiErr *dto.OpenAIErrorWith
 
 	relayInfo.UpstreamModelName = rerankRequest.Model
 	modelPrice, success := common.GetModelPrice(rerankRequest.Model, false)
-	groupRatio := common.GetGroupRatio(relayInfo.Group)
+	groupRatio := setting.GetGroupRatio(relayInfo.Group)
 
 	var preConsumedQuota int
 	var ratio float64

+ 2 - 1
relay/relay_task.go

@@ -16,6 +16,7 @@ import (
 	relaycommon "one-api/relay/common"
 	relayconstant "one-api/relay/constant"
 	"one-api/service"
+	"one-api/setting"
 )
 
 /*
@@ -48,7 +49,7 @@ func RelayTaskSubmit(c *gin.Context, relayMode int) (taskErr *dto.TaskError) {
 	}
 
 	// 预扣
-	groupRatio := common.GetGroupRatio(relayInfo.Group)
+	groupRatio := setting.GetGroupRatio(relayInfo.Group)
 	ratio := modelPrice * groupRatio
 	userQuota, err := model.CacheGetUserQuota(relayInfo.UserId)
 	if err != nil {

+ 2 - 1
relay/websocket.go

@@ -10,6 +10,7 @@ import (
 	"one-api/dto"
 	relaycommon "one-api/relay/common"
 	"one-api/service"
+	"one-api/setting"
 )
 
 //func getAndValidateWssRequest(c *gin.Context, ws *websocket.Conn) (*dto.RealtimeEvent, error) {
@@ -57,7 +58,7 @@ func WssHelper(c *gin.Context, ws *websocket.Conn) (openaiErr *dto.OpenAIErrorWi
 	}
 	//relayInfo.UpstreamModelName = textRequest.Model
 	modelPrice, getModelPriceSuccess := common.GetModelPrice(relayInfo.UpstreamModelName, false)
-	groupRatio := common.GetGroupRatio(relayInfo.Group)
+	groupRatio := setting.GetGroupRatio(relayInfo.Group)
 
 	var preConsumedQuota int
 	var ratio float64

+ 2 - 1
service/quota.go

@@ -9,6 +9,7 @@ import (
 	"one-api/dto"
 	"one-api/model"
 	relaycommon "one-api/relay/common"
+	"one-api/setting"
 	"strings"
 	"time"
 )
@@ -36,7 +37,7 @@ func PreWssConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usag
 	completionRatio := common.GetCompletionRatio(modelName)
 	audioRatio := common.GetAudioRatio(relayInfo.UpstreamModelName)
 	audioCompletionRatio := common.GetAudioCompletionRatio(modelName)
-	groupRatio := common.GetGroupRatio(relayInfo.Group)
+	groupRatio := setting.GetGroupRatio(relayInfo.Group)
 	modelRatio := common.GetModelRatio(modelName)
 
 	ratio := groupRatio * modelRatio

+ 22 - 8
common/group-ratio.go → setting/group_ratio.go

@@ -1,33 +1,47 @@
-package common
+package setting
 
 import (
 	"encoding/json"
 	"errors"
+	"one-api/common"
 )
 
-var GroupRatio = map[string]float64{
+var groupRatio = map[string]float64{
 	"default": 1,
 	"vip":     1,
 	"svip":    1,
 }
 
+func GetGroupRatioCopy() map[string]float64 {
+	groupRatioCopy := make(map[string]float64)
+	for k, v := range groupRatio {
+		groupRatioCopy[k] = v
+	}
+	return groupRatioCopy
+}
+
+func ContainsGroupRatio(name string) bool {
+	_, ok := groupRatio[name]
+	return ok
+}
+
 func GroupRatio2JSONString() string {
-	jsonBytes, err := json.Marshal(GroupRatio)
+	jsonBytes, err := json.Marshal(groupRatio)
 	if err != nil {
-		SysError("error marshalling model ratio: " + err.Error())
+		common.SysError("error marshalling model ratio: " + err.Error())
 	}
 	return string(jsonBytes)
 }
 
 func UpdateGroupRatioByJSONString(jsonStr string) error {
-	GroupRatio = make(map[string]float64)
-	return json.Unmarshal([]byte(jsonStr), &GroupRatio)
+	groupRatio = make(map[string]float64)
+	return json.Unmarshal([]byte(jsonStr), &groupRatio)
 }
 
 func GetGroupRatio(name string) float64 {
-	ratio, ok := GroupRatio[name]
+	ratio, ok := groupRatio[name]
 	if !ok {
-		SysError("group ratio not found: " + name)
+		common.SysError("group ratio not found: " + name)
 		return 1
 	}
 	return ratio

+ 3 - 2
common/user_groups.go → setting/user_usable_group.go

@@ -1,7 +1,8 @@
-package common
+package setting
 
 import (
 	"encoding/json"
+	"one-api/common"
 )
 
 var UserUsableGroups = map[string]string{
@@ -12,7 +13,7 @@ var UserUsableGroups = map[string]string{
 func UserUsableGroups2JSONString() string {
 	jsonBytes, err := json.Marshal(UserUsableGroups)
 	if err != nil {
-		SysError("error marshalling user groups: " + err.Error())
+		common.SysError("error marshalling user groups: " + err.Error())
 	}
 	return string(jsonBytes)
 }