Просмотр исходного кода

feat: 添加令牌ip白名单功能

CalciumIon 1 год назад
Родитель
Сommit
f505afdc10
6 измененных файлов с 57 добавлено и 3 удалено
  1. 5 0
      common/utils.go
  2. 2 0
      controller/token.go
  3. 1 0
      middleware/auth.go
  4. 8 0
      middleware/distributor.go
  5. 24 1
      model/token.go
  6. 17 2
      web/src/pages/Token/EditToken.js

+ 5 - 0
common/utils.go

@@ -128,6 +128,11 @@ func IntMax(a int, b int) int {
 	}
 }
 
+func IsIP(s string) bool {
+	ip := net.ParseIP(s)
+	return ip != nil
+}
+
 func GetUUID() string {
 	code := uuid.New().String()
 	code = strings.Replace(code, "-", "", -1)

+ 2 - 0
controller/token.go

@@ -134,6 +134,7 @@ func AddToken(c *gin.Context) {
 		UnlimitedQuota:     token.UnlimitedQuota,
 		ModelLimitsEnabled: token.ModelLimitsEnabled,
 		ModelLimits:        token.ModelLimits,
+		AllowIps:           token.AllowIps,
 	}
 	err = cleanToken.Insert()
 	if err != nil {
@@ -221,6 +222,7 @@ func UpdateToken(c *gin.Context) {
 		cleanToken.UnlimitedQuota = token.UnlimitedQuota
 		cleanToken.ModelLimitsEnabled = token.ModelLimitsEnabled
 		cleanToken.ModelLimits = token.ModelLimits
+		cleanToken.AllowIps = token.AllowIps
 	}
 	err = cleanToken.Update()
 	if err != nil {

+ 1 - 0
middleware/auth.go

@@ -175,6 +175,7 @@ func TokenAuth() func(c *gin.Context) {
 		} else {
 			c.Set("token_model_limit_enabled", false)
 		}
+		c.Set("allow_ips", token.GetIpLimitsMap())
 		if len(parts) > 1 {
 			if model.IsAdmin(token.UserId) {
 				c.Set("specific_channel_id", parts[1])

+ 8 - 0
middleware/distributor.go

@@ -22,6 +22,14 @@ type ModelRequest struct {
 
 func Distribute() func(c *gin.Context) {
 	return func(c *gin.Context) {
+		allowIpsMap := c.GetStringMap("allow_ips")
+		if len(allowIpsMap) != 0 {
+			clientIp := c.ClientIP()
+			if _, ok := allowIpsMap[clientIp]; !ok {
+				abortWithOpenAiMessage(c, http.StatusForbidden, "您的 IP 不在令牌允许访问的列表中")
+				return
+			}
+		}
 		userId := c.GetInt("id")
 		var channel *model.Channel
 		channelId, ok := c.Get("specific_channel_id")

+ 24 - 1
model/token.go

@@ -23,10 +23,33 @@ type Token struct {
 	UnlimitedQuota     bool           `json:"unlimited_quota" gorm:"default:false"`
 	ModelLimitsEnabled bool           `json:"model_limits_enabled" gorm:"default:false"`
 	ModelLimits        string         `json:"model_limits" gorm:"type:varchar(1024);default:''"`
+	AllowIps           *string        `json:"allow_ips" gorm:"default:''"`
 	UsedQuota          int            `json:"used_quota" gorm:"default:0"` // used quota
 	DeletedAt          gorm.DeletedAt `gorm:"index"`
 }
 
+func (token *Token) GetIpLimitsMap() map[string]any {
+	// delete empty spaces
+	//split with \n
+	ipLimitsMap := make(map[string]any)
+	if token.AllowIps == nil {
+		return ipLimitsMap
+	}
+	cleanIps := strings.ReplaceAll(*token.AllowIps, " ", "")
+	if cleanIps == "" {
+		return ipLimitsMap
+	}
+	ips := strings.Split(cleanIps, "\n")
+	for _, ip := range ips {
+		ip = strings.TrimSpace(ip)
+		ip = strings.ReplaceAll(ip, ",", "")
+		if common.IsIP(ip) {
+			ipLimitsMap[ip] = true
+		}
+	}
+	return ipLimitsMap
+}
+
 func GetAllUserTokens(userId int, startIdx int, num int) ([]*Token, error) {
 	var tokens []*Token
 	var err error
@@ -130,7 +153,7 @@ func (token *Token) Insert() error {
 // Update Make sure your token's fields is completed, because this will update non-zero values
 func (token *Token) Update() error {
 	var err error
-	err = DB.Model(token).Select("name", "status", "expired_time", "remain_quota", "unlimited_quota", "model_limits_enabled", "model_limits").Updates(token).Error
+	err = DB.Model(token).Select("name", "status", "expired_time", "remain_quota", "unlimited_quota", "model_limits_enabled", "model_limits", "allow_ips").Updates(token).Error
 	return err
 }
 

+ 17 - 2
web/src/pages/Token/EditToken.js

@@ -18,8 +18,8 @@ import {
   Select,
   SideSheet,
   Space,
-  Spin,
-  Typography,
+  Spin, TextArea,
+  Typography
 } from '@douyinfe/semi-ui';
 import Title from '@douyinfe/semi-ui/lib/es/typography/title';
 import { Divider } from 'semantic-ui-react';
@@ -34,6 +34,7 @@ const EditToken = (props) => {
     unlimited_quota: false,
     model_limits_enabled: false,
     model_limits: [],
+    allow_ips: '',
   };
   const [inputs, setInputs] = useState(originInputs);
   const {
@@ -43,6 +44,7 @@ const EditToken = (props) => {
     unlimited_quota,
     model_limits_enabled,
     model_limits,
+    allow_ips
   } = inputs;
   // const [visible, setVisible] = useState(false);
   const [models, setModels] = useState({});
@@ -374,6 +376,19 @@ const EditToken = (props) => {
             </Button>
           </div>
           <Divider />
+          <div style={{ marginTop: 10 }}>
+            <Typography.Text>IP白名单(请勿过度信任此功能)</Typography.Text>
+          </div>
+          <TextArea
+            label='IP白名单'
+            name='allow_ips'
+            placeholder={'允许的IP,一行一个'}
+            onChange={(value) => {
+              handleInputChange('allow_ips', value);
+            }}
+            value={inputs.allow_ips}
+            style={{ fontFamily: 'JetBrains Mono, Consolas' }}
+          />
           <div style={{ marginTop: 10, display: 'flex' }}>
             <Space>
               <Checkbox