Browse Source

feat: pricing page support multi groups #487

1808837298@qq.com 1 year ago
parent
commit
ed972eef06
5 changed files with 87 additions and 40 deletions
  1. 2 9
      controller/pricing.go
  2. 6 0
      model/ability.go
  3. 27 21
      model/pricing.go
  4. 1 1
      web/src/components/HeaderBar.js
  5. 51 9
      web/src/components/ModelPricing.js

+ 2 - 9
controller/pricing.go

@@ -7,18 +7,11 @@ import (
 )
 
 func GetPricing(c *gin.Context) {
-	userId := c.GetInt("id")
-	// if no login, get default group ratio
-	groupRatio := common.GetGroupRatio("default")
-	group, err := model.CacheGetUserGroup(userId)
-	if err == nil {
-		groupRatio = common.GetGroupRatio(group)
-	}
-	pricing := model.GetPricing(group)
+	pricing := model.GetPricing()
 	c.JSON(200, gin.H{
 		"success":     true,
 		"data":        pricing,
-		"group_ratio": groupRatio,
+		"group_ratio": common.GroupRatio,
 	})
 }
 

+ 6 - 0
model/ability.go

@@ -36,6 +36,12 @@ func GetEnabledModels() []string {
 	return models
 }
 
+func GetAllEnableAbilities() []Ability {
+	var abilities []Ability
+	DB.Find(&abilities, "enabled = ?", true)
+	return abilities
+}
+
 func getPriority(group string, model string, retry int) (int, error) {
 	groupCol := "`group`"
 	trueVal := "1"

+ 27 - 21
model/pricing.go

@@ -7,14 +7,13 @@ import (
 )
 
 type Pricing struct {
-	Available       bool     `json:"available"`
 	ModelName       string   `json:"model_name"`
 	QuotaType       int      `json:"quota_type"`
 	ModelRatio      float64  `json:"model_ratio"`
 	ModelPrice      float64  `json:"model_price"`
 	OwnerBy         string   `json:"owner_by"`
 	CompletionRatio float64  `json:"completion_ratio"`
-	EnableGroup     []string `json:"enable_group,omitempty"`
+	EnableGroup     []string `json:"enable_groups,omitempty"`
 }
 
 var (
@@ -23,40 +22,47 @@ var (
 	updatePricingLock  sync.Mutex
 )
 
-func GetPricing(group string) []Pricing {
+func GetPricing() []Pricing {
 	updatePricingLock.Lock()
 	defer updatePricingLock.Unlock()
 
 	if time.Since(lastGetPricingTime) > time.Minute*1 || len(pricingMap) == 0 {
 		updatePricing()
 	}
-	if group != "" {
-		userPricingMap := make([]Pricing, 0)
-		models := GetGroupModels(group)
-		for _, pricing := range pricingMap {
-			if !common.StringsContains(models, pricing.ModelName) {
-				pricing.Available = false
-			}
-			userPricingMap = append(userPricingMap, pricing)
-		}
-		return userPricingMap
-	}
+	//if group != "" {
+	//	userPricingMap := make([]Pricing, 0)
+	//	models := GetGroupModels(group)
+	//	for _, pricing := range pricingMap {
+	//		if !common.StringsContains(models, pricing.ModelName) {
+	//			pricing.Available = false
+	//		}
+	//		userPricingMap = append(userPricingMap, pricing)
+	//	}
+	//	return userPricingMap
+	//}
 	return pricingMap
 }
 
 func updatePricing() {
 	//modelRatios := common.GetModelRatios()
-	enabledModels := GetEnabledModels()
-	allModels := make(map[string]int)
-	for i, model := range enabledModels {
-		allModels[model] = i
+	enableAbilities := GetAllEnableAbilities()
+	modelGroupsMap := make(map[string][]string)
+	for _, ability := range enableAbilities {
+		groups := modelGroupsMap[ability.Model]
+		if groups == nil {
+			groups = make([]string, 0)
+		}
+		if !common.StringsContains(groups, ability.Group) {
+			groups = append(groups, ability.Group)
+		}
+		modelGroupsMap[ability.Model] = groups
 	}
 
 	pricingMap = make([]Pricing, 0)
-	for model, _ := range allModels {
+	for model, groups := range modelGroupsMap {
 		pricing := Pricing{
-			Available: true,
-			ModelName: model,
+			ModelName:   model,
+			EnableGroup: groups,
 		}
 		modelPrice, findPrice := common.GetModelPrice(model, false)
 		if findPrice {

+ 1 - 1
web/src/components/HeaderBar.js

@@ -36,7 +36,7 @@ let buttons = [
     text: '首页',
     itemKey: 'home',
     to: '/',
-    icon: <IconHomeStroked />,
+    // icon: <IconHomeStroked />,
   },
   // {
   //   text: '模型价格',

+ 51 - 9
web/src/components/ModelPricing.js

@@ -1,5 +1,5 @@
 import React, { useContext, useEffect, useRef, useMemo, useState } from 'react';
-import { API, copy, showError, showSuccess } from '../helpers';
+import { API, copy, showError, showInfo, showSuccess } from '../helpers';
 
 import {
   Banner,
@@ -87,6 +87,7 @@ const ModelPricing = () => {
   const [selectedRowKeys, setSelectedRowKeys] = useState([]);
   const [modalImageUrl, setModalImageUrl] = useState('');
   const [isModalOpenurl, setIsModalOpenurl] = useState(false);
+  const [selectedGroup, setSelectedGroup] = useState('default');
 
   const rowSelection = useMemo(
       () => ({
@@ -120,7 +121,8 @@ const ModelPricing = () => {
       title: '可用性',
       dataIndex: 'available',
       render: (text, record, index) => {
-        return renderAvailable(text);
+         // if record.enable_groups contains selectedGroup, then available is true
+        return renderAvailable(record.enable_groups.includes(selectedGroup));
       },
       sorter: (a, b) => a.available - b.available,
     },
@@ -166,6 +168,43 @@ const ModelPricing = () => {
       },
       sorter: (a, b) => a.quota_type - b.quota_type,
     },
+    {
+      title: '可用分组',
+      dataIndex: 'enable_groups',
+      render: (text, record, index) => {
+        // enable_groups is a string array
+        return (
+          <Space>
+            {text.map((group) => {
+              if (group === selectedGroup) {
+                return (
+                  <Tag
+                    color='blue'
+                    size='large'
+                    prefixIcon={<IconVerify />}
+                  >
+                    {group}
+                  </Tag>
+                );
+              } else {
+                return (
+                  <Tag
+                    color='blue'
+                    size='large'
+                    onClick={() => {
+                      setSelectedGroup(group);
+                      showInfo('当前查看的分组为:' + group + ',倍率为:' + groupRatio[group]);
+                    }}
+                  >
+                    {group}
+                  </Tag>
+                );
+              }
+            })}
+          </Space>
+        );
+      },
+    },
     {
       title: () => (
         <span style={{'display':'flex','alignItems':'center'}}>
@@ -201,6 +240,8 @@ const ModelPricing = () => {
             <Text>模型:{record.quota_type === 0 ? text : '无'}</Text>
             <br />
             <Text>补全:{record.quota_type === 0 ? completionRatio : '无'}</Text>
+            <br />
+            <Text>分组:{groupRatio[selectedGroup]}</Text>
           </>
         );
         return <div>{content}</div>;
@@ -213,11 +254,11 @@ const ModelPricing = () => {
         let content = text;
         if (record.quota_type === 0) {
           // 这里的 *2 是因为 1倍率=0.002刀,请勿删除
-          let inputRatioPrice = record.model_ratio * 2 * record.group_ratio;
+          let inputRatioPrice = record.model_ratio * 2 * groupRatio[selectedGroup];
           let completionRatioPrice =
             record.model_ratio *
             record.completion_ratio * 2 *
-            record.group_ratio;
+            groupRatio[selectedGroup];
           content = (
             <>
               <Text>提示 ${inputRatioPrice} / 1M tokens</Text>
@@ -226,7 +267,7 @@ const ModelPricing = () => {
             </>
           );
         } else {
-          let price = parseFloat(text) * record.group_ratio;
+          let price = parseFloat(text) * groupRatio[selectedGroup];
           content = <>模型价格:${price}</>;
         }
         return <div>{content}</div>;
@@ -237,12 +278,12 @@ const ModelPricing = () => {
   const [models, setModels] = useState([]);
   const [loading, setLoading] = useState(true);
   const [userState, userDispatch] = useContext(UserContext);
-  const [groupRatio, setGroupRatio] = useState(1);
+  const [groupRatio, setGroupRatio] = useState({});
 
   const setModelsFormat = (models, groupRatio) => {
     for (let i = 0; i < models.length; i++) {
       models[i].key = models[i].model_name;
-      models[i].group_ratio = groupRatio;
+      models[i].group_ratio = groupRatio[models[i].model_name];
     }
     // sort by quota_type
     models.sort((a, b) => {
@@ -275,6 +316,7 @@ const ModelPricing = () => {
     const { success, message, data, group_ratio } = res.data;
     if (success) {
       setGroupRatio(group_ratio);
+      setSelectedGroup(userState.user ? userState.user.group : 'default')
       setModelsFormat(data, group_ratio);
     } else {
       showError(message);
@@ -307,14 +349,14 @@ const ModelPricing = () => {
             type="success"
             fullMode={false}
             closeIcon="null"
-            description={`您的分组为:${userState.user.group},分组倍率为:${groupRatio}`}
+            description={`您的默认分组为:${userState.user.group},分组倍率为:${groupRatio[userState.user.group]}`}
           />
         ) : (
           <Banner
             type='warning'
             fullMode={false}
             closeIcon="null"
-            description={`您还未登陆,显示的价格为默认分组倍率: ${groupRatio}`}
+            description={`您还未登陆,显示的价格为默认分组倍率: ${groupRatio['default']}`}
           />
         )}
         <br/>