adaptor.go 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486
  1. package jimeng
  2. import (
  3. "bytes"
  4. "crypto/hmac"
  5. "crypto/sha256"
  6. "encoding/base64"
  7. "encoding/hex"
  8. "encoding/json"
  9. "fmt"
  10. "io"
  11. "net/http"
  12. "net/url"
  13. "sort"
  14. "strings"
  15. "time"
  16. "github.com/QuantumNous/new-api/common"
  17. "github.com/QuantumNous/new-api/model"
  18. "github.com/samber/lo"
  19. "github.com/gin-gonic/gin"
  20. "github.com/pkg/errors"
  21. "github.com/QuantumNous/new-api/constant"
  22. "github.com/QuantumNous/new-api/dto"
  23. "github.com/QuantumNous/new-api/relay/channel"
  24. relaycommon "github.com/QuantumNous/new-api/relay/common"
  25. "github.com/QuantumNous/new-api/service"
  26. )
  27. // ============================
  28. // Request / Response structures
  29. // ============================
  30. type requestPayload struct {
  31. ReqKey string `json:"req_key"`
  32. BinaryDataBase64 []string `json:"binary_data_base64,omitempty"`
  33. ImageUrls []string `json:"image_urls,omitempty"`
  34. Prompt string `json:"prompt,omitempty"`
  35. Seed int64 `json:"seed"`
  36. AspectRatio string `json:"aspect_ratio"`
  37. Frames int `json:"frames,omitempty"`
  38. }
  39. type responsePayload struct {
  40. Code int `json:"code"`
  41. Message string `json:"message"`
  42. RequestId string `json:"request_id"`
  43. Data struct {
  44. TaskID string `json:"task_id"`
  45. } `json:"data"`
  46. }
  47. type responseTask struct {
  48. Code int `json:"code"`
  49. Data struct {
  50. BinaryDataBase64 []interface{} `json:"binary_data_base64"`
  51. ImageUrls interface{} `json:"image_urls"`
  52. RespData string `json:"resp_data"`
  53. Status string `json:"status"`
  54. VideoUrl string `json:"video_url"`
  55. } `json:"data"`
  56. Message string `json:"message"`
  57. RequestId string `json:"request_id"`
  58. Status int `json:"status"`
  59. TimeElapsed string `json:"time_elapsed"`
  60. }
  61. const (
  62. // 即梦限制单个文件最大4.7MB https://www.volcengine.com/docs/85621/1747301
  63. MaxFileSize int64 = 4*1024*1024 + 700*1024 // 4.7MB (4MB + 724KB)
  64. )
  65. // ============================
  66. // Adaptor implementation
  67. // ============================
  68. type TaskAdaptor struct {
  69. ChannelType int
  70. accessKey string
  71. secretKey string
  72. baseURL string
  73. }
  74. func (a *TaskAdaptor) Init(info *relaycommon.RelayInfo) {
  75. a.ChannelType = info.ChannelType
  76. a.baseURL = info.ChannelBaseUrl
  77. // apiKey format: "access_key|secret_key"
  78. keyParts := strings.Split(info.ApiKey, "|")
  79. if len(keyParts) == 2 {
  80. a.accessKey = strings.TrimSpace(keyParts[0])
  81. a.secretKey = strings.TrimSpace(keyParts[1])
  82. }
  83. }
  84. // ValidateRequestAndSetAction parses body, validates fields and sets default action.
  85. func (a *TaskAdaptor) ValidateRequestAndSetAction(c *gin.Context, info *relaycommon.RelayInfo) (taskErr *dto.TaskError) {
  86. return relaycommon.ValidateBasicTaskRequest(c, info, constant.TaskActionGenerate)
  87. }
  88. // BuildRequestURL constructs the upstream URL.
  89. func (a *TaskAdaptor) BuildRequestURL(info *relaycommon.RelayInfo) (string, error) {
  90. if isNewAPIRelay(info.ApiKey) {
  91. return fmt.Sprintf("%s/jimeng/?Action=CVSync2AsyncSubmitTask&Version=2022-08-31", a.baseURL), nil
  92. }
  93. return fmt.Sprintf("%s/?Action=CVSync2AsyncSubmitTask&Version=2022-08-31", a.baseURL), nil
  94. }
  95. // BuildRequestHeader sets required headers.
  96. func (a *TaskAdaptor) BuildRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error {
  97. req.Header.Set("Content-Type", "application/json")
  98. req.Header.Set("Accept", "application/json")
  99. if isNewAPIRelay(info.ApiKey) {
  100. req.Header.Set("Authorization", "Bearer "+info.ApiKey)
  101. } else {
  102. return a.signRequest(req, a.accessKey, a.secretKey)
  103. }
  104. return nil
  105. }
  106. func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, info *relaycommon.RelayInfo) (io.Reader, error) {
  107. v, exists := c.Get("task_request")
  108. if !exists {
  109. return nil, fmt.Errorf("request not found in context")
  110. }
  111. req, ok := v.(relaycommon.TaskSubmitReq)
  112. if !ok {
  113. return nil, fmt.Errorf("invalid request type in context")
  114. }
  115. // 支持openai sdk的图片上传方式
  116. if mf, err := c.MultipartForm(); err == nil {
  117. if files, exists := mf.File["input_reference"]; exists && len(files) > 0 {
  118. if len(files) == 1 {
  119. info.Action = constant.TaskActionGenerate
  120. } else if len(files) > 1 {
  121. info.Action = constant.TaskActionFirstTailGenerate
  122. }
  123. // 将上传的文件转换为base64格式
  124. var images []string
  125. for _, fileHeader := range files {
  126. // 检查文件大小
  127. if fileHeader.Size > MaxFileSize {
  128. return nil, fmt.Errorf("文件 %s 大小超过限制,最大允许 %d MB", fileHeader.Filename, MaxFileSize/(1024*1024))
  129. }
  130. file, err := fileHeader.Open()
  131. if err != nil {
  132. continue
  133. }
  134. fileBytes, err := io.ReadAll(file)
  135. file.Close()
  136. if err != nil {
  137. continue
  138. }
  139. // 将文件内容转换为base64
  140. base64Str := base64.StdEncoding.EncodeToString(fileBytes)
  141. images = append(images, base64Str)
  142. }
  143. req.Images = images
  144. }
  145. }
  146. body, err := a.convertToRequestPayload(&req)
  147. if err != nil {
  148. return nil, errors.Wrap(err, "convert request payload failed")
  149. }
  150. data, err := json.Marshal(body)
  151. if err != nil {
  152. return nil, err
  153. }
  154. return bytes.NewReader(data), nil
  155. }
  156. // DoRequest delegates to common helper.
  157. func (a *TaskAdaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
  158. return channel.DoTaskApiRequest(a, c, info, requestBody)
  159. }
  160. // DoResponse handles upstream response, returns taskID etc.
  161. func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (taskID string, taskData []byte, taskErr *dto.TaskError) {
  162. responseBody, err := io.ReadAll(resp.Body)
  163. if err != nil {
  164. taskErr = service.TaskErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError)
  165. return
  166. }
  167. _ = resp.Body.Close()
  168. // Parse Jimeng response
  169. var jResp responsePayload
  170. if err := json.Unmarshal(responseBody, &jResp); err != nil {
  171. taskErr = service.TaskErrorWrapper(errors.Wrapf(err, "body: %s", responseBody), "unmarshal_response_body_failed", http.StatusInternalServerError)
  172. return
  173. }
  174. if jResp.Code != 10000 {
  175. taskErr = service.TaskErrorWrapper(fmt.Errorf("%s", jResp.Message), fmt.Sprintf("%d", jResp.Code), http.StatusInternalServerError)
  176. return
  177. }
  178. ov := dto.NewOpenAIVideo()
  179. ov.ID = jResp.Data.TaskID
  180. ov.TaskID = jResp.Data.TaskID
  181. ov.CreatedAt = time.Now().Unix()
  182. ov.Model = info.OriginModelName
  183. c.JSON(http.StatusOK, ov)
  184. return jResp.Data.TaskID, responseBody, nil
  185. }
  186. // FetchTask fetch task status
  187. func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any, proxy string) (*http.Response, error) {
  188. taskID, ok := body["task_id"].(string)
  189. if !ok {
  190. return nil, fmt.Errorf("invalid task_id")
  191. }
  192. uri := fmt.Sprintf("%s/?Action=CVSync2AsyncGetResult&Version=2022-08-31", baseUrl)
  193. if isNewAPIRelay(key) {
  194. uri = fmt.Sprintf("%s/jimeng/?Action=CVSync2AsyncGetResult&Version=2022-08-31", a.baseURL)
  195. }
  196. payload := map[string]string{
  197. "req_key": "jimeng_vgfm_t2v_l20", // This is fixed value from doc: https://www.volcengine.com/docs/85621/1544774
  198. "task_id": taskID,
  199. }
  200. payloadBytes, err := json.Marshal(payload)
  201. if err != nil {
  202. return nil, errors.Wrap(err, "marshal fetch task payload failed")
  203. }
  204. req, err := http.NewRequest(http.MethodPost, uri, bytes.NewBuffer(payloadBytes))
  205. if err != nil {
  206. return nil, err
  207. }
  208. req.Header.Set("Accept", "application/json")
  209. req.Header.Set("Content-Type", "application/json")
  210. if isNewAPIRelay(key) {
  211. req.Header.Set("Authorization", "Bearer "+key)
  212. } else {
  213. keyParts := strings.Split(key, "|")
  214. if len(keyParts) != 2 {
  215. return nil, fmt.Errorf("invalid api key format for jimeng: expected 'ak|sk'")
  216. }
  217. accessKey := strings.TrimSpace(keyParts[0])
  218. secretKey := strings.TrimSpace(keyParts[1])
  219. if err := a.signRequest(req, accessKey, secretKey); err != nil {
  220. return nil, errors.Wrap(err, "sign request failed")
  221. }
  222. }
  223. client, err := service.GetHttpClientWithProxy(proxy)
  224. if err != nil {
  225. return nil, fmt.Errorf("new proxy http client failed: %w", err)
  226. }
  227. return client.Do(req)
  228. }
  229. func (a *TaskAdaptor) GetModelList() []string {
  230. return []string{"jimeng_vgfm_t2v_l20"}
  231. }
  232. func (a *TaskAdaptor) GetChannelName() string {
  233. return "jimeng"
  234. }
  235. func (a *TaskAdaptor) signRequest(req *http.Request, accessKey, secretKey string) error {
  236. var bodyBytes []byte
  237. var err error
  238. if req.Body != nil {
  239. bodyBytes, err = io.ReadAll(req.Body)
  240. if err != nil {
  241. return errors.Wrap(err, "read request body failed")
  242. }
  243. _ = req.Body.Close()
  244. req.Body = io.NopCloser(bytes.NewBuffer(bodyBytes)) // Rewind
  245. } else {
  246. bodyBytes = []byte{}
  247. }
  248. payloadHash := sha256.Sum256(bodyBytes)
  249. hexPayloadHash := hex.EncodeToString(payloadHash[:])
  250. t := time.Now().UTC()
  251. xDate := t.Format("20060102T150405Z")
  252. shortDate := t.Format("20060102")
  253. req.Header.Set("Host", req.URL.Host)
  254. req.Header.Set("X-Date", xDate)
  255. req.Header.Set("X-Content-Sha256", hexPayloadHash)
  256. // Sort and encode query parameters to create canonical query string
  257. queryParams := req.URL.Query()
  258. sortedKeys := make([]string, 0, len(queryParams))
  259. for k := range queryParams {
  260. sortedKeys = append(sortedKeys, k)
  261. }
  262. sort.Strings(sortedKeys)
  263. var queryParts []string
  264. for _, k := range sortedKeys {
  265. values := queryParams[k]
  266. sort.Strings(values)
  267. for _, v := range values {
  268. queryParts = append(queryParts, fmt.Sprintf("%s=%s", url.QueryEscape(k), url.QueryEscape(v)))
  269. }
  270. }
  271. canonicalQueryString := strings.Join(queryParts, "&")
  272. headersToSign := map[string]string{
  273. "host": req.URL.Host,
  274. "x-date": xDate,
  275. "x-content-sha256": hexPayloadHash,
  276. }
  277. if req.Header.Get("Content-Type") != "" {
  278. headersToSign["content-type"] = req.Header.Get("Content-Type")
  279. }
  280. var signedHeaderKeys []string
  281. for k := range headersToSign {
  282. signedHeaderKeys = append(signedHeaderKeys, k)
  283. }
  284. sort.Strings(signedHeaderKeys)
  285. var canonicalHeaders strings.Builder
  286. for _, k := range signedHeaderKeys {
  287. canonicalHeaders.WriteString(k)
  288. canonicalHeaders.WriteString(":")
  289. canonicalHeaders.WriteString(strings.TrimSpace(headersToSign[k]))
  290. canonicalHeaders.WriteString("\n")
  291. }
  292. signedHeaders := strings.Join(signedHeaderKeys, ";")
  293. canonicalRequest := fmt.Sprintf("%s\n%s\n%s\n%s\n%s\n%s",
  294. req.Method,
  295. req.URL.Path,
  296. canonicalQueryString,
  297. canonicalHeaders.String(),
  298. signedHeaders,
  299. hexPayloadHash,
  300. )
  301. hashedCanonicalRequest := sha256.Sum256([]byte(canonicalRequest))
  302. hexHashedCanonicalRequest := hex.EncodeToString(hashedCanonicalRequest[:])
  303. region := "cn-north-1"
  304. serviceName := "cv"
  305. credentialScope := fmt.Sprintf("%s/%s/%s/request", shortDate, region, serviceName)
  306. stringToSign := fmt.Sprintf("HMAC-SHA256\n%s\n%s\n%s",
  307. xDate,
  308. credentialScope,
  309. hexHashedCanonicalRequest,
  310. )
  311. kDate := hmacSHA256([]byte(secretKey), []byte(shortDate))
  312. kRegion := hmacSHA256(kDate, []byte(region))
  313. kService := hmacSHA256(kRegion, []byte(serviceName))
  314. kSigning := hmacSHA256(kService, []byte("request"))
  315. signature := hex.EncodeToString(hmacSHA256(kSigning, []byte(stringToSign)))
  316. authorization := fmt.Sprintf("HMAC-SHA256 Credential=%s/%s, SignedHeaders=%s, Signature=%s",
  317. accessKey,
  318. credentialScope,
  319. signedHeaders,
  320. signature,
  321. )
  322. req.Header.Set("Authorization", authorization)
  323. return nil
  324. }
  325. func hmacSHA256(key []byte, data []byte) []byte {
  326. h := hmac.New(sha256.New, key)
  327. h.Write(data)
  328. return h.Sum(nil)
  329. }
  330. func (a *TaskAdaptor) convertToRequestPayload(req *relaycommon.TaskSubmitReq) (*requestPayload, error) {
  331. r := requestPayload{
  332. ReqKey: req.Model,
  333. Prompt: req.Prompt,
  334. }
  335. switch req.Duration {
  336. case 10:
  337. r.Frames = 241 // 24*10+1 = 241
  338. default:
  339. r.Frames = 121 // 24*5+1 = 121
  340. }
  341. // Handle one-of image_urls or binary_data_base64
  342. if req.HasImage() {
  343. if strings.HasPrefix(req.Images[0], "http") {
  344. r.ImageUrls = req.Images
  345. } else {
  346. r.BinaryDataBase64 = req.Images
  347. }
  348. }
  349. metadata := req.Metadata
  350. medaBytes, err := json.Marshal(metadata)
  351. if err != nil {
  352. return nil, errors.Wrap(err, "metadata marshal metadata failed")
  353. }
  354. err = json.Unmarshal(medaBytes, &r)
  355. if err != nil {
  356. return nil, errors.Wrap(err, "unmarshal metadata failed")
  357. }
  358. // 即梦视频3.0 ReqKey转换
  359. // https://www.volcengine.com/docs/85621/1792707
  360. imageLen := lo.Max([]int{len(req.Images), len(r.BinaryDataBase64), len(r.ImageUrls)})
  361. if strings.Contains(r.ReqKey, "jimeng_v30") {
  362. if r.ReqKey == "jimeng_v30_pro" {
  363. // 3.0 pro只有固定的jimeng_ti2v_v30_pro
  364. r.ReqKey = "jimeng_ti2v_v30_pro"
  365. } else if imageLen > 1 {
  366. // 多张图片:首尾帧生成
  367. r.ReqKey = strings.TrimSuffix(strings.Replace(r.ReqKey, "jimeng_v30", "jimeng_i2v_first_tail_v30", 1), "p")
  368. } else if imageLen == 1 {
  369. // 单张图片:图生视频
  370. r.ReqKey = strings.TrimSuffix(strings.Replace(r.ReqKey, "jimeng_v30", "jimeng_i2v_first_v30", 1), "p")
  371. } else {
  372. // 无图片:文生视频
  373. r.ReqKey = strings.Replace(r.ReqKey, "jimeng_v30", "jimeng_t2v_v30", 1)
  374. }
  375. }
  376. return &r, nil
  377. }
  378. func (a *TaskAdaptor) ParseTaskResult(respBody []byte) (*relaycommon.TaskInfo, error) {
  379. resTask := responseTask{}
  380. if err := json.Unmarshal(respBody, &resTask); err != nil {
  381. return nil, errors.Wrap(err, "unmarshal task result failed")
  382. }
  383. taskResult := relaycommon.TaskInfo{}
  384. if resTask.Code == 10000 {
  385. taskResult.Code = 0
  386. } else {
  387. taskResult.Code = resTask.Code // todo uni code
  388. taskResult.Reason = resTask.Message
  389. taskResult.Status = model.TaskStatusFailure
  390. taskResult.Progress = "100%"
  391. }
  392. switch resTask.Data.Status {
  393. case "in_queue":
  394. taskResult.Status = model.TaskStatusQueued
  395. taskResult.Progress = "10%"
  396. case "done":
  397. taskResult.Status = model.TaskStatusSuccess
  398. taskResult.Progress = "100%"
  399. }
  400. taskResult.Url = resTask.Data.VideoUrl
  401. return &taskResult, nil
  402. }
  403. func (a *TaskAdaptor) ConvertToOpenAIVideo(originTask *model.Task) ([]byte, error) {
  404. var jimengResp responseTask
  405. if err := json.Unmarshal(originTask.Data, &jimengResp); err != nil {
  406. return nil, errors.Wrap(err, "unmarshal jimeng task data failed")
  407. }
  408. openAIVideo := dto.NewOpenAIVideo()
  409. openAIVideo.ID = originTask.TaskID
  410. openAIVideo.Status = originTask.Status.ToVideoStatus()
  411. openAIVideo.SetProgressStr(originTask.Progress)
  412. openAIVideo.SetMetadata("url", jimengResp.Data.VideoUrl)
  413. openAIVideo.CreatedAt = originTask.CreatedAt
  414. openAIVideo.CompletedAt = originTask.UpdatedAt
  415. if jimengResp.Code != 10000 {
  416. openAIVideo.Error = &dto.OpenAIVideoError{
  417. Message: jimengResp.Message,
  418. Code: fmt.Sprintf("%d", jimengResp.Code),
  419. }
  420. }
  421. jsonData, _ := common.Marshal(openAIVideo)
  422. return jsonData, nil
  423. }
  424. func isNewAPIRelay(apiKey string) bool {
  425. return strings.HasPrefix(apiKey, "sk-")
  426. }