adaptor.go 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531
  1. package replicate
  2. import (
  3. "bytes"
  4. "encoding/json"
  5. "errors"
  6. "fmt"
  7. "io"
  8. "mime/multipart"
  9. "net/http"
  10. "net/textproto"
  11. "strconv"
  12. "strings"
  13. "github.com/QuantumNous/new-api/common"
  14. "github.com/QuantumNous/new-api/constant"
  15. "github.com/QuantumNous/new-api/dto"
  16. "github.com/QuantumNous/new-api/relay/channel"
  17. relaycommon "github.com/QuantumNous/new-api/relay/common"
  18. relayconstant "github.com/QuantumNous/new-api/relay/constant"
  19. "github.com/QuantumNous/new-api/service"
  20. "github.com/QuantumNous/new-api/types"
  21. "github.com/gin-gonic/gin"
  22. "github.com/samber/lo"
  23. )
  24. type Adaptor struct {
  25. }
  26. func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
  27. }
  28. func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
  29. if info == nil {
  30. return "", errors.New("replicate adaptor: relay info is nil")
  31. }
  32. if info.ChannelBaseUrl == "" {
  33. info.ChannelBaseUrl = constant.ChannelBaseURLs[constant.ChannelTypeReplicate]
  34. }
  35. requestPath := info.RequestURLPath
  36. if requestPath == "" {
  37. return info.ChannelBaseUrl, nil
  38. }
  39. return relaycommon.GetFullRequestURL(info.ChannelBaseUrl, requestPath, info.ChannelType), nil
  40. }
  41. func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
  42. if info == nil {
  43. return errors.New("replicate adaptor: relay info is nil")
  44. }
  45. if info.ApiKey == "" {
  46. return errors.New("replicate adaptor: api key is required")
  47. }
  48. channel.SetupApiRequestHeader(info, c, req)
  49. req.Set("Authorization", "Bearer "+info.ApiKey)
  50. req.Set("Prefer", "wait")
  51. if req.Get("Content-Type") == "" {
  52. req.Set("Content-Type", "application/json")
  53. }
  54. if req.Get("Accept") == "" {
  55. req.Set("Accept", "application/json")
  56. }
  57. return nil
  58. }
  59. func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) {
  60. if info == nil {
  61. return nil, errors.New("replicate adaptor: relay info is nil")
  62. }
  63. if strings.TrimSpace(request.Prompt) == "" {
  64. if v := c.PostForm("prompt"); strings.TrimSpace(v) != "" {
  65. request.Prompt = v
  66. }
  67. }
  68. if strings.TrimSpace(request.Prompt) == "" {
  69. return nil, errors.New("replicate adaptor: prompt is required")
  70. }
  71. modelName := strings.TrimSpace(info.UpstreamModelName)
  72. if modelName == "" {
  73. modelName = strings.TrimSpace(request.Model)
  74. }
  75. if modelName == "" {
  76. modelName = ModelFlux11Pro
  77. }
  78. info.UpstreamModelName = modelName
  79. info.RequestURLPath = fmt.Sprintf("/v1/models/%s/predictions", modelName)
  80. inputPayload := make(map[string]any)
  81. inputPayload["prompt"] = request.Prompt
  82. if size := strings.TrimSpace(request.Size); size != "" {
  83. if aspect, width, height, ok := mapOpenAISizeToFlux(size); ok {
  84. if aspect != "" {
  85. if aspect == "custom" {
  86. inputPayload["aspect_ratio"] = "custom"
  87. if width > 0 {
  88. inputPayload["width"] = width
  89. }
  90. if height > 0 {
  91. inputPayload["height"] = height
  92. }
  93. } else {
  94. inputPayload["aspect_ratio"] = aspect
  95. }
  96. }
  97. }
  98. }
  99. if len(request.OutputFormat) > 0 {
  100. var outputFormat string
  101. if err := json.Unmarshal(request.OutputFormat, &outputFormat); err == nil && strings.TrimSpace(outputFormat) != "" {
  102. inputPayload["output_format"] = outputFormat
  103. }
  104. }
  105. if imageN := lo.FromPtrOr(request.N, uint(0)); imageN > 0 {
  106. inputPayload["num_outputs"] = int(imageN)
  107. }
  108. if strings.EqualFold(request.Quality, "hd") || strings.EqualFold(request.Quality, "high") {
  109. inputPayload["prompt_upsampling"] = true
  110. }
  111. if info.RelayMode == relayconstant.RelayModeImagesEdits {
  112. imageURL, err := uploadFileFromForm(c, info, "image", "image[]", "image_prompt")
  113. if err != nil {
  114. return nil, err
  115. }
  116. if imageURL == "" {
  117. return nil, errors.New("replicate adaptor: image file is required for edits")
  118. }
  119. inputPayload["image_prompt"] = imageURL
  120. }
  121. if len(request.ExtraFields) > 0 {
  122. var extra map[string]any
  123. if err := common.Unmarshal(request.ExtraFields, &extra); err != nil {
  124. return nil, fmt.Errorf("replicate adaptor: failed to decode extra_fields: %w", err)
  125. }
  126. for key, val := range extra {
  127. inputPayload[key] = val
  128. }
  129. }
  130. for key, raw := range request.Extra {
  131. if strings.EqualFold(key, "input") {
  132. var extraInput map[string]any
  133. if err := common.Unmarshal(raw, &extraInput); err != nil {
  134. return nil, fmt.Errorf("replicate adaptor: failed to decode extra input: %w", err)
  135. }
  136. for k, v := range extraInput {
  137. inputPayload[k] = v
  138. }
  139. continue
  140. }
  141. if raw == nil {
  142. continue
  143. }
  144. var val any
  145. if err := common.Unmarshal(raw, &val); err != nil {
  146. return nil, fmt.Errorf("replicate adaptor: failed to decode extra field %s: %w", key, err)
  147. }
  148. inputPayload[key] = val
  149. }
  150. return map[string]any{
  151. "input": inputPayload,
  152. }, nil
  153. }
  154. func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
  155. return channel.DoApiRequest(a, c, info, requestBody)
  156. }
  157. func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (any, *types.NewAPIError) {
  158. if resp == nil {
  159. return nil, types.NewError(errors.New("replicate adaptor: empty response"), types.ErrorCodeBadResponse)
  160. }
  161. responseBody, err := io.ReadAll(resp.Body)
  162. if err != nil {
  163. return nil, types.NewError(err, types.ErrorCodeReadResponseBodyFailed)
  164. }
  165. _ = resp.Body.Close()
  166. var prediction PredictionResponse
  167. if err := common.Unmarshal(responseBody, &prediction); err != nil {
  168. return nil, types.NewError(fmt.Errorf("replicate adaptor: failed to decode response: %w", err), types.ErrorCodeBadResponseBody)
  169. }
  170. if prediction.Error != nil {
  171. errMsg := prediction.Error.Message
  172. if errMsg == "" {
  173. errMsg = prediction.Error.Detail
  174. }
  175. if errMsg == "" {
  176. errMsg = prediction.Error.Code
  177. }
  178. if errMsg == "" {
  179. errMsg = "replicate adaptor: prediction error"
  180. }
  181. return nil, types.NewError(errors.New(errMsg), types.ErrorCodeBadResponse)
  182. }
  183. if prediction.Status != "" && !strings.EqualFold(prediction.Status, "succeeded") {
  184. return nil, types.NewError(fmt.Errorf("replicate adaptor: prediction status %q", prediction.Status), types.ErrorCodeBadResponse)
  185. }
  186. var urls []string
  187. appendOutput := func(value string) {
  188. value = strings.TrimSpace(value)
  189. if value == "" {
  190. return
  191. }
  192. urls = append(urls, value)
  193. }
  194. switch output := prediction.Output.(type) {
  195. case string:
  196. appendOutput(output)
  197. case []any:
  198. for _, item := range output {
  199. if str, ok := item.(string); ok {
  200. appendOutput(str)
  201. }
  202. }
  203. case nil:
  204. // no output
  205. default:
  206. if str, ok := output.(fmt.Stringer); ok {
  207. appendOutput(str.String())
  208. }
  209. }
  210. if len(urls) == 0 {
  211. return nil, types.NewError(errors.New("replicate adaptor: empty prediction output"), types.ErrorCodeBadResponseBody)
  212. }
  213. var imageReq *dto.ImageRequest
  214. if info != nil {
  215. if req, ok := info.Request.(*dto.ImageRequest); ok {
  216. imageReq = req
  217. }
  218. }
  219. wantsBase64 := imageReq != nil && strings.EqualFold(imageReq.ResponseFormat, "b64_json")
  220. imageResponse := dto.ImageResponse{
  221. Created: common.GetTimestamp(),
  222. Data: make([]dto.ImageData, 0),
  223. }
  224. if wantsBase64 {
  225. converted, convErr := downloadImagesToBase64(urls)
  226. if convErr != nil {
  227. return nil, types.NewError(convErr, types.ErrorCodeBadResponse)
  228. }
  229. for _, content := range converted {
  230. if content == "" {
  231. continue
  232. }
  233. imageResponse.Data = append(imageResponse.Data, dto.ImageData{B64Json: content})
  234. }
  235. } else {
  236. for _, url := range urls {
  237. if url == "" {
  238. continue
  239. }
  240. imageResponse.Data = append(imageResponse.Data, dto.ImageData{Url: url})
  241. }
  242. }
  243. if len(imageResponse.Data) == 0 {
  244. return nil, types.NewError(errors.New("replicate adaptor: no usable image data"), types.ErrorCodeBadResponse)
  245. }
  246. responseBytes, err := common.Marshal(imageResponse)
  247. if err != nil {
  248. return nil, types.NewError(fmt.Errorf("replicate adaptor: encode response failed: %w", err), types.ErrorCodeBadResponseBody)
  249. }
  250. c.Writer.Header().Set("Content-Type", "application/json")
  251. c.Writer.WriteHeader(http.StatusOK)
  252. _, _ = c.Writer.Write(responseBytes)
  253. usage := &dto.Usage{}
  254. return usage, nil
  255. }
  256. func (a *Adaptor) GetModelList() []string {
  257. return ModelList
  258. }
  259. func (a *Adaptor) GetChannelName() string {
  260. return ChannelName
  261. }
  262. func downloadImagesToBase64(urls []string) ([]string, error) {
  263. results := make([]string, 0, len(urls))
  264. for _, url := range urls {
  265. if strings.TrimSpace(url) == "" {
  266. continue
  267. }
  268. _, data, err := service.GetImageFromUrl(url)
  269. if err != nil {
  270. return nil, fmt.Errorf("replicate adaptor: failed to download image from %s: %w", url, err)
  271. }
  272. results = append(results, data)
  273. }
  274. return results, nil
  275. }
  276. func mapOpenAISizeToFlux(size string) (aspect string, width int, height int, ok bool) {
  277. parts := strings.Split(size, "x")
  278. if len(parts) != 2 {
  279. return "", 0, 0, false
  280. }
  281. w, err1 := strconv.Atoi(strings.TrimSpace(parts[0]))
  282. h, err2 := strconv.Atoi(strings.TrimSpace(parts[1]))
  283. if err1 != nil || err2 != nil || w <= 0 || h <= 0 {
  284. return "", 0, 0, false
  285. }
  286. switch {
  287. case w == h:
  288. return "1:1", 0, 0, true
  289. case w == 1792 && h == 1024:
  290. return "16:9", 0, 0, true
  291. case w == 1024 && h == 1792:
  292. return "9:16", 0, 0, true
  293. case w == 1536 && h == 1024:
  294. return "3:2", 0, 0, true
  295. case w == 1024 && h == 1536:
  296. return "2:3", 0, 0, true
  297. }
  298. rw, rh := reduceRatio(w, h)
  299. ratioStr := fmt.Sprintf("%d:%d", rw, rh)
  300. switch ratioStr {
  301. case "1:1", "16:9", "9:16", "3:2", "2:3", "4:5", "5:4", "3:4", "4:3":
  302. return ratioStr, 0, 0, true
  303. }
  304. width = normalizeFluxDimension(w)
  305. height = normalizeFluxDimension(h)
  306. return "custom", width, height, true
  307. }
  308. func reduceRatio(w, h int) (int, int) {
  309. g := gcd(w, h)
  310. if g == 0 {
  311. return w, h
  312. }
  313. return w / g, h / g
  314. }
  315. func gcd(a, b int) int {
  316. for b != 0 {
  317. a, b = b, a%b
  318. }
  319. if a < 0 {
  320. return -a
  321. }
  322. return a
  323. }
  324. func normalizeFluxDimension(value int) int {
  325. const (
  326. minDim = 256
  327. maxDim = 1440
  328. step = 32
  329. )
  330. if value < minDim {
  331. value = minDim
  332. }
  333. if value > maxDim {
  334. value = maxDim
  335. }
  336. remainder := value % step
  337. if remainder != 0 {
  338. if remainder >= step/2 {
  339. value += step - remainder
  340. } else {
  341. value -= remainder
  342. }
  343. }
  344. if value < minDim {
  345. value = minDim
  346. }
  347. if value > maxDim {
  348. value = maxDim
  349. }
  350. return value
  351. }
  352. func uploadFileFromForm(c *gin.Context, info *relaycommon.RelayInfo, fieldCandidates ...string) (string, error) {
  353. if info == nil {
  354. return "", errors.New("replicate adaptor: relay info is nil")
  355. }
  356. mf := c.Request.MultipartForm
  357. if mf == nil {
  358. if _, err := c.MultipartForm(); err != nil {
  359. return "", fmt.Errorf("replicate adaptor: parse multipart form failed: %w", err)
  360. }
  361. mf = c.Request.MultipartForm
  362. }
  363. if mf == nil || len(mf.File) == 0 {
  364. return "", nil
  365. }
  366. if len(fieldCandidates) == 0 {
  367. fieldCandidates = []string{"image", "image[]", "image_prompt"}
  368. }
  369. var fileHeader *multipart.FileHeader
  370. for _, key := range fieldCandidates {
  371. if files := mf.File[key]; len(files) > 0 {
  372. fileHeader = files[0]
  373. break
  374. }
  375. }
  376. if fileHeader == nil {
  377. for _, files := range mf.File {
  378. if len(files) > 0 {
  379. fileHeader = files[0]
  380. break
  381. }
  382. }
  383. }
  384. if fileHeader == nil {
  385. return "", nil
  386. }
  387. file, err := fileHeader.Open()
  388. if err != nil {
  389. return "", fmt.Errorf("replicate adaptor: failed to open image file: %w", err)
  390. }
  391. defer file.Close()
  392. var body bytes.Buffer
  393. writer := multipart.NewWriter(&body)
  394. hdr := make(textproto.MIMEHeader)
  395. hdr.Set("Content-Disposition", fmt.Sprintf("form-data; name=\"content\"; filename=\"%s\"", fileHeader.Filename))
  396. contentType := fileHeader.Header.Get("Content-Type")
  397. if contentType == "" {
  398. contentType = "application/octet-stream"
  399. }
  400. hdr.Set("Content-Type", contentType)
  401. part, err := writer.CreatePart(hdr)
  402. if err != nil {
  403. writer.Close()
  404. return "", fmt.Errorf("replicate adaptor: create upload form failed: %w", err)
  405. }
  406. if _, err := io.Copy(part, file); err != nil {
  407. writer.Close()
  408. return "", fmt.Errorf("replicate adaptor: copy image content failed: %w", err)
  409. }
  410. formContentType := writer.FormDataContentType()
  411. writer.Close()
  412. baseURL := info.ChannelBaseUrl
  413. if baseURL == "" {
  414. baseURL = constant.ChannelBaseURLs[constant.ChannelTypeReplicate]
  415. }
  416. uploadURL := relaycommon.GetFullRequestURL(baseURL, "/v1/files", info.ChannelType)
  417. req, err := http.NewRequest(http.MethodPost, uploadURL, &body)
  418. if err != nil {
  419. return "", fmt.Errorf("replicate adaptor: create upload request failed: %w", err)
  420. }
  421. req.Header.Set("Content-Type", formContentType)
  422. req.Header.Set("Authorization", "Bearer "+info.ApiKey)
  423. resp, err := service.GetHttpClient().Do(req)
  424. if err != nil {
  425. return "", fmt.Errorf("replicate adaptor: upload image failed: %w", err)
  426. }
  427. defer resp.Body.Close()
  428. respBody, err := io.ReadAll(resp.Body)
  429. if err != nil {
  430. return "", fmt.Errorf("replicate adaptor: read upload response failed: %w", err)
  431. }
  432. if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusCreated {
  433. return "", fmt.Errorf("replicate adaptor: upload image failed with status %d: %s", resp.StatusCode, strings.TrimSpace(string(respBody)))
  434. }
  435. var uploadResp FileUploadResponse
  436. if err := common.Unmarshal(respBody, &uploadResp); err != nil {
  437. return "", fmt.Errorf("replicate adaptor: decode upload response failed: %w", err)
  438. }
  439. if uploadResp.Urls.Get == "" {
  440. return "", errors.New("replicate adaptor: upload response missing url")
  441. }
  442. return uploadResp.Urls.Get, nil
  443. }
  444. func (a *Adaptor) ConvertOpenAIRequest(*gin.Context, *relaycommon.RelayInfo, *dto.GeneralOpenAIRequest) (any, error) {
  445. return nil, errors.New("replicate adaptor: ConvertOpenAIRequest is not implemented")
  446. }
  447. func (a *Adaptor) ConvertRerankRequest(*gin.Context, int, dto.RerankRequest) (any, error) {
  448. return nil, errors.New("replicate adaptor: ConvertRerankRequest is not implemented")
  449. }
  450. func (a *Adaptor) ConvertEmbeddingRequest(*gin.Context, *relaycommon.RelayInfo, dto.EmbeddingRequest) (any, error) {
  451. return nil, errors.New("replicate adaptor: ConvertEmbeddingRequest is not implemented")
  452. }
  453. func (a *Adaptor) ConvertAudioRequest(*gin.Context, *relaycommon.RelayInfo, dto.AudioRequest) (io.Reader, error) {
  454. return nil, errors.New("replicate adaptor: ConvertAudioRequest is not implemented")
  455. }
  456. func (a *Adaptor) ConvertOpenAIResponsesRequest(*gin.Context, *relaycommon.RelayInfo, dto.OpenAIResponsesRequest) (any, error) {
  457. return nil, errors.New("replicate adaptor: ConvertOpenAIResponsesRequest is not implemented")
  458. }
  459. func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) {
  460. return nil, errors.New("replicate adaptor: ConvertClaudeRequest is not implemented")
  461. }
  462. func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dto.GeminiChatRequest) (any, error) {
  463. return nil, errors.New("replicate adaptor: ConvertGeminiRequest is not implemented")
  464. }