relay.go 4.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138
  1. package controller
  2. import (
  3. "fmt"
  4. "github.com/gin-gonic/gin"
  5. "log"
  6. "net/http"
  7. "one-api/common"
  8. "one-api/dto"
  9. "one-api/relay"
  10. "one-api/relay/constant"
  11. relayconstant "one-api/relay/constant"
  12. "one-api/service"
  13. "strconv"
  14. "strings"
  15. )
  16. func Relay(c *gin.Context) {
  17. relayMode := constant.Path2RelayMode(c.Request.URL.Path)
  18. var err *dto.OpenAIErrorWithStatusCode
  19. switch relayMode {
  20. case relayconstant.RelayModeImagesGenerations:
  21. err = relay.RelayImageHelper(c, relayMode)
  22. case relayconstant.RelayModeAudioSpeech:
  23. fallthrough
  24. case relayconstant.RelayModeAudioTranslation:
  25. fallthrough
  26. case relayconstant.RelayModeAudioTranscription:
  27. err = relay.AudioHelper(c, relayMode)
  28. default:
  29. err = relay.TextHelper(c)
  30. }
  31. if err != nil {
  32. requestId := c.GetString(common.RequestIdKey)
  33. retryTimesStr := c.Query("retry")
  34. retryTimes, _ := strconv.Atoi(retryTimesStr)
  35. if retryTimesStr == "" {
  36. retryTimes = common.RetryTimes
  37. }
  38. if retryTimes > 0 {
  39. c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s?retry=%d", c.Request.URL.Path, retryTimes-1))
  40. } else {
  41. if err.StatusCode == http.StatusTooManyRequests {
  42. //err.Error.Message = "当前分组上游负载已饱和,请稍后再试"
  43. }
  44. err.Error.Message = common.MessageWithRequestId(err.Error.Message, requestId)
  45. c.JSON(err.StatusCode, gin.H{
  46. "error": err.Error,
  47. })
  48. }
  49. channelId := c.GetInt("channel_id")
  50. autoBan := c.GetBool("auto_ban")
  51. common.LogError(c.Request.Context(), fmt.Sprintf("relay error (channel #%d): %s", channelId, err.Error.Message))
  52. // https://platform.openai.com/docs/guides/error-codes/api-errors
  53. if service.ShouldDisableChannel(&err.Error, err.StatusCode) && autoBan {
  54. channelId := c.GetInt("channel_id")
  55. channelName := c.GetString("channel_name")
  56. service.DisableChannel(channelId, channelName, err.Error.Message)
  57. }
  58. }
  59. }
  60. func RelayMidjourney(c *gin.Context) {
  61. relayMode := relayconstant.RelayModeUnknown
  62. if strings.HasPrefix(c.Request.URL.Path, "/mj/submit/action") {
  63. // midjourney plus
  64. relayMode = relayconstant.RelayModeMidjourneyAction
  65. } else if strings.HasPrefix(c.Request.URL.Path, "/mj/submit/modal") {
  66. // midjourney plus
  67. relayMode = relayconstant.RelayModeMidjourneyModal
  68. } else if strings.HasPrefix(c.Request.URL.Path, "/mj/submit/imagine") {
  69. relayMode = relayconstant.RelayModeMidjourneyImagine
  70. } else if strings.HasPrefix(c.Request.URL.Path, "/mj/submit/blend") {
  71. relayMode = relayconstant.RelayModeMidjourneyBlend
  72. } else if strings.HasPrefix(c.Request.URL.Path, "/mj/submit/describe") {
  73. relayMode = relayconstant.RelayModeMidjourneyDescribe
  74. } else if strings.HasPrefix(c.Request.URL.Path, "/mj/notify") {
  75. relayMode = relayconstant.RelayModeMidjourneyNotify
  76. } else if strings.HasPrefix(c.Request.URL.Path, "/mj/submit/change") {
  77. relayMode = relayconstant.RelayModeMidjourneyChange
  78. } else if strings.HasPrefix(c.Request.URL.Path, "/mj/submit/simple-change") {
  79. relayMode = relayconstant.RelayModeMidjourneyChange
  80. } else if strings.HasSuffix(c.Request.URL.Path, "/fetch") {
  81. relayMode = relayconstant.RelayModeMidjourneyTaskFetch
  82. } else if strings.HasSuffix(c.Request.URL.Path, "/list-by-condition") {
  83. relayMode = relayconstant.RelayModeMidjourneyTaskFetchByCondition
  84. }
  85. var err *dto.MidjourneyResponse
  86. switch relayMode {
  87. case relayconstant.RelayModeMidjourneyNotify:
  88. err = relay.RelayMidjourneyNotify(c)
  89. case relayconstant.RelayModeMidjourneyTaskFetch, relayconstant.RelayModeMidjourneyTaskFetchByCondition:
  90. err = relay.RelayMidjourneyTask(c, relayMode)
  91. //case relayconstant.RelayModeMidjourneyModal:
  92. // err = relay.RelayMidjournneyModal(c)
  93. default:
  94. err = relay.RelayMidjourneySubmit(c, relayMode)
  95. }
  96. //err = relayMidjourneySubmit(c, relayMode)
  97. log.Println(err)
  98. if err != nil {
  99. if err.Code == 30 {
  100. err.Result = "当前分组负载已饱和,请稍后再试,或升级账户以提升服务质量。"
  101. }
  102. c.JSON(429, gin.H{
  103. "error": fmt.Sprintf("%s %s", err.Description, err.Result),
  104. "type": "upstream_error",
  105. "code": err.Code,
  106. })
  107. channelId := c.GetInt("channel_id")
  108. common.SysError(fmt.Sprintf("relay error (channel #%d): %s", channelId, fmt.Sprintf("%s %s", err.Description, err.Result)))
  109. }
  110. }
  111. func RelayNotImplemented(c *gin.Context) {
  112. err := dto.OpenAIError{
  113. Message: "API not implemented",
  114. Type: "new_api_error",
  115. Param: "",
  116. Code: "api_not_implemented",
  117. }
  118. c.JSON(http.StatusNotImplemented, gin.H{
  119. "error": err,
  120. })
  121. }
  122. func RelayNotFound(c *gin.Context) {
  123. err := dto.OpenAIError{
  124. Message: fmt.Sprintf("Invalid URL (%s %s)", c.Request.Method, c.Request.URL.Path),
  125. Type: "invalid_request_error",
  126. Param: "",
  127. Code: "",
  128. }
  129. c.JSON(http.StatusNotFound, gin.H{
  130. "error": err,
  131. })
  132. }