Просмотр исходного кода

feat: support aff now (close #75)

JustSong 2 лет назад
Родитель
Сommit
c5837c3bb7

+ 2 - 0
common/constants.go

@@ -55,6 +55,8 @@ var TurnstileSiteKey = ""
 var TurnstileSecretKey = ""
 
 var QuotaForNewUser = 0
+var QuotaForInviter = 0
+var QuotaForInvitee = 0
 var ChannelDisableThreshold = 5.0
 var AutomaticDisableChannelEnabled = false
 var QuotaRemindThreshold = 1000

+ 9 - 0
common/utils.go

@@ -157,6 +157,15 @@ func GenerateKey() string {
 	return string(key)
 }
 
+func GetRandomString(length int) string {
+	rand.Seed(time.Now().UnixNano())
+	key := make([]byte, length)
+	for i := 0; i < length; i++ {
+		key[i] = keyChars[rand.Intn(len(keyChars))]
+	}
+	return string(key)
+}
+
 func GetTimestamp() int64 {
 	return time.Now().Unix()
 }

+ 1 - 1
controller/github.go

@@ -125,7 +125,7 @@ func GitHubOAuth(c *gin.Context) {
 			user.Role = common.RoleCommonUser
 			user.Status = common.UserStatusEnabled
 
-			if err := user.Insert(); err != nil {
+			if err := user.Insert(0); err != nil {
 				c.JSON(http.StatusOK, gin.H{
 					"success": false,
 					"message": err.Error(),

+ 33 - 2
controller/user.go

@@ -150,15 +150,18 @@ func Register(c *gin.Context) {
 			return
 		}
 	}
+	affCode := user.AffCode // this code is the inviter's code, not the user's own code
+	inviterId, _ := model.GetUserIdByAffCode(affCode)
 	cleanUser := model.User{
 		Username:    user.Username,
 		Password:    user.Password,
 		DisplayName: user.Username,
+		InviterId:   inviterId,
 	}
 	if common.EmailVerificationEnabled {
 		cleanUser.Email = user.Email
 	}
-	if err := cleanUser.Insert(); err != nil {
+	if err := cleanUser.Insert(inviterId); err != nil {
 		c.JSON(http.StatusOK, gin.H{
 			"success": false,
 			"message": err.Error(),
@@ -280,6 +283,34 @@ func GenerateAccessToken(c *gin.Context) {
 	return
 }
 
+func GetAffCode(c *gin.Context) {
+	id := c.GetInt("id")
+	user, err := model.GetUserById(id, true)
+	if err != nil {
+		c.JSON(http.StatusOK, gin.H{
+			"success": false,
+			"message": err.Error(),
+		})
+		return
+	}
+	if user.AffCode == "" {
+		user.AffCode = common.GetRandomString(4)
+		if err := user.Update(false); err != nil {
+			c.JSON(http.StatusOK, gin.H{
+				"success": false,
+				"message": err.Error(),
+			})
+			return
+		}
+	}
+	c.JSON(http.StatusOK, gin.H{
+		"success": true,
+		"message": "",
+		"data":    user.AffCode,
+	})
+	return
+}
+
 func GetSelf(c *gin.Context) {
 	id := c.GetInt("id")
 	user, err := model.GetUserById(id, false)
@@ -495,7 +526,7 @@ func CreateUser(c *gin.Context) {
 		Password:    user.Password,
 		DisplayName: user.DisplayName,
 	}
-	if err := cleanUser.Insert(); err != nil {
+	if err := cleanUser.Insert(0); err != nil {
 		c.JSON(http.StatusOK, gin.H{
 			"success": false,
 			"message": err.Error(),

+ 1 - 1
controller/wechat.go

@@ -85,7 +85,7 @@ func WeChatAuth(c *gin.Context) {
 			user.Role = common.RoleCommonUser
 			user.Status = common.UserStatusEnabled
 
-			if err := user.Insert(); err != nil {
+			if err := user.Insert(0); err != nil {
 				c.JSON(http.StatusOK, gin.H{
 					"success": false,
 					"message": err.Error(),

+ 6 - 0
model/option.go

@@ -56,6 +56,8 @@ func InitOptionMap() {
 	common.OptionMap["TurnstileSiteKey"] = ""
 	common.OptionMap["TurnstileSecretKey"] = ""
 	common.OptionMap["QuotaForNewUser"] = strconv.Itoa(common.QuotaForNewUser)
+	common.OptionMap["QuotaForInviter"] = strconv.Itoa(common.QuotaForInviter)
+	common.OptionMap["QuotaForInvitee"] = strconv.Itoa(common.QuotaForInvitee)
 	common.OptionMap["QuotaRemindThreshold"] = strconv.Itoa(common.QuotaRemindThreshold)
 	common.OptionMap["PreConsumedQuota"] = strconv.Itoa(common.PreConsumedQuota)
 	common.OptionMap["ModelRatio"] = common.ModelRatio2JSONString()
@@ -175,6 +177,10 @@ func updateOptionMap(key string, value string) (err error) {
 		common.TurnstileSecretKey = value
 	case "QuotaForNewUser":
 		common.QuotaForNewUser, _ = strconv.Atoi(value)
+	case "QuotaForInviter":
+		common.QuotaForInviter, _ = strconv.Atoi(value)
+	case "QuotaForInvitee":
+		common.QuotaForInvitee, _ = strconv.Atoi(value)
 	case "QuotaRemindThreshold":
 		common.QuotaRemindThreshold, _ = strconv.Atoi(value)
 	case "PreConsumedQuota":

+ 23 - 1
model/user.go

@@ -26,6 +26,8 @@ type User struct {
 	UsedQuota        int    `json:"used_quota" gorm:"type:int;default:0;column:used_quota"` // used quota
 	RequestCount     int    `json:"request_count" gorm:"type:int;default:0;"`               // request number
 	Group            string `json:"group" gorm:"type:varchar(32);default:'default'"`
+	AffCode          string `json:"aff_code" gorm:"type:varchar(32);column:aff_code;uniqueIndex"`
+	InviterId        int    `json:"inviter_id" gorm:"type:int;column:inviter_id;index"`
 }
 
 func GetMaxUserId() int {
@@ -58,6 +60,15 @@ func GetUserById(id int, selectAll bool) (*User, error) {
 	return &user, err
 }
 
+func GetUserIdByAffCode(affCode string) (int, error) {
+	if affCode == "" {
+		return 0, errors.New("affCode 为空!")
+	}
+	var user User
+	err := DB.Select("id").First(&user, "aff_code = ?", affCode).Error
+	return user.Id, err
+}
+
 func DeleteUserById(id int) (err error) {
 	if id == 0 {
 		return errors.New("id 为空!")
@@ -66,7 +77,7 @@ func DeleteUserById(id int) (err error) {
 	return user.Delete()
 }
 
-func (user *User) Insert() error {
+func (user *User) Insert(inviterId int) error {
 	var err error
 	if user.Password != "" {
 		user.Password, err = common.Password2Hash(user.Password)
@@ -76,6 +87,7 @@ func (user *User) Insert() error {
 	}
 	user.Quota = common.QuotaForNewUser
 	user.AccessToken = common.GetUUID()
+	user.AffCode = common.GetRandomString(4)
 	result := DB.Create(user)
 	if result.Error != nil {
 		return result.Error
@@ -83,6 +95,16 @@ func (user *User) Insert() error {
 	if common.QuotaForNewUser > 0 {
 		RecordLog(user.Id, LogTypeSystem, fmt.Sprintf("新用户注册赠送 %d 点额度", common.QuotaForNewUser))
 	}
+	if inviterId != 0 {
+		if common.QuotaForInvitee > 0 {
+			_ = IncreaseUserQuota(user.Id, common.QuotaForInvitee)
+			RecordLog(user.Id, LogTypeSystem, fmt.Sprintf("使用邀请码赠送 %d 点额度", common.QuotaForInvitee))
+		}
+		if common.QuotaForInviter > 0 {
+			_ = IncreaseUserQuota(inviterId, common.QuotaForInviter)
+			RecordLog(inviterId, LogTypeSystem, fmt.Sprintf("邀请用户赠送 %d 点额度", common.QuotaForInviter))
+		}
+	}
 	return nil
 }
 

+ 1 - 0
router/api-router.go

@@ -37,6 +37,7 @@ func SetApiRouter(router *gin.Engine) {
 				selfRoute.PUT("/self", controller.UpdateSelf)
 				selfRoute.DELETE("/self", controller.DeleteSelf)
 				selfRoute.GET("/token", controller.GenerateAccessToken)
+				selfRoute.GET("/aff", controller.GetAffCode)
 				selfRoute.POST("/topup", controller.TopUp)
 			}
 

+ 14 - 1
web/src/components/PersonalSetting.js

@@ -1,7 +1,7 @@
 import React, { useEffect, useState } from 'react';
 import { Button, Divider, Form, Header, Image, Message, Modal } from 'semantic-ui-react';
 import { Link } from 'react-router-dom';
-import { API, copy, showError, showInfo, showSuccess } from '../helpers';
+import { API, copy, showError, showInfo, showNotice, showSuccess } from '../helpers';
 import Turnstile from 'react-turnstile';
 
 const PersonalSetting = () => {
@@ -45,6 +45,18 @@ const PersonalSetting = () => {
     }
   };
 
+  const getAffLink = async () => {
+    const res = await API.get('/api/user/aff');
+    const { success, message, data } = res.data;
+    if (success) {
+      let link = `${window.location.origin}/register?aff=${data}`;
+      await copy(link);
+      showNotice(`邀请链接已复制到剪切板:${link}`);
+    } else {
+      showError(message);
+    }
+  };
+
   const bindWeChat = async () => {
     if (inputs.wechat_verification_code === '') return;
     const res = await API.get(
@@ -110,6 +122,7 @@ const PersonalSetting = () => {
         更新个人信息
       </Button>
       <Button onClick={generateAccessToken}>生成系统访问令牌</Button>
+      <Button onClick={getAffLink}>复制邀请链接</Button>
       <Divider />
       <Header as='h3'>账号绑定</Header>
       {

+ 8 - 0
web/src/components/RegisterForm.js

@@ -27,6 +27,10 @@ const RegisterForm = () => {
   const [turnstileToken, setTurnstileToken] = useState('');
   const [loading, setLoading] = useState(false);
   const logo = getLogo();
+  let affCode = new URLSearchParams(window.location.search).get('aff');
+  if (affCode) {
+    localStorage.setItem('aff', affCode);
+  }
 
   useEffect(() => {
     let status = localStorage.getItem('status');
@@ -63,6 +67,10 @@ const RegisterForm = () => {
         return;
       }
       setLoading(true);
+      if (!affCode) {
+        affCode = localStorage.getItem('aff');
+      }
+      inputs.aff_code = affCode;
       const res = await API.post(
         `/api/user/register?turnstile=${turnstileToken}`,
         inputs

+ 33 - 1
web/src/components/SystemSetting.js

@@ -27,6 +27,8 @@ const SystemSetting = () => {
     TurnstileSecretKey: '',
     RegisterEnabled: '',
     QuotaForNewUser: 0,
+    QuotaForInviter: 0,
+    QuotaForInvitee: 0,
     QuotaRemindThreshold: 0,
     PreConsumedQuota: 0,
     ModelRatio: '',
@@ -34,7 +36,7 @@ const SystemSetting = () => {
     TopUpLink: '',
     AutomaticDisableChannelEnabled: '',
     ChannelDisableThreshold: 0,
-    LogConsumeEnabled: '',
+    LogConsumeEnabled: ''
   });
   const [originInputs, setOriginInputs] = useState({});
   let [loading, setLoading] = useState(false);
@@ -101,6 +103,8 @@ const SystemSetting = () => {
       name === 'TurnstileSiteKey' ||
       name === 'TurnstileSecretKey' ||
       name === 'QuotaForNewUser' ||
+      name === 'QuotaForInviter' ||
+      name === 'QuotaForInvitee' ||
       name === 'QuotaRemindThreshold' ||
       name === 'PreConsumedQuota' ||
       name === 'ModelRatio' ||
@@ -122,6 +126,12 @@ const SystemSetting = () => {
     if (originInputs['QuotaForNewUser'] !== inputs.QuotaForNewUser) {
       await updateOption('QuotaForNewUser', inputs.QuotaForNewUser);
     }
+    if (originInputs['QuotaForInvitee'] !== inputs.QuotaForInvitee) {
+      await updateOption('QuotaForInvitee', inputs.QuotaForInvitee);
+    }
+    if (originInputs['QuotaForInviter'] !== inputs.QuotaForInviter) {
+      await updateOption('QuotaForInviter', inputs.QuotaForInviter);
+    }
     if (originInputs['QuotaRemindThreshold'] !== inputs.QuotaRemindThreshold) {
       await updateOption('QuotaRemindThreshold', inputs.QuotaRemindThreshold);
     }
@@ -329,6 +339,28 @@ const SystemSetting = () => {
               placeholder='请求结束后多退少补'
             />
           </Form.Group>
+          <Form.Group widths={4}>
+            <Form.Input
+              label='邀请新用户奖励配额'
+              name='QuotaForInviter'
+              onChange={handleInputChange}
+              autoComplete='new-password'
+              value={inputs.QuotaForInviter}
+              type='number'
+              min='0'
+              placeholder='例如:100'
+            />
+            <Form.Input
+              label='新用户使用邀请码奖励配额'
+              name='QuotaForInvitee'
+              onChange={handleInputChange}
+              autoComplete='new-password'
+              value={inputs.QuotaForInvitee}
+              type='number'
+              min='0'
+              placeholder='例如:100'
+            />
+          </Form.Group>
           <Form.Group widths='equal'>
             <Form.TextArea
               label='模型倍率'