| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484 |
- package jimeng
- import (
- "bytes"
- "crypto/hmac"
- "crypto/sha256"
- "encoding/base64"
- "encoding/hex"
- "encoding/json"
- "fmt"
- "io"
- "net/http"
- "net/url"
- "sort"
- "strings"
- "time"
- "github.com/QuantumNous/new-api/common"
- "github.com/QuantumNous/new-api/model"
- "github.com/gin-gonic/gin"
- "github.com/pkg/errors"
- "github.com/QuantumNous/new-api/constant"
- "github.com/QuantumNous/new-api/dto"
- "github.com/QuantumNous/new-api/relay/channel"
- relaycommon "github.com/QuantumNous/new-api/relay/common"
- "github.com/QuantumNous/new-api/service"
- )
- // ============================
- // Request / Response structures
- // ============================
- type requestPayload struct {
- ReqKey string `json:"req_key"`
- BinaryDataBase64 []string `json:"binary_data_base64,omitempty"`
- ImageUrls []string `json:"image_urls,omitempty"`
- Prompt string `json:"prompt,omitempty"`
- Seed int64 `json:"seed"`
- AspectRatio string `json:"aspect_ratio"`
- Frames int `json:"frames,omitempty"`
- }
- type responsePayload struct {
- Code int `json:"code"`
- Message string `json:"message"`
- RequestId string `json:"request_id"`
- Data struct {
- TaskID string `json:"task_id"`
- } `json:"data"`
- }
- type responseTask struct {
- Code int `json:"code"`
- Data struct {
- BinaryDataBase64 []interface{} `json:"binary_data_base64"`
- ImageUrls interface{} `json:"image_urls"`
- RespData string `json:"resp_data"`
- Status string `json:"status"`
- VideoUrl string `json:"video_url"`
- } `json:"data"`
- Message string `json:"message"`
- RequestId string `json:"request_id"`
- Status int `json:"status"`
- TimeElapsed string `json:"time_elapsed"`
- }
- const (
- // 即梦限制单个文件最大4.7MB https://www.volcengine.com/docs/85621/1747301
- MaxFileSize int64 = 4*1024*1024 + 700*1024 // 4.7MB (4MB + 724KB)
- )
- // ============================
- // Adaptor implementation
- // ============================
- type TaskAdaptor struct {
- ChannelType int
- accessKey string
- secretKey string
- baseURL string
- }
- func (a *TaskAdaptor) Init(info *relaycommon.RelayInfo) {
- a.ChannelType = info.ChannelType
- a.baseURL = info.ChannelBaseUrl
- // apiKey format: "access_key|secret_key"
- keyParts := strings.Split(info.ApiKey, "|")
- if len(keyParts) == 2 {
- a.accessKey = strings.TrimSpace(keyParts[0])
- a.secretKey = strings.TrimSpace(keyParts[1])
- }
- }
- // ValidateRequestAndSetAction parses body, validates fields and sets default action.
- func (a *TaskAdaptor) ValidateRequestAndSetAction(c *gin.Context, info *relaycommon.RelayInfo) (taskErr *dto.TaskError) {
- return relaycommon.ValidateBasicTaskRequest(c, info, constant.TaskActionGenerate)
- }
- // BuildRequestURL constructs the upstream URL.
- func (a *TaskAdaptor) BuildRequestURL(info *relaycommon.RelayInfo) (string, error) {
- if isNewAPIRelay(info.ApiKey) {
- return fmt.Sprintf("%s/jimeng/?Action=CVSync2AsyncSubmitTask&Version=2022-08-31", a.baseURL), nil
- }
- return fmt.Sprintf("%s/?Action=CVSync2AsyncSubmitTask&Version=2022-08-31", a.baseURL), nil
- }
- // BuildRequestHeader sets required headers.
- func (a *TaskAdaptor) BuildRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error {
- req.Header.Set("Content-Type", "application/json")
- req.Header.Set("Accept", "application/json")
- if isNewAPIRelay(info.ApiKey) {
- req.Header.Set("Authorization", "Bearer "+info.ApiKey)
- } else {
- return a.signRequest(req, a.accessKey, a.secretKey)
- }
- return nil
- }
- func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, info *relaycommon.RelayInfo) (io.Reader, error) {
- v, exists := c.Get("task_request")
- if !exists {
- return nil, fmt.Errorf("request not found in context")
- }
- req, ok := v.(relaycommon.TaskSubmitReq)
- if !ok {
- return nil, fmt.Errorf("invalid request type in context")
- }
- // 支持openai sdk的图片上传方式
- if mf, err := c.MultipartForm(); err == nil {
- if files, exists := mf.File["input_reference"]; exists && len(files) > 0 {
- if len(files) == 1 {
- info.Action = constant.TaskActionGenerate
- } else if len(files) > 1 {
- info.Action = constant.TaskActionFirstTailGenerate
- }
- // 将上传的文件转换为base64格式
- var images []string
- for _, fileHeader := range files {
- // 检查文件大小
- if fileHeader.Size > MaxFileSize {
- return nil, fmt.Errorf("文件 %s 大小超过限制,最大允许 %d MB", fileHeader.Filename, MaxFileSize/(1024*1024))
- }
- file, err := fileHeader.Open()
- if err != nil {
- continue
- }
- fileBytes, err := io.ReadAll(file)
- file.Close()
- if err != nil {
- continue
- }
- // 将文件内容转换为base64
- base64Str := base64.StdEncoding.EncodeToString(fileBytes)
- images = append(images, base64Str)
- }
- req.Images = images
- }
- }
- body, err := a.convertToRequestPayload(&req)
- if err != nil {
- return nil, errors.Wrap(err, "convert request payload failed")
- }
- data, err := json.Marshal(body)
- if err != nil {
- return nil, err
- }
- return bytes.NewReader(data), nil
- }
- // DoRequest delegates to common helper.
- func (a *TaskAdaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
- return channel.DoTaskApiRequest(a, c, info, requestBody)
- }
- // DoResponse handles upstream response, returns taskID etc.
- func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (taskID string, taskData []byte, taskErr *dto.TaskError) {
- responseBody, err := io.ReadAll(resp.Body)
- if err != nil {
- taskErr = service.TaskErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError)
- return
- }
- _ = resp.Body.Close()
- // Parse Jimeng response
- var jResp responsePayload
- if err := json.Unmarshal(responseBody, &jResp); err != nil {
- taskErr = service.TaskErrorWrapper(errors.Wrapf(err, "body: %s", responseBody), "unmarshal_response_body_failed", http.StatusInternalServerError)
- return
- }
- if jResp.Code != 10000 {
- taskErr = service.TaskErrorWrapper(fmt.Errorf("%s", jResp.Message), fmt.Sprintf("%d", jResp.Code), http.StatusInternalServerError)
- return
- }
- ov := dto.NewOpenAIVideo()
- ov.ID = jResp.Data.TaskID
- ov.TaskID = jResp.Data.TaskID
- ov.CreatedAt = time.Now().Unix()
- ov.Model = info.OriginModelName
- c.JSON(http.StatusOK, ov)
- return jResp.Data.TaskID, responseBody, nil
- }
- // FetchTask fetch task status
- func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any, proxy string) (*http.Response, error) {
- taskID, ok := body["task_id"].(string)
- if !ok {
- return nil, fmt.Errorf("invalid task_id")
- }
- uri := fmt.Sprintf("%s/?Action=CVSync2AsyncGetResult&Version=2022-08-31", baseUrl)
- if isNewAPIRelay(key) {
- uri = fmt.Sprintf("%s/jimeng/?Action=CVSync2AsyncGetResult&Version=2022-08-31", a.baseURL)
- }
- payload := map[string]string{
- "req_key": "jimeng_vgfm_t2v_l20", // This is fixed value from doc: https://www.volcengine.com/docs/85621/1544774
- "task_id": taskID,
- }
- payloadBytes, err := json.Marshal(payload)
- if err != nil {
- return nil, errors.Wrap(err, "marshal fetch task payload failed")
- }
- req, err := http.NewRequest(http.MethodPost, uri, bytes.NewBuffer(payloadBytes))
- if err != nil {
- return nil, err
- }
- req.Header.Set("Accept", "application/json")
- req.Header.Set("Content-Type", "application/json")
- if isNewAPIRelay(key) {
- req.Header.Set("Authorization", "Bearer "+key)
- } else {
- keyParts := strings.Split(key, "|")
- if len(keyParts) != 2 {
- return nil, fmt.Errorf("invalid api key format for jimeng: expected 'ak|sk'")
- }
- accessKey := strings.TrimSpace(keyParts[0])
- secretKey := strings.TrimSpace(keyParts[1])
- if err := a.signRequest(req, accessKey, secretKey); err != nil {
- return nil, errors.Wrap(err, "sign request failed")
- }
- }
- client, err := service.GetHttpClientWithProxy(proxy)
- if err != nil {
- return nil, fmt.Errorf("new proxy http client failed: %w", err)
- }
- return client.Do(req)
- }
- func (a *TaskAdaptor) GetModelList() []string {
- return []string{"jimeng_vgfm_t2v_l20"}
- }
- func (a *TaskAdaptor) GetChannelName() string {
- return "jimeng"
- }
- func (a *TaskAdaptor) signRequest(req *http.Request, accessKey, secretKey string) error {
- var bodyBytes []byte
- var err error
- if req.Body != nil {
- bodyBytes, err = io.ReadAll(req.Body)
- if err != nil {
- return errors.Wrap(err, "read request body failed")
- }
- _ = req.Body.Close()
- req.Body = io.NopCloser(bytes.NewBuffer(bodyBytes)) // Rewind
- } else {
- bodyBytes = []byte{}
- }
- payloadHash := sha256.Sum256(bodyBytes)
- hexPayloadHash := hex.EncodeToString(payloadHash[:])
- t := time.Now().UTC()
- xDate := t.Format("20060102T150405Z")
- shortDate := t.Format("20060102")
- req.Header.Set("Host", req.URL.Host)
- req.Header.Set("X-Date", xDate)
- req.Header.Set("X-Content-Sha256", hexPayloadHash)
- // Sort and encode query parameters to create canonical query string
- queryParams := req.URL.Query()
- sortedKeys := make([]string, 0, len(queryParams))
- for k := range queryParams {
- sortedKeys = append(sortedKeys, k)
- }
- sort.Strings(sortedKeys)
- var queryParts []string
- for _, k := range sortedKeys {
- values := queryParams[k]
- sort.Strings(values)
- for _, v := range values {
- queryParts = append(queryParts, fmt.Sprintf("%s=%s", url.QueryEscape(k), url.QueryEscape(v)))
- }
- }
- canonicalQueryString := strings.Join(queryParts, "&")
- headersToSign := map[string]string{
- "host": req.URL.Host,
- "x-date": xDate,
- "x-content-sha256": hexPayloadHash,
- }
- if req.Header.Get("Content-Type") != "" {
- headersToSign["content-type"] = req.Header.Get("Content-Type")
- }
- var signedHeaderKeys []string
- for k := range headersToSign {
- signedHeaderKeys = append(signedHeaderKeys, k)
- }
- sort.Strings(signedHeaderKeys)
- var canonicalHeaders strings.Builder
- for _, k := range signedHeaderKeys {
- canonicalHeaders.WriteString(k)
- canonicalHeaders.WriteString(":")
- canonicalHeaders.WriteString(strings.TrimSpace(headersToSign[k]))
- canonicalHeaders.WriteString("\n")
- }
- signedHeaders := strings.Join(signedHeaderKeys, ";")
- canonicalRequest := fmt.Sprintf("%s\n%s\n%s\n%s\n%s\n%s",
- req.Method,
- req.URL.Path,
- canonicalQueryString,
- canonicalHeaders.String(),
- signedHeaders,
- hexPayloadHash,
- )
- hashedCanonicalRequest := sha256.Sum256([]byte(canonicalRequest))
- hexHashedCanonicalRequest := hex.EncodeToString(hashedCanonicalRequest[:])
- region := "cn-north-1"
- serviceName := "cv"
- credentialScope := fmt.Sprintf("%s/%s/%s/request", shortDate, region, serviceName)
- stringToSign := fmt.Sprintf("HMAC-SHA256\n%s\n%s\n%s",
- xDate,
- credentialScope,
- hexHashedCanonicalRequest,
- )
- kDate := hmacSHA256([]byte(secretKey), []byte(shortDate))
- kRegion := hmacSHA256(kDate, []byte(region))
- kService := hmacSHA256(kRegion, []byte(serviceName))
- kSigning := hmacSHA256(kService, []byte("request"))
- signature := hex.EncodeToString(hmacSHA256(kSigning, []byte(stringToSign)))
- authorization := fmt.Sprintf("HMAC-SHA256 Credential=%s/%s, SignedHeaders=%s, Signature=%s",
- accessKey,
- credentialScope,
- signedHeaders,
- signature,
- )
- req.Header.Set("Authorization", authorization)
- return nil
- }
- func hmacSHA256(key []byte, data []byte) []byte {
- h := hmac.New(sha256.New, key)
- h.Write(data)
- return h.Sum(nil)
- }
- func (a *TaskAdaptor) convertToRequestPayload(req *relaycommon.TaskSubmitReq) (*requestPayload, error) {
- r := requestPayload{
- ReqKey: req.Model,
- Prompt: req.Prompt,
- }
- switch req.Duration {
- case 10:
- r.Frames = 241 // 24*10+1 = 241
- default:
- r.Frames = 121 // 24*5+1 = 121
- }
- // Handle one-of image_urls or binary_data_base64
- if req.HasImage() {
- if strings.HasPrefix(req.Images[0], "http") {
- r.ImageUrls = req.Images
- } else {
- r.BinaryDataBase64 = req.Images
- }
- }
- metadata := req.Metadata
- medaBytes, err := json.Marshal(metadata)
- if err != nil {
- return nil, errors.Wrap(err, "metadata marshal metadata failed")
- }
- err = json.Unmarshal(medaBytes, &r)
- if err != nil {
- return nil, errors.Wrap(err, "unmarshal metadata failed")
- }
- // 即梦视频3.0 ReqKey转换
- // https://www.volcengine.com/docs/85621/1792707
- if strings.Contains(r.ReqKey, "jimeng_v30") {
- if r.ReqKey == "jimeng_v30_pro" {
- // 3.0 pro只有固定的jimeng_ti2v_v30_pro
- r.ReqKey = "jimeng_ti2v_v30_pro"
- } else if len(req.Images) > 1 {
- // 多张图片:首尾帧生成
- r.ReqKey = strings.TrimSuffix(strings.Replace(r.ReqKey, "jimeng_v30", "jimeng_i2v_first_tail_v30", 1), "p")
- } else if len(req.Images) == 1 {
- // 单张图片:图生视频
- r.ReqKey = strings.TrimSuffix(strings.Replace(r.ReqKey, "jimeng_v30", "jimeng_i2v_first_v30", 1), "p")
- } else {
- // 无图片:文生视频
- r.ReqKey = strings.Replace(r.ReqKey, "jimeng_v30", "jimeng_t2v_v30", 1)
- }
- }
- return &r, nil
- }
- func (a *TaskAdaptor) ParseTaskResult(respBody []byte) (*relaycommon.TaskInfo, error) {
- resTask := responseTask{}
- if err := json.Unmarshal(respBody, &resTask); err != nil {
- return nil, errors.Wrap(err, "unmarshal task result failed")
- }
- taskResult := relaycommon.TaskInfo{}
- if resTask.Code == 10000 {
- taskResult.Code = 0
- } else {
- taskResult.Code = resTask.Code // todo uni code
- taskResult.Reason = resTask.Message
- taskResult.Status = model.TaskStatusFailure
- taskResult.Progress = "100%"
- }
- switch resTask.Data.Status {
- case "in_queue":
- taskResult.Status = model.TaskStatusQueued
- taskResult.Progress = "10%"
- case "done":
- taskResult.Status = model.TaskStatusSuccess
- taskResult.Progress = "100%"
- }
- taskResult.Url = resTask.Data.VideoUrl
- return &taskResult, nil
- }
- func (a *TaskAdaptor) ConvertToOpenAIVideo(originTask *model.Task) ([]byte, error) {
- var jimengResp responseTask
- if err := json.Unmarshal(originTask.Data, &jimengResp); err != nil {
- return nil, errors.Wrap(err, "unmarshal jimeng task data failed")
- }
- openAIVideo := dto.NewOpenAIVideo()
- openAIVideo.ID = originTask.TaskID
- openAIVideo.Status = originTask.Status.ToVideoStatus()
- openAIVideo.SetProgressStr(originTask.Progress)
- openAIVideo.SetMetadata("url", jimengResp.Data.VideoUrl)
- openAIVideo.CreatedAt = originTask.CreatedAt
- openAIVideo.CompletedAt = originTask.UpdatedAt
- if jimengResp.Code != 10000 {
- openAIVideo.Error = &dto.OpenAIVideoError{
- Message: jimengResp.Message,
- Code: fmt.Sprintf("%d", jimengResp.Code),
- }
- }
- jsonData, _ := common.Marshal(openAIVideo)
- return jsonData, nil
- }
- func isNewAPIRelay(apiKey string) bool {
- return strings.HasPrefix(apiKey, "sk-")
- }
|