浏览代码

Merge pull request #3038 from seefs001/fix/video-vertex-fetch

fix: align Vertex content fetch flow with Gemini and handle base64
Calcium-Ion 1 周之前
父节点
当前提交
d668788be2
共有 3 个文件被更改,包括 187 次插入1 次删除
  1. 57 0
      controller/video_proxy.go
  2. 128 0
      controller/video_proxy_gemini.go
  3. 2 1
      model/task.go

+ 57 - 0
controller/video_proxy.go

@@ -2,10 +2,12 @@ package controller
 
 
 import (
 import (
 	"context"
 	"context"
+	"encoding/base64"
 	"fmt"
 	"fmt"
 	"io"
 	"io"
 	"net/http"
 	"net/http"
 	"net/url"
 	"net/url"
+	"strings"
 	"time"
 	"time"
 
 
 	"github.com/QuantumNous/new-api/constant"
 	"github.com/QuantumNous/new-api/constant"
@@ -94,6 +96,13 @@ func VideoProxy(c *gin.Context) {
 			return
 			return
 		}
 		}
 		req.Header.Set("x-goog-api-key", apiKey)
 		req.Header.Set("x-goog-api-key", apiKey)
+	case constant.ChannelTypeVertexAi:
+		videoURL, err = getVertexVideoURL(channel, task)
+		if err != nil {
+			logger.LogError(c.Request.Context(), fmt.Sprintf("Failed to resolve Vertex video URL for task %s: %s", taskID, err.Error()))
+			videoProxyError(c, http.StatusBadGateway, "server_error", "Failed to resolve Vertex video URL")
+			return
+		}
 	case constant.ChannelTypeOpenAI, constant.ChannelTypeSora:
 	case constant.ChannelTypeOpenAI, constant.ChannelTypeSora:
 		videoURL = fmt.Sprintf("%s/v1/videos/%s/content", baseURL, task.GetUpstreamTaskID())
 		videoURL = fmt.Sprintf("%s/v1/videos/%s/content", baseURL, task.GetUpstreamTaskID())
 		req.Header.Set("Authorization", "Bearer "+channel.Key)
 		req.Header.Set("Authorization", "Bearer "+channel.Key)
@@ -102,6 +111,21 @@ func VideoProxy(c *gin.Context) {
 		videoURL = task.GetResultURL()
 		videoURL = task.GetResultURL()
 	}
 	}
 
 
+	videoURL = strings.TrimSpace(videoURL)
+	if videoURL == "" {
+		logger.LogError(c.Request.Context(), fmt.Sprintf("Video URL is empty for task %s", taskID))
+		videoProxyError(c, http.StatusBadGateway, "server_error", "Failed to fetch video content")
+		return
+	}
+
+	if strings.HasPrefix(videoURL, "data:") {
+		if err := writeVideoDataURL(c, videoURL); err != nil {
+			logger.LogError(c.Request.Context(), fmt.Sprintf("Failed to decode video data URL for task %s: %s", taskID, err.Error()))
+			videoProxyError(c, http.StatusBadGateway, "server_error", "Failed to fetch video content")
+		}
+		return
+	}
+
 	req.URL, err = url.Parse(videoURL)
 	req.URL, err = url.Parse(videoURL)
 	if err != nil {
 	if err != nil {
 		logger.LogError(c.Request.Context(), fmt.Sprintf("Failed to parse URL %s: %s", videoURL, err.Error()))
 		logger.LogError(c.Request.Context(), fmt.Sprintf("Failed to parse URL %s: %s", videoURL, err.Error()))
@@ -136,3 +160,36 @@ func VideoProxy(c *gin.Context) {
 		logger.LogError(c.Request.Context(), fmt.Sprintf("Failed to stream video content: %s", err.Error()))
 		logger.LogError(c.Request.Context(), fmt.Sprintf("Failed to stream video content: %s", err.Error()))
 	}
 	}
 }
 }
+
+func writeVideoDataURL(c *gin.Context, dataURL string) error {
+	parts := strings.SplitN(dataURL, ",", 2)
+	if len(parts) != 2 {
+		return fmt.Errorf("invalid data url")
+	}
+
+	header := parts[0]
+	payload := parts[1]
+	if !strings.HasPrefix(header, "data:") || !strings.Contains(header, ";base64") {
+		return fmt.Errorf("unsupported data url")
+	}
+
+	mimeType := strings.TrimPrefix(header, "data:")
+	mimeType = strings.TrimSuffix(mimeType, ";base64")
+	if mimeType == "" {
+		mimeType = "video/mp4"
+	}
+
+	videoBytes, err := base64.StdEncoding.DecodeString(payload)
+	if err != nil {
+		videoBytes, err = base64.RawStdEncoding.DecodeString(payload)
+		if err != nil {
+			return err
+		}
+	}
+
+	c.Writer.Header().Set("Content-Type", mimeType)
+	c.Writer.Header().Set("Cache-Control", "public, max-age=86400")
+	c.Writer.WriteHeader(http.StatusOK)
+	_, err = c.Writer.Write(videoBytes)
+	return err
+}

+ 128 - 0
controller/video_proxy_gemini.go

@@ -145,6 +145,134 @@ func extractGeminiVideoURLFromGeneratedSamples(gvr map[string]any) string {
 	return ""
 	return ""
 }
 }
 
 
+func getVertexVideoURL(channel *model.Channel, task *model.Task) (string, error) {
+	if channel == nil || task == nil {
+		return "", fmt.Errorf("invalid channel or task")
+	}
+	if url := strings.TrimSpace(task.GetResultURL()); url != "" {
+		return url, nil
+	}
+	if url := extractVertexVideoURLFromTaskData(task); url != "" {
+		return url, nil
+	}
+
+	baseURL := constant.ChannelBaseURLs[channel.Type]
+	if channel.GetBaseURL() != "" {
+		baseURL = channel.GetBaseURL()
+	}
+
+	adaptor := relay.GetTaskAdaptor(constant.TaskPlatform(strconv.Itoa(channel.Type)))
+	if adaptor == nil {
+		return "", fmt.Errorf("vertex task adaptor not found")
+	}
+
+	key := getVertexTaskKey(channel, task)
+	if key == "" {
+		return "", fmt.Errorf("vertex key not available for task")
+	}
+
+	resp, err := adaptor.FetchTask(baseURL, key, map[string]any{
+		"task_id": task.GetUpstreamTaskID(),
+		"action":  task.Action,
+	}, channel.GetSetting().Proxy)
+	if err != nil {
+		return "", fmt.Errorf("fetch task failed: %w", err)
+	}
+	defer resp.Body.Close()
+
+	body, err := io.ReadAll(resp.Body)
+	if err != nil {
+		return "", fmt.Errorf("read task response failed: %w", err)
+	}
+
+	taskInfo, parseErr := adaptor.ParseTaskResult(body)
+	if parseErr == nil && taskInfo != nil && strings.TrimSpace(taskInfo.Url) != "" {
+		return taskInfo.Url, nil
+	}
+	if url := extractVertexVideoURLFromPayload(body); url != "" {
+		return url, nil
+	}
+	if parseErr != nil {
+		return "", fmt.Errorf("parse task result failed: %w", parseErr)
+	}
+	return "", fmt.Errorf("vertex video url not found")
+}
+
+func getVertexTaskKey(channel *model.Channel, task *model.Task) string {
+	if task != nil {
+		if key := strings.TrimSpace(task.PrivateData.Key); key != "" {
+			return key
+		}
+	}
+	if channel == nil {
+		return ""
+	}
+	keys := channel.GetKeys()
+	for _, key := range keys {
+		key = strings.TrimSpace(key)
+		if key != "" {
+			return key
+		}
+	}
+	return strings.TrimSpace(channel.Key)
+}
+
+func extractVertexVideoURLFromTaskData(task *model.Task) string {
+	if task == nil || len(task.Data) == 0 {
+		return ""
+	}
+	return extractVertexVideoURLFromPayload(task.Data)
+}
+
+func extractVertexVideoURLFromPayload(body []byte) string {
+	var payload map[string]any
+	if err := common.Unmarshal(body, &payload); err != nil {
+		return ""
+	}
+	resp, ok := payload["response"].(map[string]any)
+	if !ok || resp == nil {
+		return ""
+	}
+
+	if videos, ok := resp["videos"].([]any); ok && len(videos) > 0 {
+		if video, ok := videos[0].(map[string]any); ok && video != nil {
+			if b64, _ := video["bytesBase64Encoded"].(string); strings.TrimSpace(b64) != "" {
+				mime, _ := video["mimeType"].(string)
+				enc, _ := video["encoding"].(string)
+				return buildVideoDataURL(mime, enc, b64)
+			}
+		}
+	}
+	if b64, _ := resp["bytesBase64Encoded"].(string); strings.TrimSpace(b64) != "" {
+		enc, _ := resp["encoding"].(string)
+		return buildVideoDataURL("", enc, b64)
+	}
+	if video, _ := resp["video"].(string); strings.TrimSpace(video) != "" {
+		if strings.HasPrefix(video, "data:") || strings.HasPrefix(video, "http://") || strings.HasPrefix(video, "https://") {
+			return video
+		}
+		enc, _ := resp["encoding"].(string)
+		return buildVideoDataURL("", enc, video)
+	}
+	return ""
+}
+
+func buildVideoDataURL(mimeType string, encoding string, base64Data string) string {
+	mime := strings.TrimSpace(mimeType)
+	if mime == "" {
+		enc := strings.TrimSpace(encoding)
+		if enc == "" {
+			enc = "mp4"
+		}
+		if strings.Contains(enc, "/") {
+			mime = enc
+		} else {
+			mime = "video/" + enc
+		}
+	}
+	return "data:" + mime + ";base64," + base64Data
+}
+
 func ensureAPIKey(uri, key string) string {
 func ensureAPIKey(uri, key string) string {
 	if key == "" || uri == "" {
 	if key == "" || uri == "" {
 		return uri
 		return uri

+ 2 - 1
model/task.go

@@ -173,7 +173,8 @@ func InitTask(platform constant.TaskPlatform, relayInfo *commonRelay.RelayInfo)
 	properties := Properties{}
 	properties := Properties{}
 	privateData := TaskPrivateData{}
 	privateData := TaskPrivateData{}
 	if relayInfo != nil && relayInfo.ChannelMeta != nil {
 	if relayInfo != nil && relayInfo.ChannelMeta != nil {
-		if relayInfo.ChannelMeta.ChannelType == constant.ChannelTypeGemini {
+		if relayInfo.ChannelMeta.ChannelType == constant.ChannelTypeGemini ||
+			relayInfo.ChannelMeta.ChannelType == constant.ChannelTypeVertexAi {
 			privateData.Key = relayInfo.ChannelMeta.ApiKey
 			privateData.Key = relayInfo.ChannelMeta.ApiKey
 		}
 		}
 		if relayInfo.UpstreamModelName != "" {
 		if relayInfo.UpstreamModelName != "" {