Просмотр исходного кода

fix: 修复用户可选分组不能选择用户分组 (close #528)

1808837298@qq.com 1 год назад
Родитель
Сommit
f599c65944

+ 18 - 0
common/user_groups.go

@@ -22,6 +22,24 @@ func UpdateUserUsableGroupsByJSONString(jsonStr string) error {
 	return json.Unmarshal([]byte(jsonStr), &UserUsableGroups)
 }
 
+func GetUserUsableGroups(userGroup string) map[string]string {
+	if userGroup == "" {
+		// 如果userGroup为空,返回UserUsableGroups
+		return UserUsableGroups
+	}
+	// 如果userGroup不在UserUsableGroups中,返回UserUsableGroups + userGroup
+	if _, ok := UserUsableGroups[userGroup]; !ok {
+		appendUserUsableGroups := make(map[string]string)
+		for k, v := range UserUsableGroups {
+			appendUserUsableGroups[k] = v
+		}
+		appendUserUsableGroups[userGroup] = "用户分组"
+		return appendUserUsableGroups
+	}
+	// 如果userGroup在UserUsableGroups中,返回UserUsableGroups
+	return UserUsableGroups
+}
+
 func GroupInUserUsableGroups(groupName string) bool {
 	_, ok := UserUsableGroups[groupName]
 	return ok

+ 7 - 2
controller/group.go

@@ -4,6 +4,7 @@ import (
 	"github.com/gin-gonic/gin"
 	"net/http"
 	"one-api/common"
+	"one-api/model"
 )
 
 func GetGroups(c *gin.Context) {
@@ -20,10 +21,14 @@ func GetGroups(c *gin.Context) {
 
 func GetUserGroups(c *gin.Context) {
 	usableGroups := make(map[string]string)
+	userGroup := ""
+	userId := c.GetInt("id")
+	userGroup, _ = model.CacheGetUserGroup(userId)
 	for groupName, _ := range common.GroupRatio {
 		// UserUsableGroups contains the groups that the user can use
-		if _, ok := common.UserUsableGroups[groupName]; ok {
-			usableGroups[groupName] = common.UserUsableGroups[groupName]
+		userUsableGroups := common.GetUserUsableGroups(userGroup)
+		if _, ok := userUsableGroups[groupName]; ok {
+			usableGroups[groupName] = userUsableGroups[groupName]
 		}
 	}
 	c.JSON(http.StatusOK, gin.H{

+ 1 - 1
middleware/distributor.go

@@ -42,7 +42,7 @@ func Distribute() func(c *gin.Context) {
 		tokenGroup := c.GetString("token_group")
 		if tokenGroup != "" {
 			// check common.UserUsableGroups[userGroup]
-			if _, ok := common.UserUsableGroups[tokenGroup]; !ok {
+			if _, ok := common.GetUserUsableGroups(userGroup)[tokenGroup]; !ok {
 				abortWithOpenAiMessage(c, http.StatusForbidden, fmt.Sprintf("令牌分组 %s 已被禁用", tokenGroup))
 				return
 			}

+ 1 - 0
router/api-router.go

@@ -44,6 +44,7 @@ func SetApiRouter(router *gin.Engine) {
 			selfRoute := userRoute.Group("/")
 			selfRoute.Use(middleware.UserAuth())
 			{
+				selfRoute.GET("/self/groups", controller.GetUserGroups)
 				selfRoute.GET("/self", controller.GetSelf)
 				selfRoute.GET("/models", controller.GetUserModels)
 				selfRoute.PUT("/self", controller.UpdateSelf)

+ 10 - 10
web/src/components/LogsTable.js

@@ -250,7 +250,7 @@ const LogsTable = () => {
       title: '类型',
       dataIndex: 'type',
       render: (text, record, index) => {
-        return <div>{renderType(text)}</div>;
+        return <>{renderType(text)}</>;
       },
     },
     {
@@ -258,7 +258,7 @@ const LogsTable = () => {
       dataIndex: 'model_name',
       render: (text, record, index) => {
         return record.type === 0 || record.type === 2 ? (
-          <div>
+          <>
             <Tag
               color={stringToColor(text)}
               size='large'
@@ -269,7 +269,7 @@ const LogsTable = () => {
               {' '}
               {text}{' '}
             </Tag>
-          </div>
+          </>
         ) : (
           <></>
         );
@@ -282,22 +282,22 @@ const LogsTable = () => {
         if (record.is_stream) {
           let other = getLogOther(record.other);
           return (
-            <div>
+            <>
               <Space>
                 {renderUseTime(text)}
                 {renderFirstUseTime(other.frt)}
                 {renderIsStream(record.is_stream)}
               </Space>
-            </div>
+            </>
           );
         } else {
           return (
-            <div>
+            <>
               <Space>
                 {renderUseTime(text)}
                 {renderIsStream(record.is_stream)}
               </Space>
-            </div>
+            </>
           );
         }
       },
@@ -307,7 +307,7 @@ const LogsTable = () => {
       dataIndex: 'prompt_tokens',
       render: (text, record, index) => {
         return record.type === 0 || record.type === 2 ? (
-          <div>{<span> {text} </span>}</div>
+          <>{<span> {text} </span>}</>
         ) : (
           <></>
         );
@@ -319,7 +319,7 @@ const LogsTable = () => {
       render: (text, record, index) => {
         return parseInt(text) > 0 &&
           (record.type === 0 || record.type === 2) ? (
-          <div>{<span> {text} </span>}</div>
+          <>{<span> {text} </span>}</>
         ) : (
           <></>
         );
@@ -330,7 +330,7 @@ const LogsTable = () => {
       dataIndex: 'quota',
       render: (text, record, index) => {
         return record.type === 0 || record.type === 2 ? (
-          <div>{renderQuota(text, 6)}</div>
+          <>{renderQuota(text, 6)}</>
         ) : (
           <></>
         );

+ 1 - 1
web/src/pages/Token/EditToken.js

@@ -92,7 +92,7 @@ const EditToken = (props) => {
   };
 
   const loadGroups = async () => {
-    let res = await API.get(`/api/user/groups`);
+    let res = await API.get(`/api/user/self/groups`);
     const { success, message, data } = res.data;
     if (success) {
       // return data is a map, key is group name, value is group description