Bladeren bron

fix: align Vertex content fetch flow with Gemini and handle base64 payloads

Seefs 1 week geleden
bovenliggende
commit
5ed997905c
2 gewijzigde bestanden met toevoegingen van 161 en 0 verwijderingen
  1. 57 0
      controller/video_proxy.go
  2. 104 0
      controller/video_proxy_gemini.go

+ 57 - 0
controller/video_proxy.go

@@ -2,10 +2,12 @@ package controller
 
 import (
 	"context"
+	"encoding/base64"
 	"fmt"
 	"io"
 	"net/http"
 	"net/url"
+	"strings"
 	"time"
 
 	"github.com/QuantumNous/new-api/constant"
@@ -94,6 +96,13 @@ func VideoProxy(c *gin.Context) {
 			return
 		}
 		req.Header.Set("x-goog-api-key", apiKey)
+	case constant.ChannelTypeVertexAi:
+		videoURL, err = getVertexVideoURL(channel, task)
+		if err != nil {
+			logger.LogError(c.Request.Context(), fmt.Sprintf("Failed to resolve Vertex video URL for task %s: %s", taskID, err.Error()))
+			videoProxyError(c, http.StatusBadGateway, "server_error", "Failed to resolve Vertex video URL")
+			return
+		}
 	case constant.ChannelTypeOpenAI, constant.ChannelTypeSora:
 		videoURL = fmt.Sprintf("%s/v1/videos/%s/content", baseURL, task.GetUpstreamTaskID())
 		req.Header.Set("Authorization", "Bearer "+channel.Key)
@@ -102,6 +111,21 @@ func VideoProxy(c *gin.Context) {
 		videoURL = task.GetResultURL()
 	}
 
+	videoURL = strings.TrimSpace(videoURL)
+	if videoURL == "" {
+		logger.LogError(c.Request.Context(), fmt.Sprintf("Video URL is empty for task %s", taskID))
+		videoProxyError(c, http.StatusBadGateway, "server_error", "Failed to fetch video content")
+		return
+	}
+
+	if strings.HasPrefix(videoURL, "data:") {
+		if err := writeVideoDataURL(c, videoURL); err != nil {
+			logger.LogError(c.Request.Context(), fmt.Sprintf("Failed to decode video data URL for task %s: %s", taskID, err.Error()))
+			videoProxyError(c, http.StatusBadGateway, "server_error", "Failed to fetch video content")
+		}
+		return
+	}
+
 	req.URL, err = url.Parse(videoURL)
 	if err != nil {
 		logger.LogError(c.Request.Context(), fmt.Sprintf("Failed to parse URL %s: %s", videoURL, err.Error()))
@@ -136,3 +160,36 @@ func VideoProxy(c *gin.Context) {
 		logger.LogError(c.Request.Context(), fmt.Sprintf("Failed to stream video content: %s", err.Error()))
 	}
 }
+
+func writeVideoDataURL(c *gin.Context, dataURL string) error {
+	parts := strings.SplitN(dataURL, ",", 2)
+	if len(parts) != 2 {
+		return fmt.Errorf("invalid data url")
+	}
+
+	header := parts[0]
+	payload := parts[1]
+	if !strings.HasPrefix(header, "data:") || !strings.Contains(header, ";base64") {
+		return fmt.Errorf("unsupported data url")
+	}
+
+	mimeType := strings.TrimPrefix(header, "data:")
+	mimeType = strings.TrimSuffix(mimeType, ";base64")
+	if mimeType == "" {
+		mimeType = "video/mp4"
+	}
+
+	videoBytes, err := base64.StdEncoding.DecodeString(payload)
+	if err != nil {
+		videoBytes, err = base64.RawStdEncoding.DecodeString(payload)
+		if err != nil {
+			return err
+		}
+	}
+
+	c.Writer.Header().Set("Content-Type", mimeType)
+	c.Writer.Header().Set("Cache-Control", "public, max-age=86400")
+	c.Writer.WriteHeader(http.StatusOK)
+	_, err = c.Writer.Write(videoBytes)
+	return err
+}

+ 104 - 0
controller/video_proxy_gemini.go

@@ -145,6 +145,110 @@ func extractGeminiVideoURLFromGeneratedSamples(gvr map[string]any) string {
 	return ""
 }
 
+func getVertexVideoURL(channel *model.Channel, task *model.Task) (string, error) {
+	if channel == nil || task == nil {
+		return "", fmt.Errorf("invalid channel or task")
+	}
+	if url := strings.TrimSpace(task.GetResultURL()); url != "" {
+		return url, nil
+	}
+	if url := extractVertexVideoURLFromTaskData(task); url != "" {
+		return url, nil
+	}
+
+	baseURL := constant.ChannelBaseURLs[channel.Type]
+	if channel.GetBaseURL() != "" {
+		baseURL = channel.GetBaseURL()
+	}
+
+	adaptor := relay.GetTaskAdaptor(constant.TaskPlatform(strconv.Itoa(channel.Type)))
+	if adaptor == nil {
+		return "", fmt.Errorf("vertex task adaptor not found")
+	}
+
+	resp, err := adaptor.FetchTask(baseURL, channel.Key, map[string]any{
+		"task_id": task.GetUpstreamTaskID(),
+		"action":  task.Action,
+	}, channel.GetSetting().Proxy)
+	if err != nil {
+		return "", fmt.Errorf("fetch task failed: %w", err)
+	}
+	defer resp.Body.Close()
+
+	body, err := io.ReadAll(resp.Body)
+	if err != nil {
+		return "", fmt.Errorf("read task response failed: %w", err)
+	}
+
+	taskInfo, parseErr := adaptor.ParseTaskResult(body)
+	if parseErr == nil && taskInfo != nil && strings.TrimSpace(taskInfo.Url) != "" {
+		return taskInfo.Url, nil
+	}
+	if url := extractVertexVideoURLFromPayload(body); url != "" {
+		return url, nil
+	}
+	if parseErr != nil {
+		return "", fmt.Errorf("parse task result failed: %w", parseErr)
+	}
+	return "", fmt.Errorf("vertex video url not found")
+}
+
+func extractVertexVideoURLFromTaskData(task *model.Task) string {
+	if task == nil || len(task.Data) == 0 {
+		return ""
+	}
+	return extractVertexVideoURLFromPayload(task.Data)
+}
+
+func extractVertexVideoURLFromPayload(body []byte) string {
+	var payload map[string]any
+	if err := common.Unmarshal(body, &payload); err != nil {
+		return ""
+	}
+	resp, ok := payload["response"].(map[string]any)
+	if !ok || resp == nil {
+		return ""
+	}
+
+	if videos, ok := resp["videos"].([]any); ok && len(videos) > 0 {
+		if video, ok := videos[0].(map[string]any); ok && video != nil {
+			if b64, _ := video["bytesBase64Encoded"].(string); strings.TrimSpace(b64) != "" {
+				mime, _ := video["mimeType"].(string)
+				enc, _ := video["encoding"].(string)
+				return buildVideoDataURL(mime, enc, b64)
+			}
+		}
+	}
+	if b64, _ := resp["bytesBase64Encoded"].(string); strings.TrimSpace(b64) != "" {
+		enc, _ := resp["encoding"].(string)
+		return buildVideoDataURL("", enc, b64)
+	}
+	if video, _ := resp["video"].(string); strings.TrimSpace(video) != "" {
+		if strings.HasPrefix(video, "data:") || strings.HasPrefix(video, "http://") || strings.HasPrefix(video, "https://") {
+			return video
+		}
+		enc, _ := resp["encoding"].(string)
+		return buildVideoDataURL("", enc, video)
+	}
+	return ""
+}
+
+func buildVideoDataURL(mimeType string, encoding string, base64Data string) string {
+	mime := strings.TrimSpace(mimeType)
+	if mime == "" {
+		enc := strings.TrimSpace(encoding)
+		if enc == "" {
+			enc = "mp4"
+		}
+		if strings.Contains(enc, "/") {
+			mime = enc
+		} else {
+			mime = "video/" + enc
+		}
+	}
+	return "data:" + mime + ";base64," + base64Data
+}
+
 func ensureAPIKey(uri, key string) string {
 	if key == "" || uri == "" {
 		return uri