Просмотр исходного кода

feat: add MiniMax Hailuo video

feitianbubu 3 месяцев назад
Родитель
Сommit
850a553958

+ 332 - 0
relay/channel/task/hailuo/adaptor.go

@@ -0,0 +1,332 @@
+package hailuo
+
+import (
+	"bytes"
+	"encoding/json"
+	"fmt"
+	"io"
+	"net/http"
+	"strconv"
+	"strings"
+	"time"
+
+	"github.com/QuantumNous/new-api/common"
+	"github.com/QuantumNous/new-api/model"
+
+	"github.com/gin-gonic/gin"
+	"github.com/pkg/errors"
+
+	"github.com/QuantumNous/new-api/constant"
+	"github.com/QuantumNous/new-api/dto"
+	"github.com/QuantumNous/new-api/relay/channel"
+	relaycommon "github.com/QuantumNous/new-api/relay/common"
+	"github.com/QuantumNous/new-api/service"
+)
+
+type TaskAdaptor struct {
+	ChannelType int
+	apiKey      string
+	baseURL     string
+}
+
+func (a *TaskAdaptor) Init(info *relaycommon.RelayInfo) {
+	a.ChannelType = info.ChannelType
+	a.baseURL = info.ChannelBaseUrl
+	a.apiKey = info.ApiKey
+}
+
+func (a *TaskAdaptor) ValidateRequestAndSetAction(c *gin.Context, info *relaycommon.RelayInfo) (taskErr *dto.TaskError) {
+	return relaycommon.ValidateBasicTaskRequest(c, info, constant.TaskActionGenerate)
+}
+
+func (a *TaskAdaptor) BuildRequestURL(info *relaycommon.RelayInfo) (string, error) {
+	return fmt.Sprintf("%s%s", a.baseURL, TextToVideoEndpoint), nil
+}
+
+func (a *TaskAdaptor) BuildRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error {
+	req.Header.Set("Content-Type", "application/json")
+	req.Header.Set("Accept", "application/json")
+	req.Header.Set("Authorization", "Bearer "+a.apiKey)
+	return nil
+}
+
+func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, info *relaycommon.RelayInfo) (io.Reader, error) {
+	v, exists := c.Get("task_request")
+	if !exists {
+		return nil, fmt.Errorf("request not found in context")
+	}
+	req, ok := v.(relaycommon.TaskSubmitReq)
+	if !ok {
+		return nil, fmt.Errorf("invalid request type in context")
+	}
+
+	body, err := a.convertToRequestPayload(&req)
+	if err != nil {
+		return nil, errors.Wrap(err, "convert request payload failed")
+	}
+
+	data, err := json.Marshal(body)
+	if err != nil {
+		return nil, err
+	}
+
+	return bytes.NewReader(data), nil
+}
+
+func (a *TaskAdaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
+	return channel.DoTaskApiRequest(a, c, info, requestBody)
+}
+
+func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (taskID string, taskData []byte, taskErr *dto.TaskError) {
+	responseBody, err := io.ReadAll(resp.Body)
+	if err != nil {
+		taskErr = service.TaskErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError)
+		return
+	}
+	_ = resp.Body.Close()
+
+	var hResp TextToVideoResponse
+	if err := json.Unmarshal(responseBody, &hResp); err != nil {
+		taskErr = service.TaskErrorWrapper(errors.Wrapf(err, "body: %s", responseBody), "unmarshal_response_body_failed", http.StatusInternalServerError)
+		return
+	}
+
+	if hResp.BaseResp.StatusCode != StatusSuccess {
+		taskErr = service.TaskErrorWrapper(
+			fmt.Errorf("hailuo api error: %s", hResp.BaseResp.StatusMsg),
+			strconv.Itoa(hResp.BaseResp.StatusCode),
+			http.StatusBadRequest,
+		)
+		return
+	}
+
+	ov := dto.NewOpenAIVideo()
+	ov.ID = hResp.TaskID
+	ov.TaskID = hResp.TaskID
+	ov.CreatedAt = time.Now().Unix()
+	ov.Model = info.OriginModelName
+
+	c.JSON(http.StatusOK, ov)
+	return hResp.TaskID, responseBody, nil
+}
+
+func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any) (*http.Response, error) {
+	taskID, ok := body["task_id"].(string)
+	if !ok {
+		return nil, fmt.Errorf("invalid task_id")
+	}
+
+	uri := fmt.Sprintf("%s%s?task_id=%s", baseUrl, QueryTaskEndpoint, taskID)
+
+	req, err := http.NewRequest(http.MethodGet, uri, nil)
+	if err != nil {
+		return nil, err
+	}
+
+	req.Header.Set("Accept", "application/json")
+	req.Header.Set("Authorization", "Bearer "+key)
+
+	return service.GetHttpClient().Do(req)
+}
+
+func (a *TaskAdaptor) GetModelList() []string {
+	return ModelList
+}
+
+func (a *TaskAdaptor) GetChannelName() string {
+	return ChannelName
+}
+
+func (a *TaskAdaptor) convertToRequestPayload(req *relaycommon.TaskSubmitReq) (*TextToVideoRequest, error) {
+	modelConfig := GetModelConfig(req.Model)
+	if !contains(ModelList, req.Model) {
+		return nil, fmt.Errorf("unsupported model: %s", req.Model)
+	}
+
+	duration := DefaultDuration
+	if req.Duration > 0 {
+		duration = req.Duration
+	}
+
+	if !containsInt(modelConfig.SupportedDurations, duration) {
+		return nil, fmt.Errorf("duration %d is not supported by model %s, supported durations: %v",
+			duration, req.Model, modelConfig.SupportedDurations)
+	}
+
+	resolution := modelConfig.DefaultResolution
+	if req.Size != "" {
+		resolution = a.parseResolutionFromSize(req.Size, modelConfig)
+	}
+
+	if !contains(modelConfig.SupportedResolutions, resolution) {
+		return nil, fmt.Errorf("resolution %s is not supported by model %s, supported resolutions: %v",
+			resolution, req.Model, modelConfig.SupportedResolutions)
+	}
+
+	hailuoReq := &TextToVideoRequest{
+		Model:      req.Model,
+		Prompt:     req.Prompt,
+		Duration:   &duration,
+		Resolution: resolution,
+	}
+
+	promptOptimizer := DefaultPromptOptimizer
+	hailuoReq.PromptOptimizer = &promptOptimizer
+
+	metadata := req.Metadata
+	if metadata != nil {
+		metadataBytes, err := json.Marshal(metadata)
+		if err != nil {
+			return nil, errors.Wrap(err, "marshal metadata failed")
+		}
+
+		var metadataMap map[string]interface{}
+		if err := json.Unmarshal(metadataBytes, &metadataMap); err != nil {
+			return nil, errors.Wrap(err, "unmarshal metadata failed")
+		}
+
+		if val, exists := metadataMap["prompt_optimizer"]; exists {
+			if boolVal, ok := val.(bool); ok {
+				hailuoReq.PromptOptimizer = &boolVal
+			}
+		}
+
+		if modelConfig.HasFastPretreatment {
+			if val, exists := metadataMap["fast_pretreatment"]; exists {
+				if boolVal, ok := val.(bool); ok {
+					hailuoReq.FastPretreatment = &boolVal
+				}
+			}
+		}
+
+		if val, exists := metadataMap["callback_url"]; exists {
+			if strVal, ok := val.(string); ok {
+				hailuoReq.CallbackURL = strVal
+			}
+		}
+
+		if val, exists := metadataMap["aigc_watermark"]; exists {
+			if boolVal, ok := val.(bool); ok {
+				hailuoReq.AigcWatermark = &boolVal
+			}
+		}
+	}
+
+	if req.HasImage() {
+		return nil, fmt.Errorf("image input is not supported by hailuo video generation")
+	}
+
+	return hailuoReq, nil
+}
+
+func (a *TaskAdaptor) parseResolutionFromSize(size string, modelConfig ModelConfig) string {
+	switch {
+	case strings.Contains(size, "1080"):
+		return Resolution1080P
+	case strings.Contains(size, "768"):
+		return Resolution768P
+	case strings.Contains(size, "720"):
+		return Resolution720P
+	default:
+		return modelConfig.DefaultResolution
+	}
+}
+
+func (a *TaskAdaptor) ParseTaskResult(respBody []byte) (*relaycommon.TaskInfo, error) {
+	resTask := QueryTaskResponse{}
+	if err := json.Unmarshal(respBody, &resTask); err != nil {
+		return nil, errors.Wrap(err, "unmarshal task result failed")
+	}
+
+	taskResult := relaycommon.TaskInfo{}
+
+	if resTask.BaseResp.StatusCode == StatusSuccess {
+		taskResult.Code = 0
+	} else {
+		taskResult.Code = resTask.BaseResp.StatusCode
+		taskResult.Reason = resTask.BaseResp.StatusMsg
+		taskResult.Status = model.TaskStatusFailure
+		taskResult.Progress = "100%"
+	}
+
+	switch resTask.Status {
+	case TaskStatusPreparing, TaskStatusQueueing, TaskStatusProcessing:
+		taskResult.Status = model.TaskStatusInProgress
+		taskResult.Progress = "30%"
+		if resTask.Status == TaskStatusProcessing {
+			taskResult.Progress = "50%"
+		}
+	case TaskStatusSuccess:
+		taskResult.Status = model.TaskStatusSuccess
+		taskResult.Progress = "100%"
+		if resTask.VideoURL != "" {
+			taskResult.Url = resTask.VideoURL
+		} else if resTask.FileID != "" {
+			taskResult.Url = fmt.Sprintf("https://api.minimaxi.com/v1/files/download?file_id=%s", resTask.FileID)
+		}
+	case TaskStatusFailed:
+		taskResult.Status = model.TaskStatusFailure
+		taskResult.Progress = "100%"
+		if taskResult.Reason == "" {
+			taskResult.Reason = "task failed"
+		}
+	default:
+		taskResult.Status = model.TaskStatusInProgress
+		taskResult.Progress = "30%"
+	}
+
+	return &taskResult, nil
+}
+
+func (a *TaskAdaptor) ConvertToOpenAIVideo(originTask *model.Task) ([]byte, error) {
+	var hailuoResp QueryTaskResponse
+	if err := json.Unmarshal(originTask.Data, &hailuoResp); err != nil {
+		return nil, errors.Wrap(err, "unmarshal hailuo task data failed")
+	}
+
+	openAIVideo := dto.NewOpenAIVideo()
+	openAIVideo.ID = originTask.TaskID
+	openAIVideo.Status = originTask.Status.ToVideoStatus()
+	openAIVideo.SetProgressStr(originTask.Progress)
+	openAIVideo.CreatedAt = originTask.CreatedAt
+	openAIVideo.CompletedAt = originTask.UpdatedAt
+
+	if hailuoResp.VideoURL != "" {
+		openAIVideo.SetMetadata("url", hailuoResp.VideoURL)
+	} else if hailuoResp.FileID != "" {
+		openAIVideo.SetMetadata("file_id", hailuoResp.FileID)
+		openAIVideo.SetMetadata("url", fmt.Sprintf("https://api.minimaxi.com/v1/files/download?file_id=%s", hailuoResp.FileID))
+	}
+
+	if hailuoResp.BaseResp.StatusCode != StatusSuccess {
+		openAIVideo.Error = &dto.OpenAIVideoError{
+			Message: hailuoResp.BaseResp.StatusMsg,
+			Code:    strconv.Itoa(hailuoResp.BaseResp.StatusCode),
+		}
+	}
+
+	jsonData, err := common.Marshal(openAIVideo)
+	if err != nil {
+		return nil, errors.Wrap(err, "marshal openai video failed")
+	}
+
+	return jsonData, nil
+}
+
+func contains(slice []string, item string) bool {
+	for _, s := range slice {
+		if s == item {
+			return true
+		}
+	}
+	return false
+}
+
+func containsInt(slice []int, item int) bool {
+	for _, s := range slice {
+		if s == item {
+			return true
+		}
+	}
+	return false
+}

+ 47 - 0
relay/channel/task/hailuo/constants.go

@@ -0,0 +1,47 @@
+package hailuo
+
+const (
+	ChannelName = "hailuo-video"
+)
+
+var ModelList = []string{
+	"MiniMax-Hailuo-2.3",
+	"MiniMax-Hailuo-02",
+	"T2V-01-Director",
+	"T2V-01",
+}
+
+const (
+	TextToVideoEndpoint = "/v1/video_generation"
+	QueryTaskEndpoint   = "/v1/query/video_generation"
+)
+
+const (
+	StatusSuccess    = 0
+	StatusRateLimit  = 1002
+	StatusAuthFailed = 1004
+	StatusNoBalance  = 1008
+	StatusSensitive  = 1026
+	StatusParamError = 2013
+	StatusInvalidKey = 2049
+)
+
+const (
+	TaskStatusPreparing  = "Preparing"
+	TaskStatusQueueing   = "Queueing"
+	TaskStatusProcessing = "Processing"
+	TaskStatusSuccess    = "Success"
+	TaskStatusFailed     = "Fail"
+)
+
+const (
+	Resolution720P  = "720P"
+	Resolution768P  = "768P"
+	Resolution1080P = "1080P"
+)
+
+const (
+	DefaultDuration        = 6
+	DefaultResolution      = Resolution768P
+	DefaultPromptOptimizer = true
+)

+ 107 - 0
relay/channel/task/hailuo/models.go

@@ -0,0 +1,107 @@
+package hailuo
+
+type TextToVideoRequest struct {
+	Model            string `json:"model"`
+	Prompt           string `json:"prompt"`
+	PromptOptimizer  *bool  `json:"prompt_optimizer,omitempty"`
+	FastPretreatment *bool  `json:"fast_pretreatment,omitempty"`
+	Duration         *int   `json:"duration,omitempty"`
+	Resolution       string `json:"resolution,omitempty"`
+	CallbackURL      string `json:"callback_url,omitempty"`
+	AigcWatermark    *bool  `json:"aigc_watermark,omitempty"`
+}
+
+type TextToVideoResponse struct {
+	TaskID   string   `json:"task_id"`
+	BaseResp BaseResp `json:"base_resp"`
+}
+
+type BaseResp struct {
+	StatusCode int    `json:"status_code"`
+	StatusMsg  string `json:"status_msg"`
+}
+
+type QueryTaskRequest struct {
+	TaskID string `json:"task_id"`
+}
+
+type QueryTaskResponse struct {
+	TaskID   string   `json:"task_id"`
+	Status   string   `json:"status"`
+	FileID   string   `json:"file_id,omitempty"`
+	VideoURL string   `json:"video_url,omitempty"`
+	BaseResp BaseResp `json:"base_resp"`
+}
+
+type ErrorInfo struct {
+	StatusCode int    `json:"status_code"`
+	StatusMsg  string `json:"status_msg"`
+}
+
+type TaskStatusInfo struct {
+	TaskID    string `json:"task_id"`
+	Status    string `json:"status"`
+	FileID    string `json:"file_id,omitempty"`
+	VideoURL  string `json:"video_url,omitempty"`
+	ErrorCode int    `json:"error_code,omitempty"`
+	ErrorMsg  string `json:"error_msg,omitempty"`
+}
+
+type ModelConfig struct {
+	Name                 string
+	DefaultResolution    string
+	SupportedDurations   []int
+	SupportedResolutions []string
+	HasPromptOptimizer   bool
+	HasFastPretreatment  bool
+}
+
+func GetModelConfig(model string) ModelConfig {
+	configs := map[string]ModelConfig{
+		"MiniMax-Hailuo-2.3": {
+			Name:                 "MiniMax-Hailuo-2.3",
+			DefaultResolution:    Resolution768P,
+			SupportedDurations:   []int{6, 10},
+			SupportedResolutions: []string{Resolution768P, Resolution1080P},
+			HasPromptOptimizer:   true,
+			HasFastPretreatment:  true,
+		},
+		"MiniMax-Hailuo-02": {
+			Name:                 "MiniMax-Hailuo-02",
+			DefaultResolution:    Resolution768P,
+			SupportedDurations:   []int{6, 10},
+			SupportedResolutions: []string{Resolution768P, Resolution1080P},
+			HasPromptOptimizer:   true,
+			HasFastPretreatment:  true,
+		},
+		"T2V-01-Director": {
+			Name:                 "T2V-01-Director",
+			DefaultResolution:    Resolution768P,
+			SupportedDurations:   []int{6},
+			SupportedResolutions: []string{Resolution768P, Resolution1080P},
+			HasPromptOptimizer:   true,
+			HasFastPretreatment:  false,
+		},
+		"T2V-01": {
+			Name:                 "T2V-01",
+			DefaultResolution:    Resolution720P,
+			SupportedDurations:   []int{6},
+			SupportedResolutions: []string{Resolution720P},
+			HasPromptOptimizer:   true,
+			HasFastPretreatment:  false,
+		},
+	}
+
+	if config, exists := configs[model]; exists {
+		return config
+	}
+
+	return ModelConfig{
+		Name:                 model,
+		DefaultResolution:    Resolution720P,
+		SupportedDurations:   []int{6},
+		SupportedResolutions: []string{Resolution720P},
+		HasPromptOptimizer:   true,
+		HasFastPretreatment:  false,
+	}
+}

+ 3 - 0
relay/relay_adaptor.go

@@ -32,6 +32,7 @@ import (
 	taskali "github.com/QuantumNous/new-api/relay/channel/task/ali"
 	taskdoubao "github.com/QuantumNous/new-api/relay/channel/task/doubao"
 	taskGemini "github.com/QuantumNous/new-api/relay/channel/task/gemini"
+	"github.com/QuantumNous/new-api/relay/channel/task/hailuo"
 	taskjimeng "github.com/QuantumNous/new-api/relay/channel/task/jimeng"
 	"github.com/QuantumNous/new-api/relay/channel/task/kling"
 	tasksora "github.com/QuantumNous/new-api/relay/channel/task/sora"
@@ -153,6 +154,8 @@ func GetTaskAdaptor(platform constant.TaskPlatform) channel.TaskAdaptor {
 			return &tasksora.TaskAdaptor{}
 		case constant.ChannelTypeGemini:
 			return &taskGemini.TaskAdaptor{}
+		case constant.ChannelTypeMiniMax:
+			return &hailuo.TaskAdaptor{}
 		}
 	}
 	return nil