Browse Source

feat: add vidu video channel

feitianbubu 7 months ago
parent
commit
352da66bd1

+ 2 - 0
constant/channel.go

@@ -49,6 +49,7 @@ const (
 	ChannelTypeCoze           = 49
 	ChannelTypeKling          = 50
 	ChannelTypeJimeng         = 51
+	ChannelTypeVidu           = 52
 	ChannelTypeDummy          // this one is only for count, do not add any channel after this
 
 )
@@ -106,4 +107,5 @@ var ChannelBaseURLs = []string{
 	"https://api.coze.cn",                       //49
 	"https://api.klingai.com",                   //50
 	"https://visual.volcengineapi.com",          //51
+	"https://api.vidu.cn",                       //52
 }

+ 6 - 0
controller/channel-test.go

@@ -69,6 +69,12 @@ func testChannel(channel *model.Channel, testModel string) testResult {
 			newAPIError: nil,
 		}
 	}
+	if channel.Type == constant.ChannelTypeVidu {
+		return testResult{
+			localErr:    errors.New("vidu channel test is not supported"),
+			newAPIError: nil,
+		}
+	}
 	w := httptest.NewRecorder()
 	c, _ := gin.CreateTestContext(w)
 

+ 1 - 1
controller/task_video.go

@@ -83,7 +83,7 @@ func updateVideoSingleTask(ctx context.Context, adaptor channel.TaskAdaptor, cha
 	taskResult := &relaycommon.TaskInfo{}
 	// try parse as New API response format
 	var responseItems dto.TaskResponse[model.Task]
-	if err = json.Unmarshal(responseBody, &responseItems); err == nil {
+	if err = json.Unmarshal(responseBody, &responseItems); err == nil && responseItems.IsSuccess() {
 		t := responseItems.Data
 		taskResult.TaskID = t.TaskID
 		taskResult.Status = string(t.Status)

+ 285 - 0
relay/channel/task/vidu/adaptor.go

@@ -0,0 +1,285 @@
+package vidu
+
+import (
+	"bytes"
+	"encoding/json"
+	"fmt"
+	"io"
+	"net/http"
+
+	"github.com/gin-gonic/gin"
+
+	"one-api/constant"
+	"one-api/dto"
+	"one-api/model"
+	"one-api/relay/channel"
+	relaycommon "one-api/relay/common"
+	"one-api/service"
+
+	"github.com/pkg/errors"
+)
+
+// ============================
+// Request / Response structures
+// ============================
+
+type SubmitReq struct {
+	Prompt   string                 `json:"prompt"`
+	Model    string                 `json:"model,omitempty"`
+	Mode     string                 `json:"mode,omitempty"`
+	Image    string                 `json:"image,omitempty"`
+	Size     string                 `json:"size,omitempty"`
+	Duration int                    `json:"duration,omitempty"`
+	Metadata map[string]interface{} `json:"metadata,omitempty"`
+}
+
+type requestPayload struct {
+	Model             string   `json:"model"`
+	Images            []string `json:"images"`
+	Prompt            string   `json:"prompt,omitempty"`
+	Duration          int      `json:"duration,omitempty"`
+	Seed              int      `json:"seed,omitempty"`
+	Resolution        string   `json:"resolution,omitempty"`
+	MovementAmplitude string   `json:"movement_amplitude,omitempty"`
+	Bgm               bool     `json:"bgm,omitempty"`
+	Payload           string   `json:"payload,omitempty"`
+	CallbackUrl       string   `json:"callback_url,omitempty"`
+}
+
+type responsePayload struct {
+	TaskId            string   `json:"task_id"`
+	State             string   `json:"state"`
+	Model             string   `json:"model"`
+	Images            []string `json:"images"`
+	Prompt            string   `json:"prompt"`
+	Duration          int      `json:"duration"`
+	Seed              int      `json:"seed"`
+	Resolution        string   `json:"resolution"`
+	Bgm               bool     `json:"bgm"`
+	MovementAmplitude string   `json:"movement_amplitude"`
+	Payload           string   `json:"payload"`
+	CreatedAt         string   `json:"created_at"`
+}
+
+type taskResultResponse struct {
+	State     string     `json:"state"`
+	ErrCode   string     `json:"err_code"`
+	Credits   int        `json:"credits"`
+	Payload   string     `json:"payload"`
+	Creations []creation `json:"creations"`
+}
+
+type creation struct {
+	ID       string `json:"id"`
+	URL      string `json:"url"`
+	CoverURL string `json:"cover_url"`
+}
+
+// ============================
+// Adaptor implementation
+// ============================
+
+type TaskAdaptor struct {
+	ChannelType int
+	baseURL     string
+}
+
+func (a *TaskAdaptor) Init(info *relaycommon.TaskRelayInfo) {
+	a.ChannelType = info.ChannelType
+	a.baseURL = info.BaseUrl
+}
+
+func (a *TaskAdaptor) ValidateRequestAndSetAction(c *gin.Context, info *relaycommon.TaskRelayInfo) *dto.TaskError {
+	var req SubmitReq
+	if err := c.ShouldBindJSON(&req); err != nil {
+		return service.TaskErrorWrapper(err, "invalid_request_body", http.StatusBadRequest)
+	}
+
+	if req.Prompt == "" {
+		return service.TaskErrorWrapperLocal(fmt.Errorf("prompt is required"), "missing_prompt", http.StatusBadRequest)
+	}
+
+	if req.Image != "" {
+		info.Action = constant.TaskActionGenerate
+	} else {
+		info.Action = constant.TaskActionTextGenerate
+	}
+
+	c.Set("task_request", req)
+	return nil
+}
+
+func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, _ *relaycommon.TaskRelayInfo) (io.Reader, error) {
+	v, exists := c.Get("task_request")
+	if !exists {
+		return nil, fmt.Errorf("request not found in context")
+	}
+	req := v.(SubmitReq)
+
+	body, err := a.convertToRequestPayload(&req)
+	if err != nil {
+		return nil, err
+	}
+
+	if len(body.Images) == 0 {
+		c.Set("action", constant.TaskActionTextGenerate)
+	}
+
+	data, err := json.Marshal(body)
+	if err != nil {
+		return nil, err
+	}
+	return bytes.NewReader(data), nil
+}
+
+func (a *TaskAdaptor) BuildRequestURL(info *relaycommon.TaskRelayInfo) (string, error) {
+	var path string
+	switch info.Action {
+	case constant.TaskActionGenerate:
+		path = "/img2video"
+	default:
+		path = "/text2video"
+	}
+	return fmt.Sprintf("%s/ent/v2%s", a.baseURL, path), nil
+}
+
+func (a *TaskAdaptor) BuildRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.TaskRelayInfo) error {
+	req.Header.Set("Content-Type", "application/json")
+	req.Header.Set("Accept", "application/json")
+	req.Header.Set("Authorization", "Token "+info.ApiKey)
+	return nil
+}
+
+func (a *TaskAdaptor) DoRequest(c *gin.Context, info *relaycommon.TaskRelayInfo, requestBody io.Reader) (*http.Response, error) {
+	if action := c.GetString("action"); action != "" {
+		info.Action = action
+	}
+	return channel.DoTaskApiRequest(a, c, info, requestBody)
+}
+
+func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, _ *relaycommon.TaskRelayInfo) (taskID string, taskData []byte, taskErr *dto.TaskError) {
+	responseBody, err := io.ReadAll(resp.Body)
+	if err != nil {
+		taskErr = service.TaskErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError)
+		return
+	}
+
+	var vResp responsePayload
+	err = json.Unmarshal(responseBody, &vResp)
+	if err != nil {
+		taskErr = service.TaskErrorWrapper(errors.Wrap(err, fmt.Sprintf("%s", responseBody)), "unmarshal_response_failed", http.StatusInternalServerError)
+		return
+	}
+
+	if vResp.State == "failed" {
+		taskErr = service.TaskErrorWrapperLocal(fmt.Errorf("task failed"), "task_failed", http.StatusBadRequest)
+		return
+	}
+
+	c.JSON(http.StatusOK, vResp)
+	return vResp.TaskId, responseBody, nil
+}
+
+func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any) (*http.Response, error) {
+	taskID, ok := body["task_id"].(string)
+	if !ok {
+		return nil, fmt.Errorf("invalid task_id")
+	}
+
+	url := fmt.Sprintf("%s/ent/v2/tasks/%s/creations", baseUrl, taskID)
+
+	req, err := http.NewRequest(http.MethodGet, url, nil)
+	if err != nil {
+		return nil, err
+	}
+
+	req.Header.Set("Accept", "application/json")
+	req.Header.Set("Authorization", "Token "+key)
+
+	return service.GetHttpClient().Do(req)
+}
+
+func (a *TaskAdaptor) GetModelList() []string {
+	return []string{"viduq1", "vidu2.0", "vidu1.5"}
+}
+
+func (a *TaskAdaptor) GetChannelName() string {
+	return "vidu"
+}
+
+// ============================
+// helpers
+// ============================
+
+func (a *TaskAdaptor) convertToRequestPayload(req *SubmitReq) (*requestPayload, error) {
+	var images []string
+	if req.Image != "" {
+		images = []string{req.Image}
+	}
+
+	r := requestPayload{
+		Model:             defaultString(req.Model, "viduq1"),
+		Images:            images,
+		Prompt:            req.Prompt,
+		Duration:          defaultInt(req.Duration, 5),
+		Resolution:        defaultString(req.Size, "1080p"),
+		MovementAmplitude: "auto",
+		Bgm:               false,
+	}
+	metadata := req.Metadata
+	medaBytes, err := json.Marshal(metadata)
+	if err != nil {
+		return nil, errors.Wrap(err, "metadata marshal metadata failed")
+	}
+	err = json.Unmarshal(medaBytes, &r)
+	if err != nil {
+		return nil, errors.Wrap(err, "unmarshal metadata failed")
+	}
+	return &r, nil
+}
+
+func defaultString(value, defaultValue string) string {
+	if value == "" {
+		return defaultValue
+	}
+	return value
+}
+
+func defaultInt(value, defaultValue int) int {
+	if value == 0 {
+		return defaultValue
+	}
+	return value
+}
+
+func (a *TaskAdaptor) ParseTaskResult(respBody []byte) (*relaycommon.TaskInfo, error) {
+	taskInfo := &relaycommon.TaskInfo{}
+
+	var taskResp taskResultResponse
+	err := json.Unmarshal(respBody, &taskResp)
+	if err != nil {
+		return nil, errors.Wrap(err, "failed to unmarshal response body")
+	}
+
+	state := taskResp.State
+	switch state {
+	case "created", "queueing":
+		taskInfo.Status = model.TaskStatusSubmitted
+	case "processing":
+		taskInfo.Status = model.TaskStatusInProgress
+	case "success":
+		taskInfo.Status = model.TaskStatusSuccess
+		if len(taskResp.Creations) > 0 {
+			taskInfo.Url = taskResp.Creations[0].URL
+		}
+	case "failed":
+		taskInfo.Status = model.TaskStatusFailure
+		if taskResp.ErrCode != "" {
+			taskInfo.Reason = taskResp.ErrCode
+		}
+	default:
+		return nil, fmt.Errorf("unknown task state: %s", state)
+	}
+
+	return taskInfo, nil
+}

+ 3 - 0
relay/relay_adaptor.go

@@ -27,6 +27,7 @@ import (
 	taskjimeng "one-api/relay/channel/task/jimeng"
 	"one-api/relay/channel/task/kling"
 	"one-api/relay/channel/task/suno"
+	taskVidu "one-api/relay/channel/task/vidu"
 	"one-api/relay/channel/tencent"
 	"one-api/relay/channel/vertex"
 	"one-api/relay/channel/volcengine"
@@ -122,6 +123,8 @@ func GetTaskAdaptor(platform constant.TaskPlatform) channel.TaskAdaptor {
 			return &kling.TaskAdaptor{}
 		case constant.ChannelTypeJimeng:
 			return &taskjimeng.TaskAdaptor{}
+		case constant.ChannelTypeVidu:
+			return &taskVidu.TaskAdaptor{}
 		}
 	}
 	return nil

+ 5 - 0
web/src/constants/channel.constants.js

@@ -154,6 +154,11 @@ export const CHANNEL_OPTIONS = [
     color: 'blue',
     label: '即梦',
   },
+  {
+    value: 52,
+    color: 'purple',
+    label: 'Vidu',
+  },
 ];
 
 export const MODEL_TABLE_PAGE_SIZE = 10;