فهرست منبع

feat: 完善模型价格获取逻辑

CaIon 1 سال پیش
والد
کامیت
93858c32d9
3فایلهای تغییر یافته به همراه14 افزوده شده و 14 حذف شده
  1. 8 8
      controller/model.go
  2. 5 5
      model/pricing.go
  3. 1 1
      router/api-router.go

+ 8 - 8
controller/model.go

@@ -108,8 +108,8 @@ func init() {
 		})
 	}
 	openAIModelsMap = make(map[string]dto.OpenAIModels)
-	for _, model := range openAIModels {
-		openAIModelsMap[model.Id] = model
+	for _, aiModel := range openAIModels {
+		openAIModelsMap[aiModel.Id] = aiModel
 	}
 	channelId2Models = make(map[int][]string)
 	for i := 1; i <= common.ChannelTypeDummy; i++ {
@@ -174,8 +174,8 @@ func DashboardListModels(c *gin.Context) {
 
 func RetrieveModel(c *gin.Context) {
 	modelId := c.Param("model")
-	if model, ok := openAIModelsMap[modelId]; ok {
-		c.JSON(200, model)
+	if aiModel, ok := openAIModelsMap[modelId]; ok {
+		c.JSON(200, aiModel)
 	} else {
 		openAIError := dto.OpenAIError{
 			Message: fmt.Sprintf("The model '%s' does not exist", modelId),
@@ -191,12 +191,12 @@ func RetrieveModel(c *gin.Context) {
 
 func GetPricing(c *gin.Context) {
 	userId := c.GetInt("id")
-	user, _ := model.GetUserById(userId, true)
+	group, err := model.CacheGetUserGroup(userId)
 	groupRatio := common.GetGroupRatio("default")
-	if user != nil {
-		groupRatio = common.GetGroupRatio(user.Group)
+	if err != nil {
+		groupRatio = common.GetGroupRatio(group)
 	}
-	pricing := model.GetPricing(user, openAIModels)
+	pricing := model.GetPricing(group)
 	c.JSON(200, gin.H{
 		"success":     true,
 		"data":        pricing,

+ 5 - 5
model/pricing.go

@@ -13,16 +13,16 @@ var (
 	updatePricingLock  sync.Mutex
 )
 
-func GetPricing(user *User, openAIModels []dto.OpenAIModels) []dto.ModelPricing {
+func GetPricing(group string) []dto.ModelPricing {
 	updatePricingLock.Lock()
 	defer updatePricingLock.Unlock()
 
 	if time.Since(lastGetPricingTime) > time.Minute*1 || len(pricingMap) == 0 {
-		updatePricing(openAIModels)
+		updatePricing()
 	}
-	if user != nil {
+	if group != "" {
 		userPricingMap := make([]dto.ModelPricing, 0)
-		models := GetGroupModels(user.Group)
+		models := GetGroupModels(group)
 		for _, pricing := range pricingMap {
 			if !common.StringsContains(models, pricing.ModelName) {
 				pricing.Available = false
@@ -34,7 +34,7 @@ func GetPricing(user *User, openAIModels []dto.OpenAIModels) []dto.ModelPricing
 	return pricingMap
 }
 
-func updatePricing(openAIModels []dto.OpenAIModels) {
+func updatePricing() {
 	//modelRatios := common.GetModelRatios()
 	enabledModels := GetEnabledModels()
 	allModels := make(map[string]int)

+ 1 - 1
router/api-router.go

@@ -20,7 +20,7 @@ func SetApiRouter(router *gin.Engine) {
 		apiRouter.GET("/about", controller.GetAbout)
 		//apiRouter.GET("/midjourney", controller.GetMidjourney)
 		apiRouter.GET("/home_page_content", controller.GetHomePageContent)
-		apiRouter.GET("/pricing", middleware.CriticalRateLimit(), middleware.TryUserAuth(), controller.GetPricing)
+		apiRouter.GET("/pricing", middleware.TryUserAuth(), controller.GetPricing)
 		apiRouter.GET("/verification", middleware.CriticalRateLimit(), middleware.TurnstileCheck(), controller.SendEmailVerification)
 		apiRouter.GET("/reset_password", middleware.CriticalRateLimit(), middleware.TurnstileCheck(), controller.SendPasswordResetEmail)
 		apiRouter.POST("/user/reset", middleware.CriticalRateLimit(), controller.ResetPassword)