relay.go 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449
  1. package controller
  2. import (
  3. "bytes"
  4. "errors"
  5. "fmt"
  6. "github.com/gin-gonic/gin"
  7. "github.com/gorilla/websocket"
  8. "io"
  9. "log"
  10. "net/http"
  11. "one-api/common"
  12. "one-api/dto"
  13. "one-api/middleware"
  14. "one-api/model"
  15. "one-api/relay"
  16. "one-api/relay/constant"
  17. relayconstant "one-api/relay/constant"
  18. "one-api/service"
  19. "one-api/setting"
  20. "strings"
  21. )
  22. func relayHandler(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode {
  23. var err *dto.OpenAIErrorWithStatusCode
  24. switch relayMode {
  25. case relayconstant.RelayModeImagesGenerations:
  26. err = relay.ImageHelper(c, relayMode)
  27. case relayconstant.RelayModeAudioSpeech:
  28. fallthrough
  29. case relayconstant.RelayModeAudioTranslation:
  30. fallthrough
  31. case relayconstant.RelayModeAudioTranscription:
  32. err = relay.AudioHelper(c)
  33. case relayconstant.RelayModeRerank:
  34. err = relay.RerankHelper(c, relayMode)
  35. default:
  36. err = relay.TextHelper(c)
  37. }
  38. return err
  39. }
  40. func wsHandler(c *gin.Context, ws *websocket.Conn, relayMode int) *dto.OpenAIErrorWithStatusCode {
  41. var err *dto.OpenAIErrorWithStatusCode
  42. switch relayMode {
  43. default:
  44. err = relay.TextHelper(c)
  45. }
  46. return err
  47. }
  48. func Playground(c *gin.Context) {
  49. var openaiErr *dto.OpenAIErrorWithStatusCode
  50. defer func() {
  51. if openaiErr != nil {
  52. c.JSON(openaiErr.StatusCode, gin.H{
  53. "error": openaiErr.Error,
  54. })
  55. }
  56. }()
  57. useAccessToken := c.GetBool("use_access_token")
  58. if useAccessToken {
  59. openaiErr = service.OpenAIErrorWrapperLocal(errors.New("暂不支持使用 access token"), "access_token_not_supported", http.StatusBadRequest)
  60. return
  61. }
  62. playgroundRequest := &dto.PlayGroundRequest{}
  63. err := common.UnmarshalBodyReusable(c, playgroundRequest)
  64. if err != nil {
  65. openaiErr = service.OpenAIErrorWrapperLocal(err, "unmarshal_request_failed", http.StatusBadRequest)
  66. return
  67. }
  68. if playgroundRequest.Model == "" {
  69. openaiErr = service.OpenAIErrorWrapperLocal(errors.New("请选择模型"), "model_required", http.StatusBadRequest)
  70. return
  71. }
  72. c.Set("original_model", playgroundRequest.Model)
  73. group := playgroundRequest.Group
  74. userGroup := c.GetString("group")
  75. if group == "" {
  76. group = userGroup
  77. } else {
  78. if !setting.GroupInUserUsableGroups(group) && group != userGroup {
  79. openaiErr = service.OpenAIErrorWrapperLocal(errors.New("无权访问该分组"), "group_not_allowed", http.StatusForbidden)
  80. return
  81. }
  82. c.Set("group", group)
  83. }
  84. c.Set("token_name", "playground-"+group)
  85. channel, err := model.CacheGetRandomSatisfiedChannel(group, playgroundRequest.Model, 0)
  86. if err != nil {
  87. message := fmt.Sprintf("当前分组 %s 下对于模型 %s 无可用渠道", group, playgroundRequest.Model)
  88. openaiErr = service.OpenAIErrorWrapperLocal(errors.New(message), "get_playground_channel_failed", http.StatusInternalServerError)
  89. return
  90. }
  91. middleware.SetupContextForSelectedChannel(c, channel, playgroundRequest.Model)
  92. Relay(c)
  93. }
  94. func Relay(c *gin.Context) {
  95. relayMode := constant.Path2RelayMode(c.Request.URL.Path)
  96. requestId := c.GetString(common.RequestIdKey)
  97. group := c.GetString("group")
  98. originalModel := c.GetString("original_model")
  99. var openaiErr *dto.OpenAIErrorWithStatusCode
  100. for i := 0; i <= common.RetryTimes; i++ {
  101. channel, err := getChannel(c, group, originalModel, i)
  102. if err != nil {
  103. common.LogError(c, err.Error())
  104. openaiErr = service.OpenAIErrorWrapperLocal(err, "get_channel_failed", http.StatusInternalServerError)
  105. break
  106. }
  107. openaiErr = relayRequest(c, relayMode, channel)
  108. if openaiErr == nil {
  109. return // 成功处理请求,直接返回
  110. }
  111. go processChannelError(c, channel.Id, channel.Type, channel.Name, channel.GetAutoBan(), openaiErr)
  112. if !shouldRetry(c, openaiErr, common.RetryTimes-i) {
  113. break
  114. }
  115. }
  116. useChannel := c.GetStringSlice("use_channel")
  117. if len(useChannel) > 1 {
  118. retryLogStr := fmt.Sprintf("重试:%s", strings.Trim(strings.Join(strings.Fields(fmt.Sprint(useChannel)), "->"), "[]"))
  119. common.LogInfo(c, retryLogStr)
  120. }
  121. if openaiErr != nil {
  122. if openaiErr.StatusCode == http.StatusTooManyRequests {
  123. openaiErr.Error.Message = "当前分组上游负载已饱和,请稍后再试"
  124. }
  125. openaiErr.Error.Message = common.MessageWithRequestId(openaiErr.Error.Message, requestId)
  126. c.JSON(openaiErr.StatusCode, gin.H{
  127. "error": openaiErr.Error,
  128. })
  129. }
  130. }
  131. var upgrader = websocket.Upgrader{
  132. Subprotocols: []string{"realtime"}, // WS 握手支持的协议,如果有使用 Sec-WebSocket-Protocol,则必须在此声明对应的 Protocol TODO add other protocol
  133. CheckOrigin: func(r *http.Request) bool {
  134. return true // 允许跨域
  135. },
  136. }
  137. func WssRelay(c *gin.Context) {
  138. // 将 HTTP 连接升级为 WebSocket 连接
  139. ws, err := upgrader.Upgrade(c.Writer, c.Request, nil)
  140. defer ws.Close()
  141. if err != nil {
  142. openaiErr := service.OpenAIErrorWrapper(err, "get_channel_failed", http.StatusInternalServerError)
  143. service.WssError(c, ws, openaiErr.Error)
  144. return
  145. }
  146. relayMode := constant.Path2RelayMode(c.Request.URL.Path)
  147. requestId := c.GetString(common.RequestIdKey)
  148. group := c.GetString("group")
  149. //wss://api.openai.com/v1/realtime?model=gpt-4o-realtime-preview-2024-10-01
  150. originalModel := c.GetString("original_model")
  151. var openaiErr *dto.OpenAIErrorWithStatusCode
  152. for i := 0; i <= common.RetryTimes; i++ {
  153. channel, err := getChannel(c, group, originalModel, i)
  154. if err != nil {
  155. common.LogError(c, err.Error())
  156. openaiErr = service.OpenAIErrorWrapperLocal(err, "get_channel_failed", http.StatusInternalServerError)
  157. break
  158. }
  159. openaiErr = wssRequest(c, ws, relayMode, channel)
  160. if openaiErr == nil {
  161. return // 成功处理请求,直接返回
  162. }
  163. go processChannelError(c, channel.Id, channel.Type, channel.Name, channel.GetAutoBan(), openaiErr)
  164. if !shouldRetry(c, openaiErr, common.RetryTimes-i) {
  165. break
  166. }
  167. }
  168. useChannel := c.GetStringSlice("use_channel")
  169. if len(useChannel) > 1 {
  170. retryLogStr := fmt.Sprintf("重试:%s", strings.Trim(strings.Join(strings.Fields(fmt.Sprint(useChannel)), "->"), "[]"))
  171. common.LogInfo(c, retryLogStr)
  172. }
  173. if openaiErr != nil {
  174. if openaiErr.StatusCode == http.StatusTooManyRequests {
  175. openaiErr.Error.Message = "当前分组上游负载已饱和,请稍后再试"
  176. }
  177. openaiErr.Error.Message = common.MessageWithRequestId(openaiErr.Error.Message, requestId)
  178. service.WssError(c, ws, openaiErr.Error)
  179. }
  180. }
  181. func relayRequest(c *gin.Context, relayMode int, channel *model.Channel) *dto.OpenAIErrorWithStatusCode {
  182. addUsedChannel(c, channel.Id)
  183. requestBody, _ := common.GetRequestBody(c)
  184. c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody))
  185. return relayHandler(c, relayMode)
  186. }
  187. func wssRequest(c *gin.Context, ws *websocket.Conn, relayMode int, channel *model.Channel) *dto.OpenAIErrorWithStatusCode {
  188. addUsedChannel(c, channel.Id)
  189. requestBody, _ := common.GetRequestBody(c)
  190. c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody))
  191. return relay.WssHelper(c, ws)
  192. }
  193. func addUsedChannel(c *gin.Context, channelId int) {
  194. useChannel := c.GetStringSlice("use_channel")
  195. useChannel = append(useChannel, fmt.Sprintf("%d", channelId))
  196. c.Set("use_channel", useChannel)
  197. }
  198. func getChannel(c *gin.Context, group, originalModel string, retryCount int) (*model.Channel, error) {
  199. if retryCount == 0 {
  200. autoBan := c.GetBool("auto_ban")
  201. autoBanInt := 1
  202. if !autoBan {
  203. autoBanInt = 0
  204. }
  205. return &model.Channel{
  206. Id: c.GetInt("channel_id"),
  207. Type: c.GetInt("channel_type"),
  208. Name: c.GetString("channel_name"),
  209. AutoBan: &autoBanInt,
  210. }, nil
  211. }
  212. channel, err := model.CacheGetRandomSatisfiedChannel(group, originalModel, retryCount)
  213. if err != nil {
  214. return nil, errors.New(fmt.Sprintf("获取重试渠道失败: %s", err.Error()))
  215. }
  216. middleware.SetupContextForSelectedChannel(c, channel, originalModel)
  217. return channel, nil
  218. }
  219. func shouldRetry(c *gin.Context, openaiErr *dto.OpenAIErrorWithStatusCode, retryTimes int) bool {
  220. if openaiErr == nil {
  221. return false
  222. }
  223. if openaiErr.LocalError {
  224. return false
  225. }
  226. if retryTimes <= 0 {
  227. return false
  228. }
  229. if _, ok := c.Get("specific_channel_id"); ok {
  230. return false
  231. }
  232. if openaiErr.StatusCode == http.StatusTooManyRequests {
  233. return true
  234. }
  235. if openaiErr.StatusCode == 307 {
  236. return true
  237. }
  238. if openaiErr.StatusCode/100 == 5 {
  239. // 超时不重试
  240. if openaiErr.StatusCode == 504 || openaiErr.StatusCode == 524 {
  241. return false
  242. }
  243. return true
  244. }
  245. if openaiErr.StatusCode == http.StatusBadRequest {
  246. channelType := c.GetInt("channel_type")
  247. if channelType == common.ChannelTypeAnthropic {
  248. return true
  249. }
  250. return false
  251. }
  252. if openaiErr.StatusCode == 408 {
  253. // azure处理超时不重试
  254. return false
  255. }
  256. if openaiErr.StatusCode/100 == 2 {
  257. return false
  258. }
  259. return true
  260. }
  261. func processChannelError(c *gin.Context, channelId int, channelType int, channelName string, autoBan bool, err *dto.OpenAIErrorWithStatusCode) {
  262. // 不要使用context获取渠道信息,异步处理时可能会出现渠道信息不一致的情况
  263. // do not use context to get channel info, there may be inconsistent channel info when processing asynchronously
  264. common.LogError(c, fmt.Sprintf("relay error (channel #%d, status code: %d): %s", channelId, err.StatusCode, err.Error.Message))
  265. if service.ShouldDisableChannel(channelType, err) && autoBan {
  266. service.DisableChannel(channelId, channelName, err.Error.Message)
  267. }
  268. }
  269. func RelayMidjourney(c *gin.Context) {
  270. relayMode := c.GetInt("relay_mode")
  271. var err *dto.MidjourneyResponse
  272. switch relayMode {
  273. case relayconstant.RelayModeMidjourneyNotify:
  274. err = relay.RelayMidjourneyNotify(c)
  275. case relayconstant.RelayModeMidjourneyTaskFetch, relayconstant.RelayModeMidjourneyTaskFetchByCondition:
  276. err = relay.RelayMidjourneyTask(c, relayMode)
  277. case relayconstant.RelayModeMidjourneyTaskImageSeed:
  278. err = relay.RelayMidjourneyTaskImageSeed(c)
  279. case relayconstant.RelayModeSwapFace:
  280. err = relay.RelaySwapFace(c)
  281. default:
  282. err = relay.RelayMidjourneySubmit(c, relayMode)
  283. }
  284. //err = relayMidjourneySubmit(c, relayMode)
  285. log.Println(err)
  286. if err != nil {
  287. statusCode := http.StatusBadRequest
  288. if err.Code == 30 {
  289. err.Result = "当前分组负载已饱和,请稍后再试,或升级账户以提升服务质量。"
  290. statusCode = http.StatusTooManyRequests
  291. }
  292. c.JSON(statusCode, gin.H{
  293. "description": fmt.Sprintf("%s %s", err.Description, err.Result),
  294. "type": "upstream_error",
  295. "code": err.Code,
  296. })
  297. channelId := c.GetInt("channel_id")
  298. common.LogError(c, fmt.Sprintf("relay error (channel #%d, status code %d): %s", channelId, statusCode, fmt.Sprintf("%s %s", err.Description, err.Result)))
  299. }
  300. }
  301. func RelayNotImplemented(c *gin.Context) {
  302. err := dto.OpenAIError{
  303. Message: "API not implemented",
  304. Type: "new_api_error",
  305. Param: "",
  306. Code: "api_not_implemented",
  307. }
  308. c.JSON(http.StatusNotImplemented, gin.H{
  309. "error": err,
  310. })
  311. }
  312. func RelayNotFound(c *gin.Context) {
  313. err := dto.OpenAIError{
  314. Message: fmt.Sprintf("Invalid URL (%s %s)", c.Request.Method, c.Request.URL.Path),
  315. Type: "invalid_request_error",
  316. Param: "",
  317. Code: "",
  318. }
  319. c.JSON(http.StatusNotFound, gin.H{
  320. "error": err,
  321. })
  322. }
  323. func RelayTask(c *gin.Context) {
  324. retryTimes := common.RetryTimes
  325. channelId := c.GetInt("channel_id")
  326. relayMode := c.GetInt("relay_mode")
  327. group := c.GetString("group")
  328. originalModel := c.GetString("original_model")
  329. c.Set("use_channel", []string{fmt.Sprintf("%d", channelId)})
  330. taskErr := taskRelayHandler(c, relayMode)
  331. if taskErr == nil {
  332. retryTimes = 0
  333. }
  334. for i := 0; shouldRetryTaskRelay(c, channelId, taskErr, retryTimes) && i < retryTimes; i++ {
  335. channel, err := model.CacheGetRandomSatisfiedChannel(group, originalModel, i)
  336. if err != nil {
  337. common.LogError(c, fmt.Sprintf("CacheGetRandomSatisfiedChannel failed: %s", err.Error()))
  338. break
  339. }
  340. channelId = channel.Id
  341. useChannel := c.GetStringSlice("use_channel")
  342. useChannel = append(useChannel, fmt.Sprintf("%d", channelId))
  343. c.Set("use_channel", useChannel)
  344. common.LogInfo(c, fmt.Sprintf("using channel #%d to retry (remain times %d)", channel.Id, i))
  345. middleware.SetupContextForSelectedChannel(c, channel, originalModel)
  346. requestBody, err := common.GetRequestBody(c)
  347. c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody))
  348. taskErr = taskRelayHandler(c, relayMode)
  349. }
  350. useChannel := c.GetStringSlice("use_channel")
  351. if len(useChannel) > 1 {
  352. retryLogStr := fmt.Sprintf("重试:%s", strings.Trim(strings.Join(strings.Fields(fmt.Sprint(useChannel)), "->"), "[]"))
  353. common.LogInfo(c, retryLogStr)
  354. }
  355. if taskErr != nil {
  356. if taskErr.StatusCode == http.StatusTooManyRequests {
  357. taskErr.Message = "当前分组上游负载已饱和,请稍后再试"
  358. }
  359. c.JSON(taskErr.StatusCode, taskErr)
  360. }
  361. }
  362. func taskRelayHandler(c *gin.Context, relayMode int) *dto.TaskError {
  363. var err *dto.TaskError
  364. switch relayMode {
  365. case relayconstant.RelayModeSunoFetch, relayconstant.RelayModeSunoFetchByID:
  366. err = relay.RelayTaskFetch(c, relayMode)
  367. default:
  368. err = relay.RelayTaskSubmit(c, relayMode)
  369. }
  370. return err
  371. }
  372. func shouldRetryTaskRelay(c *gin.Context, channelId int, taskErr *dto.TaskError, retryTimes int) bool {
  373. if taskErr == nil {
  374. return false
  375. }
  376. if retryTimes <= 0 {
  377. return false
  378. }
  379. if _, ok := c.Get("specific_channel_id"); ok {
  380. return false
  381. }
  382. if taskErr.StatusCode == http.StatusTooManyRequests {
  383. return true
  384. }
  385. if taskErr.StatusCode == 307 {
  386. return true
  387. }
  388. if taskErr.StatusCode/100 == 5 {
  389. // 超时不重试
  390. if taskErr.StatusCode == 504 || taskErr.StatusCode == 524 {
  391. return false
  392. }
  393. return true
  394. }
  395. if taskErr.StatusCode == http.StatusBadRequest {
  396. return false
  397. }
  398. if taskErr.StatusCode == 408 {
  399. // azure处理超时不重试
  400. return false
  401. }
  402. if taskErr.LocalError {
  403. return false
  404. }
  405. if taskErr.StatusCode/100 == 2 {
  406. return false
  407. }
  408. return true
  409. }