CalciumIon пре 1 година
родитељ
комит
d4578e28b3
6 измењених фајлова са 20 додато и 12 уклоњено
  1. 1 1
      controller/channel-test.go
  2. 5 4
      controller/relay.go
  3. 1 0
      middleware/distributor.go
  4. 4 4
      relay/relay-audio.go
  5. 1 1
      relay/relay-image.go
  6. 8 2
      service/channel.go

+ 1 - 1
controller/channel-test.go

@@ -228,7 +228,7 @@ func testAllChannels(notify bool) error {
 					Error:      *openaiErr,
 					Error:      *openaiErr,
 					LocalError: false,
 					LocalError: false,
 				}
 				}
-				if isChannelEnabled && service.ShouldDisableChannel(&openAiErrWithStatus) && ban {
+				if isChannelEnabled && service.ShouldDisableChannel(channel.Type, &openAiErrWithStatus) && ban {
 					service.DisableChannel(channel.Id, channel.Name, err.Error())
 					service.DisableChannel(channel.Id, channel.Name, err.Error())
 				}
 				}
 				if !isChannelEnabled && service.ShouldEnableChannel(err, openaiErr, channel.Status) {
 				if !isChannelEnabled && service.ShouldEnableChannel(err, openaiErr, channel.Status) {

+ 5 - 4
controller/relay.go

@@ -40,12 +40,13 @@ func Relay(c *gin.Context) {
 	retryTimes := common.RetryTimes
 	retryTimes := common.RetryTimes
 	requestId := c.GetString(common.RequestIdKey)
 	requestId := c.GetString(common.RequestIdKey)
 	channelId := c.GetInt("channel_id")
 	channelId := c.GetInt("channel_id")
+	channelType := c.GetInt("channel_type")
 	group := c.GetString("group")
 	group := c.GetString("group")
 	originalModel := c.GetString("original_model")
 	originalModel := c.GetString("original_model")
 	openaiErr := relayHandler(c, relayMode)
 	openaiErr := relayHandler(c, relayMode)
 	c.Set("use_channel", []string{fmt.Sprintf("%d", channelId)})
 	c.Set("use_channel", []string{fmt.Sprintf("%d", channelId)})
 	if openaiErr != nil {
 	if openaiErr != nil {
-		go processChannelError(c, channelId, openaiErr)
+		go processChannelError(c, channelId, channelType, openaiErr)
 	} else {
 	} else {
 		retryTimes = 0
 		retryTimes = 0
 	}
 	}
@@ -66,7 +67,7 @@ func Relay(c *gin.Context) {
 		c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody))
 		c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody))
 		openaiErr = relayHandler(c, relayMode)
 		openaiErr = relayHandler(c, relayMode)
 		if openaiErr != nil {
 		if openaiErr != nil {
-			go processChannelError(c, channelId, openaiErr)
+			go processChannelError(c, channelId, channel.Type, openaiErr)
 		}
 		}
 	}
 	}
 	useChannel := c.GetStringSlice("use_channel")
 	useChannel := c.GetStringSlice("use_channel")
@@ -125,10 +126,10 @@ func shouldRetry(c *gin.Context, channelId int, openaiErr *dto.OpenAIErrorWithSt
 	return true
 	return true
 }
 }
 
 
-func processChannelError(c *gin.Context, channelId int, err *dto.OpenAIErrorWithStatusCode) {
+func processChannelError(c *gin.Context, channelId int, channelType int, err *dto.OpenAIErrorWithStatusCode) {
 	autoBan := c.GetBool("auto_ban")
 	autoBan := c.GetBool("auto_ban")
 	common.LogError(c.Request.Context(), fmt.Sprintf("relay error (channel #%d, status code: %d): %s", channelId, err.StatusCode, err.Error.Message))
 	common.LogError(c.Request.Context(), fmt.Sprintf("relay error (channel #%d, status code: %d): %s", channelId, err.StatusCode, err.Error.Message))
-	if service.ShouldDisableChannel(err) && autoBan {
+	if service.ShouldDisableChannel(channelType, err) && autoBan {
 		channelName := c.GetString("channel_name")
 		channelName := c.GetString("channel_name")
 		service.DisableChannel(channelId, channelName, err.Error.Message)
 		service.DisableChannel(channelId, channelName, err.Error.Message)
 	}
 	}

+ 1 - 0
middleware/distributor.go

@@ -178,6 +178,7 @@ func SetupContextForSelectedChannel(c *gin.Context, channel *model.Channel, mode
 	c.Set("channel", channel.Type)
 	c.Set("channel", channel.Type)
 	c.Set("channel_id", channel.Id)
 	c.Set("channel_id", channel.Id)
 	c.Set("channel_name", channel.Name)
 	c.Set("channel_name", channel.Name)
+	c.Set("channel_type", channel.Type)
 	ban := true
 	ban := true
 	// parse *int to bool
 	// parse *int to bool
 	if channel.AutoBan != nil && *channel.AutoBan == 0 {
 	if channel.AutoBan != nil && *channel.AutoBan == 0 {

+ 4 - 4
relay/relay-audio.go

@@ -73,14 +73,14 @@ func AudioHelper(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode {
 	preConsumedQuota := int(float64(preConsumedTokens) * ratio)
 	preConsumedQuota := int(float64(preConsumedTokens) * ratio)
 	userQuota, err := model.CacheGetUserQuota(userId)
 	userQuota, err := model.CacheGetUserQuota(userId)
 	if err != nil {
 	if err != nil {
-		return service.OpenAIErrorWrapper(err, "get_user_quota_failed", http.StatusInternalServerError)
+		return service.OpenAIErrorWrapperLocal(err, "get_user_quota_failed", http.StatusInternalServerError)
 	}
 	}
 	if userQuota-preConsumedQuota < 0 {
 	if userQuota-preConsumedQuota < 0 {
-		return service.OpenAIErrorWrapper(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden)
+		return service.OpenAIErrorWrapperLocal(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden)
 	}
 	}
 	err = model.CacheDecreaseUserQuota(userId, preConsumedQuota)
 	err = model.CacheDecreaseUserQuota(userId, preConsumedQuota)
 	if err != nil {
 	if err != nil {
-		return service.OpenAIErrorWrapper(err, "decrease_user_quota_failed", http.StatusInternalServerError)
+		return service.OpenAIErrorWrapperLocal(err, "decrease_user_quota_failed", http.StatusInternalServerError)
 	}
 	}
 	if userQuota > 100*preConsumedQuota {
 	if userQuota > 100*preConsumedQuota {
 		// in this case, we do not pre-consume quota
 		// in this case, we do not pre-consume quota
@@ -90,7 +90,7 @@ func AudioHelper(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode {
 	if preConsumedQuota > 0 {
 	if preConsumedQuota > 0 {
 		userQuota, err = model.PreConsumeTokenQuota(tokenId, preConsumedQuota)
 		userQuota, err = model.PreConsumeTokenQuota(tokenId, preConsumedQuota)
 		if err != nil {
 		if err != nil {
-			return service.OpenAIErrorWrapper(err, "pre_consume_token_quota_failed", http.StatusForbidden)
+			return service.OpenAIErrorWrapperLocal(err, "pre_consume_token_quota_failed", http.StatusForbidden)
 		}
 		}
 	}
 	}
 
 

+ 1 - 1
relay/relay-image.go

@@ -147,7 +147,7 @@ func RelayImageHelper(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusC
 	quota := int(modelPrice*groupRatio*common.QuotaPerUnit*sizeRatio*qualityRatio) * imageRequest.N
 	quota := int(modelPrice*groupRatio*common.QuotaPerUnit*sizeRatio*qualityRatio) * imageRequest.N
 
 
 	if userQuota-quota < 0 {
 	if userQuota-quota < 0 {
-		return service.OpenAIErrorWrapper(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden)
+		return service.OpenAIErrorWrapperLocal(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden)
 	}
 	}
 
 
 	req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody)
 	req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody)

+ 8 - 2
service/channel.go

@@ -24,7 +24,7 @@ func EnableChannel(channelId int, channelName string) {
 	notifyRootUser(subject, content)
 	notifyRootUser(subject, content)
 }
 }
 
 
-func ShouldDisableChannel(err *relaymodel.OpenAIErrorWithStatusCode) bool {
+func ShouldDisableChannel(channelType int, err *relaymodel.OpenAIErrorWithStatusCode) bool {
 	if !common.AutomaticDisableChannelEnabled {
 	if !common.AutomaticDisableChannelEnabled {
 		return false
 		return false
 	}
 	}
@@ -34,9 +34,15 @@ func ShouldDisableChannel(err *relaymodel.OpenAIErrorWithStatusCode) bool {
 	if err.LocalError {
 	if err.LocalError {
 		return false
 		return false
 	}
 	}
-	if err.StatusCode == http.StatusUnauthorized || err.StatusCode == http.StatusForbidden {
+	if err.StatusCode == http.StatusUnauthorized {
 		return true
 		return true
 	}
 	}
+	if err.StatusCode == http.StatusForbidden {
+		switch channelType {
+		case common.ChannelTypeGemini:
+			return true
+		}
+	}
 	switch err.Error.Code {
 	switch err.Error.Code {
 	case "invalid_api_key":
 	case "invalid_api_key":
 		return true
 		return true