|
|
@@ -1,21 +1,23 @@
|
|
|
package controller
|
|
|
|
|
|
import (
|
|
|
+ "bytes"
|
|
|
"fmt"
|
|
|
"github.com/gin-gonic/gin"
|
|
|
+ "io"
|
|
|
"log"
|
|
|
"net/http"
|
|
|
"one-api/common"
|
|
|
"one-api/dto"
|
|
|
+ "one-api/middleware"
|
|
|
+ "one-api/model"
|
|
|
"one-api/relay"
|
|
|
"one-api/relay/constant"
|
|
|
relayconstant "one-api/relay/constant"
|
|
|
"one-api/service"
|
|
|
- "strconv"
|
|
|
)
|
|
|
|
|
|
-func Relay(c *gin.Context) {
|
|
|
- relayMode := constant.Path2RelayMode(c.Request.URL.Path)
|
|
|
+func relayHandler(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode {
|
|
|
var err *dto.OpenAIErrorWithStatusCode
|
|
|
switch relayMode {
|
|
|
case relayconstant.RelayModeImagesGenerations:
|
|
|
@@ -29,33 +31,88 @@ func Relay(c *gin.Context) {
|
|
|
default:
|
|
|
err = relay.TextHelper(c)
|
|
|
}
|
|
|
- if err != nil {
|
|
|
- requestId := c.GetString(common.RequestIdKey)
|
|
|
- retryTimesStr := c.Query("retry")
|
|
|
- retryTimes, _ := strconv.Atoi(retryTimesStr)
|
|
|
- if retryTimesStr == "" {
|
|
|
- retryTimes = common.RetryTimes
|
|
|
+ return err
|
|
|
+}
|
|
|
+
|
|
|
+func Relay(c *gin.Context) {
|
|
|
+ relayMode := constant.Path2RelayMode(c.Request.URL.Path)
|
|
|
+ retryTimes := common.RetryTimes
|
|
|
+ requestId := c.GetString(common.RequestIdKey)
|
|
|
+ channelId := c.GetInt("channel_id")
|
|
|
+ group := c.GetString("group")
|
|
|
+ originalModel := c.GetString("original_model")
|
|
|
+ openaiErr := relayHandler(c, relayMode)
|
|
|
+ retryLogStr := fmt.Sprintf("重试:%d", channelId)
|
|
|
+ if openaiErr != nil {
|
|
|
+ go processChannelError(c, channelId, openaiErr)
|
|
|
+ } else {
|
|
|
+ retryTimes = 0
|
|
|
+ }
|
|
|
+ for i := 0; shouldRetry(c, channelId, openaiErr, retryTimes) && i < retryTimes; i++ {
|
|
|
+ channel, err := model.CacheGetRandomSatisfiedChannel(group, originalModel, i)
|
|
|
+ if err != nil {
|
|
|
+ common.LogError(c.Request.Context(), fmt.Sprintf("CacheGetRandomSatisfiedChannel failed: %s", err.Error()))
|
|
|
+ break
|
|
|
}
|
|
|
- if retryTimes > 0 {
|
|
|
- c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s?retry=%d", c.Request.URL.Path, retryTimes-1))
|
|
|
- } else {
|
|
|
- if err.StatusCode == http.StatusTooManyRequests {
|
|
|
- //err.Error.Message = "当前分组上游负载已饱和,请稍后再试"
|
|
|
- }
|
|
|
- err.Error.Message = common.MessageWithRequestId(err.Error.Message, requestId)
|
|
|
- c.JSON(err.StatusCode, gin.H{
|
|
|
- "error": err.Error,
|
|
|
- })
|
|
|
+ channelId = channel.Id
|
|
|
+ retryLogStr += fmt.Sprintf("->%d", channel.Id)
|
|
|
+ common.LogInfo(c.Request.Context(), fmt.Sprintf("using channel #%d to retry (remain times %d)", channel.Id, i))
|
|
|
+ middleware.SetupContextForSelectedChannel(c, channel, originalModel)
|
|
|
+
|
|
|
+ requestBody, err := common.GetRequestBody(c)
|
|
|
+ c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody))
|
|
|
+ openaiErr = relayHandler(c, relayMode)
|
|
|
+ if openaiErr != nil {
|
|
|
+ go processChannelError(c, channelId, openaiErr)
|
|
|
}
|
|
|
- channelId := c.GetInt("channel_id")
|
|
|
- autoBan := c.GetBool("auto_ban")
|
|
|
- common.LogError(c.Request.Context(), fmt.Sprintf("relay error (channel #%d): %s", channelId, err.Error.Message))
|
|
|
- // https://platform.openai.com/docs/guides/error-codes/api-errors
|
|
|
- if service.ShouldDisableChannel(&err.Error, err.StatusCode) && autoBan {
|
|
|
- channelId := c.GetInt("channel_id")
|
|
|
- channelName := c.GetString("channel_name")
|
|
|
- service.DisableChannel(channelId, channelName, err.Error.Message)
|
|
|
+ }
|
|
|
+ common.LogInfo(c.Request.Context(), retryLogStr)
|
|
|
+
|
|
|
+ if openaiErr != nil {
|
|
|
+ if openaiErr.StatusCode == http.StatusTooManyRequests {
|
|
|
+ openaiErr.Error.Message = "当前分组上游负载已饱和,请稍后再试"
|
|
|
}
|
|
|
+ openaiErr.Error.Message = common.MessageWithRequestId(openaiErr.Error.Message, requestId)
|
|
|
+ c.JSON(openaiErr.StatusCode, gin.H{
|
|
|
+ "error": openaiErr.Error,
|
|
|
+ })
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+func shouldRetry(c *gin.Context, channelId int, openaiErr *dto.OpenAIErrorWithStatusCode, retryTimes int) bool {
|
|
|
+ if openaiErr == nil {
|
|
|
+ return false
|
|
|
+ }
|
|
|
+ if retryTimes <= 0 {
|
|
|
+ return false
|
|
|
+ }
|
|
|
+ if _, ok := c.Get("specific_channel_id"); ok {
|
|
|
+ return false
|
|
|
+ }
|
|
|
+ if openaiErr.StatusCode == http.StatusTooManyRequests {
|
|
|
+ return true
|
|
|
+ }
|
|
|
+ if openaiErr.StatusCode/100 == 5 {
|
|
|
+ return true
|
|
|
+ }
|
|
|
+ if openaiErr.StatusCode == http.StatusBadRequest {
|
|
|
+ return false
|
|
|
+ }
|
|
|
+ if openaiErr.LocalError {
|
|
|
+ return false
|
|
|
+ }
|
|
|
+ if openaiErr.StatusCode/100 == 2 {
|
|
|
+ return false
|
|
|
+ }
|
|
|
+ return true
|
|
|
+}
|
|
|
+
|
|
|
+func processChannelError(c *gin.Context, channelId int, err *dto.OpenAIErrorWithStatusCode) {
|
|
|
+ autoBan := c.GetBool("auto_ban")
|
|
|
+ common.LogError(c.Request.Context(), fmt.Sprintf("relay error (channel #%d): %s", channelId, err.Error.Message))
|
|
|
+ if service.ShouldDisableChannel(&err.Error, err.StatusCode) && autoBan {
|
|
|
+ channelName := c.GetString("channel_name")
|
|
|
+ service.DisableChannel(channelId, channelName, err.Error.Message)
|
|
|
}
|
|
|
}
|
|
|
|