Просмотр исходного кода

Merge pull request #2004 from feitianbubu/pr/openai-sdk-kling

支持可灵使用openai sdk生成视频
Calcium-Ion 4 месяцев назад
Родитель
Сommit
1031f1ddf0

+ 16 - 8
middleware/distributor.go

@@ -174,14 +174,22 @@ func getModelRequest(c *gin.Context) (*ModelRequest, bool, error) {
 		relayMode := relayconstant.RelayModeUnknown
 		if c.Request.Method == http.MethodPost {
 			relayMode = relayconstant.RelayModeVideoSubmit
-			form, err := common.ParseMultipartFormReusable(c)
-			if err != nil {
-				return nil, false, errors.New("无效的video请求, " + err.Error())
-			}
-			defer form.RemoveAll()
-			if form != nil {
-				if values, ok := form.Value["model"]; ok && len(values) > 0 {
-					modelRequest.Model = values[0]
+			contentType := c.Request.Header.Get("Content-Type")
+			if strings.HasPrefix(contentType, "multipart/form-data") {
+				form, err := common.ParseMultipartFormReusable(c)
+				if err != nil {
+					return nil, false, errors.New("无效的video请求, " + err.Error())
+				}
+				defer form.RemoveAll()
+				if form != nil {
+					if values, ok := form.Value["model"]; ok && len(values) > 0 {
+						modelRequest.Model = values[0]
+					}
+				}
+			} else if strings.HasPrefix(contentType, "application/json") {
+				err = common.UnmarshalBodyReusable(c, &modelRequest)
+				if err != nil {
+					return nil, false, errors.New("无效的video请求, " + err.Error())
 				}
 			}
 		} else if c.Request.Method == http.MethodGet {

+ 5 - 0
relay/channel/adapter.go

@@ -4,6 +4,7 @@ import (
 	"io"
 	"net/http"
 	"one-api/dto"
+	"one-api/model"
 	relaycommon "one-api/relay/common"
 	"one-api/types"
 
@@ -49,3 +50,7 @@ type TaskAdaptor interface {
 
 	ParseTaskResult(respBody []byte) (*relaycommon.TaskInfo, error)
 }
+
+type OpenAIVideoConverter interface {
+	ConvertToOpenAIVideo(originTask *model.Task) (*relaycommon.OpenAIVideo, error)
+}

+ 49 - 8
relay/channel/task/kling/adaptor.go

@@ -7,9 +7,11 @@ import (
 	"io"
 	"net/http"
 	"one-api/model"
+	"strconv"
 	"strings"
 	"time"
 
+	"github.com/bytedance/gopkg/util/logger"
 	"github.com/samber/lo"
 
 	"github.com/gin-gonic/gin"
@@ -303,14 +305,6 @@ func (a *TaskAdaptor) createJWTToken() (string, error) {
 	return a.createJWTTokenWithKey(a.apiKey)
 }
 
-//func (a *TaskAdaptor) createJWTTokenWithKey(apiKey string) (string, error) {
-//	parts := strings.Split(apiKey, "|")
-//	if len(parts) != 2 {
-//		return "", fmt.Errorf("invalid API key format, expected 'access_key,secret_key'")
-//	}
-//	return a.createJWTTokenWithKey(strings.TrimSpace(parts[0]), strings.TrimSpace(parts[1]))
-//}
-
 func (a *TaskAdaptor) createJWTTokenWithKey(apiKey string) (string, error) {
 	if isNewAPIRelay(apiKey) {
 		return apiKey, nil // new api relay
@@ -369,3 +363,50 @@ func (a *TaskAdaptor) ParseTaskResult(respBody []byte) (*relaycommon.TaskInfo, e
 func isNewAPIRelay(apiKey string) bool {
 	return strings.HasPrefix(apiKey, "sk-")
 }
+
+func (a *TaskAdaptor) ConvertToOpenAIVideo(originTask *model.Task) (*relaycommon.OpenAIVideo, error) {
+	var klingResp responsePayload
+	if err := json.Unmarshal(originTask.Data, &klingResp); err != nil {
+		return nil, errors.Wrap(err, "unmarshal kling task data failed")
+	}
+
+	convertProgress := func(progress string) int {
+		progress = strings.TrimSuffix(progress, "%")
+		p, err := strconv.Atoi(progress)
+		if err != nil {
+			logger.Warnf("convert progress failed, progress: %s, err: %v", progress, err)
+		}
+		return p
+	}
+
+	openAIVideo := &relaycommon.OpenAIVideo{
+		ID:     klingResp.Data.TaskId,
+		Object: "video",
+		//Model:       "kling-v1", //todo save model
+		Status:      string(originTask.Status),
+		CreatedAt:   klingResp.Data.CreatedAt,
+		CompletedAt: klingResp.Data.UpdatedAt,
+		Metadata:    make(map[string]any),
+		Progress:    convertProgress(originTask.Progress),
+	}
+
+	// 处理视频 URL
+	if len(klingResp.Data.TaskResult.Videos) > 0 {
+		video := klingResp.Data.TaskResult.Videos[0]
+		if video.Url != "" {
+			openAIVideo.Metadata["url"] = video.Url
+		}
+		if video.Duration != "" {
+			openAIVideo.Seconds = video.Duration
+		}
+	}
+
+	if klingResp.Code != 0 && klingResp.Message != "" {
+		openAIVideo.Error = &relaycommon.OpenAIVideoError{
+			Message: klingResp.Message,
+			Code:    fmt.Sprintf("%d", klingResp.Code),
+		}
+	}
+
+	return openAIVideo, nil
+}

+ 9 - 0
relay/channel/task/sora/adaptor.go

@@ -184,3 +184,12 @@ func (a *TaskAdaptor) ParseTaskResult(respBody []byte) (*relaycommon.TaskInfo, e
 
 	return &taskResult, nil
 }
+
+func (a *TaskAdaptor) ConvertToOpenAIVideo(task *model.Task) (*relaycommon.OpenAIVideo, error) {
+	openAIVideo := &relaycommon.OpenAIVideo{}
+	err := json.Unmarshal(task.Data, openAIVideo)
+	if err != nil {
+		return nil, errors.Wrap(err, "unmarshal to OpenAIVideo failed")
+	}
+	return openAIVideo, nil
+}

+ 21 - 0
relay/common/relay_info.go

@@ -550,3 +550,24 @@ func RemoveDisabledFields(jsonData []byte, channelOtherSettings dto.ChannelOther
 	}
 	return jsonDataAfter, nil
 }
+
+type OpenAIVideo struct {
+	ID                 string            `json:"id"`
+	TaskID             string            `json:"task_id,omitempty"` //兼容旧接口 待废弃
+	Object             string            `json:"object"`
+	Model              string            `json:"model"`
+	Status             string            `json:"status"`
+	Progress           int               `json:"progress"`
+	CreatedAt          int64             `json:"created_at"`
+	CompletedAt        int64             `json:"completed_at,omitempty"`
+	ExpiresAt          int64             `json:"expires_at,omitempty"`
+	Seconds            string            `json:"seconds,omitempty"`
+	Size               string            `json:"size,omitempty"`
+	RemixedFromVideoID string            `json:"remixed_from_video_id,omitempty"`
+	Error              *OpenAIVideoError `json:"error,omitempty"`
+	Metadata           map[string]any    `json:"metadata,omitempty"`
+}
+type OpenAIVideoError struct {
+	Message string `json:"message"`
+	Code    string `json:"code"`
+}

+ 41 - 13
relay/common/relay_utils.go

@@ -106,25 +106,53 @@ func validateMultipartTaskRequest(c *gin.Context, info *RelayInfo, action string
 }
 
 func ValidateMultipartDirect(c *gin.Context, info *RelayInfo) *dto.TaskError {
-	form, err := common.ParseMultipartFormReusable(c)
-	if err != nil {
-		return createTaskError(err, "invalid_multipart_form", http.StatusBadRequest, true)
-	}
-	defer form.RemoveAll()
+	contentType := c.GetHeader("Content-Type")
+	var prompt string
+	var hasInputReference bool
+
+	if strings.HasPrefix(contentType, "multipart/form-data") {
+		form, err := common.ParseMultipartFormReusable(c)
+		if err != nil {
+			return createTaskError(err, "invalid_multipart_form", http.StatusBadRequest, true)
+		}
+		defer form.RemoveAll()
+
+		prompts, ok := form.Value["prompt"]
+		if !ok || len(prompts) == 0 {
+			return createTaskError(fmt.Errorf("prompt field is required"), "missing_prompt", http.StatusBadRequest, true)
+		}
+		prompt = prompts[0]
+
+		if _, ok := form.Value["model"]; !ok {
+			return createTaskError(fmt.Errorf("model field is required"), "missing_model", http.StatusBadRequest, true)
+		}
+
+		if _, ok := form.File["input_reference"]; ok {
+			hasInputReference = true
+		}
+	} else {
+		var req TaskSubmitReq
+		if err := common.UnmarshalBodyReusable(c, &req); err != nil {
+			return createTaskError(err, "invalid_json", http.StatusBadRequest, true)
+		}
+
+		prompt = req.Prompt
 
-	prompts, ok := form.Value["prompt"]
-	if !ok || len(prompts) == 0 {
-		return createTaskError(fmt.Errorf("prompt field is required"), "missing_prompt", http.StatusBadRequest, true)
+		if strings.TrimSpace(req.Model) == "" {
+			return createTaskError(fmt.Errorf("model field is required"), "missing_model", http.StatusBadRequest, true)
+		}
+
+		if req.HasImage() {
+			hasInputReference = true
+		}
 	}
-	if taskErr := validatePrompt(prompts[0]); taskErr != nil {
+
+	if taskErr := validatePrompt(prompt); taskErr != nil {
 		return taskErr
 	}
 
-	if _, ok := form.Value["model"]; !ok {
-		return createTaskError(fmt.Errorf("model field is required"), "missing_model", http.StatusBadRequest, true)
-	}
 	action := constant.TaskActionTextGenerate
-	if _, ok := form.File["input_reference"]; ok {
+	if hasInputReference {
 		action = constant.TaskActionGenerate
 	}
 	info.Action = action

+ 16 - 1
relay/relay_task.go

@@ -11,6 +11,7 @@ import (
 	"one-api/constant"
 	"one-api/dto"
 	"one-api/model"
+	"one-api/relay/channel"
 	relaycommon "one-api/relay/common"
 	relayconstant "one-api/relay/constant"
 	"one-api/service"
@@ -367,7 +368,21 @@ func videoFetchByIDRespBodyBuilder(c *gin.Context) (respBody []byte, taskResp *d
 	}
 
 	if strings.HasPrefix(c.Request.RequestURI, "/v1/videos/") {
-		respBody = originTask.Data
+		adaptor := GetTaskAdaptor(originTask.Platform)
+		if adaptor == nil {
+			taskResp = service.TaskErrorWrapperLocal(fmt.Errorf("invalid channel id: %d", originTask.ChannelId), "invalid_channel_id", http.StatusBadRequest)
+			return
+		}
+		if converter, ok := adaptor.(channel.OpenAIVideoConverter); ok {
+			openAIVideo, err := converter.ConvertToOpenAIVideo(originTask)
+			if err != nil {
+				taskResp = service.TaskErrorWrapper(err, "convert_to_openai_video_failed", http.StatusInternalServerError)
+				return
+			}
+			respBody, _ = json.Marshal(openAIVideo)
+			return
+		}
+		taskResp = service.TaskErrorWrapperLocal(errors.New(fmt.Sprintf("not_implemented:%s", originTask.Platform)), "not_implemented", http.StatusNotImplemented)
 		return
 	}
 	respBody, err = json.Marshal(dto.TaskResponse[any]{