relay.go 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404
  1. package controller
  2. import (
  3. "bytes"
  4. "errors"
  5. "fmt"
  6. "github.com/gin-gonic/gin"
  7. "github.com/gorilla/websocket"
  8. "io"
  9. "log"
  10. "net/http"
  11. "one-api/common"
  12. "one-api/dto"
  13. "one-api/middleware"
  14. "one-api/model"
  15. "one-api/relay"
  16. "one-api/relay/constant"
  17. relayconstant "one-api/relay/constant"
  18. "one-api/service"
  19. "strings"
  20. )
  21. func relayHandler(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode {
  22. var err *dto.OpenAIErrorWithStatusCode
  23. switch relayMode {
  24. case relayconstant.RelayModeImagesGenerations:
  25. err = relay.ImageHelper(c, relayMode)
  26. case relayconstant.RelayModeAudioSpeech:
  27. fallthrough
  28. case relayconstant.RelayModeAudioTranslation:
  29. fallthrough
  30. case relayconstant.RelayModeAudioTranscription:
  31. err = relay.AudioHelper(c)
  32. case relayconstant.RelayModeRerank:
  33. err = relay.RerankHelper(c, relayMode)
  34. case relayconstant.RelayModeEmbeddings:
  35. err = relay.EmbeddingHelper(c,relayMode)
  36. default:
  37. err = relay.TextHelper(c)
  38. }
  39. return err
  40. }
  41. func wsHandler(c *gin.Context, ws *websocket.Conn, relayMode int) *dto.OpenAIErrorWithStatusCode {
  42. var err *dto.OpenAIErrorWithStatusCode
  43. switch relayMode {
  44. default:
  45. err = relay.TextHelper(c)
  46. }
  47. return err
  48. }
  49. func Relay(c *gin.Context) {
  50. relayMode := constant.Path2RelayMode(c.Request.URL.Path)
  51. requestId := c.GetString(common.RequestIdKey)
  52. group := c.GetString("group")
  53. originalModel := c.GetString("original_model")
  54. var openaiErr *dto.OpenAIErrorWithStatusCode
  55. //获取request body 并输出到日志
  56. requestBody, _ := common.GetRequestBody(c)
  57. common.LogInfo(c, fmt.Sprintf("relayMode: %d ,request body: %s",relayMode, string(requestBody)))
  58. for i := 0; i <= common.RetryTimes; i++ {
  59. channel, err := getChannel(c, group, originalModel, i)
  60. if err != nil {
  61. common.LogError(c, err.Error())
  62. openaiErr = service.OpenAIErrorWrapperLocal(err, "get_channel_failed", http.StatusInternalServerError)
  63. break
  64. }
  65. openaiErr = relayRequest(c, relayMode, channel)
  66. if openaiErr == nil {
  67. return // 成功处理请求,直接返回
  68. }
  69. go processChannelError(c, channel.Id, channel.Type, channel.Name, channel.GetAutoBan(), openaiErr)
  70. if !shouldRetry(c, openaiErr, common.RetryTimes-i) {
  71. break
  72. }
  73. }
  74. useChannel := c.GetStringSlice("use_channel")
  75. if len(useChannel) > 1 {
  76. retryLogStr := fmt.Sprintf("重试:%s", strings.Trim(strings.Join(strings.Fields(fmt.Sprint(useChannel)), "->"), "[]"))
  77. common.LogInfo(c, retryLogStr)
  78. }
  79. if openaiErr != nil {
  80. if openaiErr.StatusCode == http.StatusTooManyRequests {
  81. openaiErr.Error.Message = "当前分组上游负载已饱和,请稍后再试"
  82. }
  83. openaiErr.Error.Message = common.MessageWithRequestId(openaiErr.Error.Message, requestId)
  84. c.JSON(openaiErr.StatusCode, gin.H{
  85. "error": openaiErr.Error,
  86. })
  87. }
  88. }
  89. var upgrader = websocket.Upgrader{
  90. Subprotocols: []string{"realtime"}, // WS 握手支持的协议,如果有使用 Sec-WebSocket-Protocol,则必须在此声明对应的 Protocol TODO add other protocol
  91. CheckOrigin: func(r *http.Request) bool {
  92. return true // 允许跨域
  93. },
  94. }
  95. func WssRelay(c *gin.Context) {
  96. // 将 HTTP 连接升级为 WebSocket 连接
  97. ws, err := upgrader.Upgrade(c.Writer, c.Request, nil)
  98. defer ws.Close()
  99. if err != nil {
  100. openaiErr := service.OpenAIErrorWrapper(err, "get_channel_failed", http.StatusInternalServerError)
  101. service.WssError(c, ws, openaiErr.Error)
  102. return
  103. }
  104. relayMode := constant.Path2RelayMode(c.Request.URL.Path)
  105. requestId := c.GetString(common.RequestIdKey)
  106. group := c.GetString("group")
  107. //wss://api.openai.com/v1/realtime?model=gpt-4o-realtime-preview-2024-10-01
  108. originalModel := c.GetString("original_model")
  109. var openaiErr *dto.OpenAIErrorWithStatusCode
  110. for i := 0; i <= common.RetryTimes; i++ {
  111. channel, err := getChannel(c, group, originalModel, i)
  112. if err != nil {
  113. common.LogError(c, err.Error())
  114. openaiErr = service.OpenAIErrorWrapperLocal(err, "get_channel_failed", http.StatusInternalServerError)
  115. break
  116. }
  117. openaiErr = wssRequest(c, ws, relayMode, channel)
  118. if openaiErr == nil {
  119. return // 成功处理请求,直接返回
  120. }
  121. go processChannelError(c, channel.Id, channel.Type, channel.Name, channel.GetAutoBan(), openaiErr)
  122. if !shouldRetry(c, openaiErr, common.RetryTimes-i) {
  123. break
  124. }
  125. }
  126. useChannel := c.GetStringSlice("use_channel")
  127. if len(useChannel) > 1 {
  128. retryLogStr := fmt.Sprintf("重试:%s", strings.Trim(strings.Join(strings.Fields(fmt.Sprint(useChannel)), "->"), "[]"))
  129. common.LogInfo(c, retryLogStr)
  130. }
  131. if openaiErr != nil {
  132. if openaiErr.StatusCode == http.StatusTooManyRequests {
  133. openaiErr.Error.Message = "当前分组上游负载已饱和,请稍后再试"
  134. }
  135. openaiErr.Error.Message = common.MessageWithRequestId(openaiErr.Error.Message, requestId)
  136. service.WssError(c, ws, openaiErr.Error)
  137. }
  138. }
  139. func relayRequest(c *gin.Context, relayMode int, channel *model.Channel) *dto.OpenAIErrorWithStatusCode {
  140. common.LogInfo(c, fmt.Sprintf("relayMode: %d ,channel Id : %s",relayMode, string(channel.Id)))
  141. addUsedChannel(c, channel.Id)
  142. requestBody, _ := common.GetRequestBody(c)
  143. c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody))
  144. return relayHandler(c, relayMode)
  145. }
  146. func wssRequest(c *gin.Context, ws *websocket.Conn, relayMode int, channel *model.Channel) *dto.OpenAIErrorWithStatusCode {
  147. addUsedChannel(c, channel.Id)
  148. requestBody, _ := common.GetRequestBody(c)
  149. c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody))
  150. return relay.WssHelper(c, ws)
  151. }
  152. func addUsedChannel(c *gin.Context, channelId int) {
  153. useChannel := c.GetStringSlice("use_channel")
  154. useChannel = append(useChannel, fmt.Sprintf("%d", channelId))
  155. c.Set("use_channel", useChannel)
  156. }
  157. func getChannel(c *gin.Context, group, originalModel string, retryCount int) (*model.Channel, error) {
  158. if retryCount == 0 {
  159. autoBan := c.GetBool("auto_ban")
  160. autoBanInt := 1
  161. if !autoBan {
  162. autoBanInt = 0
  163. }
  164. return &model.Channel{
  165. Id: c.GetInt("channel_id"),
  166. Type: c.GetInt("channel_type"),
  167. Name: c.GetString("channel_name"),
  168. AutoBan: &autoBanInt,
  169. }, nil
  170. }
  171. channel, err := model.CacheGetRandomSatisfiedChannel(group, originalModel, retryCount)
  172. if err != nil {
  173. return nil, errors.New(fmt.Sprintf("获取重试渠道失败: %s", err.Error()))
  174. }
  175. middleware.SetupContextForSelectedChannel(c, channel, originalModel)
  176. return channel, nil
  177. }
  178. func shouldRetry(c *gin.Context, openaiErr *dto.OpenAIErrorWithStatusCode, retryTimes int) bool {
  179. if openaiErr == nil {
  180. return false
  181. }
  182. if openaiErr.LocalError {
  183. return false
  184. }
  185. if retryTimes <= 0 {
  186. return false
  187. }
  188. if _, ok := c.Get("specific_channel_id"); ok {
  189. return false
  190. }
  191. if openaiErr.StatusCode == http.StatusTooManyRequests {
  192. return true
  193. }
  194. if openaiErr.StatusCode == 307 {
  195. return true
  196. }
  197. if openaiErr.StatusCode/100 == 5 {
  198. // 超时不重试
  199. if openaiErr.StatusCode == 504 || openaiErr.StatusCode == 524 {
  200. return false
  201. }
  202. return true
  203. }
  204. if openaiErr.StatusCode == http.StatusBadRequest {
  205. channelType := c.GetInt("channel_type")
  206. if channelType == common.ChannelTypeAnthropic {
  207. return true
  208. }
  209. return false
  210. }
  211. if openaiErr.StatusCode == 408 {
  212. // azure处理超时不重试
  213. return false
  214. }
  215. if openaiErr.StatusCode/100 == 2 {
  216. return false
  217. }
  218. return true
  219. }
  220. func processChannelError(c *gin.Context, channelId int, channelType int, channelName string, autoBan bool, err *dto.OpenAIErrorWithStatusCode) {
  221. // 不要使用context获取渠道信息,异步处理时可能会出现渠道信息不一致的情况
  222. // do not use context to get channel info, there may be inconsistent channel info when processing asynchronously
  223. common.LogError(c, fmt.Sprintf("relay error (channel #%d, status code: %d): %s", channelId, err.StatusCode, err.Error.Message))
  224. if service.ShouldDisableChannel(channelType, err) && autoBan {
  225. service.DisableChannel(channelId, channelName, err.Error.Message)
  226. }
  227. }
  228. func RelayMidjourney(c *gin.Context) {
  229. relayMode := c.GetInt("relay_mode")
  230. var err *dto.MidjourneyResponse
  231. switch relayMode {
  232. case relayconstant.RelayModeMidjourneyNotify:
  233. err = relay.RelayMidjourneyNotify(c)
  234. case relayconstant.RelayModeMidjourneyTaskFetch, relayconstant.RelayModeMidjourneyTaskFetchByCondition:
  235. err = relay.RelayMidjourneyTask(c, relayMode)
  236. case relayconstant.RelayModeMidjourneyTaskImageSeed:
  237. err = relay.RelayMidjourneyTaskImageSeed(c)
  238. case relayconstant.RelayModeSwapFace:
  239. err = relay.RelaySwapFace(c)
  240. default:
  241. err = relay.RelayMidjourneySubmit(c, relayMode)
  242. }
  243. //err = relayMidjourneySubmit(c, relayMode)
  244. log.Println(err)
  245. if err != nil {
  246. statusCode := http.StatusBadRequest
  247. if err.Code == 30 {
  248. err.Result = "当前分组负载已饱和,请稍后再试,或升级账户以提升服务质量。"
  249. statusCode = http.StatusTooManyRequests
  250. }
  251. c.JSON(statusCode, gin.H{
  252. "description": fmt.Sprintf("%s %s", err.Description, err.Result),
  253. "type": "upstream_error",
  254. "code": err.Code,
  255. })
  256. channelId := c.GetInt("channel_id")
  257. common.LogError(c, fmt.Sprintf("relay error (channel #%d, status code %d): %s", channelId, statusCode, fmt.Sprintf("%s %s", err.Description, err.Result)))
  258. }
  259. }
  260. func RelayNotImplemented(c *gin.Context) {
  261. err := dto.OpenAIError{
  262. Message: "API not implemented",
  263. Type: "new_api_error",
  264. Param: "",
  265. Code: "api_not_implemented",
  266. }
  267. c.JSON(http.StatusNotImplemented, gin.H{
  268. "error": err,
  269. })
  270. }
  271. func RelayNotFound(c *gin.Context) {
  272. err := dto.OpenAIError{
  273. Message: fmt.Sprintf("Invalid URL (%s %s)", c.Request.Method, c.Request.URL.Path),
  274. Type: "invalid_request_error",
  275. Param: "",
  276. Code: "",
  277. }
  278. c.JSON(http.StatusNotFound, gin.H{
  279. "error": err,
  280. })
  281. }
  282. func RelayTask(c *gin.Context) {
  283. retryTimes := common.RetryTimes
  284. channelId := c.GetInt("channel_id")
  285. relayMode := c.GetInt("relay_mode")
  286. group := c.GetString("group")
  287. originalModel := c.GetString("original_model")
  288. c.Set("use_channel", []string{fmt.Sprintf("%d", channelId)})
  289. taskErr := taskRelayHandler(c, relayMode)
  290. if taskErr == nil {
  291. retryTimes = 0
  292. }
  293. for i := 0; shouldRetryTaskRelay(c, channelId, taskErr, retryTimes) && i < retryTimes; i++ {
  294. channel, err := model.CacheGetRandomSatisfiedChannel(group, originalModel, i)
  295. if err != nil {
  296. common.LogError(c, fmt.Sprintf("CacheGetRandomSatisfiedChannel failed: %s", err.Error()))
  297. break
  298. }
  299. channelId = channel.Id
  300. useChannel := c.GetStringSlice("use_channel")
  301. useChannel = append(useChannel, fmt.Sprintf("%d", channelId))
  302. c.Set("use_channel", useChannel)
  303. common.LogInfo(c, fmt.Sprintf("using channel #%d to retry (remain times %d)", channel.Id, i))
  304. middleware.SetupContextForSelectedChannel(c, channel, originalModel)
  305. requestBody, err := common.GetRequestBody(c)
  306. c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody))
  307. taskErr = taskRelayHandler(c, relayMode)
  308. }
  309. useChannel := c.GetStringSlice("use_channel")
  310. if len(useChannel) > 1 {
  311. retryLogStr := fmt.Sprintf("重试:%s", strings.Trim(strings.Join(strings.Fields(fmt.Sprint(useChannel)), "->"), "[]"))
  312. common.LogInfo(c, retryLogStr)
  313. }
  314. if taskErr != nil {
  315. if taskErr.StatusCode == http.StatusTooManyRequests {
  316. taskErr.Message = "当前分组上游负载已饱和,请稍后再试"
  317. }
  318. c.JSON(taskErr.StatusCode, taskErr)
  319. }
  320. }
  321. func taskRelayHandler(c *gin.Context, relayMode int) *dto.TaskError {
  322. var err *dto.TaskError
  323. switch relayMode {
  324. case relayconstant.RelayModeSunoFetch, relayconstant.RelayModeSunoFetchByID:
  325. err = relay.RelayTaskFetch(c, relayMode)
  326. default:
  327. err = relay.RelayTaskSubmit(c, relayMode)
  328. }
  329. return err
  330. }
  331. func shouldRetryTaskRelay(c *gin.Context, channelId int, taskErr *dto.TaskError, retryTimes int) bool {
  332. if taskErr == nil {
  333. return false
  334. }
  335. if retryTimes <= 0 {
  336. return false
  337. }
  338. if _, ok := c.Get("specific_channel_id"); ok {
  339. return false
  340. }
  341. if taskErr.StatusCode == http.StatusTooManyRequests {
  342. return true
  343. }
  344. if taskErr.StatusCode == 307 {
  345. return true
  346. }
  347. if taskErr.StatusCode/100 == 5 {
  348. // 超时不重试
  349. if taskErr.StatusCode == 504 || taskErr.StatusCode == 524 {
  350. return false
  351. }
  352. return true
  353. }
  354. if taskErr.StatusCode == http.StatusBadRequest {
  355. return false
  356. }
  357. if taskErr.StatusCode == 408 {
  358. // azure处理超时不重试
  359. return false
  360. }
  361. if taskErr.LocalError {
  362. return false
  363. }
  364. if taskErr.StatusCode/100 == 2 {
  365. return false
  366. }
  367. return true
  368. }