Просмотр исходного кода

feat: use cache to avoid database access (#158)

JustSong 2 лет назад
Родитель
Сommit
3d76a974d1
5 измененных файлов с 123 добавлено и 4 удалено
  1. 15 0
      common/redis.go
  2. 6 0
      main.go
  3. 2 2
      middleware/distributor.go
  4. 99 0
      model/cache.go
  5. 1 2
      model/token.go

+ 15 - 0
common/redis.go

@@ -37,3 +37,18 @@ func ParseRedisOption() *redis.Options {
 	}
 	return opt
 }
+
+func RedisSet(key string, value string, expiration time.Duration) error {
+	ctx := context.Background()
+	return RDB.Set(ctx, key, value, expiration).Err()
+}
+
+func RedisGet(key string) (string, error) {
+	ctx := context.Background()
+	return RDB.Get(ctx, key).Result()
+}
+
+func RedisDel(key string) error {
+	ctx := context.Background()
+	return RDB.Del(ctx, key).Err()
+}

+ 6 - 0
main.go

@@ -47,12 +47,18 @@ func main() {
 
 	// Initialize options
 	model.InitOptionMap()
+	if common.RedisEnabled {
+		model.InitChannelCache()
+	}
 	if os.Getenv("SYNC_FREQUENCY") != "" {
 		frequency, err := strconv.Atoi(os.Getenv("SYNC_FREQUENCY"))
 		if err != nil {
 			common.FatalLog(err)
 		}
 		go model.SyncOptions(frequency)
+		if common.RedisEnabled {
+			go model.SyncChannelCache(frequency)
+		}
 	}
 
 	// Initialize HTTP server

+ 2 - 2
middleware/distributor.go

@@ -17,7 +17,7 @@ type ModelRequest struct {
 func Distribute() func(c *gin.Context) {
 	return func(c *gin.Context) {
 		userId := c.GetInt("id")
-		userGroup, _ := model.GetUserGroup(userId)
+		userGroup, _ := model.CacheGetUserGroup(userId)
 		c.Set("group", userGroup)
 		var channel *model.Channel
 		channelId, ok := c.Get("channelId")
@@ -73,7 +73,7 @@ func Distribute() func(c *gin.Context) {
 					modelRequest.Model = "text-moderation-stable"
 				}
 			}
-			channel, err = model.GetRandomSatisfiedChannel(userGroup, modelRequest.Model)
+			channel, err = model.CacheGetRandomSatisfiedChannel(userGroup, modelRequest.Model)
 			if err != nil {
 				c.JSON(200, gin.H{
 					"error": gin.H{

+ 99 - 0
model/cache.go

@@ -0,0 +1,99 @@
+package model
+
+import (
+	"encoding/json"
+	"fmt"
+	"one-api/common"
+	"sync"
+	"time"
+)
+
+const (
+	TokenCacheSeconds        = 60 * 60
+	UserId2GroupCacheSeconds = 60 * 60
+)
+
+func CacheGetTokenByKey(key string) (*Token, error) {
+	var token Token
+	if !common.RedisEnabled {
+		err := DB.Where("`key` = ?", key).First(token).Error
+		return &token, err
+	}
+	tokenObjectString, err := common.RedisGet(fmt.Sprintf("token:%s", key))
+	if err != nil {
+		err := DB.Where("`key` = ?", key).First(token).Error
+		if err != nil {
+			return nil, err
+		}
+		jsonBytes, err := json.Marshal(token)
+		if err != nil {
+			return nil, err
+		}
+		err = common.RedisSet(fmt.Sprintf("token:%s", key), string(jsonBytes), TokenCacheSeconds*time.Second)
+		if err != nil {
+			common.SysError("Redis set token error: " + err.Error())
+		}
+	}
+	err = json.Unmarshal([]byte(tokenObjectString), &token)
+	return &token, err
+}
+
+func CacheGetUserGroup(id int) (group string, err error) {
+	if !common.RedisEnabled {
+		return GetUserGroup(id)
+	}
+	group, err = common.RedisGet(fmt.Sprintf("user_group:%d", id))
+	if err != nil {
+		group, err = GetUserGroup(id)
+		if err != nil {
+			return "", err
+		}
+		err = common.RedisSet(fmt.Sprintf("user_group:%d", id), group, UserId2GroupCacheSeconds*time.Second)
+		if err != nil {
+			common.SysError("Redis set user group error: " + err.Error())
+		}
+	}
+	return group, err
+}
+
+var channelId2channel map[int]*Channel
+var channelSyncLock sync.RWMutex
+var group2model2channels map[string]map[string][]*Channel
+
+func InitChannelCache() {
+	channelSyncLock.Lock()
+	defer channelSyncLock.Unlock()
+	channelId2channel = make(map[int]*Channel)
+	var channels []*Channel
+	DB.Find(&channels)
+	for _, channel := range channels {
+		channelId2channel[channel.Id] = channel
+	}
+	var abilities []*Ability
+	DB.Find(&abilities)
+	groups := make(map[string]bool)
+	for _, ability := range abilities {
+		groups[ability.Group] = true
+	}
+	group2model2channels = make(map[string]map[string][]*Channel)
+	for group := range groups {
+		group2model2channels[group] = make(map[string][]*Channel)
+		// TODO: implement this
+	}
+}
+
+func SyncChannelCache(frequency int) {
+	for {
+		time.Sleep(time.Duration(frequency) * time.Second)
+		common.SysLog("Syncing channels from database")
+		InitChannelCache()
+	}
+}
+
+func CacheGetRandomSatisfiedChannel(group string, model string) (*Channel, error) {
+	if !common.RedisEnabled {
+		return GetRandomSatisfiedChannel(group, model)
+	}
+	// TODO: implement this
+	return nil, nil
+}

+ 1 - 2
model/token.go

@@ -36,8 +36,7 @@ func ValidateUserToken(key string) (token *Token, err error) {
 	if key == "" {
 		return nil, errors.New("未提供 token")
 	}
-	token = &Token{}
-	err = DB.Where("`key` = ?", key).First(token).Error
+	token, err = CacheGetTokenByKey(key)
 	if err == nil {
 		if token.Status != common.TokenStatusEnabled {
 			return nil, errors.New("该 token 状态不可用")