Просмотр исходного кода

✨ feat(channel): enhance channel status management

CaIon 8 месяцев назад
Родитель
Сommit
cd8c23c0ab

+ 1 - 0
constant/context_key.go

@@ -29,6 +29,7 @@ const (
 	ContextKeyChannelModelMapping      ContextKey = "model_mapping"
 	ContextKeyChannelStatusCodeMapping ContextKey = "status_code_mapping"
 	ContextKeyChannelIsMultiKey        ContextKey = "channel_is_multi_key"
+	ContextKeyChannelKey               ContextKey = "channel_key"
 
 	/* user related keys */
 	ContextKeyUserId      ContextKey = "id"

+ 4 - 4
controller/channel-billing.go

@@ -452,14 +452,14 @@ func updateAllChannelsBalance() error {
 		//if channel.Type != common.ChannelTypeOpenAI && channel.Type != common.ChannelTypeCustom {
 		//	continue
 		//}
-		balance, err := updateChannelBalance(channel)
+		_, err := updateChannelBalance(channel)
 		if err != nil {
 			continue
 		} else {
 			// err is nil & balance <= 0 means quota is used up
-			if balance <= 0 {
-				service.DisableChannel(channel.Id, channel.Name, "余额不足")
-			}
+			//if balance <= 0 {
+			//	service.DisableChannel(channel.Id, channel.Name, "余额不足")
+			//}
 		}
 		time.Sleep(common.RequestInterval)
 	}

+ 122 - 35
controller/channel-test.go

@@ -30,22 +30,43 @@ import (
 	"github.com/gin-gonic/gin"
 )
 
-func testChannel(channel *model.Channel, testModel string) (err error, newAPIError *types.NewAPIError) {
+type testResult struct {
+	context     *gin.Context
+	localErr    error
+	newAPIError *types.NewAPIError
+}
+
+func testChannel(channel *model.Channel, testModel string) testResult {
 	tik := time.Now()
 	if channel.Type == constant.ChannelTypeMidjourney {
-		return errors.New("midjourney channel test is not supported"), nil
+		return testResult{
+			localErr:    errors.New("midjourney channel test is not supported"),
+			newAPIError: nil,
+		}
 	}
 	if channel.Type == constant.ChannelTypeMidjourneyPlus {
-		return errors.New("midjourney plus channel test is not supported"), nil
+		return testResult{
+			localErr:    errors.New("midjourney plus channel test is not supported"),
+			newAPIError: nil,
+		}
 	}
 	if channel.Type == constant.ChannelTypeSunoAPI {
-		return errors.New("suno channel test is not supported"), nil
+		return testResult{
+			localErr:    errors.New("suno channel test is not supported"),
+			newAPIError: nil,
+		}
 	}
 	if channel.Type == constant.ChannelTypeKling {
-		return errors.New("kling channel test is not supported"), nil
+		return testResult{
+			localErr:    errors.New("kling channel test is not supported"),
+			newAPIError: nil,
+		}
 	}
 	if channel.Type == constant.ChannelTypeJimeng {
-		return errors.New("jimeng channel test is not supported"), nil
+		return testResult{
+			localErr:    errors.New("jimeng channel test is not supported"),
+			newAPIError: nil,
+		}
 	}
 	w := httptest.NewRecorder()
 	c, _ := gin.CreateTestContext(w)
@@ -82,7 +103,10 @@ func testChannel(channel *model.Channel, testModel string) (err error, newAPIErr
 
 	cache, err := model.GetUserCache(1)
 	if err != nil {
-		return err, nil
+		return testResult{
+			localErr:    err,
+			newAPIError: nil,
+		}
 	}
 	cache.WriteContext(c)
 
@@ -93,20 +117,35 @@ func testChannel(channel *model.Channel, testModel string) (err error, newAPIErr
 	group, _ := model.GetUserGroup(1, false)
 	c.Set("group", group)
 
-	middleware.SetupContextForSelectedChannel(c, channel, testModel)
+	newAPIError := middleware.SetupContextForSelectedChannel(c, channel, testModel)
+	if newAPIError != nil {
+		return testResult{
+			context:     c,
+			localErr:    newAPIError,
+			newAPIError: newAPIError,
+		}
+	}
 
 	info := relaycommon.GenRelayInfo(c)
 
 	err = helper.ModelMappedHelper(c, info, nil)
 	if err != nil {
-		return err, types.NewError(err, types.ErrorCodeChannelModelMappedError)
+		return testResult{
+			context:     c,
+			localErr:    err,
+			newAPIError: types.NewError(err, types.ErrorCodeChannelModelMappedError),
+		}
 	}
 	testModel = info.UpstreamModelName
 
 	apiType, _ := common.ChannelType2APIType(channel.Type)
 	adaptor := relay.GetAdaptor(apiType)
 	if adaptor == nil {
-		return fmt.Errorf("invalid api type: %d, adaptor is nil", apiType), types.NewError(fmt.Errorf("invalid api type: %d, adaptor is nil", apiType), types.ErrorCodeInvalidApiType)
+		return testResult{
+			context:     c,
+			localErr:    fmt.Errorf("invalid api type: %d, adaptor is nil", apiType),
+			newAPIError: types.NewError(fmt.Errorf("invalid api type: %d, adaptor is nil", apiType), types.ErrorCodeInvalidApiType),
+		}
 	}
 
 	request := buildTestRequest(testModel)
@@ -117,45 +156,77 @@ func testChannel(channel *model.Channel, testModel string) (err error, newAPIErr
 
 	priceData, err := helper.ModelPriceHelper(c, info, 0, int(request.MaxTokens))
 	if err != nil {
-		return err, types.NewError(err, types.ErrorCodeModelPriceError)
+		return testResult{
+			context:     c,
+			localErr:    err,
+			newAPIError: types.NewError(err, types.ErrorCodeModelPriceError),
+		}
 	}
 
 	adaptor.Init(info)
 
 	convertedRequest, err := adaptor.ConvertOpenAIRequest(c, info, request)
 	if err != nil {
-		return err, types.NewError(err, types.ErrorCodeConvertRequestFailed)
+		return testResult{
+			context:     c,
+			localErr:    err,
+			newAPIError: types.NewError(err, types.ErrorCodeConvertRequestFailed),
+		}
 	}
 	jsonData, err := json.Marshal(convertedRequest)
 	if err != nil {
-		return err, types.NewError(err, types.ErrorCodeJsonMarshalFailed)
+		return testResult{
+			context:     c,
+			localErr:    err,
+			newAPIError: types.NewError(err, types.ErrorCodeJsonMarshalFailed),
+		}
 	}
 	requestBody := bytes.NewBuffer(jsonData)
 	c.Request.Body = io.NopCloser(requestBody)
 	resp, err := adaptor.DoRequest(c, info, requestBody)
 	if err != nil {
-		return err, types.NewError(err, types.ErrorCodeDoRequestFailed)
+		return testResult{
+			context:     c,
+			localErr:    err,
+			newAPIError: types.NewError(err, types.ErrorCodeDoRequestFailed),
+		}
 	}
 	var httpResp *http.Response
 	if resp != nil {
 		httpResp = resp.(*http.Response)
 		if httpResp.StatusCode != http.StatusOK {
 			err := service.RelayErrorHandler(httpResp, true)
-			return err, types.NewError(err, types.ErrorCodeBadResponse)
+			return testResult{
+				context:     c,
+				localErr:    err,
+				newAPIError: types.NewError(err, types.ErrorCodeBadResponse),
+			}
 		}
 	}
 	usageA, respErr := adaptor.DoResponse(c, httpResp, info)
 	if respErr != nil {
-		return respErr, respErr
+		return testResult{
+			context:     c,
+			localErr:    respErr,
+			newAPIError: respErr,
+		}
 	}
 	if usageA == nil {
-		return errors.New("usage is nil"), types.NewError(errors.New("usage is nil"), types.ErrorCodeBadResponseBody)
+		return testResult{
+			context:     c,
+			localErr:    errors.New("usage is nil"),
+			newAPIError: types.NewError(errors.New("usage is nil"), types.ErrorCodeBadResponseBody),
+		}
 	}
 	usage := usageA.(*dto.Usage)
 	result := w.Result()
 	respBody, err := io.ReadAll(result.Body)
 	if err != nil {
-		return err, types.NewError(err, types.ErrorCodeReadResponseBodyFailed)
+		return testResult{
+			context:     c,
+			localErr:    err,
+			newAPIError: types.NewError(err, types.ErrorCodeReadResponseBodyFailed),
+		}
 	}
 	info.PromptTokens = usage.PromptTokens
 
@@ -188,7 +259,11 @@ func testChannel(channel *model.Channel, testModel string) (err error, newAPIErr
 		Other:            other,
 	})
 	common.SysLog(fmt.Sprintf("testing channel #%d, response: \n%s", channel.Id, string(respBody)))
-	return nil, nil
+	return testResult{
+		context:     c,
+		localErr:    nil,
+		newAPIError: nil,
+	}
 }
 
 func buildTestRequest(model string) *dto.GeneralOpenAIRequest {
@@ -247,15 +322,23 @@ func TestChannel(c *gin.Context) {
 	}
 	testModel := c.Query("model")
 	tik := time.Now()
-	_, newAPIError := testChannel(channel, testModel)
+	result := testChannel(channel, testModel)
+	if result.localErr != nil {
+		c.JSON(http.StatusOK, gin.H{
+			"success": false,
+			"message": result.localErr.Error(),
+			"time":    0.0,
+		})
+		return
+	}
 	tok := time.Now()
 	milliseconds := tok.Sub(tik).Milliseconds()
 	go channel.UpdateResponseTime(milliseconds)
 	consumedTime := float64(milliseconds) / 1000.0
-	if newAPIError != nil {
+	if result.newAPIError != nil {
 		c.JSON(http.StatusOK, gin.H{
 			"success": false,
-			"message": newAPIError.Error(),
+			"message": result.newAPIError.Error(),
 			"time":    consumedTime,
 		})
 		return
@@ -280,9 +363,9 @@ func testAllChannels(notify bool) error {
 	}
 	testAllChannelsRunning = true
 	testAllChannelsLock.Unlock()
-	channels, err := model.GetAllChannels(0, 0, true, false)
-	if err != nil {
-		return err
+	channels, getChannelErr := model.GetAllChannels(0, 0, true, false)
+	if getChannelErr != nil {
+		return getChannelErr
 	}
 	var disableThreshold = int64(common.ChannelDisableThreshold * 1000)
 	if disableThreshold == 0 {
@@ -299,30 +382,34 @@ func testAllChannels(notify bool) error {
 		for _, channel := range channels {
 			isChannelEnabled := channel.Status == common.ChannelStatusEnabled
 			tik := time.Now()
-			err, newAPIError := testChannel(channel, "")
+			result := testChannel(channel, "")
 			tok := time.Now()
 			milliseconds := tok.Sub(tik).Milliseconds()
 
 			shouldBanChannel := false
-
+			newAPIError := result.newAPIError
 			// request error disables the channel
-			if err != nil {
-				shouldBanChannel = service.ShouldDisableChannel(channel.Type, newAPIError)
+			if newAPIError != nil {
+				shouldBanChannel = service.ShouldDisableChannel(channel.Type, result.newAPIError)
 			}
 
-			if milliseconds > disableThreshold {
-				err = errors.New(fmt.Sprintf("响应时间 %.2fs 超过阈值 %.2fs", float64(milliseconds)/1000.0, float64(disableThreshold)/1000.0))
-				shouldBanChannel = true
+			// 当错误检查通过,才检查响应时间
+			if common.AutomaticDisableChannelEnabled && !shouldBanChannel {
+				if milliseconds > disableThreshold {
+					err := errors.New(fmt.Sprintf("响应时间 %.2fs 超过阈值 %.2fs", float64(milliseconds)/1000.0, float64(disableThreshold)/1000.0))
+					newAPIError = types.NewError(err, types.ErrorCodeChannelResponseTimeExceeded)
+					shouldBanChannel = true
+				}
 			}
 
 			// disable channel
 			if isChannelEnabled && shouldBanChannel && channel.GetAutoBan() {
-				service.DisableChannel(channel.Id, channel.Name, err.Error())
+				go processChannelError(result.context, *types.NewChannelError(channel.Id, channel.Type, channel.Name, channel.ChannelInfo.IsMultiKey, common.GetContextKeyString(result.context, constant.ContextKeyChannelKey), channel.GetAutoBan()), newAPIError)
 			}
 
 			// enable channel
-			if !isChannelEnabled && service.ShouldEnableChannel(err, newAPIError, channel.Status) {
-				service.EnableChannel(channel.Id, channel.Name)
+			if !isChannelEnabled && service.ShouldEnableChannel(newAPIError, channel.Status) {
+				service.EnableChannel(channel.Id, common.GetContextKeyString(result.context, constant.ContextKeyChannelKey), channel.Name)
 			}
 
 			channel.UpdateResponseTime(milliseconds)

+ 2 - 0
controller/channel.go

@@ -497,6 +497,7 @@ func AddChannel(c *gin.Context) {
 				})
 				return
 			}
+			addChannelRequest.Channel.ChannelInfo.MultiKeySize = len(array)
 			addChannelRequest.Channel.Key = strings.Join(array, "\n")
 		} else {
 			cleanKeys := make([]string, 0)
@@ -507,6 +508,7 @@ func AddChannel(c *gin.Context) {
 				key = strings.TrimSpace(key)
 				cleanKeys = append(cleanKeys, key)
 			}
+			addChannelRequest.Channel.ChannelInfo.MultiKeySize = len(cleanKeys)
 			addChannelRequest.Channel.Key = strings.Join(cleanKeys, "\n")
 		}
 		keys = []string{addChannelRequest.Channel.Key}

+ 29 - 26
controller/relay.go

@@ -80,7 +80,7 @@ func Relay(c *gin.Context) {
 		channel, err := getChannel(c, group, originalModel, i)
 		if err != nil {
 			common.LogError(c, err.Error())
-			newAPIError = types.NewError(err, types.ErrorCodeGetChannelFailed)
+			newAPIError = err
 			break
 		}
 
@@ -90,7 +90,7 @@ func Relay(c *gin.Context) {
 			return // 成功处理请求,直接返回
 		}
 
-		go processChannelError(c, channel.Id, channel.Type, channel.Name, channel.GetAutoBan(), newAPIError)
+		go processChannelError(c, *types.NewChannelError(channel.Id, channel.Type, channel.Name, channel.ChannelInfo.IsMultiKey, common.GetContextKeyString(c, constant.ContextKeyChannelKey), channel.GetAutoBan()), newAPIError)
 
 		if !shouldRetry(c, newAPIError, common.RetryTimes-i) {
 			break
@@ -103,10 +103,10 @@ func Relay(c *gin.Context) {
 	}
 
 	if newAPIError != nil {
-		if newAPIError.StatusCode == http.StatusTooManyRequests {
-			common.LogError(c, fmt.Sprintf("origin 429 error: %s", newAPIError.Error()))
-			newAPIError.SetMessage("当前分组上游负载已饱和,请稍后再试")
-		}
+		//if newAPIError.StatusCode == http.StatusTooManyRequests {
+		//	common.LogError(c, fmt.Sprintf("origin 429 error: %s", newAPIError.Error()))
+		//	newAPIError.SetMessage("当前分组上游负载已饱和,请稍后再试")
+		//}
 		newAPIError.SetMessage(common.MessageWithRequestId(newAPIError.Error(), requestId))
 		c.JSON(newAPIError.StatusCode, gin.H{
 			"error": newAPIError.ToOpenAIError(),
@@ -143,7 +143,7 @@ func WssRelay(c *gin.Context) {
 		channel, err := getChannel(c, group, originalModel, i)
 		if err != nil {
 			common.LogError(c, err.Error())
-			newAPIError = types.NewError(err, types.ErrorCodeGetChannelFailed)
+			newAPIError = err
 			break
 		}
 
@@ -153,7 +153,7 @@ func WssRelay(c *gin.Context) {
 			return // 成功处理请求,直接返回
 		}
 
-		go processChannelError(c, channel.Id, channel.Type, channel.Name, channel.GetAutoBan(), newAPIError)
+		go processChannelError(c, *types.NewChannelError(channel.Id, channel.Type, channel.Name, channel.ChannelInfo.IsMultiKey, common.GetContextKeyString(c, constant.ContextKeyChannelKey), channel.GetAutoBan()), newAPIError)
 
 		if !shouldRetry(c, newAPIError, common.RetryTimes-i) {
 			break
@@ -166,9 +166,9 @@ func WssRelay(c *gin.Context) {
 	}
 
 	if newAPIError != nil {
-		if newAPIError.StatusCode == http.StatusTooManyRequests {
-			newAPIError.SetMessage("当前分组上游负载已饱和,请稍后再试")
-		}
+		//if newAPIError.StatusCode == http.StatusTooManyRequests {
+		//	newAPIError.SetMessage("当前分组上游负载已饱和,请稍后再试")
+		//}
 		newAPIError.SetMessage(common.MessageWithRequestId(newAPIError.Error(), requestId))
 		helper.WssError(c, ws, newAPIError.ToOpenAIError())
 	}
@@ -185,7 +185,7 @@ func RelayClaude(c *gin.Context) {
 		channel, err := getChannel(c, group, originalModel, i)
 		if err != nil {
 			common.LogError(c, err.Error())
-			newAPIError = types.NewError(err, types.ErrorCodeGetChannelFailed)
+			newAPIError = err
 			break
 		}
 
@@ -195,7 +195,7 @@ func RelayClaude(c *gin.Context) {
 			return // 成功处理请求,直接返回
 		}
 
-		go processChannelError(c, channel.Id, channel.Type, channel.Name, channel.GetAutoBan(), newAPIError)
+		go processChannelError(c, *types.NewChannelError(channel.Id, channel.Type, channel.Name, channel.ChannelInfo.IsMultiKey, common.GetContextKeyString(c, constant.ContextKeyChannelKey), channel.GetAutoBan()), newAPIError)
 
 		if !shouldRetry(c, newAPIError, common.RetryTimes-i) {
 			break
@@ -243,7 +243,7 @@ func addUsedChannel(c *gin.Context, channelId int) {
 	c.Set("use_channel", useChannel)
 }
 
-func getChannel(c *gin.Context, group, originalModel string, retryCount int) (*model.Channel, error) {
+func getChannel(c *gin.Context, group, originalModel string, retryCount int) (*model.Channel, *types.NewAPIError) {
 	if retryCount == 0 {
 		autoBan := c.GetBool("auto_ban")
 		autoBanInt := 1
@@ -260,11 +260,14 @@ func getChannel(c *gin.Context, group, originalModel string, retryCount int) (*m
 	channel, selectGroup, err := model.CacheGetRandomSatisfiedChannel(c, group, originalModel, retryCount)
 	if err != nil {
 		if group == "auto" {
-			return nil, errors.New(fmt.Sprintf("获取自动分组下模型 %s 的可用渠道失败: %s", originalModel, err.Error()))
+			return nil, types.NewError(errors.New(fmt.Sprintf("获取自动分组下模型 %s 的可用渠道失败: %s", originalModel, err.Error())), types.ErrorCodeGetChannelFailed)
 		}
-		return nil, errors.New(fmt.Sprintf("获取分组 %s 下模型 %s 的可用渠道失败: %s", selectGroup, originalModel, err.Error()))
+		return nil, types.NewError(errors.New(fmt.Sprintf("获取分组 %s 下模型 %s 的可用渠道失败: %s", selectGroup, originalModel, err.Error())), types.ErrorCodeGetChannelFailed)
+	}
+	newAPIError := middleware.SetupContextForSelectedChannel(c, channel, originalModel)
+	if newAPIError != nil {
+		return nil, newAPIError
 	}
-	middleware.SetupContextForSelectedChannel(c, channel, originalModel)
 	return channel, nil
 }
 
@@ -314,12 +317,12 @@ func shouldRetry(c *gin.Context, openaiErr *types.NewAPIError, retryTimes int) b
 	return true
 }
 
-func processChannelError(c *gin.Context, channelId int, channelType int, channelName string, autoBan bool, err *types.NewAPIError) {
+func processChannelError(c *gin.Context, channelError types.ChannelError, err *types.NewAPIError) {
 	// 不要使用context获取渠道信息,异步处理时可能会出现渠道信息不一致的情况
 	// do not use context to get channel info, there may be inconsistent channel info when processing asynchronously
-	common.LogError(c, fmt.Sprintf("relay error (channel #%d, status code: %d): %s", channelId, err.StatusCode, err.Error()))
-	if service.ShouldDisableChannel(channelType, err) && autoBan {
-		service.DisableChannel(channelId, channelName, err.Error())
+	common.LogError(c, fmt.Sprintf("relay error (channel #%d, status code: %d): %s", channelError.ChannelId, err.StatusCode, err.Error()))
+	if service.ShouldDisableChannel(channelError.ChannelId, err) && channelError.AutoBan {
+		service.DisableChannel(channelError, err.Error())
 	}
 }
 
@@ -392,10 +395,10 @@ func RelayTask(c *gin.Context) {
 		retryTimes = 0
 	}
 	for i := 0; shouldRetryTaskRelay(c, channelId, taskErr, retryTimes) && i < retryTimes; i++ {
-		channel, err := getChannel(c, group, originalModel, i)
-		if err != nil {
-			common.LogError(c, fmt.Sprintf("CacheGetRandomSatisfiedChannel failed: %s", err.Error()))
-			taskErr = service.TaskErrorWrapperLocal(err, "get_channel_failed", http.StatusInternalServerError)
+		channel, newAPIError := getChannel(c, group, originalModel, i)
+		if newAPIError != nil {
+			common.LogError(c, fmt.Sprintf("CacheGetRandomSatisfiedChannel failed: %s", newAPIError.Error()))
+			taskErr = service.TaskErrorWrapperLocal(newAPIError.Err, "get_channel_failed", http.StatusInternalServerError)
 			break
 		}
 		channelId = channel.Id
@@ -405,7 +408,7 @@ func RelayTask(c *gin.Context) {
 		common.LogInfo(c, fmt.Sprintf("using channel #%d to retry (remain times %d)", channel.Id, i))
 		//middleware.SetupContextForSelectedChannel(c, channel, originalModel)
 
-		requestBody, err := common.GetRequestBody(c)
+		requestBody, _ := common.GetRequestBody(c)
 		c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody))
 		taskErr = taskRelayHandler(c, relayMode)
 	}

+ 11 - 3
middleware/distributor.go

@@ -12,6 +12,7 @@ import (
 	"one-api/service"
 	"one-api/setting"
 	"one-api/setting/ratio_setting"
+	"one-api/types"
 	"strconv"
 	"strings"
 	"time"
@@ -249,10 +250,10 @@ func getModelRequest(c *gin.Context) (*ModelRequest, bool, error) {
 	return &modelRequest, shouldSelectChannel, nil
 }
 
-func SetupContextForSelectedChannel(c *gin.Context, channel *model.Channel, modelName string) {
+func SetupContextForSelectedChannel(c *gin.Context, channel *model.Channel, modelName string) *types.NewAPIError {
 	c.Set("original_model", modelName) // for retry
 	if channel == nil {
-		return
+		return types.NewError(errors.New("channel is nil"), types.ErrorCodeGetChannelFailed)
 	}
 	common.SetContextKey(c, constant.ContextKeyChannelId, channel.Id)
 	common.SetContextKey(c, constant.ContextKeyChannelName, channel.Name)
@@ -270,7 +271,13 @@ func SetupContextForSelectedChannel(c *gin.Context, channel *model.Channel, mode
 		common.SetContextKey(c, constant.ContextKeyChannelIsMultiKey, true)
 
 	}
-	c.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", channel.Key))
+
+	key, newAPIError := channel.GetNextEnabledKey()
+	if newAPIError != nil {
+		return newAPIError
+	}
+	// c.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", key))
+	common.SetContextKey(c, constant.ContextKeyChannelKey, key)
 	common.SetContextKey(c, constant.ContextKeyChannelBaseUrl, channel.GetBaseURL())
 
 	// TODO: api_version统一
@@ -292,6 +299,7 @@ func SetupContextForSelectedChannel(c *gin.Context, channel *model.Channel, mode
 	case constant.ChannelTypeCoze:
 		c.Set("bot_id", channel.Other)
 	}
+	return nil
 }
 
 // extractModelNameFromGeminiPath 从 Gemini API URL 路径中提取模型名

+ 13 - 0
model/cache.go

@@ -203,3 +203,16 @@ func CacheUpdateChannelStatus(id int, status int) {
 		channel.Status = status
 	}
 }
+
+func CacheUpdateChannel(channel *Channel) {
+	if !common.MemoryCacheEnabled {
+		return
+	}
+	channelSyncLock.Lock()
+	defer channelSyncLock.Unlock()
+
+	if channel == nil {
+		return
+	}
+	channelsIDM[channel.Id] = channel
+}

+ 81 - 31
model/channel.go

@@ -3,11 +3,12 @@ package model
 import (
 	"database/sql/driver"
 	"encoding/json"
-	"fmt"
+	"errors"
 	"math/rand"
 	"one-api/common"
 	"one-api/constant"
 	"one-api/dto"
+	"one-api/types"
 	"strings"
 	"sync"
 
@@ -48,6 +49,7 @@ type Channel struct {
 
 type ChannelInfo struct {
 	IsMultiKey           bool                  `json:"is_multi_key"`            // 是否多Key模式
+	MultiKeySize         int                   `json:"multi_key_size"`          // 多Key模式下的Key数量
 	MultiKeyStatusList   map[int]int           `json:"multi_key_status_list"`   // key状态列表,key index -> status
 	MultiKeyPollingIndex int                   `json:"multi_key_polling_index"` // 多Key模式下轮询的key索引
 	MultiKeyMode         constant.MultiKeyMode `json:"multi_key_mode"`
@@ -73,7 +75,7 @@ func (channel *Channel) getKeys() []string {
 	return keys
 }
 
-func (channel *Channel) GetNextEnabledKey() (string, error) {
+func (channel *Channel) GetNextEnabledKey() (string, *types.NewAPIError) {
 	// If not in multi-key mode, return the original key string directly.
 	if !channel.ChannelInfo.IsMultiKey {
 		return channel.Key, nil
@@ -83,7 +85,7 @@ func (channel *Channel) GetNextEnabledKey() (string, error) {
 	keys := channel.getKeys()
 	if len(keys) == 0 {
 		// No keys available, return error, should disable the channel
-		return "", fmt.Errorf("no valid keys in channel")
+		return "", types.NewError(errors.New("no keys available"), types.ErrorCodeChannelNoAvailableKey)
 	}
 
 	statusList := channel.ChannelInfo.MultiKeyStatusList
@@ -404,48 +406,94 @@ func (channel *Channel) Delete() error {
 
 var channelStatusLock sync.Mutex
 
-func UpdateChannelStatusById(id int, status int, reason string) bool {
+func handlerMultiKeyUpdate(channel *Channel, usingKey string, status int) {
+	keys := channel.getKeys()
+	if len(keys) == 0 {
+		channel.Status = status
+	} else {
+		var keyIndex int
+		for i, key := range keys {
+			if key == usingKey {
+				keyIndex = i
+				break
+			}
+		}
+		if channel.ChannelInfo.MultiKeyStatusList == nil {
+			channel.ChannelInfo.MultiKeyStatusList = make(map[int]int)
+		}
+		if status == common.ChannelStatusEnabled {
+			delete(channel.ChannelInfo.MultiKeyStatusList, keyIndex)
+		} else {
+			channel.ChannelInfo.MultiKeyStatusList[keyIndex] = status
+		}
+		if len(channel.ChannelInfo.MultiKeyStatusList) >= channel.ChannelInfo.MultiKeySize {
+			channel.Status = common.ChannelStatusAutoDisabled
+			info := channel.GetOtherInfo()
+			info["status_reason"] = "All keys are disabled"
+			info["status_time"] = common.GetTimestamp()
+			channel.SetOtherInfo(info)
+		}
+	}
+}
+
+func UpdateChannelStatus(channelId int, usingKey string, status int, reason string) bool {
 	if common.MemoryCacheEnabled {
 		channelStatusLock.Lock()
 		defer channelStatusLock.Unlock()
 
-		channelCache, _ := CacheGetChannel(id)
-		// 如果缓存渠道存在,且状态已是目标状态,直接返回
-		if channelCache != nil && channelCache.Status == status {
+		channelCache, _ := CacheGetChannel(channelId)
+		if channelCache == nil {
 			return false
 		}
-		// 如果缓存渠道不存在(说明已经被禁用),且要设置的状态不为启用,直接返回
-		if channelCache == nil && status != common.ChannelStatusEnabled {
-			return false
+		if channelCache.ChannelInfo.IsMultiKey {
+			// 如果是多Key模式,更新缓存中的状态
+			handlerMultiKeyUpdate(channelCache, usingKey, status)
+			CacheUpdateChannel(channelCache)
+			//return true
+		} else {
+			// 如果缓存渠道存在,且状态已是目标状态,直接返回
+			if channelCache.Status == status {
+				return false
+			}
+			// 如果缓存渠道不存在(说明已经被禁用),且要设置的状态不为启用,直接返回
+			if status != common.ChannelStatusEnabled {
+				return false
+			}
+			CacheUpdateChannelStatus(channelId, status)
 		}
-		CacheUpdateChannelStatus(id, status)
 	}
-	err := UpdateAbilityStatus(id, status == common.ChannelStatusEnabled)
+
+	shouldUpdateAbilities := false
+	defer func() {
+		if shouldUpdateAbilities {
+			err := UpdateAbilityStatus(channelId, status == common.ChannelStatusEnabled)
+			if err != nil {
+				common.SysError("failed to update ability status: " + err.Error())
+			}
+		}
+	}()
+	channel, err := GetChannelById(channelId, true)
 	if err != nil {
-		common.SysError("failed to update ability status: " + err.Error())
 		return false
-	}
-	channel, err := GetChannelById(id, true)
-	if err != nil {
-		// find channel by id error, directly update status
-		result := DB.Model(&Channel{}).Where("id = ?", id).Update("status", status)
-		if result.Error != nil {
-			common.SysError("failed to update channel status: " + result.Error.Error())
-			return false
-		}
-		if result.RowsAffected == 0 {
-			return false
-		}
 	} else {
 		if channel.Status == status {
 			return false
 		}
-		// find channel by id success, update status and other info
-		info := channel.GetOtherInfo()
-		info["status_reason"] = reason
-		info["status_time"] = common.GetTimestamp()
-		channel.SetOtherInfo(info)
-		channel.Status = status
+
+		if channel.ChannelInfo.IsMultiKey {
+			beforeStatus := channel.Status
+			handlerMultiKeyUpdate(channel, usingKey, status)
+			if beforeStatus != channel.Status {
+				shouldUpdateAbilities = true
+			}
+		} else {
+			info := channel.GetOtherInfo()
+			info["status_reason"] = reason
+			info["status_time"] = common.GetTimestamp()
+			channel.SetOtherInfo(info)
+			channel.Status = status
+			shouldUpdateAbilities = true
+		}
 		err = channel.Save()
 		if err != nil {
 			common.SysError("failed to update channel status: " + err.Error())
@@ -628,6 +676,8 @@ func (channel *Channel) GetSetting() dto.ChannelSettings {
 		err := json.Unmarshal([]byte(*channel.Setting), &setting)
 		if err != nil {
 			common.SysError("failed to unmarshal setting: " + err.Error())
+			channel.Setting = nil // 清空设置以避免后续错误
+			_ = channel.Save()    // 保存修改
 		}
 	}
 	return setting

+ 2 - 1
relay/channel/tencent/adaptor.go

@@ -6,6 +6,7 @@ import (
 	"io"
 	"net/http"
 	"one-api/common"
+	"one-api/constant"
 	"one-api/dto"
 	"one-api/relay/channel"
 	relaycommon "one-api/relay/common"
@@ -63,7 +64,7 @@ func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayIn
 	if request == nil {
 		return nil, errors.New("request is nil")
 	}
-	apiKey := c.Request.Header.Get("Authorization")
+	apiKey := common.GetContextKeyString(c, constant.ContextKeyChannelKey)
 	apiKey = strings.TrimPrefix(apiKey, "Bearer ")
 	appId, secretId, secretKey, err := parseTencentConfig(apiKey)
 	a.AppID = appId

+ 1 - 1
relay/common/relay_info.go

@@ -247,7 +247,7 @@ func GenRelayInfo(c *gin.Context) *RelayInfo {
 		IsModelMapped: false,
 		ApiType:       apiType,
 		ApiVersion:    c.GetString("api_version"),
-		ApiKey:        strings.TrimPrefix(c.Request.Header.Get("Authorization"), "Bearer "),
+		ApiKey:        common.GetContextKeyString(c, constant.ContextKeyChannelKey),
 		Organization:  c.GetString("channel_organization"),
 
 		ChannelCreateTime: c.GetInt64("channel_create_time"),

+ 1 - 1
relay/relay-mj.go

@@ -575,7 +575,7 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons
 			common.SysError("get_channel_null: " + err.Error())
 		}
 		if channel.GetAutoBan() && common.AutomaticDisableChannelEnabled {
-			model.UpdateChannelStatusById(midjourneyTask.ChannelId, 2, "No available account instance")
+			model.UpdateChannelStatus(midjourneyTask.ChannelId, "", 2, "No available account instance")
 		}
 	}
 	if midjResponse.Code != 1 && midjResponse.Code != 21 && midjResponse.Code != 22 {

+ 8 - 11
service/channel.go

@@ -17,17 +17,17 @@ func formatNotifyType(channelId int, status int) string {
 }
 
 // disable & notify
-func DisableChannel(channelId int, channelName string, reason string) {
-	success := model.UpdateChannelStatusById(channelId, common.ChannelStatusAutoDisabled, reason)
+func DisableChannel(channelError types.ChannelError, reason string) {
+	success := model.UpdateChannelStatus(channelError.ChannelId, channelError.UsingKey, common.ChannelStatusAutoDisabled, reason)
 	if success {
-		subject := fmt.Sprintf("通道「%s」(#%d)已被禁用", channelName, channelId)
-		content := fmt.Sprintf("通道「%s」(#%d)已被禁用,原因:%s", channelName, channelId, reason)
-		NotifyRootUser(formatNotifyType(channelId, common.ChannelStatusAutoDisabled), subject, content)
+		subject := fmt.Sprintf("通道「%s」(#%d)已被禁用", channelError.ChannelName, channelError.ChannelId)
+		content := fmt.Sprintf("通道「%s」(#%d)已被禁用,原因:%s", channelError.ChannelName, channelError.ChannelId, reason)
+		NotifyRootUser(formatNotifyType(channelError.ChannelId, common.ChannelStatusAutoDisabled), subject, content)
 	}
 }
 
-func EnableChannel(channelId int, channelName string) {
-	success := model.UpdateChannelStatusById(channelId, common.ChannelStatusEnabled, "")
+func EnableChannel(channelId int, usingKey string, channelName string) {
+	success := model.UpdateChannelStatus(channelId, usingKey, common.ChannelStatusEnabled, "")
 	if success {
 		subject := fmt.Sprintf("通道「%s」(#%d)已被启用", channelName, channelId)
 		content := fmt.Sprintf("通道「%s」(#%d)已被启用", channelName, channelId)
@@ -87,13 +87,10 @@ func ShouldDisableChannel(channelType int, err *types.NewAPIError) bool {
 	return search
 }
 
-func ShouldEnableChannel(err error, newAPIError *types.NewAPIError, status int) bool {
+func ShouldEnableChannel(newAPIError *types.NewAPIError, status int) bool {
 	if !common.AutomaticEnableChannelEnabled {
 		return false
 	}
-	if err != nil {
-		return false
-	}
 	if newAPIError != nil {
 		return false
 	}

+ 1 - 1
service/midjourney.go

@@ -204,7 +204,7 @@ func DoMidjourneyHttpRequest(c *gin.Context, timeout time.Duration, fullRequestU
 	req = req.WithContext(ctx)
 	req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type"))
 	req.Header.Set("Accept", c.Request.Header.Get("Accept"))
-	auth := c.Request.Header.Get("Authorization")
+	auth := common.GetContextKeyString(c, constant.ContextKeyChannelKey)
 	if auth != "" {
 		auth = strings.TrimPrefix(auth, "Bearer ")
 		req.Header.Set("mj-api-secret", auth)

+ 21 - 0
types/channel_error.go

@@ -0,0 +1,21 @@
+package types
+
+type ChannelError struct {
+	ChannelId   int    `json:"channel_id"`
+	ChannelType int    `json:"channel_type"`
+	ChannelName string `json:"channel_name"`
+	IsMultiKey  bool   `json:"is_multi_key"`
+	AutoBan     bool   `json:"auto_ban"`
+	UsingKey    string `json:"using_key"`
+}
+
+func NewChannelError(channelId int, channelType int, channelName string, isMultiKey bool, usingKey string, autoBan bool) *ChannelError {
+	return &ChannelError{
+		ChannelId:   channelId,
+		ChannelType: channelType,
+		ChannelName: channelName,
+		IsMultiKey:  isMultiKey,
+		AutoBan:     autoBan,
+		UsingKey:    usingKey,
+	}
+}

+ 1 - 0
types/error.go

@@ -50,6 +50,7 @@ const (
 	ErrorCodeChannelModelMappedError     ErrorCode = "channel:model_mapped_error"
 	ErrorCodeChannelAwsClientError       ErrorCode = "channel:aws_client_error"
 	ErrorCodeChannelInvalidKey           ErrorCode = "channel:invalid_key"
+	ErrorCodeChannelResponseTimeExceeded ErrorCode = "channel:response_time_exceeded"
 
 	// client request error
 	ErrorCodeReadRequestBodyFailed ErrorCode = "read_request_body_failed"

+ 65 - 5
web/src/components/table/ChannelsTable.js

@@ -42,6 +42,7 @@ import {
   IconTreeTriangleDown,
   IconSearch,
   IconMore,
+  IconList
 } from '@douyinfe/semi-icons';
 import { loadChannelModels, isMobile, copy } from '../../helpers';
 import EditTagModal from '../../pages/Channel/EditTagModal.js';
@@ -53,7 +54,7 @@ const ChannelsTable = () => {
 
   let type2label = undefined;
 
-  const renderType = (type) => {
+  const renderType = (type, multiKey = false) => {
     if (!type2label) {
       type2label = new Map();
       for (let i = 0; i < CHANNEL_OPTIONS.length; i++) {
@@ -61,12 +62,24 @@ const ChannelsTable = () => {
       }
       type2label[0] = { value: 0, label: t('未知类型'), color: 'grey' };
     }
+    
+    let icon = getChannelIcon(type);
+    
+    if (multiKey) {
+      icon = (
+        <div className="flex items-center gap-1">
+          <IconList className="text-blue-500" />
+          {icon}
+        </div>
+      );
+    }
+    
     return (
       <Tag
         size='large'
         color={type2label[type]?.color}
         shape='circle'
-        prefixIcon={getChannelIcon(type)}
+        prefixIcon={icon}
       >
         {type2label[type]?.label}
       </Tag>
@@ -86,7 +99,19 @@ const ChannelsTable = () => {
     );
   };
 
-  const renderStatus = (status) => {
+  const renderStatus = (status, channelInfo = undefined) => {
+    if (channelInfo) {
+      if (channelInfo.is_multi_key) {
+        let keySize = channelInfo.multi_key_size;
+        let enabledKeySize = keySize;
+        if (channelInfo.multi_key_status_list) {
+          // multi_key_status_list is a map, key is key, value is status
+          // get multi_key_status_list length
+          enabledKeySize = keySize - Object.keys(channelInfo.multi_key_status_list).length;
+        }
+        return renderMultiKeyStatus(status, keySize, enabledKeySize);
+      }
+    }
     switch (status) {
       case 1:
         return (
@@ -115,6 +140,36 @@ const ChannelsTable = () => {
     }
   };
 
+  const renderMultiKeyStatus = (status, keySize, enabledKeySize) => {
+    switch (status) {
+      case 1:
+        return (
+          <Tag size='large' color='green' shape='circle'>
+            {t('已启用')} {enabledKeySize}/{keySize}
+          </Tag>
+        );
+      case 2:
+        return (
+          <Tag size='large' color='red' shape='circle'>
+            {t('已禁用')} {enabledKeySize}/{keySize}
+          </Tag>
+        );
+      case 3:
+        return (
+          <Tag size='large' color='yellow' shape='circle'>
+            {t('自动禁用')} {enabledKeySize}/{keySize}
+          </Tag>
+        );
+      default:
+        return (
+          <Tag size='large' color='grey' shape='circle'>
+            {t('未知状态')} {enabledKeySize}/{keySize}
+          </Tag>
+        );
+    }
+  }
+
+
   const renderResponseTime = (responseTime) => {
     let time = responseTime / 1000;
     time = time.toFixed(2) + t(' 秒');
@@ -281,6 +336,11 @@ const ChannelsTable = () => {
       dataIndex: 'type',
       render: (text, record, index) => {
         if (record.children === undefined) {
+          if (record.channel_info) {
+            if (record.channel_info.is_multi_key) {
+              return <>{renderType(text, record.channel_info)}</>;
+            }
+          }
           return <>{renderType(text)}</>;
         } else {
           return <>{renderTagType()}</>;
@@ -304,12 +364,12 @@ const ChannelsTable = () => {
               <Tooltip
                 content={t('原因:') + reason + t(',时间:') + timestamp2string(time)}
               >
-                {renderStatus(text)}
+                {renderStatus(text, record.channel_info)}
               </Tooltip>
             </div>
           );
         } else {
-          return renderStatus(text);
+          return renderStatus(text, record.channel_info);
         }
       },
     },