run_script_single.py 23 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684
  1. """
  2. 脚本理解测试脚本
  3. 功能:
  4. 1. 从指定目录读取最新的 result_XXX.json 文件
  5. 2. 提取选题描述和帖子内容
  6. 3. 运行 ScriptUnderstandingAgent(step1 + step2)
  7. 4. 保存结果到 script_result_xxx_xxx.json
  8. """
  9. import json
  10. import sys
  11. import os
  12. import argparse
  13. import time
  14. from pathlib import Path
  15. from datetime import datetime
  16. # 添加项目根目录到路径
  17. project_root = Path(__file__).parent.parent
  18. sys.path.insert(0, str(project_root))
  19. # 手动加载.env文件
  20. def load_env_file(env_path):
  21. """手动加载.env文件"""
  22. if not env_path.exists():
  23. return False
  24. with open(env_path, 'r') as f:
  25. for line in f:
  26. line = line.strip()
  27. # 跳过注释和空行
  28. if not line or line.startswith('#'):
  29. continue
  30. # 解析KEY=VALUE
  31. if '=' in line:
  32. key, value = line.split('=', 1)
  33. os.environ[key.strip()] = value.strip()
  34. return True
  35. env_path = project_root / ".env"
  36. if load_env_file(env_path):
  37. print(f"✅ 已加载环境变量从: {env_path}")
  38. # 验证API密钥
  39. api_key = os.environ.get("GEMINI_API_KEY", "")
  40. if api_key:
  41. print(f" GEMINI_API_KEY: {api_key[:10]}...")
  42. else:
  43. print(f"⚠️ 未找到.env文件: {env_path}")
  44. from src.components.agents.script_section_division_agent import ScriptSectionDivisionAgent
  45. from src.components.agents.script_element_extraction_agent import ScriptElementExtractionAgent
  46. from src.utils.logger import get_logger
  47. from src.utils.llm_invoker import LLMInvoker
  48. import requests
  49. import tempfile
  50. import os
  51. from urllib3.exceptions import IncompleteRead
  52. logger = get_logger(__name__)
  53. def find_latest_result_file(directory):
  54. """
  55. 查找指定目录中最新的 result_XXX.json 文件
  56. Args:
  57. directory: 帖子目录名(如"阿里多多酱"或"G88818")
  58. Returns:
  59. Path: 最新result文件的路径,如果找不到则返回None
  60. """
  61. output_dir = Path(__file__).parent / directory / "output"
  62. if not output_dir.exists():
  63. print(f"⚠️ 输出目录不存在: {output_dir}")
  64. return None
  65. # 查找所有result_*.json文件
  66. result_files = list(output_dir.glob("result_*.json"))
  67. if not result_files:
  68. print(f"⚠️ 未找到result_*.json文件")
  69. return None
  70. # 按修改时间排序,取最新的
  71. latest_file = max(result_files, key=lambda p: p.stat().st_mtime)
  72. return latest_file
  73. def find_post_file(directory):
  74. """
  75. 查找指定目录中的视频详情.json文件
  76. Args:
  77. directory: 视频目录名(如"56898272")
  78. Returns:
  79. Path: 视频详情文件的路径,如果找不到则返回None
  80. """
  81. post_file = Path(__file__).parent / directory / "视频详情.json"
  82. if not post_file.exists():
  83. print(f"⚠️ 视频详情文件不存在: {post_file}")
  84. return None
  85. return post_file
  86. def load_result_file(file_path):
  87. """
  88. 加载result文件
  89. Args:
  90. file_path: result文件路径
  91. Returns:
  92. dict: 解析后的JSON数据
  93. """
  94. with open(file_path, 'r', encoding='utf-8') as f:
  95. data = json.load(f)
  96. return data
  97. def extract_topic_description(result_data):
  98. """
  99. 从result数据中提取选题描述
  100. Args:
  101. result_data: result.json的数据
  102. Returns:
  103. dict: 选题描述字典
  104. """
  105. topic_understanding = result_data.get("选题理解", {})
  106. # 返回结构化的选题描述
  107. return {
  108. "主题": topic_understanding.get("主题", ""),
  109. "描述": topic_understanding.get("描述", "")
  110. }
  111. def infer_content_category(result_data, post_data):
  112. """
  113. 从result数据和帖子数据中推断内容品类
  114. Args:
  115. result_data: result.json的数据
  116. post_data: 待解构帖子.json的数据
  117. Returns:
  118. str: 内容品类
  119. """
  120. # 尝试从选题理解中推断
  121. topic_understanding = result_data.get("选题理解", {})
  122. theme = topic_understanding.get("主题", "")
  123. description = topic_understanding.get("描述", "")
  124. # 基于关键词推断品类
  125. content = f"{theme} {description} {post_data.get('title', '')} {post_data.get('body_text', '')}"
  126. content_lower = content.lower()
  127. # 常见品类关键词映射
  128. category_keywords = {
  129. "美妆教程": ["化妆", "眼妆", "底妆", "口红", "粉底"],
  130. "美甲分享": ["美甲", "指甲", "甲油", "美甲设计"],
  131. "美食教程": ["食谱", "做菜", "烹饪", "美食", "制作"],
  132. "穿搭分享": ["穿搭", "搭配", "outfit", "服装", "衣服"],
  133. "旅行vlog": ["旅行", "旅游", "打卡", "游玩", "景点"],
  134. "健身教程": ["健身", "运动", "锻炼", "瑜伽", "训练"],
  135. "手工DIY": ["手工", "diy", "制作", "手作"],
  136. "护肤分享": ["护肤", "面膜", "精华", "皮肤"],
  137. "摄影分享": ["摄影", "拍照", "相机", "照片"],
  138. }
  139. # 匹配品类
  140. for category, keywords in category_keywords.items():
  141. for keyword in keywords:
  142. if keyword in content_lower or keyword in content:
  143. return category
  144. # 如果没有匹配到,使用通用描述
  145. return "创意分享"
  146. def extract_post_content(post_data):
  147. """
  148. 从视频详情数据中提取视频内容,并移除所有话题标签
  149. Args:
  150. post_data: 视频详情.json的数据
  151. Returns:
  152. tuple: (text_data, video_url)
  153. """
  154. import re
  155. # 提取原始数据
  156. title = post_data.get("title", "")
  157. body = post_data.get("body_text", "")
  158. # 移除body中的所有话题标签(格式:#xxx[话题]# 或 #xxx#)
  159. # 匹配模式:# 开头,后面是任意字符,可能包含[话题],以 # 结尾
  160. body_cleaned = re.sub(r'#[^#]+?(?:\[话题\])?\s*#', '', body)
  161. # 清理多余的空白字符
  162. body_cleaned = re.sub(r'\s+', ' ', body_cleaned).strip()
  163. text_data = {
  164. "title": title,
  165. "body": body_cleaned
  166. }
  167. video_url = post_data.get("video", "")
  168. return text_data, video_url
  169. def download_and_upload_video(video_url: str, directory: str):
  170. """
  171. 下载视频并上传到Gemini
  172. Args:
  173. video_url: 视频URL
  174. directory: 目录名(用于查找本地文件)
  175. Returns:
  176. Gemini文件对象,失败返回 None
  177. """
  178. if not video_url:
  179. print("⚠️ 未提供视频URL,跳过上传")
  180. return None
  181. try:
  182. # 1. 首先检查examples目录下是否有对应的mp4文件
  183. examples_dir = Path(__file__).parent
  184. local_video_path = examples_dir / directory / f"{directory}.mp4"
  185. if local_video_path.exists() and local_video_path.is_file():
  186. print(f"✅ 在examples目录下找到现有文件: {local_video_path.name}")
  187. video_file_path = str(local_video_path)
  188. is_temp_file = False
  189. else:
  190. # 2. 如果没有找到,则下载到临时文件
  191. print(f"📥 开始下载视频: {video_url}")
  192. # 创建临时文件
  193. temp_file = tempfile.NamedTemporaryFile(
  194. suffix=".mp4",
  195. delete=False
  196. )
  197. temp_file_path = temp_file.name
  198. temp_file.close()
  199. # 下载视频(带重试机制)
  200. max_retries = 3
  201. retry_count = 0
  202. last_exception = None
  203. video_file_path = None
  204. is_temp_file = True
  205. while retry_count < max_retries:
  206. try:
  207. if retry_count > 0:
  208. print(f"🔄 重试下载视频 (第 {retry_count}/{max_retries-1} 次)...")
  209. # 使用 Session 进行下载
  210. session = requests.Session()
  211. session.headers.update({
  212. "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36"
  213. })
  214. # 下载视频(增加超时时间)
  215. response = session.get(
  216. video_url,
  217. timeout=(30, 120), # (连接超时, 读取超时)
  218. stream=True
  219. )
  220. response.raise_for_status()
  221. # 确保目录存在
  222. os.makedirs(os.path.dirname(temp_file_path), exist_ok=True)
  223. # 写入文件
  224. with open(temp_file_path, "wb") as f:
  225. for chunk in response.iter_content(chunk_size=8192):
  226. if chunk:
  227. f.write(chunk)
  228. # 验证文件大小
  229. file_size = os.path.getsize(temp_file_path)
  230. if file_size == 0:
  231. raise ValueError("下载的文件大小为0")
  232. print(f"✅ 视频下载完成,大小: {file_size / 1024 / 1024:.2f} MB")
  233. video_file_path = temp_file_path
  234. break # 下载成功,退出重试循环
  235. except (requests.exceptions.ChunkedEncodingError,
  236. requests.exceptions.ConnectionError,
  237. requests.exceptions.Timeout,
  238. requests.exceptions.RequestException,
  239. ConnectionError,
  240. IncompleteRead) as e:
  241. last_exception = e
  242. retry_count += 1
  243. # 清理不完整的文件
  244. if os.path.exists(temp_file_path):
  245. try:
  246. os.remove(temp_file_path)
  247. except:
  248. pass
  249. if retry_count < max_retries:
  250. wait_time = retry_count * 2 # 递增等待时间:2秒、4秒
  251. print(f"⚠️ 下载失败 (尝试 {retry_count}/{max_retries}): {e}")
  252. print(f" 等待 {wait_time} 秒后重试...")
  253. time.sleep(wait_time)
  254. else:
  255. print(f"❌ 下载失败,已重试 {max_retries} 次")
  256. raise
  257. except Exception as e:
  258. # 其他类型的异常直接抛出,不重试
  259. if os.path.exists(temp_file_path):
  260. try:
  261. os.remove(temp_file_path)
  262. except:
  263. pass
  264. raise
  265. # 如果所有重试都失败了
  266. if not video_file_path:
  267. if last_exception:
  268. raise last_exception
  269. else:
  270. raise Exception("视频下载失败")
  271. # 3. 上传视频到Gemini
  272. print(f"📤 上传视频到Gemini...")
  273. video_file = LLMInvoker.upload_video_to_gemini(video_file_path)
  274. # 4. 清理临时文件
  275. if is_temp_file:
  276. try:
  277. os.remove(video_file_path)
  278. print(f"✅ 临时文件已删除")
  279. except Exception as e:
  280. print(f"⚠️ 删除临时文件失败: {e}")
  281. if not video_file:
  282. print(f"❌ 视频上传到Gemini失败")
  283. return None
  284. # 5. 获取文件信息(用于日志)
  285. file_name = None
  286. if hasattr(video_file, 'name'):
  287. file_name = video_file.name
  288. elif hasattr(video_file, 'uri'):
  289. # 从URI中提取文件名
  290. file_uri = video_file.uri
  291. if "/files/" in file_uri:
  292. file_name = file_uri.split("/files/")[-1]
  293. print(f"✅ 视频上传成功")
  294. if file_name:
  295. print(f" 文件名称: {file_name}")
  296. # 直接返回文件对象
  297. return video_file
  298. except Exception as e:
  299. print(f"❌ 视频下载/上传失败: {e}")
  300. import traceback
  301. traceback.print_exc()
  302. return None
  303. def main():
  304. """主函数"""
  305. # 解析命令行参数
  306. parser = argparse.ArgumentParser(description='运行脚本理解Agent(视频分析版本)')
  307. parser.add_argument('directory', type=str, help='视频目录名(如"56898272"),目录下需要有"视频详情.json"文件')
  308. args = parser.parse_args()
  309. directory = args.directory
  310. print("=" * 80)
  311. print(f"开始运行脚本理解Agent - 目录: {directory}")
  312. print("=" * 80)
  313. # 1. 查找视频详情文件
  314. print("\n[1] 查找视频详情文件...")
  315. try:
  316. post_file = find_post_file(directory)
  317. if not post_file:
  318. print(f"❌ 未找到视频详情文件")
  319. return
  320. print(f"✅ 找到视频详情文件: {post_file.name}")
  321. print(f" 文件路径: {post_file}")
  322. except Exception as e:
  323. print(f"❌ 查找视频详情文件失败: {e}")
  324. return
  325. # 2. 加载视频详情文件
  326. print("\n[2] 加载视频详情文件...")
  327. try:
  328. post_data = load_result_file(post_file)
  329. print(f"✅ 成功加载视频详情文件")
  330. except Exception as e:
  331. print(f"❌ 加载视频详情文件失败: {e}")
  332. return
  333. # 3. 提取视频内容
  334. print("\n[3] 提取视频内容...")
  335. try:
  336. text_data, video_url = extract_post_content(post_data)
  337. print(f"✅ 成功提取视频内容")
  338. print(f" 标题: {text_data.get('title', '无')}")
  339. print(f" 正文长度: {len(text_data.get('body', ''))}")
  340. print(f" 视频URL: {'有' if video_url else '无'}")
  341. except Exception as e:
  342. print(f"❌ 提取视频内容失败: {e}")
  343. return
  344. # 4. 查找最新的result文件
  345. print("\n[4] 查找最新的result文件...")
  346. try:
  347. result_file = find_latest_result_file(directory)
  348. if not result_file:
  349. print(f"❌ 未找到result文件")
  350. return
  351. print(f"✅ 找到最新result文件: {result_file.name}")
  352. print(f" 文件路径: {result_file}")
  353. print(f" 修改时间: {datetime.fromtimestamp(result_file.stat().st_mtime)}")
  354. except Exception as e:
  355. print(f"❌ 查找result文件失败: {e}")
  356. return
  357. # 5. 加载result文件
  358. print("\n[5] 加载result文件...")
  359. try:
  360. result_data = load_result_file(result_file)
  361. print(f"✅ 成功加载result文件")
  362. except Exception as e:
  363. print(f"❌ 加载result文件失败: {e}")
  364. return
  365. # 6. 提取选题描述
  366. print("\n[6] 提取选题描述...")
  367. try:
  368. topic_description = extract_topic_description(result_data)
  369. print(f"✅ 成功提取选题描述")
  370. print(f" 选题描述:")
  371. if topic_description.get("主题"):
  372. print(f" 主题: {topic_description['主题']}")
  373. if topic_description.get("描述"):
  374. print(f" 描述: {topic_description['描述']}")
  375. except Exception as e:
  376. print(f"❌ 提取选题描述失败: {e}")
  377. return
  378. # 7. 下载并上传视频到Gemini
  379. print("\n[7] 下载并上传视频到Gemini...")
  380. video_file = None
  381. if video_url:
  382. try:
  383. video_file = download_and_upload_video(video_url, directory)
  384. if not video_file:
  385. print(f"⚠️ 视频上传失败,但继续执行(可能影响视频分析功能)")
  386. except Exception as e:
  387. print(f"⚠️ 视频上传失败: {e},但继续执行(可能影响视频分析功能)")
  388. import traceback
  389. traceback.print_exc()
  390. else:
  391. print(f"⚠️ 未提供视频URL,跳过上传")
  392. # 8. 初始化两个Agent
  393. print("\n[8] 初始化ScriptSectionDivisionAgent和ScriptElementExtractionAgent...")
  394. try:
  395. section_agent = ScriptSectionDivisionAgent(
  396. model_provider="google_genai"
  397. )
  398. element_agent = ScriptElementExtractionAgent(
  399. model_provider="google_genai"
  400. )
  401. print(f"✅ Agent初始化成功")
  402. except Exception as e:
  403. print(f"❌ Agent初始化失败: {e}")
  404. import traceback
  405. traceback.print_exc()
  406. return
  407. # 9. 组装state对象
  408. print("\n[9] 组装state对象...")
  409. try:
  410. # 构建选题理解格式(模拟workflow中的格式)
  411. topic_understanding = result_data.get("选题理解", {})
  412. state = {
  413. "text": text_data,
  414. "video": video_url,
  415. "topic_selection_understanding": topic_understanding
  416. }
  417. # 添加视频文件对象(如果上传成功)
  418. if video_file:
  419. state["video_file"] = video_file
  420. print(f"✅ State对象组装成功")
  421. print(f" - 文本: {bool(text_data)}")
  422. print(f" - 视频URL: {'有' if video_url else '无'}")
  423. print(f" - 视频文件对象: {'有' if video_file else '无'}")
  424. print(f" - 选题理解: {bool(topic_understanding)}")
  425. except Exception as e:
  426. print(f"❌ 组装state对象失败: {e}")
  427. return
  428. # 10. 执行两个Agent
  429. print("\n[10] 执行脚本段落划分Agent...")
  430. try:
  431. section_result = section_agent.process(state)
  432. sections = section_result.get("段落列表", [])
  433. content_category = section_result.get("内容品类", "未知品类")
  434. print(f"✅ 段落划分执行成功")
  435. print(f" 内容品类: {content_category}")
  436. print(f" 划分出 {len(sections)} 个Section")
  437. except Exception as e:
  438. print(f"❌ 段落划分执行失败: {e}")
  439. import traceback
  440. traceback.print_exc()
  441. return
  442. print("\n[11] 执行脚本元素提取Agent...")
  443. try:
  444. # 更新state,添加段落划分结果和其他必需数据
  445. state["section_division"] = {"段落列表": sections}
  446. # 从result_data中提取灵感点、目的点、关键点(从"三点解构"中提取)
  447. three_points = result_data.get("三点解构", {})
  448. state["inspiration_points"] = three_points.get("灵感点", {})
  449. state["purpose_points"] = three_points.get("目的点", {})
  450. state["key_points"] = three_points.get("关键点", {})
  451. element_result = element_agent.process(state)
  452. elements = element_result.get("元素列表", [])
  453. tendency_judgment = element_result.get("视频倾向判断", {})
  454. print(f"✅ 元素提取执行成功")
  455. print(f" 识别出 {len(elements)} 个元素")
  456. if tendency_judgment:
  457. print(f" 视频倾向: {tendency_judgment.get('判断结果', '未知')}")
  458. except Exception as e:
  459. print(f"❌ 元素提取执行失败: {e}")
  460. import traceback
  461. traceback.print_exc()
  462. return
  463. # 12. 组装最终结果
  464. print("\n[12] 组装最终结果...")
  465. try:
  466. # 递归统计数量
  467. def count_items(items_list):
  468. count = len(items_list)
  469. for item in items_list:
  470. if item.get('子项'):
  471. count += count_items(item['子项'])
  472. return count
  473. total_sections = count_items(sections)
  474. total_elements = count_items(elements)
  475. # 组装脚本理解结果
  476. script_understanding = {
  477. "内容品类": content_category,
  478. "段落列表": sections,
  479. "元素列表": elements,
  480. "视频URL": video_url,
  481. "视频倾向判断": tendency_judgment # 添加视频倾向判断
  482. }
  483. final_result = {
  484. "选题描述": topic_description,
  485. "脚本理解": script_understanding,
  486. "元信息": {
  487. "段落总数": total_sections,
  488. "元素总数": total_elements,
  489. "来源帖子文件": post_file.name,
  490. "来源结果文件": result_file.name,
  491. "执行时间": datetime.now().isoformat()
  492. }
  493. }
  494. print(f"✅ 结果组装成功")
  495. except Exception as e:
  496. print(f"❌ 结果组装失败: {e}")
  497. return
  498. # 13. 保存结果
  499. print("\n[13] 保存结果...")
  500. try:
  501. # 生成带时间戳的文件名
  502. timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
  503. output_filename = f"script_result_{timestamp}.json"
  504. output_path = Path(__file__).parent / directory / "output" / output_filename
  505. output_path.parent.mkdir(parents=True, exist_ok=True)
  506. with open(output_path, "w", encoding="utf-8") as f:
  507. json.dump(final_result, f, ensure_ascii=False, indent=2)
  508. print(f"✅ 结果已保存到: {output_path}")
  509. print(f" 文件名: {output_filename}")
  510. except Exception as e:
  511. print(f"❌ 保存结果失败: {e}")
  512. return
  513. # 13. 显示结果摘要
  514. print("\n" + "=" * 80)
  515. print("结果摘要")
  516. print("=" * 80)
  517. print(f"\n选题描述:")
  518. if topic_description.get("主题"):
  519. print(f" 主题: {topic_description['主题']}")
  520. if topic_description.get("描述"):
  521. print(f" 描述: {topic_description['描述']}")
  522. # 递归打印Section树状结构
  523. def print_sections(sections_list, indent=0):
  524. for idx, section in enumerate(sections_list, 1):
  525. prefix = " " + " " * indent
  526. print(f"{prefix}{idx}. {section.get('描述', 'N/A')}")
  527. if section.get('子项'):
  528. print_sections(section['子项'], indent + 1)
  529. # 递归统计Section数量(只统计叶子节点)
  530. def count_sections(sections_list):
  531. count = 0
  532. for section in sections_list:
  533. if section.get('子项'):
  534. # 有子项,递归统计子项
  535. count += count_sections(section['子项'])
  536. else:
  537. # 无子项,是叶子节点
  538. count += 1
  539. return count
  540. total_sections = count_sections(sections)
  541. print(f"\nSection列表 ({total_sections} 个):")
  542. print_sections(sections)
  543. # 打印Element列表(只打印名称和类型,不打印树状结构)
  544. def print_elements(elements_list):
  545. for element in elements_list:
  546. name = element.get('名称', 'N/A')
  547. elem_type = element.get('类型', 'N/A')
  548. classification = element.get('分类', {})
  549. # 构建分类路径
  550. if classification:
  551. class_path = " > ".join([v for v in classification.values() if v])
  552. print(f" - [{elem_type}] {name} ({class_path})")
  553. else:
  554. print(f" - [{elem_type}] {name}")
  555. # 不再递归统计,直接使用列表长度
  556. total_elements = len(elements)
  557. print(f"\n元素列表 ({total_elements} 个):")
  558. if elements:
  559. print_elements(elements)
  560. else:
  561. print(" (无)")
  562. print("\n" + "=" * 80)
  563. print("测试完成!")
  564. print("=" * 80)
  565. if __name__ == "__main__":
  566. main()