Преглед изворни кода

feat: jimeng use openai sdk input_reference i2v

feitianbubu пре 4 месеци
родитељ
комит
dfca9681c8
1 измењених фајлова са 37 додато и 5 уклоњено
  1. 37 5
      relay/channel/task/jimeng/adaptor.go

+ 37 - 5
relay/channel/task/jimeng/adaptor.go

@@ -4,6 +4,7 @@ import (
 	"bytes"
 	"bytes"
 	"crypto/hmac"
 	"crypto/hmac"
 	"crypto/sha256"
 	"crypto/sha256"
+	"encoding/base64"
 	"encoding/hex"
 	"encoding/hex"
 	"encoding/json"
 	"encoding/json"
 	"fmt"
 	"fmt"
@@ -89,7 +90,6 @@ func (a *TaskAdaptor) Init(info *relaycommon.RelayInfo) {
 
 
 // ValidateRequestAndSetAction parses body, validates fields and sets default action.
 // ValidateRequestAndSetAction parses body, validates fields and sets default action.
 func (a *TaskAdaptor) ValidateRequestAndSetAction(c *gin.Context, info *relaycommon.RelayInfo) (taskErr *dto.TaskError) {
 func (a *TaskAdaptor) ValidateRequestAndSetAction(c *gin.Context, info *relaycommon.RelayInfo) (taskErr *dto.TaskError) {
-	// Accept only POST /v1/video/generations as "generate" action.
 	return relaycommon.ValidateBasicTaskRequest(c, info, constant.TaskActionGenerate)
 	return relaycommon.ValidateBasicTaskRequest(c, info, constant.TaskActionGenerate)
 }
 }
 
 
@@ -113,13 +113,45 @@ func (a *TaskAdaptor) BuildRequestHeader(c *gin.Context, req *http.Request, info
 	return nil
 	return nil
 }
 }
 
 
-// BuildRequestBody converts request into Jimeng specific format.
 func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, info *relaycommon.RelayInfo) (io.Reader, error) {
 func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, info *relaycommon.RelayInfo) (io.Reader, error) {
 	v, exists := c.Get("task_request")
 	v, exists := c.Get("task_request")
 	if !exists {
 	if !exists {
 		return nil, fmt.Errorf("request not found in context")
 		return nil, fmt.Errorf("request not found in context")
 	}
 	}
-	req := v.(relaycommon.TaskSubmitReq)
+	req, ok := v.(relaycommon.TaskSubmitReq)
+	if !ok {
+		return nil, fmt.Errorf("invalid request type in context")
+	}
+	// 支持openai sdk的图片上传方式
+	if mf, err := c.MultipartForm(); err == nil {
+		if files, exists := mf.File["input_reference"]; exists && len(files) > 0 {
+			if len(files) == 1 {
+				info.Action = constant.TaskActionGenerate
+			} else if len(files) == 2 {
+				info.Action = constant.TaskActionFirstTailGenerate
+			} else if len(files) > 2 {
+				info.Action = constant.TaskActionReferenceGenerate
+			}
+
+			// 将上传的文件转换为base64格式
+			var images []string
+			for _, fileHeader := range files {
+				file, err := fileHeader.Open()
+				if err != nil {
+					continue
+				}
+				fileBytes, err := io.ReadAll(file)
+				file.Close()
+				if err != nil {
+					continue
+				}
+				// 将文件内容转换为base64
+				base64Str := base64.StdEncoding.EncodeToString(fileBytes)
+				images = append(images, base64Str)
+			}
+			req.Images = images
+		}
+	}
 
 
 	body, err := a.convertToRequestPayload(&req)
 	body, err := a.convertToRequestPayload(&req)
 	if err != nil {
 	if err != nil {
@@ -364,10 +396,10 @@ func (a *TaskAdaptor) convertToRequestPayload(req *relaycommon.TaskSubmitReq) (*
 	// 即梦视频3.0 ReqKey转换
 	// 即梦视频3.0 ReqKey转换
 	// https://www.volcengine.com/docs/85621/1792707
 	// https://www.volcengine.com/docs/85621/1792707
 	if strings.Contains(r.ReqKey, "jimeng_v30") {
 	if strings.Contains(r.ReqKey, "jimeng_v30") {
-		if len(r.ImageUrls) > 1 {
+		if len(req.Images) > 1 {
 			// 多张图片:首尾帧生成
 			// 多张图片:首尾帧生成
 			r.ReqKey = strings.Replace(r.ReqKey, "jimeng_v30", "jimeng_i2v_first_tail_v30", 1)
 			r.ReqKey = strings.Replace(r.ReqKey, "jimeng_v30", "jimeng_i2v_first_tail_v30", 1)
-		} else if len(r.ImageUrls) == 1 {
+		} else if len(req.Images) == 1 {
 			// 单张图片:图生视频
 			// 单张图片:图生视频
 			r.ReqKey = strings.Replace(r.ReqKey, "jimeng_v30", "jimeng_i2v_first_v30", 1)
 			r.ReqKey = strings.Replace(r.ReqKey, "jimeng_v30", "jimeng_i2v_first_v30", 1)
 		} else {
 		} else {