relay.go 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449
  1. package controller
  2. import (
  3. "encoding/json"
  4. "fmt"
  5. "log"
  6. "net/http"
  7. "one-api/common"
  8. "strconv"
  9. "strings"
  10. "github.com/gin-gonic/gin"
  11. )
  12. type Message struct {
  13. Role string `json:"role"`
  14. Content json.RawMessage `json:"content"`
  15. Name *string `json:"name,omitempty"`
  16. ToolCalls any `json:"tool_calls,omitempty"`
  17. ToolCallId string `json:"tool_call_id,omitempty"`
  18. }
  19. type MediaMessage struct {
  20. Type string `json:"type"`
  21. Text string `json:"text"`
  22. ImageUrl any `json:"image_url,omitempty"`
  23. }
  24. type MessageImageUrl struct {
  25. Url string `json:"url"`
  26. Detail string `json:"detail"`
  27. }
  28. const (
  29. ContentTypeText = "text"
  30. ContentTypeImageURL = "image_url"
  31. )
  32. func (m Message) StringContent() string {
  33. var stringContent string
  34. if err := json.Unmarshal(m.Content, &stringContent); err == nil {
  35. return stringContent
  36. }
  37. return string(m.Content)
  38. }
  39. func (m Message) ParseContent() []MediaMessage {
  40. var contentList []MediaMessage
  41. var stringContent string
  42. if err := json.Unmarshal(m.Content, &stringContent); err == nil {
  43. contentList = append(contentList, MediaMessage{
  44. Type: ContentTypeText,
  45. Text: stringContent,
  46. })
  47. return contentList
  48. }
  49. var arrayContent []json.RawMessage
  50. if err := json.Unmarshal(m.Content, &arrayContent); err == nil {
  51. for _, contentItem := range arrayContent {
  52. var contentMap map[string]any
  53. if err := json.Unmarshal(contentItem, &contentMap); err != nil {
  54. continue
  55. }
  56. switch contentMap["type"] {
  57. case ContentTypeText:
  58. if subStr, ok := contentMap["text"].(string); ok {
  59. contentList = append(contentList, MediaMessage{
  60. Type: ContentTypeText,
  61. Text: subStr,
  62. })
  63. }
  64. case ContentTypeImageURL:
  65. if subObj, ok := contentMap["image_url"].(map[string]any); ok {
  66. detail, ok := subObj["detail"]
  67. if ok {
  68. subObj["detail"] = detail.(string)
  69. } else {
  70. subObj["detail"] = "auto"
  71. }
  72. contentList = append(contentList, MediaMessage{
  73. Type: ContentTypeImageURL,
  74. ImageUrl: MessageImageUrl{
  75. Url: subObj["url"].(string),
  76. Detail: subObj["detail"].(string),
  77. },
  78. })
  79. }
  80. }
  81. }
  82. return contentList
  83. }
  84. return nil
  85. }
  86. const (
  87. RelayModeUnknown = iota
  88. RelayModeChatCompletions
  89. RelayModeCompletions
  90. RelayModeEmbeddings
  91. RelayModeModerations
  92. RelayModeImagesGenerations
  93. RelayModeEdits
  94. RelayModeMidjourneyImagine
  95. RelayModeMidjourneyDescribe
  96. RelayModeMidjourneyBlend
  97. RelayModeMidjourneyChange
  98. RelayModeMidjourneySimpleChange
  99. RelayModeMidjourneyNotify
  100. RelayModeMidjourneyTaskFetch
  101. RelayModeMidjourneyTaskFetchByCondition
  102. RelayModeAudioSpeech
  103. RelayModeAudioTranscription
  104. RelayModeAudioTranslation
  105. )
  106. // https://platform.openai.com/docs/api-reference/chat
  107. type ResponseFormat struct {
  108. Type string `json:"type,omitempty"`
  109. }
  110. type GeneralOpenAIRequest struct {
  111. Model string `json:"model,omitempty"`
  112. Messages []Message `json:"messages,omitempty"`
  113. Prompt any `json:"prompt,omitempty"`
  114. Stream bool `json:"stream,omitempty"`
  115. MaxTokens uint `json:"max_tokens,omitempty"`
  116. Temperature float64 `json:"temperature,omitempty"`
  117. TopP float64 `json:"top_p,omitempty"`
  118. N int `json:"n,omitempty"`
  119. Input any `json:"input,omitempty"`
  120. Instruction string `json:"instruction,omitempty"`
  121. Size string `json:"size,omitempty"`
  122. Functions any `json:"functions,omitempty"`
  123. FrequencyPenalty float64 `json:"frequency_penalty,omitempty"`
  124. PresencePenalty float64 `json:"presence_penalty,omitempty"`
  125. ResponseFormat *ResponseFormat `json:"response_format,omitempty"`
  126. Seed float64 `json:"seed,omitempty"`
  127. Tools any `json:"tools,omitempty"`
  128. ToolChoice any `json:"tool_choice,omitempty"`
  129. User string `json:"user,omitempty"`
  130. LogProbs bool `json:"logprobs,omitempty"`
  131. TopLogProbs int `json:"top_logprobs,omitempty"`
  132. }
  133. func (r GeneralOpenAIRequest) ParseInput() []string {
  134. if r.Input == nil {
  135. return nil
  136. }
  137. var input []string
  138. switch r.Input.(type) {
  139. case string:
  140. input = []string{r.Input.(string)}
  141. case []any:
  142. input = make([]string, 0, len(r.Input.([]any)))
  143. for _, item := range r.Input.([]any) {
  144. if str, ok := item.(string); ok {
  145. input = append(input, str)
  146. }
  147. }
  148. }
  149. return input
  150. }
  151. type AudioRequest struct {
  152. Model string `json:"model"`
  153. Voice string `json:"voice"`
  154. Input string `json:"input"`
  155. }
  156. type ChatRequest struct {
  157. Model string `json:"model"`
  158. Messages []Message `json:"messages"`
  159. MaxTokens uint `json:"max_tokens"`
  160. }
  161. type TextRequest struct {
  162. Model string `json:"model"`
  163. Messages []Message `json:"messages"`
  164. Prompt string `json:"prompt"`
  165. MaxTokens uint `json:"max_tokens"`
  166. //Stream bool `json:"stream"`
  167. }
  168. type ImageRequest struct {
  169. Model string `json:"model"`
  170. Prompt string `json:"prompt"`
  171. N int `json:"n"`
  172. Size string `json:"size"`
  173. Quality string `json:"quality,omitempty"`
  174. ResponseFormat string `json:"response_format,omitempty"`
  175. Style string `json:"style,omitempty"`
  176. }
  177. type AudioResponse struct {
  178. Text string `json:"text,omitempty"`
  179. }
  180. type Usage struct {
  181. PromptTokens int `json:"prompt_tokens"`
  182. CompletionTokens int `json:"completion_tokens"`
  183. TotalTokens int `json:"total_tokens"`
  184. }
  185. type OpenAIError struct {
  186. Message string `json:"message"`
  187. Type string `json:"type"`
  188. Param string `json:"param"`
  189. Code any `json:"code"`
  190. }
  191. type OpenAIErrorWithStatusCode struct {
  192. OpenAIError
  193. StatusCode int `json:"status_code"`
  194. }
  195. type TextResponse struct {
  196. Choices []OpenAITextResponseChoice `json:"choices"`
  197. Usage `json:"usage"`
  198. Error OpenAIError `json:"error"`
  199. }
  200. type OpenAITextResponseChoice struct {
  201. Index int `json:"index"`
  202. Message `json:"message"`
  203. FinishReason string `json:"finish_reason"`
  204. }
  205. type OpenAITextResponse struct {
  206. Id string `json:"id"`
  207. Object string `json:"object"`
  208. Created int64 `json:"created"`
  209. Choices []OpenAITextResponseChoice `json:"choices"`
  210. Usage `json:"usage"`
  211. }
  212. type OpenAIEmbeddingResponseItem struct {
  213. Object string `json:"object"`
  214. Index int `json:"index"`
  215. Embedding []float64 `json:"embedding"`
  216. }
  217. type OpenAIEmbeddingResponse struct {
  218. Object string `json:"object"`
  219. Data []OpenAIEmbeddingResponseItem `json:"data"`
  220. Model string `json:"model"`
  221. Usage `json:"usage"`
  222. }
  223. type ImageResponse struct {
  224. Created int `json:"created"`
  225. Data []struct {
  226. Url string `json:"url"`
  227. B64Json string `json:"b64_json"`
  228. }
  229. }
  230. type ChatCompletionsStreamResponseChoice struct {
  231. Delta struct {
  232. Content string `json:"content"`
  233. } `json:"delta"`
  234. FinishReason *string `json:"finish_reason,omitempty"`
  235. }
  236. type ChatCompletionsStreamResponse struct {
  237. Id string `json:"id"`
  238. Object string `json:"object"`
  239. Created int64 `json:"created"`
  240. Model string `json:"model"`
  241. Choices []ChatCompletionsStreamResponseChoice `json:"choices"`
  242. }
  243. type ChatCompletionsStreamResponseSimple struct {
  244. Choices []ChatCompletionsStreamResponseChoice `json:"choices"`
  245. }
  246. type CompletionsStreamResponse struct {
  247. Choices []struct {
  248. Text string `json:"text"`
  249. FinishReason string `json:"finish_reason"`
  250. } `json:"choices"`
  251. }
  252. type MidjourneyRequest struct {
  253. Prompt string `json:"prompt"`
  254. NotifyHook string `json:"notifyHook"`
  255. Action string `json:"action"`
  256. Index int `json:"index"`
  257. State string `json:"state"`
  258. TaskId string `json:"taskId"`
  259. Base64Array []string `json:"base64Array"`
  260. Content string `json:"content"`
  261. }
  262. type MidjourneyResponse struct {
  263. Code int `json:"code"`
  264. Description string `json:"description"`
  265. Properties interface{} `json:"properties"`
  266. Result string `json:"result"`
  267. }
  268. func Relay(c *gin.Context) {
  269. relayMode := RelayModeUnknown
  270. if strings.HasPrefix(c.Request.URL.Path, "/v1/chat/completions") {
  271. relayMode = RelayModeChatCompletions
  272. } else if strings.HasPrefix(c.Request.URL.Path, "/v1/completions") {
  273. relayMode = RelayModeCompletions
  274. } else if strings.HasPrefix(c.Request.URL.Path, "/v1/embeddings") {
  275. relayMode = RelayModeEmbeddings
  276. } else if strings.HasSuffix(c.Request.URL.Path, "embeddings") {
  277. relayMode = RelayModeEmbeddings
  278. } else if strings.HasPrefix(c.Request.URL.Path, "/v1/moderations") {
  279. relayMode = RelayModeModerations
  280. } else if strings.HasPrefix(c.Request.URL.Path, "/v1/images/generations") {
  281. relayMode = RelayModeImagesGenerations
  282. } else if strings.HasPrefix(c.Request.URL.Path, "/v1/edits") {
  283. relayMode = RelayModeEdits
  284. } else if strings.HasPrefix(c.Request.URL.Path, "/v1/audio/speech") {
  285. relayMode = RelayModeAudioSpeech
  286. } else if strings.HasPrefix(c.Request.URL.Path, "/v1/audio/transcriptions") {
  287. relayMode = RelayModeAudioTranscription
  288. } else if strings.HasPrefix(c.Request.URL.Path, "/v1/audio/translations") {
  289. relayMode = RelayModeAudioTranslation
  290. }
  291. var err *OpenAIErrorWithStatusCode
  292. switch relayMode {
  293. case RelayModeImagesGenerations:
  294. err = relayImageHelper(c, relayMode)
  295. case RelayModeAudioSpeech:
  296. fallthrough
  297. case RelayModeAudioTranslation:
  298. fallthrough
  299. case RelayModeAudioTranscription:
  300. err = relayAudioHelper(c, relayMode)
  301. default:
  302. err = relayTextHelper(c, relayMode)
  303. }
  304. if err != nil {
  305. requestId := c.GetString(common.RequestIdKey)
  306. retryTimesStr := c.Query("retry")
  307. retryTimes, _ := strconv.Atoi(retryTimesStr)
  308. if retryTimesStr == "" {
  309. retryTimes = common.RetryTimes
  310. }
  311. if retryTimes > 0 {
  312. c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s?retry=%d&error=%s", c.Request.URL.Path, retryTimes-1, err.Message))
  313. } else {
  314. if err.StatusCode == http.StatusTooManyRequests {
  315. //err.OpenAIError.Message = "当前分组上游负载已饱和,请稍后再试"
  316. }
  317. err.OpenAIError.Message = common.MessageWithRequestId(err.OpenAIError.Message, requestId)
  318. c.JSON(err.StatusCode, gin.H{
  319. "error": err.OpenAIError,
  320. })
  321. }
  322. channelId := c.GetInt("channel_id")
  323. autoBan := c.GetBool("auto_ban")
  324. common.LogError(c.Request.Context(), fmt.Sprintf("relay error (channel #%d): %s", channelId, err.Message))
  325. // https://platform.openai.com/docs/guides/error-codes/api-errors
  326. if shouldDisableChannel(&err.OpenAIError, err.StatusCode) && autoBan {
  327. channelId := c.GetInt("channel_id")
  328. channelName := c.GetString("channel_name")
  329. disableChannel(channelId, channelName, err.Message)
  330. }
  331. }
  332. }
  333. func RelayMidjourney(c *gin.Context) {
  334. relayMode := RelayModeUnknown
  335. if strings.HasPrefix(c.Request.URL.Path, "/mj/submit/imagine") {
  336. relayMode = RelayModeMidjourneyImagine
  337. } else if strings.HasPrefix(c.Request.URL.Path, "/mj/submit/blend") {
  338. relayMode = RelayModeMidjourneyBlend
  339. } else if strings.HasPrefix(c.Request.URL.Path, "/mj/submit/describe") {
  340. relayMode = RelayModeMidjourneyDescribe
  341. } else if strings.HasPrefix(c.Request.URL.Path, "/mj/notify") {
  342. relayMode = RelayModeMidjourneyNotify
  343. } else if strings.HasPrefix(c.Request.URL.Path, "/mj/submit/change") {
  344. relayMode = RelayModeMidjourneyChange
  345. } else if strings.HasPrefix(c.Request.URL.Path, "/mj/submit/simple-change") {
  346. relayMode = RelayModeMidjourneyChange
  347. } else if strings.HasSuffix(c.Request.URL.Path, "/fetch") {
  348. relayMode = RelayModeMidjourneyTaskFetch
  349. } else if strings.HasSuffix(c.Request.URL.Path, "/list-by-condition") {
  350. relayMode = RelayModeMidjourneyTaskFetchByCondition
  351. }
  352. var err *MidjourneyResponse
  353. switch relayMode {
  354. case RelayModeMidjourneyNotify:
  355. err = relayMidjourneyNotify(c)
  356. case RelayModeMidjourneyTaskFetch, RelayModeMidjourneyTaskFetchByCondition:
  357. err = relayMidjourneyTask(c, relayMode)
  358. default:
  359. err = relayMidjourneySubmit(c, relayMode)
  360. }
  361. //err = relayMidjourneySubmit(c, relayMode)
  362. log.Println(err)
  363. if err != nil {
  364. retryTimesStr := c.Query("retry")
  365. retryTimes, _ := strconv.Atoi(retryTimesStr)
  366. if retryTimesStr == "" {
  367. retryTimes = common.RetryTimes
  368. }
  369. if retryTimes > 0 {
  370. c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s?retry=%d", c.Request.URL.Path, retryTimes-1))
  371. } else {
  372. if err.Code == 30 {
  373. err.Result = "当前分组负载已饱和,请稍后再试,或升级账户以提升服务质量。"
  374. }
  375. c.JSON(429, gin.H{
  376. "error": fmt.Sprintf("%s %s", err.Description, err.Result),
  377. "type": "upstream_error",
  378. })
  379. }
  380. channelId := c.GetInt("channel_id")
  381. common.SysError(fmt.Sprintf("relay error (channel #%d): %s", channelId, fmt.Sprintf("%s %s", err.Description, err.Result)))
  382. //if shouldDisableChannel(&err.OpenAIError) {
  383. // channelId := c.GetInt("channel_id")
  384. // channelName := c.GetString("channel_name")
  385. // disableChannel(channelId, channelName, err.Result)
  386. //};''''''''''''''''''''''''''''''''
  387. }
  388. }
  389. func RelayNotImplemented(c *gin.Context) {
  390. err := OpenAIError{
  391. Message: "API not implemented",
  392. Type: "new_api_error",
  393. Param: "",
  394. Code: "api_not_implemented",
  395. }
  396. c.JSON(http.StatusNotImplemented, gin.H{
  397. "error": err,
  398. })
  399. }
  400. func RelayNotFound(c *gin.Context) {
  401. err := OpenAIError{
  402. Message: fmt.Sprintf("Invalid URL (%s %s)", c.Request.Method, c.Request.URL.Path),
  403. Type: "invalid_request_error",
  404. Param: "",
  405. Code: "",
  406. }
  407. c.JSON(http.StatusNotFound, gin.H{
  408. "error": err,
  409. })
  410. }