Bläddra i källkod

feat: 本地重试

CaIon 1 år sedan
förälder
incheckning
4b60528c5f
11 ändrade filer med 215 tillägg och 103 borttagningar
  1. 23 4
      common/gin.go
  2. 84 27
      controller/relay.go
  3. 1 0
      dto/error.go
  4. 1 1
      middleware/auth.go
  5. 35 30
      middleware/distributor.go
  6. 9 14
      model/ability.go
  7. 24 18
      model/cache.go
  8. 1 0
      relay/common/relay_info.go
  9. 8 8
      relay/relay-text.go
  10. 23 1
      service/channel.go
  11. 6 0
      service/error.go

+ 23 - 4
common/gin.go

@@ -5,18 +5,37 @@ import (
 	"encoding/json"
 	"github.com/gin-gonic/gin"
 	"io"
+	"strings"
 )
 
-func UnmarshalBodyReusable(c *gin.Context, v any) error {
+const KeyRequestBody = "key_request_body"
+
+func GetRequestBody(c *gin.Context) ([]byte, error) {
+	requestBody, _ := c.Get(KeyRequestBody)
+	if requestBody != nil {
+		return requestBody.([]byte), nil
+	}
 	requestBody, err := io.ReadAll(c.Request.Body)
 	if err != nil {
-		return err
+		return nil, err
 	}
-	err = c.Request.Body.Close()
+	_ = c.Request.Body.Close()
+	c.Set(KeyRequestBody, requestBody)
+	return requestBody.([]byte), nil
+}
+
+func UnmarshalBodyReusable(c *gin.Context, v any) error {
+	requestBody, err := GetRequestBody(c)
 	if err != nil {
 		return err
 	}
-	err = json.Unmarshal(requestBody, &v)
+	contentType := c.Request.Header.Get("Content-Type")
+	if strings.HasPrefix(contentType, "application/json") {
+		err = json.Unmarshal(requestBody, &v)
+	} else {
+		// skip for now
+		// TODO: someday non json request have variant model, we will need to implementation this
+	}
 	if err != nil {
 		return err
 	}

+ 84 - 27
controller/relay.go

@@ -1,21 +1,23 @@
 package controller
 
 import (
+	"bytes"
 	"fmt"
 	"github.com/gin-gonic/gin"
+	"io"
 	"log"
 	"net/http"
 	"one-api/common"
 	"one-api/dto"
+	"one-api/middleware"
+	"one-api/model"
 	"one-api/relay"
 	"one-api/relay/constant"
 	relayconstant "one-api/relay/constant"
 	"one-api/service"
-	"strconv"
 )
 
-func Relay(c *gin.Context) {
-	relayMode := constant.Path2RelayMode(c.Request.URL.Path)
+func relayHandler(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode {
 	var err *dto.OpenAIErrorWithStatusCode
 	switch relayMode {
 	case relayconstant.RelayModeImagesGenerations:
@@ -29,33 +31,88 @@ func Relay(c *gin.Context) {
 	default:
 		err = relay.TextHelper(c)
 	}
-	if err != nil {
-		requestId := c.GetString(common.RequestIdKey)
-		retryTimesStr := c.Query("retry")
-		retryTimes, _ := strconv.Atoi(retryTimesStr)
-		if retryTimesStr == "" {
-			retryTimes = common.RetryTimes
+	return err
+}
+
+func Relay(c *gin.Context) {
+	relayMode := constant.Path2RelayMode(c.Request.URL.Path)
+	retryTimes := common.RetryTimes
+	requestId := c.GetString(common.RequestIdKey)
+	channelId := c.GetInt("channel_id")
+	group := c.GetString("group")
+	originalModel := c.GetString("original_model")
+	openaiErr := relayHandler(c, relayMode)
+	retryLogStr := fmt.Sprintf("重试:%d", channelId)
+	if openaiErr != nil {
+		go processChannelError(c, channelId, openaiErr)
+	} else {
+		retryTimes = 0
+	}
+	for i := 0; shouldRetry(c, channelId, openaiErr, retryTimes) && i < retryTimes; i++ {
+		channel, err := model.CacheGetRandomSatisfiedChannel(group, originalModel, i)
+		if err != nil {
+			common.LogError(c.Request.Context(), fmt.Sprintf("CacheGetRandomSatisfiedChannel failed: %s", err.Error()))
+			break
 		}
-		if retryTimes > 0 {
-			c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s?retry=%d", c.Request.URL.Path, retryTimes-1))
-		} else {
-			if err.StatusCode == http.StatusTooManyRequests {
-				//err.Error.Message = "当前分组上游负载已饱和,请稍后再试"
-			}
-			err.Error.Message = common.MessageWithRequestId(err.Error.Message, requestId)
-			c.JSON(err.StatusCode, gin.H{
-				"error": err.Error,
-			})
+		channelId = channel.Id
+		retryLogStr += fmt.Sprintf("->%d", channel.Id)
+		common.LogInfo(c.Request.Context(), fmt.Sprintf("using channel #%d to retry (remain times %d)", channel.Id, i))
+		middleware.SetupContextForSelectedChannel(c, channel, originalModel)
+
+		requestBody, err := common.GetRequestBody(c)
+		c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody))
+		openaiErr = relayHandler(c, relayMode)
+		if openaiErr != nil {
+			go processChannelError(c, channelId, openaiErr)
 		}
-		channelId := c.GetInt("channel_id")
-		autoBan := c.GetBool("auto_ban")
-		common.LogError(c.Request.Context(), fmt.Sprintf("relay error (channel #%d): %s", channelId, err.Error.Message))
-		// https://platform.openai.com/docs/guides/error-codes/api-errors
-		if service.ShouldDisableChannel(&err.Error, err.StatusCode) && autoBan {
-			channelId := c.GetInt("channel_id")
-			channelName := c.GetString("channel_name")
-			service.DisableChannel(channelId, channelName, err.Error.Message)
+	}
+	common.LogInfo(c.Request.Context(), retryLogStr)
+
+	if openaiErr != nil {
+		if openaiErr.StatusCode == http.StatusTooManyRequests {
+			openaiErr.Error.Message = "当前分组上游负载已饱和,请稍后再试"
 		}
+		openaiErr.Error.Message = common.MessageWithRequestId(openaiErr.Error.Message, requestId)
+		c.JSON(openaiErr.StatusCode, gin.H{
+			"error": openaiErr.Error,
+		})
+	}
+}
+
+func shouldRetry(c *gin.Context, channelId int, openaiErr *dto.OpenAIErrorWithStatusCode, retryTimes int) bool {
+	if openaiErr == nil {
+		return false
+	}
+	if retryTimes <= 0 {
+		return false
+	}
+	if _, ok := c.Get("specific_channel_id"); ok {
+		return false
+	}
+	if openaiErr.StatusCode == http.StatusTooManyRequests {
+		return true
+	}
+	if openaiErr.StatusCode/100 == 5 {
+		return true
+	}
+	if openaiErr.StatusCode == http.StatusBadRequest {
+		return false
+	}
+	if openaiErr.LocalError {
+		return false
+	}
+	if openaiErr.StatusCode/100 == 2 {
+		return false
+	}
+	return true
+}
+
+func processChannelError(c *gin.Context, channelId int, err *dto.OpenAIErrorWithStatusCode) {
+	autoBan := c.GetBool("auto_ban")
+	common.LogError(c.Request.Context(), fmt.Sprintf("relay error (channel #%d): %s", channelId, err.Error.Message))
+	if service.ShouldDisableChannel(&err.Error, err.StatusCode) && autoBan {
+		channelName := c.GetString("channel_name")
+		service.DisableChannel(channelId, channelName, err.Error.Message)
 	}
 }
 

+ 1 - 0
dto/error.go

@@ -10,6 +10,7 @@ type OpenAIError struct {
 type OpenAIErrorWithStatusCode struct {
 	Error      OpenAIError `json:"error"`
 	StatusCode int         `json:"status_code"`
+	LocalError bool
 }
 
 type GeneralErrorResponse struct {

+ 1 - 1
middleware/auth.go

@@ -127,7 +127,7 @@ func TokenAuth() func(c *gin.Context) {
 		}
 		if len(parts) > 1 {
 			if model.IsAdmin(token.UserId) {
-				c.Set("channelId", parts[1])
+				c.Set("specific_channel_id", parts[1])
 			} else {
 				abortWithOpenAiMessage(c, http.StatusForbidden, "普通用户不支持指定渠道")
 				return

+ 35 - 30
middleware/distributor.go

@@ -23,7 +23,7 @@ func Distribute() func(c *gin.Context) {
 	return func(c *gin.Context) {
 		userId := c.GetInt("id")
 		var channel *model.Channel
-		channelId, ok := c.Get("channelId")
+		channelId, ok := c.Get("specific_channel_id")
 		if ok {
 			id, err := strconv.Atoi(channelId.(string))
 			if err != nil {
@@ -131,7 +131,7 @@ func Distribute() func(c *gin.Context) {
 			userGroup, _ := model.CacheGetUserGroup(userId)
 			c.Set("group", userGroup)
 			if shouldSelectChannel {
-				channel, err = model.CacheGetRandomSatisfiedChannel(userGroup, modelRequest.Model)
+				channel, err = model.CacheGetRandomSatisfiedChannel(userGroup, modelRequest.Model, 0)
 				if err != nil {
 					message := fmt.Sprintf("当前分组 %s 下对于模型 %s 无可用渠道", userGroup, modelRequest.Model)
 					// 如果错误,但是渠道不为空,说明是数据库一致性问题
@@ -147,36 +147,41 @@ func Distribute() func(c *gin.Context) {
 					abortWithOpenAiMessage(c, http.StatusServiceUnavailable, fmt.Sprintf("当前分组 %s 下对于模型 %s 无可用渠道(数据库一致性已被破坏)", userGroup, modelRequest.Model))
 					return
 				}
-				c.Set("channel", channel.Type)
-				c.Set("channel_id", channel.Id)
-				c.Set("channel_name", channel.Name)
-				ban := true
-				// parse *int to bool
-				if channel.AutoBan != nil && *channel.AutoBan == 0 {
-					ban = false
-				}
-				if nil != channel.OpenAIOrganization {
-					c.Set("channel_organization", *channel.OpenAIOrganization)
-				}
-				c.Set("auto_ban", ban)
-				c.Set("model_mapping", channel.GetModelMapping())
-				c.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", channel.Key))
-				c.Set("base_url", channel.GetBaseURL())
-				// TODO: api_version统一
-				switch channel.Type {
-				case common.ChannelTypeAzure:
-					c.Set("api_version", channel.Other)
-				case common.ChannelTypeXunfei:
-					c.Set("api_version", channel.Other)
-				//case common.ChannelTypeAIProxyLibrary:
-				//	c.Set("library_id", channel.Other)
-				case common.ChannelTypeGemini:
-					c.Set("api_version", channel.Other)
-				case common.ChannelTypeAli:
-					c.Set("plugin", channel.Other)
-				}
+				SetupContextForSelectedChannel(c, channel, modelRequest.Model)
 			}
 		}
 		c.Next()
 	}
 }
+
+func SetupContextForSelectedChannel(c *gin.Context, channel *model.Channel, modelName string) {
+	c.Set("channel", channel.Type)
+	c.Set("channel_id", channel.Id)
+	c.Set("channel_name", channel.Name)
+	ban := true
+	// parse *int to bool
+	if channel.AutoBan != nil && *channel.AutoBan == 0 {
+		ban = false
+	}
+	if nil != channel.OpenAIOrganization {
+		c.Set("channel_organization", *channel.OpenAIOrganization)
+	}
+	c.Set("auto_ban", ban)
+	c.Set("model_mapping", channel.GetModelMapping())
+	c.Set("original_model", modelName) // for retry
+	c.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", channel.Key))
+	c.Set("base_url", channel.GetBaseURL())
+	// TODO: api_version统一
+	switch channel.Type {
+	case common.ChannelTypeAzure:
+		c.Set("api_version", channel.Other)
+	case common.ChannelTypeXunfei:
+		c.Set("api_version", channel.Other)
+	//case common.ChannelTypeAIProxyLibrary:
+	//	c.Set("library_id", channel.Other)
+	case common.ChannelTypeGemini:
+		c.Set("api_version", channel.Other)
+	case common.ChannelTypeAli:
+		c.Set("plugin", channel.Other)
+	}
+}

+ 9 - 14
model/ability.go

@@ -52,21 +52,16 @@ func GetRandomSatisfiedChannel(group string, model string) (*Channel, error) {
 		// Randomly choose one
 		weightSum := uint(0)
 		for _, ability_ := range abilities {
-			weightSum += ability_.Weight
+			weightSum += ability_.Weight + 10
 		}
-		if weightSum == 0 {
-			// All weight is 0, randomly choose one
-			channel.Id = abilities[common.GetRandomInt(len(abilities))].ChannelId
-		} else {
-			// Randomly choose one
-			weight := common.GetRandomInt(int(weightSum))
-			for _, ability_ := range abilities {
-				weight -= int(ability_.Weight)
-				//log.Printf("weight: %d, ability weight: %d", weight, *ability_.Weight)
-				if weight <= 0 {
-					channel.Id = ability_.ChannelId
-					break
-				}
+		// Randomly choose one
+		weight := common.GetRandomInt(int(weightSum))
+		for _, ability_ := range abilities {
+			weight -= int(ability_.Weight)
+			//log.Printf("weight: %d, ability weight: %d", weight, *ability_.Weight)
+			if weight <= 0 {
+				channel.Id = ability_.ChannelId
+				break
 			}
 		}
 	} else {

+ 24 - 18
model/cache.go

@@ -265,7 +265,7 @@ func SyncChannelCache(frequency int) {
 	}
 }
 
-func CacheGetRandomSatisfiedChannel(group string, model string) (*Channel, error) {
+func CacheGetRandomSatisfiedChannel(group string, model string, retry int) (*Channel, error) {
 	if strings.HasPrefix(model, "gpt-4-gizmo") {
 		model = "gpt-4-gizmo-*"
 	}
@@ -280,15 +280,27 @@ func CacheGetRandomSatisfiedChannel(group string, model string) (*Channel, error
 	if len(channels) == 0 {
 		return nil, errors.New("channel not found")
 	}
-	endIdx := len(channels)
-	// choose by priority
-	firstChannel := channels[0]
-	if firstChannel.GetPriority() > 0 {
-		for i := range channels {
-			if channels[i].GetPriority() != firstChannel.GetPriority() {
-				endIdx = i
-				break
-			}
+
+	uniquePriorities := make(map[int]bool)
+	for _, channel := range channels {
+		uniquePriorities[int(channel.GetPriority())] = true
+	}
+	var sortedUniquePriorities []int
+	for priority := range uniquePriorities {
+		sortedUniquePriorities = append(sortedUniquePriorities, priority)
+	}
+	sort.Sort(sort.Reverse(sort.IntSlice(sortedUniquePriorities)))
+
+	if retry >= len(uniquePriorities) {
+		retry = len(uniquePriorities) - 1
+	}
+	targetPriority := int64(sortedUniquePriorities[retry])
+
+	// get the priority for the given retry number
+	var targetChannels []*Channel
+	for _, channel := range channels {
+		if channel.GetPriority() == targetPriority {
+			targetChannels = append(targetChannels, channel)
 		}
 	}
 
@@ -296,20 +308,14 @@ func CacheGetRandomSatisfiedChannel(group string, model string) (*Channel, error
 	smoothingFactor := 10
 	// Calculate the total weight of all channels up to endIdx
 	totalWeight := 0
-	for _, channel := range channels[:endIdx] {
+	for _, channel := range targetChannels {
 		totalWeight += channel.GetWeight() + smoothingFactor
 	}
-
-	//if totalWeight == 0 {
-	//	// If all weights are 0, select a channel randomly
-	//	return channels[rand.Intn(endIdx)], nil
-	//}
-
 	// Generate a random value in the range [0, totalWeight)
 	randomWeight := rand.Intn(totalWeight)
 
 	// Find a channel based on its weight
-	for _, channel := range channels[:endIdx] {
+	for _, channel := range targetChannels {
 		randomWeight -= channel.GetWeight() + smoothingFactor
 		if randomWeight < 0 {
 			return channel, nil

+ 1 - 0
relay/common/relay_info.go

@@ -31,6 +31,7 @@ type RelayInfo struct {
 func GenRelayInfo(c *gin.Context) *RelayInfo {
 	channelType := c.GetInt("channel")
 	channelId := c.GetInt("channel_id")
+
 	tokenId := c.GetInt("token_id")
 	userId := c.GetInt("id")
 	group := c.GetString("group")

+ 8 - 8
relay/relay-text.go

@@ -72,7 +72,7 @@ func TextHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode {
 	textRequest, err := getAndValidateTextRequest(c, relayInfo)
 	if err != nil {
 		common.LogError(c, fmt.Sprintf("getAndValidateTextRequest failed: %s", err.Error()))
-		return service.OpenAIErrorWrapper(err, "invalid_text_request", http.StatusBadRequest)
+		return service.OpenAIErrorWrapperLocal(err, "invalid_text_request", http.StatusBadRequest)
 	}
 
 	// map model name
@@ -82,7 +82,7 @@ func TextHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode {
 		modelMap := make(map[string]string)
 		err := json.Unmarshal([]byte(modelMapping), &modelMap)
 		if err != nil {
-			return service.OpenAIErrorWrapper(err, "unmarshal_model_mapping_failed", http.StatusInternalServerError)
+			return service.OpenAIErrorWrapperLocal(err, "unmarshal_model_mapping_failed", http.StatusInternalServerError)
 		}
 		if modelMap[textRequest.Model] != "" {
 			textRequest.Model = modelMap[textRequest.Model]
@@ -103,7 +103,7 @@ func TextHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode {
 	// count messages token error 计算promptTokens错误
 	if err != nil {
 		if sensitiveTrigger {
-			return service.OpenAIErrorWrapper(err, "sensitive_words_detected", http.StatusBadRequest)
+			return service.OpenAIErrorWrapperLocal(err, "sensitive_words_detected", http.StatusBadRequest)
 		}
 		return service.OpenAIErrorWrapper(err, "count_token_messages_failed", http.StatusInternalServerError)
 	}
@@ -162,7 +162,7 @@ func TextHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode {
 
 	if resp.StatusCode != http.StatusOK {
 		returnPreConsumedQuota(c, relayInfo.TokenId, userQuota, preConsumedQuota)
-		return service.OpenAIErrorWrapper(fmt.Errorf("bad response status code: %d", resp.StatusCode), "bad_response_status_code", resp.StatusCode)
+		return service.RelayErrorHandler(resp)
 	}
 
 	usage, openaiErr := adaptor.DoResponse(c, resp, relayInfo)
@@ -200,14 +200,14 @@ func getPromptTokens(textRequest *dto.GeneralOpenAIRequest, info *relaycommon.Re
 func preConsumeQuota(c *gin.Context, preConsumedQuota int, relayInfo *relaycommon.RelayInfo) (int, int, *dto.OpenAIErrorWithStatusCode) {
 	userQuota, err := model.CacheGetUserQuota(relayInfo.UserId)
 	if err != nil {
-		return 0, 0, service.OpenAIErrorWrapper(err, "get_user_quota_failed", http.StatusInternalServerError)
+		return 0, 0, service.OpenAIErrorWrapperLocal(err, "get_user_quota_failed", http.StatusInternalServerError)
 	}
 	if userQuota <= 0 || userQuota-preConsumedQuota < 0 {
-		return 0, 0, service.OpenAIErrorWrapper(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden)
+		return 0, 0, service.OpenAIErrorWrapperLocal(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden)
 	}
 	err = model.CacheDecreaseUserQuota(relayInfo.UserId, preConsumedQuota)
 	if err != nil {
-		return 0, 0, service.OpenAIErrorWrapper(err, "decrease_user_quota_failed", http.StatusInternalServerError)
+		return 0, 0, service.OpenAIErrorWrapperLocal(err, "decrease_user_quota_failed", http.StatusInternalServerError)
 	}
 	if userQuota > 100*preConsumedQuota {
 		// 用户额度充足,判断令牌额度是否充足
@@ -229,7 +229,7 @@ func preConsumeQuota(c *gin.Context, preConsumedQuota int, relayInfo *relaycommo
 	if preConsumedQuota > 0 {
 		userQuota, err = model.PreConsumeTokenQuota(relayInfo.TokenId, preConsumedQuota)
 		if err != nil {
-			return 0, 0, service.OpenAIErrorWrapper(err, "pre_consume_token_quota_failed", http.StatusForbidden)
+			return 0, 0, service.OpenAIErrorWrapperLocal(err, "pre_consume_token_quota_failed", http.StatusForbidden)
 		}
 	}
 	return preConsumedQuota, userQuota, nil

+ 23 - 1
service/channel.go

@@ -6,6 +6,7 @@ import (
 	"one-api/common"
 	relaymodel "one-api/dto"
 	"one-api/model"
+	"strings"
 )
 
 // disable & notify
@@ -33,7 +34,28 @@ func ShouldDisableChannel(err *relaymodel.OpenAIError, statusCode int) bool {
 	if statusCode == http.StatusUnauthorized {
 		return true
 	}
-	if err.Type == "insufficient_quota" || err.Code == "invalid_api_key" || err.Code == "account_deactivated" || err.Code == "billing_not_active" {
+	switch err.Code {
+	case "invalid_api_key":
+		return true
+	case "account_deactivated":
+		return true
+	case "billing_not_active":
+		return true
+	}
+	switch err.Type {
+	case "insufficient_quota":
+		return true
+	// https://docs.anthropic.com/claude/reference/errors
+	case "authentication_error":
+		return true
+	case "permission_error":
+		return true
+	case "forbidden":
+		return true
+	}
+	if strings.HasPrefix(err.Message, "Your credit balance is too low") { // anthropic
+		return true
+	} else if strings.HasPrefix(err.Message, "This organization has been disabled.") {
 		return true
 	}
 	return false

+ 6 - 0
service/error.go

@@ -46,6 +46,12 @@ func OpenAIErrorWrapper(err error, code string, statusCode int) *dto.OpenAIError
 	}
 }
 
+func OpenAIErrorWrapperLocal(err error, code string, statusCode int) *dto.OpenAIErrorWithStatusCode {
+	openaiErr := OpenAIErrorWrapper(err, code, statusCode)
+	openaiErr.LocalError = true
+	return openaiErr
+}
+
 func RelayErrorHandler(resp *http.Response) (errWithStatusCode *dto.OpenAIErrorWithStatusCode) {
 	errWithStatusCode = &dto.OpenAIErrorWithStatusCode{
 		StatusCode: resp.StatusCode,