relay.go 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438
  1. package controller
  2. import (
  3. "encoding/json"
  4. "fmt"
  5. "log"
  6. "net/http"
  7. "one-api/common"
  8. "strconv"
  9. "strings"
  10. "github.com/gin-gonic/gin"
  11. )
  12. type Message struct {
  13. Role string `json:"role"`
  14. Content json.RawMessage `json:"content"`
  15. Name *string `json:"name,omitempty"`
  16. ToolCalls any `json:"tool_calls,omitempty"`
  17. ToolCallId string `json:"tool_call_id,omitempty"`
  18. }
  19. type MediaMessage struct {
  20. Type string `json:"type"`
  21. Text string `json:"text"`
  22. ImageUrl any `json:"image_url,omitempty"`
  23. }
  24. type MessageImageUrl struct {
  25. Url string `json:"url"`
  26. Detail string `json:"detail"`
  27. }
  28. const (
  29. ContentTypeText = "text"
  30. ContentTypeImageURL = "image_url"
  31. )
  32. func (m Message) ParseContent() []MediaMessage {
  33. var contentList []MediaMessage
  34. var stringContent string
  35. if err := json.Unmarshal(m.Content, &stringContent); err == nil {
  36. contentList = append(contentList, MediaMessage{
  37. Type: ContentTypeText,
  38. Text: stringContent,
  39. })
  40. return contentList
  41. }
  42. var arrayContent []json.RawMessage
  43. if err := json.Unmarshal(m.Content, &arrayContent); err == nil {
  44. for _, contentItem := range arrayContent {
  45. var contentMap map[string]any
  46. if err := json.Unmarshal(contentItem, &contentMap); err != nil {
  47. continue
  48. }
  49. switch contentMap["type"] {
  50. case ContentTypeText:
  51. if subStr, ok := contentMap["text"].(string); ok {
  52. contentList = append(contentList, MediaMessage{
  53. Type: ContentTypeText,
  54. Text: subStr,
  55. })
  56. }
  57. case ContentTypeImageURL:
  58. if subObj, ok := contentMap["image_url"].(map[string]any); ok {
  59. detail, ok := subObj["detail"]
  60. if ok {
  61. subObj["detail"] = detail.(string)
  62. } else {
  63. subObj["detail"] = "auto"
  64. }
  65. contentList = append(contentList, MediaMessage{
  66. Type: ContentTypeImageURL,
  67. ImageUrl: MessageImageUrl{
  68. Url: subObj["url"].(string),
  69. Detail: subObj["detail"].(string),
  70. },
  71. })
  72. }
  73. }
  74. }
  75. return contentList
  76. }
  77. return nil
  78. }
  79. const (
  80. RelayModeUnknown = iota
  81. RelayModeChatCompletions
  82. RelayModeCompletions
  83. RelayModeEmbeddings
  84. RelayModeModerations
  85. RelayModeImagesGenerations
  86. RelayModeEdits
  87. RelayModeMidjourneyImagine
  88. RelayModeMidjourneyDescribe
  89. RelayModeMidjourneyBlend
  90. RelayModeMidjourneyChange
  91. RelayModeMidjourneySimpleChange
  92. RelayModeMidjourneyNotify
  93. RelayModeMidjourneyTaskFetch
  94. RelayModeMidjourneyTaskFetchByCondition
  95. RelayModeAudioSpeech
  96. RelayModeAudioTranscription
  97. RelayModeAudioTranslation
  98. )
  99. // https://platform.openai.com/docs/api-reference/chat
  100. type ResponseFormat struct {
  101. Type string `json:"type,omitempty"`
  102. }
  103. type GeneralOpenAIRequest struct {
  104. Model string `json:"model,omitempty"`
  105. Messages []Message `json:"messages,omitempty"`
  106. Prompt any `json:"prompt,omitempty"`
  107. Stream bool `json:"stream,omitempty"`
  108. MaxTokens uint `json:"max_tokens,omitempty"`
  109. Temperature float64 `json:"temperature,omitempty"`
  110. TopP float64 `json:"top_p,omitempty"`
  111. N int `json:"n,omitempty"`
  112. Input any `json:"input,omitempty"`
  113. Instruction string `json:"instruction,omitempty"`
  114. Size string `json:"size,omitempty"`
  115. Functions any `json:"functions,omitempty"`
  116. FrequencyPenalty float64 `json:"frequency_penalty,omitempty"`
  117. PresencePenalty float64 `json:"presence_penalty,omitempty"`
  118. ResponseFormat *ResponseFormat `json:"response_format,omitempty"`
  119. Seed float64 `json:"seed,omitempty"`
  120. Tools any `json:"tools,omitempty"`
  121. ToolChoice any `json:"tool_choice,omitempty"`
  122. User string `json:"user,omitempty"`
  123. }
  124. func (r GeneralOpenAIRequest) ParseInput() []string {
  125. if r.Input == nil {
  126. return nil
  127. }
  128. var input []string
  129. switch r.Input.(type) {
  130. case string:
  131. input = []string{r.Input.(string)}
  132. case []any:
  133. input = make([]string, 0, len(r.Input.([]any)))
  134. for _, item := range r.Input.([]any) {
  135. if str, ok := item.(string); ok {
  136. input = append(input, str)
  137. }
  138. }
  139. }
  140. return input
  141. }
  142. type AudioRequest struct {
  143. Model string `json:"model"`
  144. Voice string `json:"voice"`
  145. Input string `json:"input"`
  146. }
  147. type ChatRequest struct {
  148. Model string `json:"model"`
  149. Messages []Message `json:"messages"`
  150. MaxTokens uint `json:"max_tokens"`
  151. }
  152. type TextRequest struct {
  153. Model string `json:"model"`
  154. Messages []Message `json:"messages"`
  155. Prompt string `json:"prompt"`
  156. MaxTokens uint `json:"max_tokens"`
  157. //Stream bool `json:"stream"`
  158. }
  159. type ImageRequest struct {
  160. Model string `json:"model"`
  161. Prompt string `json:"prompt"`
  162. N int `json:"n"`
  163. Size string `json:"size"`
  164. Quality string `json:"quality,omitempty"`
  165. ResponseFormat string `json:"response_format,omitempty"`
  166. Style string `json:"style,omitempty"`
  167. }
  168. type AudioResponse struct {
  169. Text string `json:"text,omitempty"`
  170. }
  171. type Usage struct {
  172. PromptTokens int `json:"prompt_tokens"`
  173. CompletionTokens int `json:"completion_tokens"`
  174. TotalTokens int `json:"total_tokens"`
  175. }
  176. type OpenAIError struct {
  177. Message string `json:"message"`
  178. Type string `json:"type"`
  179. Param string `json:"param"`
  180. Code any `json:"code"`
  181. }
  182. type OpenAIErrorWithStatusCode struct {
  183. OpenAIError
  184. StatusCode int `json:"status_code"`
  185. }
  186. type TextResponse struct {
  187. Choices []OpenAITextResponseChoice `json:"choices"`
  188. Usage `json:"usage"`
  189. Error OpenAIError `json:"error"`
  190. }
  191. type OpenAITextResponseChoice struct {
  192. Index int `json:"index"`
  193. Message `json:"message"`
  194. FinishReason string `json:"finish_reason"`
  195. }
  196. type OpenAITextResponse struct {
  197. Id string `json:"id"`
  198. Object string `json:"object"`
  199. Created int64 `json:"created"`
  200. Choices []OpenAITextResponseChoice `json:"choices"`
  201. Usage `json:"usage"`
  202. }
  203. type OpenAIEmbeddingResponseItem struct {
  204. Object string `json:"object"`
  205. Index int `json:"index"`
  206. Embedding []float64 `json:"embedding"`
  207. }
  208. type OpenAIEmbeddingResponse struct {
  209. Object string `json:"object"`
  210. Data []OpenAIEmbeddingResponseItem `json:"data"`
  211. Model string `json:"model"`
  212. Usage `json:"usage"`
  213. }
  214. type ImageResponse struct {
  215. Created int `json:"created"`
  216. Data []struct {
  217. Url string `json:"url"`
  218. B64Json string `json:"b64_json"`
  219. }
  220. }
  221. type ChatCompletionsStreamResponseChoice struct {
  222. Delta struct {
  223. Content string `json:"content"`
  224. } `json:"delta"`
  225. FinishReason *string `json:"finish_reason,omitempty"`
  226. }
  227. type ChatCompletionsStreamResponse struct {
  228. Id string `json:"id"`
  229. Object string `json:"object"`
  230. Created int64 `json:"created"`
  231. Model string `json:"model"`
  232. Choices []ChatCompletionsStreamResponseChoice `json:"choices"`
  233. }
  234. type ChatCompletionsStreamResponseSimple struct {
  235. Choices []ChatCompletionsStreamResponseChoice `json:"choices"`
  236. }
  237. type CompletionsStreamResponse struct {
  238. Choices []struct {
  239. Text string `json:"text"`
  240. FinishReason string `json:"finish_reason"`
  241. } `json:"choices"`
  242. }
  243. type MidjourneyRequest struct {
  244. Prompt string `json:"prompt"`
  245. NotifyHook string `json:"notifyHook"`
  246. Action string `json:"action"`
  247. Index int `json:"index"`
  248. State string `json:"state"`
  249. TaskId string `json:"taskId"`
  250. Base64Array []string `json:"base64Array"`
  251. Content string `json:"content"`
  252. }
  253. type MidjourneyResponse struct {
  254. Code int `json:"code"`
  255. Description string `json:"description"`
  256. Properties interface{} `json:"properties"`
  257. Result string `json:"result"`
  258. }
  259. func Relay(c *gin.Context) {
  260. relayMode := RelayModeUnknown
  261. if strings.HasPrefix(c.Request.URL.Path, "/v1/chat/completions") {
  262. relayMode = RelayModeChatCompletions
  263. } else if strings.HasPrefix(c.Request.URL.Path, "/v1/completions") {
  264. relayMode = RelayModeCompletions
  265. } else if strings.HasPrefix(c.Request.URL.Path, "/v1/embeddings") {
  266. relayMode = RelayModeEmbeddings
  267. } else if strings.HasSuffix(c.Request.URL.Path, "embeddings") {
  268. relayMode = RelayModeEmbeddings
  269. } else if strings.HasPrefix(c.Request.URL.Path, "/v1/moderations") {
  270. relayMode = RelayModeModerations
  271. } else if strings.HasPrefix(c.Request.URL.Path, "/v1/images/generations") {
  272. relayMode = RelayModeImagesGenerations
  273. } else if strings.HasPrefix(c.Request.URL.Path, "/v1/edits") {
  274. relayMode = RelayModeEdits
  275. } else if strings.HasPrefix(c.Request.URL.Path, "/v1/audio/speech") {
  276. relayMode = RelayModeAudioSpeech
  277. } else if strings.HasPrefix(c.Request.URL.Path, "/v1/audio/transcriptions") {
  278. relayMode = RelayModeAudioTranscription
  279. } else if strings.HasPrefix(c.Request.URL.Path, "/v1/audio/translations") {
  280. relayMode = RelayModeAudioTranslation
  281. }
  282. var err *OpenAIErrorWithStatusCode
  283. switch relayMode {
  284. case RelayModeImagesGenerations:
  285. err = relayImageHelper(c, relayMode)
  286. case RelayModeAudioSpeech:
  287. fallthrough
  288. case RelayModeAudioTranslation:
  289. fallthrough
  290. case RelayModeAudioTranscription:
  291. err = relayAudioHelper(c, relayMode)
  292. default:
  293. err = relayTextHelper(c, relayMode)
  294. }
  295. if err != nil {
  296. requestId := c.GetString(common.RequestIdKey)
  297. retryTimesStr := c.Query("retry")
  298. retryTimes, _ := strconv.Atoi(retryTimesStr)
  299. if retryTimesStr == "" {
  300. retryTimes = common.RetryTimes
  301. }
  302. if retryTimes > 0 {
  303. c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s?retry=%d&error=%s", c.Request.URL.Path, retryTimes-1, err.Message))
  304. } else {
  305. if err.StatusCode == http.StatusTooManyRequests {
  306. //err.OpenAIError.Message = "当前分组上游负载已饱和,请稍后再试"
  307. }
  308. err.OpenAIError.Message = common.MessageWithRequestId(err.OpenAIError.Message, requestId)
  309. c.JSON(err.StatusCode, gin.H{
  310. "error": err.OpenAIError,
  311. })
  312. }
  313. channelId := c.GetInt("channel_id")
  314. autoBan := c.GetBool("auto_ban")
  315. common.LogError(c.Request.Context(), fmt.Sprintf("relay error (channel #%d): %s", channelId, err.Message))
  316. // https://platform.openai.com/docs/guides/error-codes/api-errors
  317. if shouldDisableChannel(&err.OpenAIError, err.StatusCode) && autoBan {
  318. channelId := c.GetInt("channel_id")
  319. channelName := c.GetString("channel_name")
  320. disableChannel(channelId, channelName, err.Message)
  321. }
  322. }
  323. }
  324. func RelayMidjourney(c *gin.Context) {
  325. relayMode := RelayModeUnknown
  326. if strings.HasPrefix(c.Request.URL.Path, "/mj/submit/imagine") {
  327. relayMode = RelayModeMidjourneyImagine
  328. } else if strings.HasPrefix(c.Request.URL.Path, "/mj/submit/blend") {
  329. relayMode = RelayModeMidjourneyBlend
  330. } else if strings.HasPrefix(c.Request.URL.Path, "/mj/submit/describe") {
  331. relayMode = RelayModeMidjourneyDescribe
  332. } else if strings.HasPrefix(c.Request.URL.Path, "/mj/notify") {
  333. relayMode = RelayModeMidjourneyNotify
  334. } else if strings.HasPrefix(c.Request.URL.Path, "/mj/submit/change") {
  335. relayMode = RelayModeMidjourneyChange
  336. } else if strings.HasPrefix(c.Request.URL.Path, "/mj/submit/simple-change") {
  337. relayMode = RelayModeMidjourneyChange
  338. } else if strings.HasSuffix(c.Request.URL.Path, "/fetch") {
  339. relayMode = RelayModeMidjourneyTaskFetch
  340. } else if strings.HasSuffix(c.Request.URL.Path, "/list-by-condition") {
  341. relayMode = RelayModeMidjourneyTaskFetchByCondition
  342. }
  343. var err *MidjourneyResponse
  344. switch relayMode {
  345. case RelayModeMidjourneyNotify:
  346. err = relayMidjourneyNotify(c)
  347. case RelayModeMidjourneyTaskFetch, RelayModeMidjourneyTaskFetchByCondition:
  348. err = relayMidjourneyTask(c, relayMode)
  349. default:
  350. err = relayMidjourneySubmit(c, relayMode)
  351. }
  352. //err = relayMidjourneySubmit(c, relayMode)
  353. log.Println(err)
  354. if err != nil {
  355. retryTimesStr := c.Query("retry")
  356. retryTimes, _ := strconv.Atoi(retryTimesStr)
  357. if retryTimesStr == "" {
  358. retryTimes = common.RetryTimes
  359. }
  360. if retryTimes > 0 {
  361. c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s?retry=%d", c.Request.URL.Path, retryTimes-1))
  362. } else {
  363. if err.Code == 30 {
  364. err.Result = "当前分组负载已饱和,请稍后再试,或升级账户以提升服务质量。"
  365. }
  366. c.JSON(400, gin.H{
  367. "error": err.Description + " " + err.Result,
  368. })
  369. }
  370. channelId := c.GetInt("channel_id")
  371. common.SysError(fmt.Sprintf("relay error (channel #%d): %s", channelId, err.Result))
  372. //if shouldDisableChannel(&err.OpenAIError) {
  373. // channelId := c.GetInt("channel_id")
  374. // channelName := c.GetString("channel_name")
  375. // disableChannel(channelId, channelName, err.Result)
  376. //};''''''''''''''''''''''''''''''''
  377. }
  378. }
  379. func RelayNotImplemented(c *gin.Context) {
  380. err := OpenAIError{
  381. Message: "API not implemented",
  382. Type: "new_api_error",
  383. Param: "",
  384. Code: "api_not_implemented",
  385. }
  386. c.JSON(http.StatusNotImplemented, gin.H{
  387. "error": err,
  388. })
  389. }
  390. func RelayNotFound(c *gin.Context) {
  391. err := OpenAIError{
  392. Message: fmt.Sprintf("Invalid URL (%s %s)", c.Request.Method, c.Request.URL.Path),
  393. Type: "invalid_request_error",
  394. Param: "",
  395. Code: "",
  396. }
  397. c.JSON(http.StatusNotFound, gin.H{
  398. "error": err,
  399. })
  400. }