Просмотр исходного кода

✨ feat(channel): improve channel cache handling and add error checks for disabled channels

CaIon 7 месяцев назад
Родитель
Сommit
50b76f4466
3 измененных файлов с 55 добавлено и 12 удалено
  1. 6 1
      controller/channel-test.go
  2. 16 3
      model/channel.go
  3. 33 8
      model/channel_cache.go

+ 6 - 1
controller/channel-test.go

@@ -110,7 +110,7 @@ func testChannel(channel *model.Channel, testModel string) testResult {
 	}
 	cache.WriteContext(c)
 
-	c.Request.Header.Set("Authorization", "Bearer "+channel.Key)
+	//c.Request.Header.Set("Authorization", "Bearer "+channel.Key)
 	c.Request.Header.Set("Content-Type", "application/json")
 	c.Set("channel", channel.Type)
 	c.Set("base_url", channel.GetBaseURL())
@@ -320,6 +320,11 @@ func TestChannel(c *gin.Context) {
 		})
 		return
 	}
+	//defer func() {
+	//	if channel.ChannelInfo.IsMultiKey {
+	//		go func() { _ = channel.SaveChannelInfo() }()
+	//	}
+	//}()
 	testModel := c.Query("model")
 	tik := time.Now()
 	result := testChannel(channel, testModel)

+ 16 - 3
model/channel.go

@@ -4,6 +4,7 @@ import (
 	"database/sql/driver"
 	"encoding/json"
 	"errors"
+	"fmt"
 	"math/rand"
 	"one-api/common"
 	"one-api/constant"
@@ -122,15 +123,23 @@ func (channel *Channel) GetNextEnabledKey() (string, *types.NewAPIError) {
 		lock.Lock()
 		defer lock.Unlock()
 
+		channelInfo, err := CacheGetChannelInfo(channel.Id)
+		if err != nil {
+			return "", types.NewError(err, types.ErrorCodeGetChannelFailed)
+		}
+		//println("before polling index:", channel.ChannelInfo.MultiKeyPollingIndex)
 		defer func() {
+			if common.DebugEnabled {
+				println(fmt.Sprintf("channel %d polling index: %d", channel.Id, channel.ChannelInfo.MultiKeyPollingIndex))
+			}
 			if !common.MemoryCacheEnabled {
-				_ = channel.Save()
+				_ = channel.SaveChannelInfo()
 			} else {
 				// CacheUpdateChannel(channel)
 			}
 		}()
 		// Start from the saved polling index and look for the next enabled key
-		start := channel.ChannelInfo.MultiKeyPollingIndex
+		start := channelInfo.MultiKeyPollingIndex
 		if start < 0 || start >= len(keys) {
 			start = 0
 		}
@@ -150,6 +159,10 @@ func (channel *Channel) GetNextEnabledKey() (string, *types.NewAPIError) {
 	}
 }
 
+func (channel *Channel) SaveChannelInfo() error {
+	return DB.Model(channel).Update("channel_info", channel.ChannelInfo).Error
+}
+
 func (channel *Channel) GetModels() []string {
 	if channel.Models == "" {
 		return []string{}
@@ -500,7 +513,7 @@ func UpdateChannelStatus(channelId int, usingKey string, status int, reason stri
 		if channelCache.ChannelInfo.IsMultiKey {
 			// 如果是多Key模式,更新缓存中的状态
 			handlerMultiKeyUpdate(channelCache, usingKey, status)
-			CacheUpdateChannel(channelCache)
+			//CacheUpdateChannel(channelCache)
 			//return true
 		} else {
 			// 如果缓存渠道存在,且状态已是目标状态,直接返回

+ 33 - 8
model/channel_cache.go

@@ -14,8 +14,8 @@ import (
 	"github.com/gin-gonic/gin"
 )
 
-var group2model2channels map[string]map[string][]int
-var channelsIDM map[int]*Channel
+var group2model2channels map[string]map[string][]int // enabled channel
+var channelsIDM map[int]*Channel                     // all channels include disabled
 var channelSyncLock sync.RWMutex
 
 func InitChannelCache() {
@@ -24,7 +24,7 @@ func InitChannelCache() {
 	}
 	newChannelId2channel := make(map[int]*Channel)
 	var channels []*Channel
-	DB.Where("status = ?", common.ChannelStatusEnabled).Find(&channels)
+	DB.Find(&channels)
 	for _, channel := range channels {
 		newChannelId2channel[channel.Id] = channel
 	}
@@ -35,12 +35,13 @@ func InitChannelCache() {
 		groups[ability.Group] = true
 	}
 	newGroup2model2channels := make(map[string]map[string][]int)
-	newChannelsIDM := make(map[int]*Channel)
 	for group := range groups {
 		newGroup2model2channels[group] = make(map[string][]int)
 	}
 	for _, channel := range channels {
-		newChannelsIDM[channel.Id] = channel
+		if channel.Status != common.ChannelStatusEnabled {
+			continue // skip disabled channels
+		}
 		groups := strings.Split(channel.Group, ",")
 		for _, group := range groups {
 			models := strings.Split(channel.Models, ",")
@@ -57,7 +58,7 @@ func InitChannelCache() {
 	for group, model2channels := range newGroup2model2channels {
 		for model, channels := range model2channels {
 			sort.Slice(channels, func(i, j int) bool {
-				return newChannelsIDM[channels[i]].GetPriority() > newChannelsIDM[channels[j]].GetPriority()
+				return newChannelId2channel[channels[i]].GetPriority() > newChannelId2channel[channels[j]].GetPriority()
 			})
 			newGroup2model2channels[group][model] = channels
 		}
@@ -65,7 +66,7 @@ func InitChannelCache() {
 
 	channelSyncLock.Lock()
 	group2model2channels = newGroup2model2channels
-	channelsIDM = newChannelsIDM
+	channelsIDM = newChannelId2channel
 	channelSyncLock.Unlock()
 	common.SysLog("channels synced from database")
 }
@@ -203,11 +204,35 @@ func CacheGetChannel(id int) (*Channel, error) {
 
 	c, ok := channelsIDM[id]
 	if !ok {
-		return nil, fmt.Errorf("当前渠道# %d,已不存在", id)
+		return nil, fmt.Errorf("渠道# %d,已不存在", id)
+	}
+	if c.Status != common.ChannelStatusEnabled {
+		return nil, fmt.Errorf("渠道# %d,已被禁用", id)
 	}
 	return c, nil
 }
 
+func CacheGetChannelInfo(id int) (*ChannelInfo, error) {
+	if !common.MemoryCacheEnabled {
+		channel, err := GetChannelById(id, true)
+		if err != nil {
+			return nil, err
+		}
+		return &channel.ChannelInfo, nil
+	}
+	channelSyncLock.RLock()
+	defer channelSyncLock.RUnlock()
+
+	c, ok := channelsIDM[id]
+	if !ok {
+		return nil, fmt.Errorf("渠道# %d,已不存在", id)
+	}
+	if c.Status != common.ChannelStatusEnabled {
+		return nil, fmt.Errorf("渠道# %d,已被禁用", id)
+	}
+	return &c.ChannelInfo, nil
+}
+
 func CacheUpdateChannelStatus(id int, status int) {
 	if !common.MemoryCacheEnabled {
 		return