run_decode_script.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264
  1. #!/usr/bin/env python
  2. # -*- coding: utf-8 -*-
  3. """
  4. 批量运行解码工作流 (DecodeWorkflow)。
  5. 读取 examples/demo.json 中的视频列表,
  6. 逐一调用 DecodeWorkflow 进行处理,
  7. 并将结果输出到 examples/output_decode_result.json。
  8. """
  9. import json
  10. import sys
  11. from datetime import datetime
  12. from pathlib import Path
  13. from typing import Dict, Any, List
  14. # 添加项目根目录到路径
  15. project_root = Path(__file__).parent.parent
  16. sys.path.insert(0, str(project_root))
  17. from src.workflows.decode_workflow import DecodeWorkflow
  18. from src.utils.logger import get_logger
  19. logger = get_logger(__name__)
  20. def load_json(path: Path) -> List[Dict[str, Any]]:
  21. """加载JSON文件"""
  22. if not path.exists():
  23. return []
  24. with path.open("r", encoding="utf-8") as f:
  25. data = json.load(f)
  26. # 如果是字典且有 results 字段,提取 results
  27. if isinstance(data, dict) and "results" in data:
  28. return data["results"]
  29. # 如果是列表,直接返回
  30. elif isinstance(data, list):
  31. return data
  32. else:
  33. return []
  34. def save_json(path: Path, data: Dict[str, Any]) -> None:
  35. """保存JSON文件(使用临时文件确保原子性)"""
  36. tmp_path = path.with_suffix(".tmp")
  37. with tmp_path.open("w", encoding="utf-8") as f:
  38. json.dump(data, f, ensure_ascii=False, indent=2)
  39. tmp_path.replace(path)
  40. def build_decode_input(video_data: Dict[str, Any]) -> Dict[str, Any]:
  41. """根据视频数据构造 DecodeWorkflow 的输入结构"""
  42. return {
  43. "video_url": video_data.get("video_url", ""),
  44. "video_id": video_data.get("video_id", ""),
  45. "title": video_data.get("title", ""),
  46. }
  47. def main() -> None:
  48. """主函数"""
  49. base_dir = Path(__file__).parent
  50. input_path = base_dir / "demo.json"
  51. output_path = base_dir / "output_decode_result.json"
  52. if not input_path.exists():
  53. raise FileNotFoundError(f"找不到输入文件: {input_path}")
  54. # 读取视频列表
  55. video_list = load_json(input_path)
  56. if not video_list:
  57. logger.warning(f"输入文件 {input_path} 中没有视频数据")
  58. return
  59. logger.info(f"共读取到 {len(video_list)} 个视频")
  60. # 读取已有的输出结果,支持增量追加
  61. output_data = {}
  62. if output_path.exists():
  63. try:
  64. with output_path.open("r", encoding="utf-8") as f:
  65. output_data = json.load(f)
  66. except Exception as e:
  67. logger.warning(f"读取已有输出文件失败,将创建新文件: {e}")
  68. output_data = {}
  69. if not output_data:
  70. output_data = {
  71. "timestamp": datetime.now().strftime("%Y%m%d_%H%M%S"),
  72. "total": 0,
  73. "success_count": 0,
  74. "fail_count": 0,
  75. "results": [],
  76. }
  77. existing_results: List[Dict[str, Any]] = output_data.get("results", []) or []
  78. # 用 video_id + video URL 去重,避免重复处理(兼容旧字段名 channel_content_id)
  79. processed_keys = {
  80. f"{item.get('video_data', {}).get('video_id', '')}|"
  81. f"{item.get('video_data', {}).get('video_url', '')}"
  82. for item in existing_results
  83. }
  84. # 初始化工作流
  85. logger.info("初始化 DecodeWorkflow...")
  86. workflow = DecodeWorkflow()
  87. logger.info("DecodeWorkflow 初始化完成")
  88. # 处理每个视频
  89. for idx, video_data in enumerate(video_list, 1):
  90. video_url = video_data.get("video_url", "")
  91. video_id = video_data.get("video_id", "") or video_data.get("channel_content_id", "") # 兼容旧字段名
  92. title = video_data.get("title", "")
  93. # 生成唯一键用于去重
  94. key = f"{video_id}|{video_url}"
  95. if key in processed_keys:
  96. logger.info(f"[{idx}/{len(video_list)}] 已处理过该视频,跳过: video_id={video_id}")
  97. continue
  98. logger.info(
  99. f"[{idx}/{len(video_list)}] 开始处理视频: "
  100. f"video_id={video_id}, title={title[:50]}..."
  101. )
  102. try:
  103. # 构建输入数据
  104. decode_input = build_decode_input(video_data)
  105. # 调用工作流
  106. decode_result = workflow.invoke(decode_input)
  107. # 检查workflow返回结果中是否包含错误
  108. if isinstance(decode_result, dict):
  109. # 检查是否有错误字段(支持多种错误字段名)
  110. error_msg = (
  111. decode_result.get("error") or
  112. decode_result.get("错误") or
  113. decode_result.get("workflow_error")
  114. )
  115. workflow_status = decode_result.get("workflow_status")
  116. # 如果返回了错误信息,视为失败
  117. if error_msg or workflow_status == "failed" or workflow_status == "incomplete":
  118. error_msg = error_msg or "工作流执行失败"
  119. logger.error(
  120. f"[{idx}/{len(video_list)}] 处理失败: video_id={video_id}, error={error_msg}"
  121. )
  122. record = {
  123. "video_data": video_data,
  124. "what_deconstruction_result": None,
  125. "script_result": None,
  126. "success": False,
  127. "error": error_msg,
  128. }
  129. output_data["fail_count"] = output_data.get("fail_count", 0) + 1
  130. output_data["results"].append(record)
  131. output_data["total"] = output_data.get("total", 0) + 1
  132. save_json(output_path, output_data)
  133. continue
  134. # 检查结果是否为空(可能表示失败)
  135. # 如果所有关键字段都为空,可能表示处理失败
  136. video_info = decode_result.get("视频信息", {})
  137. three_points = decode_result.get("三点解构", {})
  138. topic_understanding = decode_result.get("选题理解", {})
  139. script_understanding = decode_result.get("脚本理解", {})
  140. # 如果所有关键结果都为空,且没有明确的成功标志,视为失败
  141. if (not video_info and not three_points and
  142. not topic_understanding and not script_understanding):
  143. error_msg = "工作流执行完成,但所有结果都为空"
  144. logger.warning(
  145. f"[{idx}/{len(video_list)}] 处理结果为空: video_id={video_id}"
  146. )
  147. # 这里可以选择记录为失败或警告,根据业务需求决定
  148. # 暂时记录为失败
  149. record = {
  150. "video_data": video_data,
  151. "what_deconstruction_result": None,
  152. "script_result": None,
  153. "success": False,
  154. "error": error_msg,
  155. }
  156. output_data["fail_count"] = output_data.get("fail_count", 0) + 1
  157. output_data["results"].append(record)
  158. output_data["total"] = output_data.get("total", 0) + 1
  159. save_json(output_path, output_data)
  160. continue
  161. # 按照 output_demo_script.json 的格式组织结果
  162. # what_deconstruction_result: 包含视频信息、三点解构、选题理解
  163. what_deconstruction_result = {
  164. "视频信息": decode_result.get("视频信息", {}),
  165. "三点解构": decode_result.get("三点解构", {}),
  166. "选题理解": decode_result.get("选题理解", {}),
  167. }
  168. # script_result: 包含选题描述和脚本理解
  169. # 从选题理解中提取选题描述
  170. topic_understanding = decode_result.get("选题理解", {})
  171. selected_topic = {}
  172. if isinstance(topic_understanding, dict):
  173. if "选题" in topic_understanding:
  174. selected_topic = topic_understanding.get("选题", {})
  175. else:
  176. selected_topic = {
  177. "主题": topic_understanding.get("主题", ""),
  178. "描述": topic_understanding.get("描述", ""),
  179. }
  180. script_result = {
  181. "选题描述": selected_topic,
  182. "脚本理解": decode_result.get("脚本理解", {}),
  183. }
  184. # 构造结果记录(参考 output_demo_script.json 格式)
  185. record = {
  186. "video_data": video_data,
  187. "what_deconstruction_result": what_deconstruction_result,
  188. "script_result": script_result,
  189. "success": True,
  190. "error": None,
  191. }
  192. output_data["success_count"] = output_data.get("success_count", 0) + 1
  193. logger.info(
  194. f"[{idx}/{len(video_list)}] 处理成功: video_id={video_id}"
  195. )
  196. except Exception as e:
  197. logger.error(
  198. f"[{idx}/{len(video_list)}] 处理失败: video_id={video_id}, error={e}",
  199. exc_info=True
  200. )
  201. record = {
  202. "video_data": video_data,
  203. "what_deconstruction_result": None,
  204. "script_result": None,
  205. "success": False,
  206. "error": str(e),
  207. }
  208. output_data["fail_count"] = output_data.get("fail_count", 0) + 1
  209. output_data["results"].append(record)
  210. output_data["total"] = output_data.get("total", 0) + 1
  211. # 处理完一条就保存一次,避免长任务中途失败导致全部丢失
  212. save_json(output_path, output_data)
  213. logger.info(f"结果已保存到 {output_path}")
  214. logger.info(
  215. f"\n{'='*60}\n"
  216. f"批量解码完成:\n"
  217. f" 总计: {output_data.get('total', 0)}\n"
  218. f" 成功: {output_data.get('success_count', 0)}\n"
  219. f" 失败: {output_data.get('fail_count', 0)}\n"
  220. f" 输出文件: {output_path}\n"
  221. f"{'='*60}"
  222. )
  223. if __name__ == "__main__":
  224. main()