Jelajahi Sumber

Merge pull request #2412 from seefs001/pr-2372

feat: add openai video remix endpoint
Seefs 2 bulan lalu
induk
melakukan
4e69c98b42

+ 1 - 0
constant/task.go

@@ -15,6 +15,7 @@ const (
 	TaskActionTextGenerate      = "textGenerate"
 	TaskActionFirstTailGenerate = "firstTailGenerate"
 	TaskActionReferenceGenerate = "referenceGenerate"
+	TaskActionRemix             = "remixGenerate"
 )
 
 var SunoModel2Action = map[string]string{

+ 4 - 0
middleware/distributor.go

@@ -181,6 +181,10 @@ func getModelRequest(c *gin.Context) (*ModelRequest, bool, error) {
 		}
 		c.Set("platform", string(constant.TaskPlatformSuno))
 		c.Set("relay_mode", relayMode)
+	} else if strings.Contains(c.Request.URL.Path, "/v1/videos/") && strings.HasSuffix(c.Request.URL.Path, "/remix") {
+		relayMode := relayconstant.RelayModeVideoSubmit
+		c.Set("relay_mode", relayMode)
+		shouldSelectChannel = false
 	} else if strings.Contains(c.Request.URL.Path, "/v1/videos") {
 		//curl https://api.openai.com/v1/videos \
 		//  -H "Authorization: Bearer $OPENAI_API_KEY" \

+ 21 - 0
relay/channel/task/sora/adaptor.go

@@ -5,8 +5,10 @@ import (
 	"fmt"
 	"io"
 	"net/http"
+	"strings"
 
 	"github.com/QuantumNous/new-api/common"
+	"github.com/QuantumNous/new-api/constant"
 	"github.com/QuantumNous/new-api/dto"
 	"github.com/QuantumNous/new-api/model"
 	"github.com/QuantumNous/new-api/relay/channel"
@@ -67,11 +69,30 @@ func (a *TaskAdaptor) Init(info *relaycommon.RelayInfo) {
 	a.apiKey = info.ApiKey
 }
 
+func validateRemixRequest(c *gin.Context) *dto.TaskError {
+	var req struct {
+		Prompt string `json:"prompt"`
+	}
+	if err := common.UnmarshalBodyReusable(c, &req); err != nil {
+		return service.TaskErrorWrapperLocal(err, "invalid_request", http.StatusBadRequest)
+	}
+	if strings.TrimSpace(req.Prompt) == "" {
+		return service.TaskErrorWrapperLocal(fmt.Errorf("field prompt is required"), "invalid_request", http.StatusBadRequest)
+	}
+	return nil
+}
+
 func (a *TaskAdaptor) ValidateRequestAndSetAction(c *gin.Context, info *relaycommon.RelayInfo) (taskErr *dto.TaskError) {
+	if info.Action == constant.TaskActionRemix {
+		return validateRemixRequest(c)
+	}
 	return relaycommon.ValidateMultipartDirect(c, info)
 }
 
 func (a *TaskAdaptor) BuildRequestURL(info *relaycommon.RelayInfo) (string, error) {
+	if info.Action == constant.TaskActionRemix {
+		return fmt.Sprintf("%s/v1/videos/%s/remix", a.baseURL, info.OriginTaskID), nil
+	}
 	return fmt.Sprintf("%s/v1/videos", a.baseURL), nil
 }
 

+ 87 - 28
relay/relay_task.go

@@ -32,7 +32,94 @@ func RelayTaskSubmit(c *gin.Context, info *relaycommon.RelayInfo) (taskErr *dto.
 	if info.TaskRelayInfo == nil {
 		info.TaskRelayInfo = &relaycommon.TaskRelayInfo{}
 	}
+	path := c.Request.URL.Path
+	if strings.Contains(path, "/v1/videos/") && strings.HasSuffix(path, "/remix") {
+		info.Action = constant.TaskActionRemix
+	}
+
+	// 提取 remix 任务的 video_id
+	if info.Action == constant.TaskActionRemix {
+		videoID := c.Param("video_id")
+		if strings.TrimSpace(videoID) == "" {
+			return service.TaskErrorWrapperLocal(fmt.Errorf("video_id is required"), "invalid_request", http.StatusBadRequest)
+		}
+		info.OriginTaskID = videoID
+	}
+
 	platform := constant.TaskPlatform(c.GetString("platform"))
+
+	// 获取原始任务信息
+	if info.OriginTaskID != "" {
+		originTask, exist, err := model.GetByTaskId(info.UserId, info.OriginTaskID)
+		if err != nil {
+			taskErr = service.TaskErrorWrapper(err, "get_origin_task_failed", http.StatusInternalServerError)
+			return
+		}
+		if !exist {
+			taskErr = service.TaskErrorWrapperLocal(errors.New("task_origin_not_exist"), "task_not_exist", http.StatusBadRequest)
+			return
+		}
+		if info.OriginModelName == "" {
+			if originTask.Properties.OriginModelName != "" {
+				info.OriginModelName = originTask.Properties.OriginModelName
+			} else if originTask.Properties.UpstreamModelName != "" {
+				info.OriginModelName = originTask.Properties.UpstreamModelName
+			} else {
+				var taskData map[string]interface{}
+				_ = json.Unmarshal(originTask.Data, &taskData)
+				if m, ok := taskData["model"].(string); ok && m != "" {
+					info.OriginModelName = m
+					platform = originTask.Platform
+				}
+			}
+		}
+		if originTask.ChannelId != info.ChannelId {
+			channel, err := model.GetChannelById(originTask.ChannelId, true)
+			if err != nil {
+				taskErr = service.TaskErrorWrapperLocal(err, "channel_not_found", http.StatusBadRequest)
+				return
+			}
+			if channel.Status != common.ChannelStatusEnabled {
+				taskErr = service.TaskErrorWrapperLocal(errors.New("the channel of the origin task is disabled"), "task_channel_disable", http.StatusBadRequest)
+				return
+			}
+			key, _, newAPIError := channel.GetNextEnabledKey()
+			if newAPIError != nil {
+				taskErr = service.TaskErrorWrapper(newAPIError, "channel_no_available_key", newAPIError.StatusCode)
+				return
+			}
+			common.SetContextKey(c, constant.ContextKeyChannelKey, key)
+			common.SetContextKey(c, constant.ContextKeyChannelType, channel.Type)
+			common.SetContextKey(c, constant.ContextKeyChannelBaseUrl, channel.GetBaseURL())
+			common.SetContextKey(c, constant.ContextKeyChannelId, originTask.ChannelId)
+
+			info.ChannelBaseUrl = channel.GetBaseURL()
+			info.ChannelId = originTask.ChannelId
+			info.ChannelType = channel.Type
+			info.ApiKey = key
+			platform = originTask.Platform
+		}
+
+		// 使用原始任务的参数
+		if info.Action == constant.TaskActionRemix {
+			var taskData map[string]interface{}
+			_ = json.Unmarshal(originTask.Data, &taskData)
+			secondsStr, _ := taskData["seconds"].(string)
+			seconds, _ := strconv.Atoi(secondsStr)
+			if seconds <= 0 {
+				seconds = 4
+			}
+			sizeStr, _ := taskData["size"].(string)
+			if info.PriceData.OtherRatios == nil {
+				info.PriceData.OtherRatios = map[string]float64{}
+			}
+			info.PriceData.OtherRatios["seconds"] = float64(seconds)
+			info.PriceData.OtherRatios["size"] = 1
+			if sizeStr == "1792x1024" || sizeStr == "1024x1792" {
+				info.PriceData.OtherRatios["size"] = 1.666667
+			}
+		}
+	}
 	if platform == "" {
 		platform = GetTaskPlatform(c)
 	}
@@ -94,34 +181,6 @@ func RelayTaskSubmit(c *gin.Context, info *relaycommon.RelayInfo) (taskErr *dto.
 		return
 	}
 
-	if info.OriginTaskID != "" {
-		originTask, exist, err := model.GetByTaskId(info.UserId, info.OriginTaskID)
-		if err != nil {
-			taskErr = service.TaskErrorWrapper(err, "get_origin_task_failed", http.StatusInternalServerError)
-			return
-		}
-		if !exist {
-			taskErr = service.TaskErrorWrapperLocal(errors.New("task_origin_not_exist"), "task_not_exist", http.StatusBadRequest)
-			return
-		}
-		if originTask.ChannelId != info.ChannelId {
-			channel, err := model.GetChannelById(originTask.ChannelId, true)
-			if err != nil {
-				taskErr = service.TaskErrorWrapperLocal(err, "channel_not_found", http.StatusBadRequest)
-				return
-			}
-			if channel.Status != common.ChannelStatusEnabled {
-				return service.TaskErrorWrapperLocal(errors.New("该任务所属渠道已被禁用"), "task_channel_disable", http.StatusBadRequest)
-			}
-			c.Set("base_url", channel.GetBaseURL())
-			c.Set("channel_id", originTask.ChannelId)
-			c.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", channel.Key))
-
-			info.ChannelBaseUrl = channel.GetBaseURL()
-			info.ChannelId = originTask.ChannelId
-		}
-	}
-
 	// build body
 	requestBody, err := adaptor.BuildRequestBody(c, info)
 	if err != nil {

+ 1 - 0
router/video-router.go

@@ -14,6 +14,7 @@ func SetVideoRouter(router *gin.Engine) {
 		videoV1Router.GET("/videos/:task_id/content", controller.VideoProxy)
 		videoV1Router.POST("/video/generations", controller.RelayTask)
 		videoV1Router.GET("/video/generations/:task_id", controller.RelayTask)
+		videoV1Router.POST("/videos/:video_id/remix", controller.RelayTask)
 	}
 	// openai compatible API video routes
 	// docs: https://platform.openai.com/docs/api-reference/videos/create

+ 9 - 1
web/src/components/table/task-logs/TaskLogsColumnDefs.jsx

@@ -39,6 +39,7 @@ import {
   TASK_ACTION_GENERATE,
   TASK_ACTION_REFERENCE_GENERATE,
   TASK_ACTION_TEXT_GENERATE,
+  TASK_ACTION_REMIX_GENERATE,
 } from '../../../constants/common.constant';
 import { CHANNEL_OPTIONS } from '../../../constants/channel.constants';
 
@@ -125,6 +126,12 @@ const renderType = (type, t) => {
           {t('参照生视频')}
         </Tag>
       );
+    case TASK_ACTION_REMIX_GENERATE:
+      return (
+        <Tag color='blue' shape='circle' prefixIcon={<Sparkles size={14} />}>
+          {t('视频Remix')}
+        </Tag>
+      );
     default:
       return (
         <Tag color='white' shape='circle' prefixIcon={<HelpCircle size={14} />}>
@@ -359,7 +366,8 @@ export const getTaskLogsColumns = ({
           record.action === TASK_ACTION_GENERATE ||
           record.action === TASK_ACTION_TEXT_GENERATE ||
           record.action === TASK_ACTION_FIRST_TAIL_GENERATE ||
-          record.action === TASK_ACTION_REFERENCE_GENERATE;
+          record.action === TASK_ACTION_REFERENCE_GENERATE ||
+          record.action === TASK_ACTION_REMIX_GENERATE;
         const isSuccess = record.status === 'SUCCESS';
         const isUrl = typeof text === 'string' && /^https?:\/\//.test(text);
         if (isSuccess && isVideoTask && isUrl) {

+ 1 - 0
web/src/constants/common.constant.js

@@ -42,3 +42,4 @@ export const TASK_ACTION_GENERATE = 'generate';
 export const TASK_ACTION_TEXT_GENERATE = 'textGenerate';
 export const TASK_ACTION_FIRST_TAIL_GENERATE = 'firstTailGenerate';
 export const TASK_ACTION_REFERENCE_GENERATE = 'referenceGenerate';
+export const TASK_ACTION_REMIX_GENERATE = 'remixGenerate';

+ 1 - 0
web/src/i18n/locales/en.json

@@ -548,6 +548,7 @@
     "参数值": "Parameter value",
     "参数覆盖": "Parameters override",
     "参照生视频": "Reference video generation",
+    "视频Remix": "Video remix",
     "友情链接": "Friendly links",
     "发布日期": "Publish Date",
     "发布时间": "Publish Time",

+ 1 - 0
web/src/i18n/locales/fr.json

@@ -551,6 +551,7 @@
     "参数值": "Valeur du paramètre",
     "参数覆盖": "Remplacement des paramètres",
     "参照生视频": "Générer une vidéo par référence",
+    "视频Remix": "Remix vidéo",
     "友情链接": "Liens amicaux",
     "发布日期": "Date de publication",
     "发布时间": "Heure de publication",

+ 1 - 0
web/src/i18n/locales/ja.json

@@ -510,6 +510,7 @@
     "参数值": "パラメータ値",
     "参数覆盖": "パラメータの上書き",
     "参照生视频": "参照動画生成",
+    "视频Remix": "動画リミックス",
     "友情链接": "関連リンク",
     "发布日期": "公開日",
     "发布时间": "公開日時",

+ 1 - 0
web/src/i18n/locales/ru.json

@@ -555,6 +555,7 @@
     "参数值": "Значение параметра",
     "参数覆盖": "Переопределение параметров",
     "参照生视频": "Ссылка на генерацию видео",
+    "视频Remix": "Видео ремикс",
     "友情链接": "Дружественные ссылки",
     "发布日期": "Дата публикации",
     "发布时间": "Время публикации",

+ 1 - 0
web/src/i18n/locales/vi.json

@@ -510,6 +510,7 @@
     "参数值": "Giá trị tham số",
     "参数覆盖": "Ghi đè tham số",
     "参照生视频": "Tạo video tham chiếu",
+    "视频Remix": "Remix video",
     "友情链接": "Liên kết thân thiện",
     "发布日期": "Ngày xuất bản",
     "发布时间": "Thời gian xuất bản",

+ 1 - 0
web/src/i18n/locales/zh.json

@@ -543,6 +543,7 @@
     "参数值": "参数值",
     "参数覆盖": "参数覆盖",
     "参照生视频": "参照生视频",
+    "视频Remix": "视频 Remix",
     "友情链接": "友情链接",
     "发布日期": "发布日期",
     "发布时间": "发布时间",