run_decode_script.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266
  1. #!/usr/bin/env python
  2. # -*- coding: utf-8 -*-
  3. """
  4. 批量运行解码工作流 (DecodeWorkflow)。
  5. 读取 examples/demo.json 中的视频列表,
  6. 逐一调用 DecodeWorkflow 进行处理,
  7. 并将结果输出到 examples/output_decode_result.json。
  8. """
  9. import json
  10. import sys
  11. from datetime import datetime
  12. from pathlib import Path
  13. from typing import Dict, Any, List
  14. # 添加项目根目录到路径
  15. project_root = Path(__file__).parent.parent
  16. sys.path.insert(0, str(project_root))
  17. from src.workflows.decode_workflow import DecodeWorkflow
  18. from src.utils.logger import get_logger
  19. logger = get_logger(__name__)
  20. def load_json(path: Path) -> List[Dict[str, Any]]:
  21. """加载JSON文件"""
  22. if not path.exists():
  23. return []
  24. with path.open("r", encoding="utf-8") as f:
  25. data = json.load(f)
  26. # 如果是字典且有 results 字段,提取 results
  27. if isinstance(data, dict) and "results" in data:
  28. return data["results"]
  29. # 如果是列表,直接返回
  30. elif isinstance(data, list):
  31. return data
  32. else:
  33. return []
  34. def save_json(path: Path, data: Dict[str, Any]) -> None:
  35. """保存JSON文件(使用临时文件确保原子性)"""
  36. tmp_path = path.with_suffix(".tmp")
  37. with tmp_path.open("w", encoding="utf-8") as f:
  38. json.dump(data, f, ensure_ascii=False, indent=2)
  39. tmp_path.replace(path)
  40. def build_decode_input(video_data: Dict[str, Any]) -> Dict[str, Any]:
  41. """根据视频数据构造 DecodeWorkflow 的输入结构"""
  42. return {
  43. "video": video_data.get("video", ""),
  44. "channel_content_id": video_data.get("channel_content_id", ""),
  45. "title": video_data.get("title", ""),
  46. "body_text": video_data.get("body_text", ""),
  47. }
  48. def main() -> None:
  49. """主函数"""
  50. base_dir = Path(__file__).parent
  51. input_path = base_dir / "demo.json"
  52. output_path = base_dir / "output_decode_result.json"
  53. if not input_path.exists():
  54. raise FileNotFoundError(f"找不到输入文件: {input_path}")
  55. # 读取视频列表
  56. video_list = load_json(input_path)
  57. if not video_list:
  58. logger.warning(f"输入文件 {input_path} 中没有视频数据")
  59. return
  60. logger.info(f"共读取到 {len(video_list)} 个视频")
  61. # 读取已有的输出结果,支持增量追加
  62. output_data = {}
  63. if output_path.exists():
  64. try:
  65. with output_path.open("r", encoding="utf-8") as f:
  66. output_data = json.load(f)
  67. except Exception as e:
  68. logger.warning(f"读取已有输出文件失败,将创建新文件: {e}")
  69. output_data = {}
  70. if not output_data:
  71. output_data = {
  72. "timestamp": datetime.now().strftime("%Y%m%d_%H%M%S"),
  73. "total": 0,
  74. "success_count": 0,
  75. "fail_count": 0,
  76. "results": [],
  77. }
  78. existing_results: List[Dict[str, Any]] = output_data.get("results", []) or []
  79. # 用 channel_content_id + video URL 去重,避免重复处理
  80. processed_keys = {
  81. f"{item.get('video_data', {}).get('channel_content_id', '')}|"
  82. f"{item.get('video_data', {}).get('video', '')}"
  83. for item in existing_results
  84. }
  85. # 初始化工作流
  86. logger.info("初始化 DecodeWorkflow...")
  87. workflow = DecodeWorkflow()
  88. logger.info("DecodeWorkflow 初始化完成")
  89. # 处理每个视频
  90. for idx, video_data in enumerate(video_list, 1):
  91. video_url = video_data.get("video", "")
  92. channel_content_id = video_data.get("channel_content_id", "")
  93. title = video_data.get("title", "")
  94. # 生成唯一键用于去重
  95. key = f"{channel_content_id}|{video_url}"
  96. if key in processed_keys:
  97. logger.info(f"[{idx}/{len(video_list)}] 已处理过该视频,跳过: channel_content_id={channel_content_id}")
  98. continue
  99. logger.info(
  100. f"[{idx}/{len(video_list)}] 开始处理视频: "
  101. f"channel_content_id={channel_content_id}, title={title[:50]}..."
  102. )
  103. try:
  104. # 构建输入数据
  105. decode_input = build_decode_input(video_data)
  106. # 调用工作流
  107. decode_result = workflow.invoke(decode_input)
  108. # 检查workflow返回结果中是否包含错误
  109. if isinstance(decode_result, dict):
  110. # 检查是否有错误字段(支持多种错误字段名)
  111. error_msg = (
  112. decode_result.get("error") or
  113. decode_result.get("错误") or
  114. decode_result.get("workflow_error")
  115. )
  116. workflow_status = decode_result.get("workflow_status")
  117. # 如果返回了错误信息,视为失败
  118. if error_msg or workflow_status == "failed" or workflow_status == "incomplete":
  119. error_msg = error_msg or "工作流执行失败"
  120. logger.error(
  121. f"[{idx}/{len(video_list)}] 处理失败: channel_content_id={channel_content_id}, error={error_msg}"
  122. )
  123. record = {
  124. "video_data": video_data,
  125. "what_deconstruction_result": None,
  126. "script_result": None,
  127. "success": False,
  128. "error": error_msg,
  129. }
  130. output_data["fail_count"] = output_data.get("fail_count", 0) + 1
  131. output_data["results"].append(record)
  132. output_data["total"] = output_data.get("total", 0) + 1
  133. save_json(output_path, output_data)
  134. continue
  135. # 检查结果是否为空(可能表示失败)
  136. # 如果所有关键字段都为空,可能表示处理失败
  137. video_info = decode_result.get("视频信息", {})
  138. three_points = decode_result.get("三点解构", {})
  139. topic_understanding = decode_result.get("选题理解", {})
  140. script_understanding = decode_result.get("脚本理解", {})
  141. # 如果所有关键结果都为空,且没有明确的成功标志,视为失败
  142. if (not video_info and not three_points and
  143. not topic_understanding and not script_understanding):
  144. error_msg = "工作流执行完成,但所有结果都为空"
  145. logger.warning(
  146. f"[{idx}/{len(video_list)}] 处理结果为空: channel_content_id={channel_content_id}"
  147. )
  148. # 这里可以选择记录为失败或警告,根据业务需求决定
  149. # 暂时记录为失败
  150. record = {
  151. "video_data": video_data,
  152. "what_deconstruction_result": None,
  153. "script_result": None,
  154. "success": False,
  155. "error": error_msg,
  156. }
  157. output_data["fail_count"] = output_data.get("fail_count", 0) + 1
  158. output_data["results"].append(record)
  159. output_data["total"] = output_data.get("total", 0) + 1
  160. save_json(output_path, output_data)
  161. continue
  162. # 按照 output_demo_script.json 的格式组织结果
  163. # what_deconstruction_result: 包含视频信息、三点解构、选题理解
  164. what_deconstruction_result = {
  165. "视频信息": decode_result.get("视频信息", {}),
  166. "三点解构": decode_result.get("三点解构", {}),
  167. "选题理解": decode_result.get("选题理解", {}),
  168. }
  169. # script_result: 包含选题描述和脚本理解
  170. # 从选题理解中提取选题描述
  171. topic_understanding = decode_result.get("选题理解", {})
  172. selected_topic = {}
  173. if isinstance(topic_understanding, dict):
  174. if "选题" in topic_understanding:
  175. selected_topic = topic_understanding.get("选题", {})
  176. else:
  177. selected_topic = {
  178. "主题": topic_understanding.get("主题", ""),
  179. "描述": topic_understanding.get("描述", ""),
  180. }
  181. script_result = {
  182. "选题描述": selected_topic,
  183. "脚本理解": decode_result.get("脚本理解", {}),
  184. }
  185. # 构造结果记录(参考 output_demo_script.json 格式)
  186. record = {
  187. "video_data": video_data,
  188. "what_deconstruction_result": what_deconstruction_result,
  189. "script_result": script_result,
  190. "success": True,
  191. "error": None,
  192. }
  193. output_data["success_count"] = output_data.get("success_count", 0) + 1
  194. logger.info(
  195. f"[{idx}/{len(video_list)}] 处理成功: channel_content_id={channel_content_id}"
  196. )
  197. except Exception as e:
  198. logger.error(
  199. f"[{idx}/{len(video_list)}] 处理失败: channel_content_id={channel_content_id}, error={e}",
  200. exc_info=True
  201. )
  202. record = {
  203. "video_data": video_data,
  204. "what_deconstruction_result": None,
  205. "script_result": None,
  206. "success": False,
  207. "error": str(e),
  208. }
  209. output_data["fail_count"] = output_data.get("fail_count", 0) + 1
  210. output_data["results"].append(record)
  211. output_data["total"] = output_data.get("total", 0) + 1
  212. # 处理完一条就保存一次,避免长任务中途失败导致全部丢失
  213. save_json(output_path, output_data)
  214. logger.info(f"结果已保存到 {output_path}")
  215. logger.info(
  216. f"\n{'='*60}\n"
  217. f"批量解码完成:\n"
  218. f" 总计: {output_data.get('total', 0)}\n"
  219. f" 成功: {output_data.get('success_count', 0)}\n"
  220. f" 失败: {output_data.get('fail_count', 0)}\n"
  221. f" 输出文件: {output_path}\n"
  222. f"{'='*60}"
  223. )
  224. if __name__ == "__main__":
  225. main()