Просмотр исходного кода

feat: channel kling support New API

feitianbubu 7 месяцев назад
Родитель
Сommit
fcc006ecd3

+ 20 - 0
controller/swag_video.go

@@ -114,3 +114,23 @@ type KlingImage2VideoRequest struct {
 	CallbackURL    string              `json:"callback_url,omitempty" example:"https://your.domain/callback"`
 	ExternalTaskId string              `json:"external_task_id,omitempty" example:"custom-task-002"`
 }
+
+// KlingImage2videoTaskId godoc
+// @Summary 可灵任务查询--图生视频
+// @Description Query the status and result of a Kling video generation task by task ID
+// @Tags Origin
+// @Accept json
+// @Produce json
+// @Param task_id path string true "Task ID"
+// @Router /kling/v1/videos/image2video/{task_id} [get]
+func KlingImage2videoTaskId(c *gin.Context) {}
+
+// KlingText2videoTaskId godoc
+// @Summary 可灵任务查询--文生视频
+// @Description Query the status and result of a Kling text-to-video generation task by task ID
+// @Tags Origin
+// @Accept json
+// @Produce json
+// @Param task_id path string true "Task ID"
+// @Router /kling/v1/videos/text2video/{task_id} [get]
+func KlingText2videoTaskId(c *gin.Context) {}

+ 16 - 7
controller/task_video.go

@@ -2,13 +2,16 @@ package controller
 
 import (
 	"context"
+	"encoding/json"
 	"fmt"
 	"io"
 	"one-api/common"
 	"one-api/constant"
+	"one-api/dto"
 	"one-api/model"
 	"one-api/relay"
 	"one-api/relay/channel"
+	relaycommon "one-api/relay/common"
 	"time"
 )
 
@@ -77,13 +80,21 @@ func updateVideoSingleTask(ctx context.Context, adaptor channel.TaskAdaptor, cha
 		return fmt.Errorf("readAll failed for task %s: %w", taskId, err)
 	}
 
-	taskResult, err := adaptor.ParseTaskResult(responseBody)
-	if err != nil {
+	taskResult := &relaycommon.TaskInfo{}
+	// try parse as New API response format
+	var responseItems dto.TaskResponse[model.Task]
+	if err = json.Unmarshal(responseBody, &responseItems); err == nil {
+		t := responseItems.Data
+		taskResult.TaskID = t.TaskID
+		taskResult.Status = string(t.Status)
+		taskResult.Url = t.FailReason
+		taskResult.Progress = t.Progress
+		taskResult.Reason = t.FailReason
+	} else if taskResult, err = adaptor.ParseTaskResult(responseBody); err != nil {
 		return fmt.Errorf("parseTaskResult failed for task %s: %w", taskId, err)
+	} else {
+		task.Data = responseBody
 	}
-	//if taskResult.Code != 0 {
-	//	return fmt.Errorf("video task fetch failed for task %s", taskId)
-	//}
 
 	now := time.Now().Unix()
 	if taskResult.Status == "" {
@@ -128,8 +139,6 @@ func updateVideoSingleTask(ctx context.Context, adaptor channel.TaskAdaptor, cha
 	if taskResult.Progress != "" {
 		task.Progress = taskResult.Progress
 	}
-
-	task.Data = responseBody
 	if err := task.Update(); err != nil {
 		common.SysError("UpdateVideoTask task error: " + err.Error())
 	}

+ 26 - 34
relay/channel/task/kling/adaptor.go

@@ -50,6 +50,7 @@ type requestPayload struct {
 type responsePayload struct {
 	Code      int    `json:"code"`
 	Message   string `json:"message"`
+	TaskId    string `json:"task_id"`
 	RequestId string `json:"request_id"`
 	Data      struct {
 		TaskId        string `json:"task_id"`
@@ -73,21 +74,16 @@ type responsePayload struct {
 
 type TaskAdaptor struct {
 	ChannelType int
-	accessKey   string
-	secretKey   string
+	apiKey      string
 	baseURL     string
 }
 
 func (a *TaskAdaptor) Init(info *relaycommon.TaskRelayInfo) {
 	a.ChannelType = info.ChannelType
 	a.baseURL = info.BaseUrl
+	a.apiKey = info.ApiKey
 
 	// apiKey format: "access_key|secret_key"
-	keyParts := strings.Split(info.ApiKey, "|")
-	if len(keyParts) == 2 {
-		a.accessKey = strings.TrimSpace(keyParts[0])
-		a.secretKey = strings.TrimSpace(keyParts[1])
-	}
 }
 
 // ValidateRequestAndSetAction parses body, validates fields and sets default action.
@@ -166,27 +162,19 @@ func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, info *rela
 		return
 	}
 
-	// Attempt Kling response parse first.
 	var kResp responsePayload
-	if err := json.Unmarshal(responseBody, &kResp); err == nil && kResp.Code == 0 {
-		c.JSON(http.StatusOK, gin.H{"task_id": kResp.Data.TaskId})
-		return kResp.Data.TaskId, responseBody, nil
-	}
-
-	// Fallback generic task response.
-	var generic dto.TaskResponse[string]
-	if err := json.Unmarshal(responseBody, &generic); err != nil {
-		taskErr = service.TaskErrorWrapper(errors.Wrapf(err, "body: %s", responseBody), "unmarshal_response_body_failed", http.StatusInternalServerError)
+	err = json.Unmarshal(responseBody, &kResp)
+	if err != nil {
+		taskErr = service.TaskErrorWrapper(err, "unmarshal_response_failed", http.StatusInternalServerError)
 		return
 	}
-
-	if !generic.IsSuccess() {
-		taskErr = service.TaskErrorWrapper(fmt.Errorf(generic.Message), generic.Code, http.StatusInternalServerError)
+	if kResp.Code != 0 {
+		taskErr = service.TaskErrorWrapperLocal(fmt.Errorf(kResp.Message), "task_failed", http.StatusBadRequest)
 		return
 	}
-
-	c.JSON(http.StatusOK, gin.H{"task_id": generic.Data})
-	return generic.Data, responseBody, nil
+	kResp.TaskId = kResp.Data.TaskId
+	c.JSON(http.StatusOK, kResp)
+	return kResp.Data.TaskId, responseBody, nil
 }
 
 // FetchTask fetch task status
@@ -288,21 +276,25 @@ func defaultInt(v int, def int) int {
 // ============================
 
 func (a *TaskAdaptor) createJWTToken() (string, error) {
-	return a.createJWTTokenWithKeys(a.accessKey, a.secretKey)
+	return a.createJWTTokenWithKey(a.apiKey)
 }
 
+//func (a *TaskAdaptor) createJWTTokenWithKey(apiKey string) (string, error) {
+//	parts := strings.Split(apiKey, "|")
+//	if len(parts) != 2 {
+//		return "", fmt.Errorf("invalid API key format, expected 'access_key,secret_key'")
+//	}
+//	return a.createJWTTokenWithKey(strings.TrimSpace(parts[0]), strings.TrimSpace(parts[1]))
+//}
+
 func (a *TaskAdaptor) createJWTTokenWithKey(apiKey string) (string, error) {
-	parts := strings.Split(apiKey, "|")
-	if len(parts) != 2 {
-		return "", fmt.Errorf("invalid API key format, expected 'access_key,secret_key'")
-	}
-	return a.createJWTTokenWithKeys(strings.TrimSpace(parts[0]), strings.TrimSpace(parts[1]))
-}
 
-func (a *TaskAdaptor) createJWTTokenWithKeys(accessKey, secretKey string) (string, error) {
-	if accessKey == "" || secretKey == "" {
-		return "", fmt.Errorf("access key and secret key are required")
+	keyParts := strings.Split(apiKey, "|")
+	accessKey := strings.TrimSpace(keyParts[0])
+	if len(keyParts) == 1 {
+		return accessKey, nil
 	}
+	secretKey := strings.TrimSpace(keyParts[1])
 	now := time.Now().Unix()
 	claims := jwt.MapClaims{
 		"iss": accessKey,
@@ -315,12 +307,12 @@ func (a *TaskAdaptor) createJWTTokenWithKeys(accessKey, secretKey string) (strin
 }
 
 func (a *TaskAdaptor) ParseTaskResult(respBody []byte) (*relaycommon.TaskInfo, error) {
+	taskInfo := &relaycommon.TaskInfo{}
 	resPayload := responsePayload{}
 	err := json.Unmarshal(respBody, &resPayload)
 	if err != nil {
 		return nil, errors.Wrap(err, "failed to unmarshal response body")
 	}
-	taskInfo := &relaycommon.TaskInfo{}
 	taskInfo.Code = resPayload.Code
 	taskInfo.TaskID = resPayload.Data.TaskId
 	taskInfo.Reason = resPayload.Message

+ 1 - 1
relay/constant/relay_mode.go

@@ -150,7 +150,7 @@ func Path2RelayKling(method, path string) int {
 	relayMode := RelayModeUnknown
 	if method == http.MethodPost && strings.HasSuffix(path, "/video/generations") {
 		relayMode = RelayModeKlingSubmit
-	} else if method == http.MethodGet && strings.Contains(path, "/video/generations/") {
+	} else if method == http.MethodGet && (strings.Contains(path, "/video/generations")) {
 		relayMode = RelayModeKlingFetchByID
 	}
 	return relayMode

+ 2 - 0
router/video-router.go

@@ -20,5 +20,7 @@ func SetVideoRouter(router *gin.Engine) {
 	{
 		klingV1Router.POST("/videos/text2video", controller.RelayTask)
 		klingV1Router.POST("/videos/image2video", controller.RelayTask)
+		klingV1Router.GET("/videos/text2video/:task_id", controller.RelayTask)
+		klingV1Router.GET("/videos/image2video/:task_id", controller.RelayTask)
 	}
 }

+ 1 - 1
web/src/pages/Channel/EditChannel.js

@@ -68,7 +68,7 @@ function type2secretPrompt(type) {
     case 33:
       return '按照如下格式输入:Ak|Sk|Region';
     case 50:
-      return '按照如下格式输入: AccessKey|SecretKey';
+      return '按照如下格式输入: AccessKey|SecretKey, 如果上游是New API,则直接输ApiKey';
     case 51:
       return '按照如下格式输入: Access Key ID|Secret Access Key';
     default: