|
@@ -0,0 +1,379 @@
|
|
|
|
|
+package jimeng
|
|
|
|
|
+
|
|
|
|
|
+import (
|
|
|
|
|
+ "bytes"
|
|
|
|
|
+ "crypto/hmac"
|
|
|
|
|
+ "crypto/sha256"
|
|
|
|
|
+ "encoding/hex"
|
|
|
|
|
+ "encoding/json"
|
|
|
|
|
+ "fmt"
|
|
|
|
|
+ "io"
|
|
|
|
|
+ "net/http"
|
|
|
|
|
+ "net/url"
|
|
|
|
|
+ "one-api/model"
|
|
|
|
|
+ "sort"
|
|
|
|
|
+ "strings"
|
|
|
|
|
+ "time"
|
|
|
|
|
+
|
|
|
|
|
+ "github.com/gin-gonic/gin"
|
|
|
|
|
+ "github.com/pkg/errors"
|
|
|
|
|
+
|
|
|
|
|
+ "one-api/common"
|
|
|
|
|
+ "one-api/dto"
|
|
|
|
|
+ "one-api/relay/channel"
|
|
|
|
|
+ relaycommon "one-api/relay/common"
|
|
|
|
|
+ "one-api/service"
|
|
|
|
|
+)
|
|
|
|
|
+
|
|
|
|
|
+// ============================
|
|
|
|
|
+// Request / Response structures
|
|
|
|
|
+// ============================
|
|
|
|
|
+
|
|
|
|
|
+type requestPayload struct {
|
|
|
|
|
+ ReqKey string `json:"req_key"`
|
|
|
|
|
+ BinaryDataBase64 []string `json:"binary_data_base64,omitempty"`
|
|
|
|
|
+ ImageUrls []string `json:"image_urls,omitempty"`
|
|
|
|
|
+ Prompt string `json:"prompt,omitempty"`
|
|
|
|
|
+ Seed int64 `json:"seed"`
|
|
|
|
|
+ AspectRatio string `json:"aspect_ratio"`
|
|
|
|
|
+}
|
|
|
|
|
+
|
|
|
|
|
+type responsePayload struct {
|
|
|
|
|
+ Code int `json:"code"`
|
|
|
|
|
+ Message string `json:"message"`
|
|
|
|
|
+ RequestId string `json:"request_id"`
|
|
|
|
|
+ Data struct {
|
|
|
|
|
+ TaskID string `json:"task_id"`
|
|
|
|
|
+ } `json:"data"`
|
|
|
|
|
+}
|
|
|
|
|
+
|
|
|
|
|
+type responseTask struct {
|
|
|
|
|
+ Code int `json:"code"`
|
|
|
|
|
+ Data struct {
|
|
|
|
|
+ BinaryDataBase64 []interface{} `json:"binary_data_base64"`
|
|
|
|
|
+ ImageUrls interface{} `json:"image_urls"`
|
|
|
|
|
+ RespData string `json:"resp_data"`
|
|
|
|
|
+ Status string `json:"status"`
|
|
|
|
|
+ VideoUrl string `json:"video_url"`
|
|
|
|
|
+ } `json:"data"`
|
|
|
|
|
+ Message string `json:"message"`
|
|
|
|
|
+ RequestId string `json:"request_id"`
|
|
|
|
|
+ Status int `json:"status"`
|
|
|
|
|
+ TimeElapsed string `json:"time_elapsed"`
|
|
|
|
|
+}
|
|
|
|
|
+
|
|
|
|
|
+// ============================
|
|
|
|
|
+// Adaptor implementation
|
|
|
|
|
+// ============================
|
|
|
|
|
+
|
|
|
|
|
+type TaskAdaptor struct {
|
|
|
|
|
+ ChannelType int
|
|
|
|
|
+ accessKey string
|
|
|
|
|
+ secretKey string
|
|
|
|
|
+ baseURL string
|
|
|
|
|
+}
|
|
|
|
|
+
|
|
|
|
|
+func (a *TaskAdaptor) Init(info *relaycommon.TaskRelayInfo) {
|
|
|
|
|
+ a.ChannelType = info.ChannelType
|
|
|
|
|
+ a.baseURL = info.BaseUrl
|
|
|
|
|
+
|
|
|
|
|
+ // apiKey format: "access_key,secret_key"
|
|
|
|
|
+ keyParts := strings.Split(info.ApiKey, ",")
|
|
|
|
|
+ if len(keyParts) == 2 {
|
|
|
|
|
+ a.accessKey = strings.TrimSpace(keyParts[0])
|
|
|
|
|
+ a.secretKey = strings.TrimSpace(keyParts[1])
|
|
|
|
|
+ }
|
|
|
|
|
+}
|
|
|
|
|
+
|
|
|
|
|
+// ValidateRequestAndSetAction parses body, validates fields and sets default action.
|
|
|
|
|
+func (a *TaskAdaptor) ValidateRequestAndSetAction(c *gin.Context, info *relaycommon.TaskRelayInfo) (taskErr *dto.TaskError) {
|
|
|
|
|
+ // Accept only POST /v1/video/generations as "generate" action.
|
|
|
|
|
+ action := "generate"
|
|
|
|
|
+ info.Action = action
|
|
|
|
|
+
|
|
|
|
|
+ req := relaycommon.TaskSubmitReq{}
|
|
|
|
|
+ if err := common.UnmarshalBodyReusable(c, &req); err != nil {
|
|
|
|
|
+ taskErr = service.TaskErrorWrapperLocal(err, "invalid_request", http.StatusBadRequest)
|
|
|
|
|
+ return
|
|
|
|
|
+ }
|
|
|
|
|
+ if strings.TrimSpace(req.Prompt) == "" {
|
|
|
|
|
+ taskErr = service.TaskErrorWrapperLocal(fmt.Errorf("prompt is required"), "invalid_request", http.StatusBadRequest)
|
|
|
|
|
+ return
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ // Store into context for later usage
|
|
|
|
|
+ c.Set("task_request", req)
|
|
|
|
|
+ return nil
|
|
|
|
|
+}
|
|
|
|
|
+
|
|
|
|
|
+// BuildRequestURL constructs the upstream URL.
|
|
|
|
|
+func (a *TaskAdaptor) BuildRequestURL(info *relaycommon.TaskRelayInfo) (string, error) {
|
|
|
|
|
+ return fmt.Sprintf("%s/?Action=CVSync2AsyncSubmitTask&Version=2022-08-31", a.baseURL), nil
|
|
|
|
|
+}
|
|
|
|
|
+
|
|
|
|
|
+// BuildRequestHeader sets required headers.
|
|
|
|
|
+func (a *TaskAdaptor) BuildRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.TaskRelayInfo) error {
|
|
|
|
|
+ req.Header.Set("Content-Type", "application/json")
|
|
|
|
|
+ req.Header.Set("Accept", "application/json")
|
|
|
|
|
+ return a.signRequest(req, a.accessKey, a.secretKey)
|
|
|
|
|
+}
|
|
|
|
|
+
|
|
|
|
|
+// BuildRequestBody converts request into Jimeng specific format.
|
|
|
|
|
+func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, info *relaycommon.TaskRelayInfo) (io.Reader, error) {
|
|
|
|
|
+ v, exists := c.Get("task_request")
|
|
|
|
|
+ if !exists {
|
|
|
|
|
+ return nil, fmt.Errorf("request not found in context")
|
|
|
|
|
+ }
|
|
|
|
|
+ req := v.(relaycommon.TaskSubmitReq)
|
|
|
|
|
+
|
|
|
|
|
+ body, err := a.convertToRequestPayload(&req)
|
|
|
|
|
+ if err != nil {
|
|
|
|
|
+ return nil, errors.Wrap(err, "convert request payload failed")
|
|
|
|
|
+ }
|
|
|
|
|
+ data, err := json.Marshal(body)
|
|
|
|
|
+ if err != nil {
|
|
|
|
|
+ return nil, err
|
|
|
|
|
+ }
|
|
|
|
|
+ return bytes.NewReader(data), nil
|
|
|
|
|
+}
|
|
|
|
|
+
|
|
|
|
|
+// DoRequest delegates to common helper.
|
|
|
|
|
+func (a *TaskAdaptor) DoRequest(c *gin.Context, info *relaycommon.TaskRelayInfo, requestBody io.Reader) (*http.Response, error) {
|
|
|
|
|
+ return channel.DoTaskApiRequest(a, c, info, requestBody)
|
|
|
|
|
+}
|
|
|
|
|
+
|
|
|
|
|
+// DoResponse handles upstream response, returns taskID etc.
|
|
|
|
|
+func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.TaskRelayInfo) (taskID string, taskData []byte, taskErr *dto.TaskError) {
|
|
|
|
|
+ responseBody, err := io.ReadAll(resp.Body)
|
|
|
|
|
+ if err != nil {
|
|
|
|
|
+ taskErr = service.TaskErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError)
|
|
|
|
|
+ return
|
|
|
|
|
+ }
|
|
|
|
|
+ _ = resp.Body.Close()
|
|
|
|
|
+
|
|
|
|
|
+ // Parse Jimeng response
|
|
|
|
|
+ var jResp responsePayload
|
|
|
|
|
+ if err := json.Unmarshal(responseBody, &jResp); err != nil {
|
|
|
|
|
+ taskErr = service.TaskErrorWrapper(errors.Wrapf(err, "body: %s", responseBody), "unmarshal_response_body_failed", http.StatusInternalServerError)
|
|
|
|
|
+ return
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ if jResp.Code != 10000 {
|
|
|
|
|
+ taskErr = service.TaskErrorWrapper(fmt.Errorf(jResp.Message), fmt.Sprintf("%d", jResp.Code), http.StatusInternalServerError)
|
|
|
|
|
+ return
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ c.JSON(http.StatusOK, gin.H{"task_id": jResp.Data.TaskID})
|
|
|
|
|
+ return jResp.Data.TaskID, responseBody, nil
|
|
|
|
|
+}
|
|
|
|
|
+
|
|
|
|
|
+// FetchTask fetch task status
|
|
|
|
|
+func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any) (*http.Response, error) {
|
|
|
|
|
+ taskID, ok := body["task_id"].(string)
|
|
|
|
|
+ if !ok {
|
|
|
|
|
+ return nil, fmt.Errorf("invalid task_id")
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ uri := fmt.Sprintf("%s/?Action=CVSync2AsyncGetResult&Version=2022-08-31", baseUrl)
|
|
|
|
|
+ payload := map[string]string{
|
|
|
|
|
+ "req_key": "jimeng_vgfm_t2v_l20", // This is fixed value from doc: https://www.volcengine.com/docs/85621/1544774
|
|
|
|
|
+ "task_id": taskID,
|
|
|
|
|
+ }
|
|
|
|
|
+ payloadBytes, err := json.Marshal(payload)
|
|
|
|
|
+ if err != nil {
|
|
|
|
|
+ return nil, errors.Wrap(err, "marshal fetch task payload failed")
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ req, err := http.NewRequest(http.MethodPost, uri, bytes.NewBuffer(payloadBytes))
|
|
|
|
|
+ if err != nil {
|
|
|
|
|
+ return nil, err
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ req.Header.Set("Accept", "application/json")
|
|
|
|
|
+ req.Header.Set("Content-Type", "application/json")
|
|
|
|
|
+
|
|
|
|
|
+ keyParts := strings.Split(key, ",")
|
|
|
|
|
+ if len(keyParts) != 2 {
|
|
|
|
|
+ return nil, fmt.Errorf("invalid api key format for jimeng: expected 'ak,sk'")
|
|
|
|
|
+ }
|
|
|
|
|
+ accessKey := strings.TrimSpace(keyParts[0])
|
|
|
|
|
+ secretKey := strings.TrimSpace(keyParts[1])
|
|
|
|
|
+
|
|
|
|
|
+ if err := a.signRequest(req, accessKey, secretKey); err != nil {
|
|
|
|
|
+ return nil, errors.Wrap(err, "sign request failed")
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ return service.GetHttpClient().Do(req)
|
|
|
|
|
+}
|
|
|
|
|
+
|
|
|
|
|
+func (a *TaskAdaptor) GetModelList() []string {
|
|
|
|
|
+ return []string{"jimeng_vgfm_t2v_l20"}
|
|
|
|
|
+}
|
|
|
|
|
+
|
|
|
|
|
+func (a *TaskAdaptor) GetChannelName() string {
|
|
|
|
|
+ return "jimeng"
|
|
|
|
|
+}
|
|
|
|
|
+
|
|
|
|
|
+func (a *TaskAdaptor) signRequest(req *http.Request, accessKey, secretKey string) error {
|
|
|
|
|
+ var bodyBytes []byte
|
|
|
|
|
+ var err error
|
|
|
|
|
+
|
|
|
|
|
+ if req.Body != nil {
|
|
|
|
|
+ bodyBytes, err = io.ReadAll(req.Body)
|
|
|
|
|
+ if err != nil {
|
|
|
|
|
+ return errors.Wrap(err, "read request body failed")
|
|
|
|
|
+ }
|
|
|
|
|
+ _ = req.Body.Close()
|
|
|
|
|
+ req.Body = io.NopCloser(bytes.NewBuffer(bodyBytes)) // Rewind
|
|
|
|
|
+ } else {
|
|
|
|
|
+ bodyBytes = []byte{}
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ payloadHash := sha256.Sum256(bodyBytes)
|
|
|
|
|
+ hexPayloadHash := hex.EncodeToString(payloadHash[:])
|
|
|
|
|
+
|
|
|
|
|
+ t := time.Now().UTC()
|
|
|
|
|
+ xDate := t.Format("20060102T150405Z")
|
|
|
|
|
+ shortDate := t.Format("20060102")
|
|
|
|
|
+
|
|
|
|
|
+ req.Header.Set("Host", req.URL.Host)
|
|
|
|
|
+ req.Header.Set("X-Date", xDate)
|
|
|
|
|
+ req.Header.Set("X-Content-Sha256", hexPayloadHash)
|
|
|
|
|
+
|
|
|
|
|
+ // Sort and encode query parameters to create canonical query string
|
|
|
|
|
+ queryParams := req.URL.Query()
|
|
|
|
|
+ sortedKeys := make([]string, 0, len(queryParams))
|
|
|
|
|
+ for k := range queryParams {
|
|
|
|
|
+ sortedKeys = append(sortedKeys, k)
|
|
|
|
|
+ }
|
|
|
|
|
+ sort.Strings(sortedKeys)
|
|
|
|
|
+ var queryParts []string
|
|
|
|
|
+ for _, k := range sortedKeys {
|
|
|
|
|
+ values := queryParams[k]
|
|
|
|
|
+ sort.Strings(values)
|
|
|
|
|
+ for _, v := range values {
|
|
|
|
|
+ queryParts = append(queryParts, fmt.Sprintf("%s=%s", url.QueryEscape(k), url.QueryEscape(v)))
|
|
|
|
|
+ }
|
|
|
|
|
+ }
|
|
|
|
|
+ canonicalQueryString := strings.Join(queryParts, "&")
|
|
|
|
|
+
|
|
|
|
|
+ headersToSign := map[string]string{
|
|
|
|
|
+ "host": req.URL.Host,
|
|
|
|
|
+ "x-date": xDate,
|
|
|
|
|
+ "x-content-sha256": hexPayloadHash,
|
|
|
|
|
+ }
|
|
|
|
|
+ if req.Header.Get("Content-Type") != "" {
|
|
|
|
|
+ headersToSign["content-type"] = req.Header.Get("Content-Type")
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ var signedHeaderKeys []string
|
|
|
|
|
+ for k := range headersToSign {
|
|
|
|
|
+ signedHeaderKeys = append(signedHeaderKeys, k)
|
|
|
|
|
+ }
|
|
|
|
|
+ sort.Strings(signedHeaderKeys)
|
|
|
|
|
+
|
|
|
|
|
+ var canonicalHeaders strings.Builder
|
|
|
|
|
+ for _, k := range signedHeaderKeys {
|
|
|
|
|
+ canonicalHeaders.WriteString(k)
|
|
|
|
|
+ canonicalHeaders.WriteString(":")
|
|
|
|
|
+ canonicalHeaders.WriteString(strings.TrimSpace(headersToSign[k]))
|
|
|
|
|
+ canonicalHeaders.WriteString("\n")
|
|
|
|
|
+ }
|
|
|
|
|
+ signedHeaders := strings.Join(signedHeaderKeys, ";")
|
|
|
|
|
+
|
|
|
|
|
+ canonicalRequest := fmt.Sprintf("%s\n%s\n%s\n%s\n%s\n%s",
|
|
|
|
|
+ req.Method,
|
|
|
|
|
+ req.URL.Path,
|
|
|
|
|
+ canonicalQueryString,
|
|
|
|
|
+ canonicalHeaders.String(),
|
|
|
|
|
+ signedHeaders,
|
|
|
|
|
+ hexPayloadHash,
|
|
|
|
|
+ )
|
|
|
|
|
+
|
|
|
|
|
+ hashedCanonicalRequest := sha256.Sum256([]byte(canonicalRequest))
|
|
|
|
|
+ hexHashedCanonicalRequest := hex.EncodeToString(hashedCanonicalRequest[:])
|
|
|
|
|
+
|
|
|
|
|
+ region := "cn-north-1"
|
|
|
|
|
+ serviceName := "cv"
|
|
|
|
|
+ credentialScope := fmt.Sprintf("%s/%s/%s/request", shortDate, region, serviceName)
|
|
|
|
|
+ stringToSign := fmt.Sprintf("HMAC-SHA256\n%s\n%s\n%s",
|
|
|
|
|
+ xDate,
|
|
|
|
|
+ credentialScope,
|
|
|
|
|
+ hexHashedCanonicalRequest,
|
|
|
|
|
+ )
|
|
|
|
|
+
|
|
|
|
|
+ kDate := hmacSHA256([]byte(secretKey), []byte(shortDate))
|
|
|
|
|
+ kRegion := hmacSHA256(kDate, []byte(region))
|
|
|
|
|
+ kService := hmacSHA256(kRegion, []byte(serviceName))
|
|
|
|
|
+ kSigning := hmacSHA256(kService, []byte("request"))
|
|
|
|
|
+ signature := hex.EncodeToString(hmacSHA256(kSigning, []byte(stringToSign)))
|
|
|
|
|
+
|
|
|
|
|
+ authorization := fmt.Sprintf("HMAC-SHA256 Credential=%s/%s, SignedHeaders=%s, Signature=%s",
|
|
|
|
|
+ accessKey,
|
|
|
|
|
+ credentialScope,
|
|
|
|
|
+ signedHeaders,
|
|
|
|
|
+ signature,
|
|
|
|
|
+ )
|
|
|
|
|
+ req.Header.Set("Authorization", authorization)
|
|
|
|
|
+ return nil
|
|
|
|
|
+}
|
|
|
|
|
+
|
|
|
|
|
+func hmacSHA256(key []byte, data []byte) []byte {
|
|
|
|
|
+ h := hmac.New(sha256.New, key)
|
|
|
|
|
+ h.Write(data)
|
|
|
|
|
+ return h.Sum(nil)
|
|
|
|
|
+}
|
|
|
|
|
+
|
|
|
|
|
+func (a *TaskAdaptor) convertToRequestPayload(req *relaycommon.TaskSubmitReq) (*requestPayload, error) {
|
|
|
|
|
+ r := requestPayload{
|
|
|
|
|
+ ReqKey: "jimeng_vgfm_i2v_l20",
|
|
|
|
|
+ Prompt: req.Prompt,
|
|
|
|
|
+ AspectRatio: "16:9", // Default aspect ratio
|
|
|
|
|
+ Seed: -1, // Default to random
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ // Handle one-of image_urls or binary_data_base64
|
|
|
|
|
+ if req.Image != "" {
|
|
|
|
|
+ if strings.HasPrefix(req.Image, "http") {
|
|
|
|
|
+ r.ImageUrls = []string{req.Image}
|
|
|
|
|
+ } else {
|
|
|
|
|
+ r.BinaryDataBase64 = []string{req.Image}
|
|
|
|
|
+ }
|
|
|
|
|
+ }
|
|
|
|
|
+ metadata := req.Metadata
|
|
|
|
|
+ medaBytes, err := json.Marshal(metadata)
|
|
|
|
|
+ if err != nil {
|
|
|
|
|
+ return nil, errors.Wrap(err, "metadata marshal metadata failed")
|
|
|
|
|
+ }
|
|
|
|
|
+ err = json.Unmarshal(medaBytes, &r)
|
|
|
|
|
+ if err != nil {
|
|
|
|
|
+ return nil, errors.Wrap(err, "unmarshal metadata failed")
|
|
|
|
|
+ }
|
|
|
|
|
+ return &r, nil
|
|
|
|
|
+}
|
|
|
|
|
+
|
|
|
|
|
+func (a *TaskAdaptor) ParseTaskResult(respBody []byte) (*relaycommon.TaskInfo, error) {
|
|
|
|
|
+ resTask := responseTask{}
|
|
|
|
|
+ if err := json.Unmarshal(respBody, &resTask); err != nil {
|
|
|
|
|
+ return nil, errors.Wrap(err, "unmarshal task result failed")
|
|
|
|
|
+ }
|
|
|
|
|
+ taskResult := relaycommon.TaskInfo{}
|
|
|
|
|
+ if resTask.Code == 10000 {
|
|
|
|
|
+ taskResult.Code = 0
|
|
|
|
|
+ } else {
|
|
|
|
|
+ taskResult.Code = resTask.Code // todo uni code
|
|
|
|
|
+ taskResult.Reason = resTask.Message
|
|
|
|
|
+ taskResult.Status = model.TaskStatusFailure
|
|
|
|
|
+ taskResult.Progress = "100%"
|
|
|
|
|
+ }
|
|
|
|
|
+ switch resTask.Data.Status {
|
|
|
|
|
+ case "in_queue":
|
|
|
|
|
+ taskResult.Status = model.TaskStatusQueued
|
|
|
|
|
+ taskResult.Progress = "10%"
|
|
|
|
|
+ case "done":
|
|
|
|
|
+ taskResult.Status = model.TaskStatusSuccess
|
|
|
|
|
+ taskResult.Progress = "100%"
|
|
|
|
|
+ }
|
|
|
|
|
+ taskResult.Url = resTask.Data.VideoUrl
|
|
|
|
|
+ return &taskResult, nil
|
|
|
|
|
+}
|