video_upload_function.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294
  1. #!/usr/bin/env python
  2. # -*- coding: utf-8 -*-
  3. """
  4. 视频上传Function
  5. 功能: 下载视频到本地并上传至Gemini,保存上传链接到state
  6. """
  7. import os
  8. import time
  9. import tempfile
  10. import requests
  11. from pathlib import Path
  12. from urllib.parse import urlparse
  13. from typing import Dict, Any, Optional, Tuple
  14. from google import genai
  15. from src.components.functions.base import BaseFunction
  16. from src.states.what_deconstruction_state import WhatDeconstructionState
  17. from src.utils.logger import get_logger
  18. from src.utils.llm_invoker import LLMInvoker
  19. logger = get_logger(__name__)
  20. class VideoUploadFunction(BaseFunction[Dict[str, Any], Dict[str, Any]]):
  21. """视频上传函数
  22. 功能:
  23. - 从URL下载视频到本地
  24. - 上传视频到Gemini File API
  25. - 保存上传后的文件URI到state
  26. """
  27. def __init__(
  28. self,
  29. name: str = "video_upload_function",
  30. description: str = "下载视频并上传至Gemini,保存上传链接"
  31. ):
  32. super().__init__(name, description)
  33. def execute(
  34. self,
  35. input_data: Dict[str, Any],
  36. context: Optional[Dict[str, Any]] = None
  37. ) -> Dict[str, Any]:
  38. """执行视频上传
  39. Args:
  40. input_data: 包含video字段的状态字典
  41. context: 上下文信息
  42. Returns:
  43. 更新后的状态字典,包含video_uploaded_uri字段
  44. """
  45. try:
  46. video_url = input_data.get("video", "")
  47. if not video_url:
  48. logger.warning("未提供视频URL,跳过上传")
  49. return {
  50. **input_data,
  51. "video_uploaded_uri": None,
  52. "video_upload_error": "未提供视频URL"
  53. }
  54. logger.info(f"开始下载视频: {video_url}")
  55. # 1. 下载视频到本地(或使用examples/videos目录下的现有文件)
  56. # 从input_data中获取channel_content_id,用于查找examples/videos目录下的文件
  57. channel_content_id = input_data.get("channel_content_id", "")
  58. local_video_path, is_temp_file = self._download_video(video_url, channel_content_id)
  59. if not local_video_path:
  60. return {
  61. **input_data,
  62. "video_uploaded_uri": None,
  63. "video_upload_error": "视频下载失败"
  64. }
  65. logger.info(f"视频文件路径: {local_video_path}")
  66. # 2. 上传视频到Gemini(使用新的API,带state校验)
  67. video_file = LLMInvoker.upload_video_to_gemini(local_video_path)
  68. # # 3. 清理临时文件(只有临时文件才需要清理)
  69. # if is_temp_file:
  70. # try:
  71. # os.remove(local_video_path)
  72. # logger.info(f"临时文件已删除: {local_video_path}")
  73. # except Exception as e:
  74. # logger.warning(f"删除临时文件失败: {e}")
  75. # else:
  76. # logger.info(f"使用examples目录下的文件,不删除: {local_video_path}")
  77. if not video_file:
  78. return {
  79. **input_data,
  80. "video_uploaded_uri": None,
  81. "video_file_name": None,
  82. "video_upload_error": "视频上传到Gemini失败"
  83. }
  84. # 获取文件URI和名称
  85. file_uri = None
  86. file_name = None
  87. if hasattr(video_file, 'uri'):
  88. file_uri = video_file.uri
  89. elif hasattr(video_file, 'name'):
  90. file_name = video_file.name
  91. file_uri = f"https://generativelanguage.googleapis.com/v1beta/files/{file_name}"
  92. logger.info(f"视频上传成功,文件名称: {file_name}")
  93. # 4. 更新state
  94. return {
  95. **input_data,
  96. "video_uploaded_uri": file_uri, # 兼容旧版本
  97. "video_file_name": file_name, # 新字段,用于获取文件对象
  98. "video_upload_error": None
  99. }
  100. except Exception as e:
  101. logger.error(f"视频上传失败: {e}", exc_info=True)
  102. return {
  103. **input_data,
  104. "video_uploaded_uri": None,
  105. "video_file_name": None,
  106. "video_upload_error": str(e)
  107. }
  108. def _download_video(self, video_url: str, channel_content_id: str = "") -> Tuple[Optional[str], bool]:
  109. """下载视频到本地,或使用examples/videos目录下的现有文件
  110. Args:
  111. video_url: 视频URL
  112. channel_content_id: 频道内容ID,用于查找examples/videos目录下的文件
  113. Returns:
  114. (本地文件路径, 是否为临时文件) 的元组,失败返回 (None, True)
  115. 如果使用examples/videos目录下的文件,返回 (文件路径, False)
  116. 如果下载到examples/videos目录,返回 (文件路径, False)
  117. """
  118. try:
  119. # 1. 首先检查examples/videos目录下是否有对应的mp4文件
  120. existing_file = self._check_examples_directory(channel_content_id)
  121. if existing_file:
  122. logger.info(f"在examples/videos目录下找到现有文件,直接使用: {existing_file}")
  123. return existing_file, False
  124. # 2. 如果没有找到,则下载到examples/videos目录
  125. if not channel_content_id:
  126. logger.warning("未提供channel_content_id,无法保存到examples/videos目录")
  127. return None, True
  128. logger.info("未在examples/videos目录下找到同名文件,开始下载...")
  129. # 获取项目根目录
  130. project_root = Path(__file__).parent.parent.parent.parent
  131. videos_dir = project_root / "examples" / "videos"
  132. # 确保目录存在
  133. videos_dir.mkdir(parents=True, exist_ok=True)
  134. # 构建文件路径:examples/videos/{channel_content_id}.mp4
  135. target_path = videos_dir / f"{channel_content_id}.mp4"
  136. # 如果文件已存在(并发情况),直接返回
  137. if target_path.exists():
  138. logger.info(f"文件已存在: {target_path}")
  139. return str(target_path), False
  140. # 下载视频(带重试机制)
  141. max_retries = 3
  142. retry_count = 0
  143. last_exception = None
  144. while retry_count < max_retries:
  145. try:
  146. if retry_count > 0:
  147. logger.info(f"重试下载视频 (第 {retry_count}/{max_retries-1} 次)...")
  148. # 使用 Session 进行下载
  149. session = requests.Session()
  150. session.headers.update({
  151. "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36"
  152. })
  153. # 下载视频(增加超时时间)
  154. response = session.get(
  155. video_url,
  156. timeout=120, # (连接超时, 读取超时)
  157. stream=True
  158. )
  159. response.raise_for_status()
  160. # 写入文件
  161. with open(target_path, "wb") as f:
  162. for chunk in response.iter_content(chunk_size=8192):
  163. if chunk:
  164. f.write(chunk)
  165. # 验证文件大小
  166. file_size = os.path.getsize(target_path)
  167. if file_size == 0:
  168. raise ValueError("下载的文件大小为0")
  169. logger.info(f"视频下载完成,大小: {file_size / 1024 / 1024:.2f} MB,保存到: {target_path}")
  170. return str(target_path), False
  171. except (requests.exceptions.ChunkedEncodingError,
  172. requests.exceptions.ConnectionError,
  173. requests.exceptions.Timeout,
  174. requests.exceptions.RequestException) as e:
  175. last_exception = e
  176. retry_count += 1
  177. # 清理不完整的文件
  178. if target_path.exists():
  179. try:
  180. os.remove(target_path)
  181. except:
  182. pass
  183. if retry_count < max_retries:
  184. wait_time = retry_count * 2 # 递增等待时间:2秒、4秒
  185. logger.warning(f"下载失败 (尝试 {retry_count}/{max_retries}): {e}")
  186. logger.info(f"等待 {wait_time} 秒后重试...")
  187. time.sleep(wait_time)
  188. else:
  189. logger.error(f"下载失败,已重试 {max_retries} 次: {e}")
  190. raise
  191. except Exception as e:
  192. # 其他类型的异常直接抛出,不重试
  193. if target_path.exists():
  194. try:
  195. os.remove(target_path)
  196. except:
  197. pass
  198. raise
  199. # 如果所有重试都失败了
  200. if last_exception:
  201. raise last_exception
  202. except Exception as e:
  203. logger.error(f"下载视频失败: {e}", exc_info=True)
  204. return None, True
  205. def _check_examples_directory(self, channel_content_id: str) -> Optional[str]:
  206. """检查examples/videos目录下是否有对应的mp4文件
  207. 文件路径格式:examples/videos/{channel_content_id}.mp4
  208. Args:
  209. channel_content_id: 频道内容ID
  210. Returns:
  211. 如果找到文件,返回文件路径;否则返回None
  212. """
  213. try:
  214. # 如果没有提供channel_content_id,无法查找
  215. if not channel_content_id:
  216. logger.info("未提供channel_content_id,跳过examples/videos目录检查")
  217. return None
  218. # 获取项目根目录
  219. # __file__ 是 src/components/functions/video_upload_function.py
  220. # 需要往上4层才能到项目根目录
  221. project_root = Path(__file__).parent.parent.parent.parent
  222. videos_dir = project_root / "examples" / "videos"
  223. if not videos_dir.exists():
  224. logger.info(f"examples/videos目录不存在: {videos_dir}")
  225. return None
  226. # 构建文件路径:examples/videos/{channel_content_id}.mp4
  227. mp4_file = videos_dir / f"{channel_content_id}.mp4"
  228. logger.info(f"构建文件路径: {mp4_file}")
  229. # 检查文件是否存在
  230. if mp4_file.exists() and mp4_file.is_file():
  231. logger.info(f"在examples/videos目录下找到文件: {mp4_file}")
  232. return str(mp4_file)
  233. logger.debug(f"在examples/videos目录下未找到文件: {mp4_file}")
  234. return None
  235. except Exception as e:
  236. logger.warning(f"检查examples/videos目录时出错: {e}", exc_info=True)
  237. return None