Просмотр исходного кода

Merge pull request #2315 from feitianbubu/pr/gemini-veo3.1-i2v

Gemini Veo3.1[AI Studio]增加图生视频支持
IcedTangerine 3 месяцев назад
Родитель
Сommit
36cf515617
1 измененных файлов с 62 добавлено и 28 удалено
  1. 62 28
      relay/channel/task/gemini/adaptor.go

+ 62 - 28
relay/channel/task/gemini/adaptor.go

@@ -24,13 +24,9 @@ import (
 	"github.com/pkg/errors"
 )
 
-// ============================
-// Request / Response structures
-// ============================
-
-// GeminiVideoGenerationConfig represents the video generation configuration
+// VideoGenerationConfig represents the video generation configuration
 // Based on: https://ai.google.dev/gemini-api/docs/video
-type GeminiVideoGenerationConfig struct {
+type VideoGenerationConfig struct {
 	AspectRatio      string  `json:"aspectRatio,omitempty"`      // "16:9" or "9:16"
 	DurationSeconds  float64 `json:"durationSeconds,omitempty"`  // 4, 6, or 8 (as number)
 	NegativePrompt   string  `json:"negativePrompt,omitempty"`   // unwanted elements
@@ -38,15 +34,21 @@ type GeminiVideoGenerationConfig struct {
 	Resolution       string  `json:"resolution,omitempty"`       // video resolution
 }
 
-// GeminiVideoRequest represents a single video generation instance
-type GeminiVideoRequest struct {
-	Prompt string `json:"prompt"`
+type Image struct {
+	BytesBase64Encoded string `json:"bytesBase64Encoded,omitempty"`
+	MimeType           string `json:"mimeType,omitempty"`
 }
 
-// GeminiVideoPayload represents the complete video generation request payload
-type GeminiVideoPayload struct {
-	Instances  []GeminiVideoRequest        `json:"instances"`
-	Parameters GeminiVideoGenerationConfig `json:"parameters,omitempty"`
+type VideoRequest struct {
+	Prompt    string `json:"prompt"`
+	Image     *Image `json:"image,omitempty"`
+	LastFrame *Image `json:"lastFrame,omitempty"`
+}
+
+// VideoPayload represents the complete video generation request payload
+type VideoPayload struct {
+	Instances  []VideoRequest        `json:"instances"`
+	Parameters VideoGenerationConfig `json:"parameters,omitempty"`
 }
 
 type submitResponse struct {
@@ -75,6 +77,8 @@ type operationResponse struct {
 					URI string `json:"uri"`
 				} `json:"video"`
 			} `json:"generatedSamples"`
+			RaiMediaFilteredCount   int      `json:"raiMediaFilteredCount"`
+			RaiMediaFilteredReasons []string `json:"raiMediaFilteredReasons"`
 		} `json:"generateVideoResponse"`
 	} `json:"response"`
 	Error struct {
@@ -100,8 +104,7 @@ func (a *TaskAdaptor) Init(info *relaycommon.RelayInfo) {
 
 // ValidateRequestAndSetAction parses body, validates fields and sets default action.
 func (a *TaskAdaptor) ValidateRequestAndSetAction(c *gin.Context, info *relaycommon.RelayInfo) (taskErr *dto.TaskError) {
-	// Use the standard validation method for TaskSubmitReq
-	return relaycommon.ValidateBasicTaskRequest(c, info, constant.TaskActionTextGenerate)
+	return relaycommon.ValidateBasicTaskRequest(c, info, constant.TaskActionGenerate)
 }
 
 // BuildRequestURL constructs the upstream URL.
@@ -137,13 +140,21 @@ func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, info *relaycommon.RelayIn
 	}
 
 	// Create structured video generation request
-	body := GeminiVideoPayload{
-		Instances: []GeminiVideoRequest{
+	body := VideoPayload{
+		Instances: []VideoRequest{
 			{Prompt: req.Prompt},
 		},
-		Parameters: GeminiVideoGenerationConfig{},
+		Parameters: VideoGenerationConfig{},
+	}
+
+	if len(req.Images) > 0 {
+		body.Instances[0].Image = a.convertImage(req.Images[0])
+	}
+	if len(req.Images) > 1 {
+		body.Instances[0].LastFrame = a.convertImage(req.Images[1])
 	}
 
+	// Parse metadata for additional configuration
 	metadata := req.Metadata
 	medaBytes, err := json.Marshal(metadata)
 	if err != nil {
@@ -247,20 +258,19 @@ func (a *TaskAdaptor) ParseTaskResult(respBody []byte) (*relaycommon.TaskInfo, e
 		return ti, nil
 	}
 
-	ti.Status = model.TaskStatusSuccess
-	ti.Progress = "100%"
-
-	taskID := encodeLocalTaskID(op.Name)
-	ti.TaskID = taskID
-	ti.Url = fmt.Sprintf("%s/v1/videos/%s/content", system_setting.ServerAddress, taskID)
-
-	// Extract URL from generateVideoResponse if available
-	if len(op.Response.GenerateVideoResponse.GeneratedSamples) > 0 {
+	if len(op.Response.GenerateVideoResponse.GeneratedSamples) == 0 {
+		ti.Status = model.TaskStatusFailure
+		ti.Reason = fmt.Sprintf("no generated video url found: %s", strings.Join(op.Response.GenerateVideoResponse.RaiMediaFilteredReasons, "; "))
+	} else {
 		if uri := op.Response.GenerateVideoResponse.GeneratedSamples[0].Video.URI; uri != "" {
 			ti.RemoteUrl = uri
 		}
+		ti.Status = model.TaskStatusSuccess
 	}
-
+	ti.Progress = "100%"
+	taskID := encodeLocalTaskID(op.Name)
+	ti.TaskID = taskID
+	ti.Url = fmt.Sprintf("%s/v1/videos/%s/content", system_setting.ServerAddress, taskID)
 	return ti, nil
 }
 
@@ -289,6 +299,30 @@ func (a *TaskAdaptor) ConvertToOpenAIVideo(task *model.Task) ([]byte, error) {
 	return common.Marshal(video)
 }
 
+func (a *TaskAdaptor) convertImage(imageStr string) *Image {
+	if strings.TrimSpace(imageStr) == "" {
+		return nil
+	}
+	img := &Image{
+		MimeType:           "image/png",
+		BytesBase64Encoded: imageStr,
+	}
+	if strings.HasPrefix(imageStr, "data:image/") {
+		parts := strings.Split(imageStr, ";base64,")
+		if len(parts) == 2 {
+			img.MimeType = strings.TrimPrefix(parts[0], "data:")
+			img.BytesBase64Encoded = parts[1]
+		}
+	} else if strings.HasPrefix(imageStr, "http") {
+		mimeType, data, err := service.GetImageFromUrl(imageStr)
+		if err == nil {
+			img.MimeType = mimeType
+			img.BytesBase64Encoded = data
+		}
+	}
+	return img
+}
+
 // ============================
 // helpers
 // ============================