relay.go 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415
  1. package controller
  2. import (
  3. "encoding/json"
  4. "fmt"
  5. "log"
  6. "net/http"
  7. "one-api/common"
  8. "strconv"
  9. "strings"
  10. "github.com/gin-gonic/gin"
  11. )
  12. type Message struct {
  13. Role string `json:"role"`
  14. Content json.RawMessage `json:"content"`
  15. Name *string `json:"name,omitempty"`
  16. }
  17. type MediaMessage struct {
  18. Type string `json:"type"`
  19. Text string `json:"text"`
  20. ImageUrl any `json:"image_url,omitempty"`
  21. }
  22. type MessageImageUrl struct {
  23. Url string `json:"url"`
  24. Detail string `json:"detail"`
  25. }
  26. const (
  27. ContentTypeText = "text"
  28. ContentTypeImageURL = "image_url"
  29. )
  30. func (m Message) ParseContent() []MediaMessage {
  31. var contentList []MediaMessage
  32. var stringContent string
  33. if err := json.Unmarshal(m.Content, &stringContent); err == nil {
  34. contentList = append(contentList, MediaMessage{
  35. Type: ContentTypeText,
  36. Text: stringContent,
  37. })
  38. return contentList
  39. }
  40. var arrayContent []json.RawMessage
  41. if err := json.Unmarshal(m.Content, &arrayContent); err == nil {
  42. for _, contentItem := range arrayContent {
  43. var contentMap map[string]any
  44. if err := json.Unmarshal(contentItem, &contentMap); err != nil {
  45. continue
  46. }
  47. switch contentMap["type"] {
  48. case ContentTypeText:
  49. if subStr, ok := contentMap["text"].(string); ok {
  50. contentList = append(contentList, MediaMessage{
  51. Type: ContentTypeText,
  52. Text: subStr,
  53. })
  54. }
  55. case ContentTypeImageURL:
  56. if subObj, ok := contentMap["image_url"].(map[string]any); ok {
  57. detail, ok := subObj["detail"]
  58. if ok {
  59. subObj["detail"] = detail.(string)
  60. } else {
  61. subObj["detail"] = "auto"
  62. }
  63. contentList = append(contentList, MediaMessage{
  64. Type: ContentTypeImageURL,
  65. ImageUrl: MessageImageUrl{
  66. Url: subObj["url"].(string),
  67. Detail: subObj["detail"].(string),
  68. },
  69. })
  70. }
  71. }
  72. }
  73. return contentList
  74. }
  75. return nil
  76. }
  77. const (
  78. RelayModeUnknown = iota
  79. RelayModeChatCompletions
  80. RelayModeCompletions
  81. RelayModeEmbeddings
  82. RelayModeModerations
  83. RelayModeImagesGenerations
  84. RelayModeEdits
  85. RelayModeMidjourneyImagine
  86. RelayModeMidjourneyDescribe
  87. RelayModeMidjourneyBlend
  88. RelayModeMidjourneyChange
  89. RelayModeMidjourneySimpleChange
  90. RelayModeMidjourneyNotify
  91. RelayModeMidjourneyTaskFetch
  92. RelayModeMidjourneyTaskFetchByCondition
  93. RelayModeAudio
  94. )
  95. // https://platform.openai.com/docs/api-reference/chat
  96. type GeneralOpenAIRequest struct {
  97. Model string `json:"model,omitempty"`
  98. Messages []Message `json:"messages,omitempty"`
  99. Prompt any `json:"prompt,omitempty"`
  100. Stream bool `json:"stream,omitempty"`
  101. MaxTokens uint `json:"max_tokens,omitempty"`
  102. Temperature float64 `json:"temperature,omitempty"`
  103. TopP float64 `json:"top_p,omitempty"`
  104. N int `json:"n,omitempty"`
  105. Input any `json:"input,omitempty"`
  106. Instruction string `json:"instruction,omitempty"`
  107. Size string `json:"size,omitempty"`
  108. Functions any `json:"functions,omitempty"`
  109. }
  110. func (r GeneralOpenAIRequest) ParseInput() []string {
  111. if r.Input == nil {
  112. return nil
  113. }
  114. var input []string
  115. switch r.Input.(type) {
  116. case string:
  117. input = []string{r.Input.(string)}
  118. case []any:
  119. input = make([]string, 0, len(r.Input.([]any)))
  120. for _, item := range r.Input.([]any) {
  121. if str, ok := item.(string); ok {
  122. input = append(input, str)
  123. }
  124. }
  125. }
  126. return input
  127. }
  128. type AudioRequest struct {
  129. Model string `json:"model"`
  130. Voice string `json:"voice"`
  131. Input string `json:"input"`
  132. }
  133. type ChatRequest struct {
  134. Model string `json:"model"`
  135. Messages []Message `json:"messages"`
  136. MaxTokens uint `json:"max_tokens"`
  137. }
  138. type TextRequest struct {
  139. Model string `json:"model"`
  140. Messages []Message `json:"messages"`
  141. Prompt string `json:"prompt"`
  142. MaxTokens uint `json:"max_tokens"`
  143. //Stream bool `json:"stream"`
  144. }
  145. type ImageRequest struct {
  146. Model string `json:"model"`
  147. Prompt string `json:"prompt"`
  148. N int `json:"n"`
  149. Size string `json:"size"`
  150. Quality string `json:"quality,omitempty"`
  151. ResponseFormat string `json:"response_format,omitempty"`
  152. Style string `json:"style,omitempty"`
  153. }
  154. type AudioResponse struct {
  155. Text string `json:"text,omitempty"`
  156. }
  157. type Usage struct {
  158. PromptTokens int `json:"prompt_tokens"`
  159. CompletionTokens int `json:"completion_tokens"`
  160. TotalTokens int `json:"total_tokens"`
  161. }
  162. type OpenAIError struct {
  163. Message string `json:"message"`
  164. Type string `json:"type"`
  165. Param string `json:"param"`
  166. Code any `json:"code"`
  167. }
  168. type OpenAIErrorWithStatusCode struct {
  169. OpenAIError
  170. StatusCode int `json:"status_code"`
  171. }
  172. type TextResponse struct {
  173. Choices []OpenAITextResponseChoice `json:"choices"`
  174. Usage `json:"usage"`
  175. Error OpenAIError `json:"error"`
  176. }
  177. type OpenAITextResponseChoice struct {
  178. Index int `json:"index"`
  179. Message `json:"message"`
  180. FinishReason string `json:"finish_reason"`
  181. }
  182. type OpenAITextResponse struct {
  183. Id string `json:"id"`
  184. Object string `json:"object"`
  185. Created int64 `json:"created"`
  186. Choices []OpenAITextResponseChoice `json:"choices"`
  187. Usage `json:"usage"`
  188. }
  189. type OpenAIEmbeddingResponseItem struct {
  190. Object string `json:"object"`
  191. Index int `json:"index"`
  192. Embedding []float64 `json:"embedding"`
  193. }
  194. type OpenAIEmbeddingResponse struct {
  195. Object string `json:"object"`
  196. Data []OpenAIEmbeddingResponseItem `json:"data"`
  197. Model string `json:"model"`
  198. Usage `json:"usage"`
  199. }
  200. type ImageResponse struct {
  201. Created int `json:"created"`
  202. Data []struct {
  203. Url string `json:"url"`
  204. B64Json string `json:"b64_json"`
  205. }
  206. }
  207. type ChatCompletionsStreamResponseChoice struct {
  208. Delta struct {
  209. Content string `json:"content"`
  210. } `json:"delta"`
  211. FinishReason *string `json:"finish_reason"`
  212. }
  213. type ChatCompletionsStreamResponse struct {
  214. Id string `json:"id"`
  215. Object string `json:"object"`
  216. Created int64 `json:"created"`
  217. Model string `json:"model"`
  218. Choices []ChatCompletionsStreamResponseChoice `json:"choices"`
  219. }
  220. type ChatCompletionsStreamResponseSimple struct {
  221. Choices []ChatCompletionsStreamResponseChoice `json:"choices"`
  222. }
  223. type CompletionsStreamResponse struct {
  224. Choices []struct {
  225. Text string `json:"text"`
  226. FinishReason string `json:"finish_reason"`
  227. } `json:"choices"`
  228. }
  229. type MidjourneyRequest struct {
  230. Prompt string `json:"prompt"`
  231. NotifyHook string `json:"notifyHook"`
  232. Action string `json:"action"`
  233. Index int `json:"index"`
  234. State string `json:"state"`
  235. TaskId string `json:"taskId"`
  236. Base64Array []string `json:"base64Array"`
  237. Content string `json:"content"`
  238. }
  239. type MidjourneyResponse struct {
  240. Code int `json:"code"`
  241. Description string `json:"description"`
  242. Properties interface{} `json:"properties"`
  243. Result string `json:"result"`
  244. }
  245. func Relay(c *gin.Context) {
  246. relayMode := RelayModeUnknown
  247. if strings.HasPrefix(c.Request.URL.Path, "/v1/chat/completions") {
  248. relayMode = RelayModeChatCompletions
  249. } else if strings.HasPrefix(c.Request.URL.Path, "/v1/completions") {
  250. relayMode = RelayModeCompletions
  251. } else if strings.HasPrefix(c.Request.URL.Path, "/v1/embeddings") {
  252. relayMode = RelayModeEmbeddings
  253. } else if strings.HasSuffix(c.Request.URL.Path, "embeddings") {
  254. relayMode = RelayModeEmbeddings
  255. } else if strings.HasPrefix(c.Request.URL.Path, "/v1/moderations") {
  256. relayMode = RelayModeModerations
  257. } else if strings.HasPrefix(c.Request.URL.Path, "/v1/images/generations") {
  258. relayMode = RelayModeImagesGenerations
  259. } else if strings.HasPrefix(c.Request.URL.Path, "/v1/edits") {
  260. relayMode = RelayModeEdits
  261. } else if strings.HasPrefix(c.Request.URL.Path, "/v1/audio") {
  262. relayMode = RelayModeAudio
  263. }
  264. var err *OpenAIErrorWithStatusCode
  265. switch relayMode {
  266. case RelayModeImagesGenerations:
  267. err = relayImageHelper(c, relayMode)
  268. case RelayModeAudio:
  269. err = relayAudioHelper(c, relayMode)
  270. default:
  271. err = relayTextHelper(c, relayMode)
  272. }
  273. if err != nil {
  274. requestId := c.GetString(common.RequestIdKey)
  275. retryTimesStr := c.Query("retry")
  276. retryTimes, _ := strconv.Atoi(retryTimesStr)
  277. if retryTimesStr == "" {
  278. retryTimes = common.RetryTimes
  279. }
  280. if retryTimes > 0 {
  281. c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s?retry=%d", c.Request.URL.Path, retryTimes-1))
  282. } else {
  283. if err.StatusCode == http.StatusTooManyRequests {
  284. //err.OpenAIError.Message = "当前分组上游负载已饱和,请稍后再试"
  285. }
  286. err.OpenAIError.Message = common.MessageWithRequestId(err.OpenAIError.Message, requestId)
  287. c.JSON(err.StatusCode, gin.H{
  288. "error": err.OpenAIError,
  289. })
  290. }
  291. channelId := c.GetInt("channel_id")
  292. autoBan := c.GetBool("auto_ban")
  293. common.LogError(c.Request.Context(), fmt.Sprintf("relay error (channel #%d): %s", channelId, err.Message))
  294. // https://platform.openai.com/docs/guides/error-codes/api-errors
  295. if shouldDisableChannel(&err.OpenAIError, err.StatusCode) && autoBan {
  296. channelId := c.GetInt("channel_id")
  297. channelName := c.GetString("channel_name")
  298. disableChannel(channelId, channelName, err.Message)
  299. }
  300. }
  301. }
  302. func RelayMidjourney(c *gin.Context) {
  303. relayMode := RelayModeUnknown
  304. if strings.HasPrefix(c.Request.URL.Path, "/mj/submit/imagine") {
  305. relayMode = RelayModeMidjourneyImagine
  306. } else if strings.HasPrefix(c.Request.URL.Path, "/mj/submit/blend") {
  307. relayMode = RelayModeMidjourneyBlend
  308. } else if strings.HasPrefix(c.Request.URL.Path, "/mj/submit/describe") {
  309. relayMode = RelayModeMidjourneyDescribe
  310. } else if strings.HasPrefix(c.Request.URL.Path, "/mj/notify") {
  311. relayMode = RelayModeMidjourneyNotify
  312. } else if strings.HasPrefix(c.Request.URL.Path, "/mj/submit/change") {
  313. relayMode = RelayModeMidjourneyChange
  314. } else if strings.HasPrefix(c.Request.URL.Path, "/mj/submit/simple-change") {
  315. relayMode = RelayModeMidjourneyChange
  316. } else if strings.HasSuffix(c.Request.URL.Path, "/fetch") {
  317. relayMode = RelayModeMidjourneyTaskFetch
  318. } else if strings.HasSuffix(c.Request.URL.Path, "/list-by-condition") {
  319. relayMode = RelayModeMidjourneyTaskFetchByCondition
  320. }
  321. var err *MidjourneyResponse
  322. switch relayMode {
  323. case RelayModeMidjourneyNotify:
  324. err = relayMidjourneyNotify(c)
  325. case RelayModeMidjourneyTaskFetch, RelayModeMidjourneyTaskFetchByCondition:
  326. err = relayMidjourneyTask(c, relayMode)
  327. default:
  328. err = relayMidjourneySubmit(c, relayMode)
  329. }
  330. //err = relayMidjourneySubmit(c, relayMode)
  331. log.Println(err)
  332. if err != nil {
  333. retryTimesStr := c.Query("retry")
  334. retryTimes, _ := strconv.Atoi(retryTimesStr)
  335. if retryTimesStr == "" {
  336. retryTimes = common.RetryTimes
  337. }
  338. if retryTimes > 0 {
  339. c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s?retry=%d", c.Request.URL.Path, retryTimes-1))
  340. } else {
  341. if err.Code == 30 {
  342. err.Result = "当前分组负载已饱和,请稍后再试,或升级账户以提升服务质量。"
  343. }
  344. c.JSON(400, gin.H{
  345. "error": err.Description + " " + err.Result,
  346. })
  347. }
  348. channelId := c.GetInt("channel_id")
  349. common.SysError(fmt.Sprintf("relay error (channel #%d): %s", channelId, err.Result))
  350. //if shouldDisableChannel(&err.OpenAIError) {
  351. // channelId := c.GetInt("channel_id")
  352. // channelName := c.GetString("channel_name")
  353. // disableChannel(channelId, channelName, err.Result)
  354. //};''''''''''''''''''''''''''''''''
  355. }
  356. }
  357. func RelayNotImplemented(c *gin.Context) {
  358. err := OpenAIError{
  359. Message: "API not implemented",
  360. Type: "new_api_error",
  361. Param: "",
  362. Code: "api_not_implemented",
  363. }
  364. c.JSON(http.StatusNotImplemented, gin.H{
  365. "error": err,
  366. })
  367. }
  368. func RelayNotFound(c *gin.Context) {
  369. err := OpenAIError{
  370. Message: fmt.Sprintf("Invalid URL (%s %s)", c.Request.Method, c.Request.URL.Path),
  371. Type: "invalid_request_error",
  372. Param: "",
  373. Code: "",
  374. }
  375. c.JSON(http.StatusNotFound, gin.H{
  376. "error": err,
  377. })
  378. }