run_single.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383
  1. """
  2. 测试脚本:运行 待解构帖子.json(带历史帖子)
  3. 功能:
  4. 1. 加载最近3篇历史帖子(从早到晚排序)
  5. 2. 加载待解构帖子
  6. 3. 运行 WhatDeconstructionWorkflow
  7. """
  8. import json
  9. import sys
  10. import os
  11. import argparse
  12. from pathlib import Path
  13. from datetime import datetime
  14. # 添加项目根目录到路径
  15. project_root = Path(__file__).parent.parent
  16. sys.path.insert(0, str(project_root))
  17. # 手动加载.env文件
  18. def load_env_file(env_path):
  19. """手动加载.env文件"""
  20. if not env_path.exists():
  21. return False
  22. with open(env_path, 'r') as f:
  23. for line in f:
  24. line = line.strip()
  25. # 跳过注释和空行
  26. if not line or line.startswith('#'):
  27. continue
  28. # 解析KEY=VALUE
  29. if '=' in line:
  30. key, value = line.split('=', 1)
  31. os.environ[key.strip()] = value.strip()
  32. return True
  33. env_path = project_root / ".env"
  34. if load_env_file(env_path):
  35. print(f"✅ 已加载环境变量从: {env_path}")
  36. # 验证API密钥
  37. api_key = os.environ.get("GEMINI_API_KEY", "")
  38. if api_key:
  39. print(f" GEMINI_API_KEY: {api_key[:10]}...")
  40. else:
  41. print(f"⚠️ 未找到.env文件: {env_path}")
  42. from src.workflows.what_deconstruction_workflow import WhatDeconstructionWorkflow
  43. from src.utils.logger import get_logger
  44. logger = get_logger(__name__)
  45. def load_historical_posts(history_dir, target_timestamp=None, target_post_id=None, max_count=10):
  46. """
  47. 加载历史帖子(根据publish_timestamp从新到旧排序)
  48. 选择比目标帖子早发布,并且是最近发布的帖子,排除目标帖子本身
  49. Args:
  50. history_dir: 历史帖子目录
  51. target_timestamp: 目标帖子的发布时间戳(可选)
  52. target_post_id: 目标帖子的ID(用于过滤重复,可选)
  53. max_count: 最多加载的帖子数量
  54. Returns:
  55. list: 历史帖子列表(从新到旧排序)
  56. """
  57. history_path = Path(history_dir)
  58. if not history_path.exists():
  59. print(f"⚠️ 历史帖子目录不存在: {history_path}")
  60. return []
  61. # 获取所有JSON文件
  62. json_files = list(history_path.glob("*.json"))
  63. if not json_files:
  64. print(f"⚠️ 未找到历史帖子文件")
  65. return []
  66. print(f"\n📁 找到 {len(json_files)} 个历史帖子文件")
  67. # 读取所有帖子并提取publish_timestamp
  68. posts_with_timestamp = []
  69. for file_path in json_files:
  70. try:
  71. with open(file_path, 'r', encoding='utf-8') as f:
  72. post_data = json.load(f)
  73. # 获取发布时间戳,如果不存在则使用0
  74. timestamp = post_data.get("publish_timestamp", 0)
  75. post_id = post_data.get("channel_content_id", "")
  76. posts_with_timestamp.append({
  77. "file_path": file_path,
  78. "post_data": post_data,
  79. "timestamp": timestamp,
  80. "post_id": post_id
  81. })
  82. except Exception as e:
  83. print(f" ⚠️ 读取文件失败 {file_path.name}: {e}")
  84. continue
  85. if not posts_with_timestamp:
  86. print(f"⚠️ 没有成功读取到任何帖子")
  87. return []
  88. # 过滤掉目标帖子本身
  89. if target_post_id is not None:
  90. original_count = len(posts_with_timestamp)
  91. posts_with_timestamp = [
  92. post for post in posts_with_timestamp
  93. if post["post_id"] != target_post_id
  94. ]
  95. filtered_count = original_count - len(posts_with_timestamp)
  96. if filtered_count > 0:
  97. print(f"🔍 过滤掉 {filtered_count} 个重复帖子(目标帖子本身)")
  98. # 如果提供了目标时间戳,只保留比目标帖子早的帖子
  99. if target_timestamp is not None:
  100. posts_with_timestamp = [
  101. post for post in posts_with_timestamp
  102. if post["timestamp"] < target_timestamp
  103. ]
  104. print(f"📊 筛选出 {len(posts_with_timestamp)} 个比目标帖子早的历史帖子")
  105. if not posts_with_timestamp:
  106. print(f"⚠️ 没有找到比目标帖子早的历史帖子")
  107. return []
  108. # 按照publish_timestamp排序(从新到旧)
  109. posts_with_timestamp.sort(key=lambda x: x["timestamp"], reverse=True)
  110. # 选择最近的N篇(从新到旧)
  111. selected_posts = posts_with_timestamp[:max_count] if len(posts_with_timestamp) > max_count else posts_with_timestamp
  112. print(f"📋 选择最近 {len(selected_posts)} 篇历史帖子(按发布时间从新到旧):")
  113. historical_posts = []
  114. for idx, post_info in enumerate(selected_posts, 1):
  115. post_data = post_info["post_data"]
  116. file_path = post_info["file_path"]
  117. timestamp = post_info["timestamp"]
  118. # 转换为需要的格式
  119. historical_post = {
  120. "text": {
  121. "title": post_data.get("title", ""),
  122. "body": post_data.get("body_text", ""),
  123. "hashtags": ""
  124. },
  125. "images": post_data.get("images", [])
  126. }
  127. historical_posts.append(historical_post)
  128. # 格式化时间显示
  129. publish_time = post_data.get("publish_time", "未知时间")
  130. print(f" {idx}. {file_path.name}")
  131. print(f" 标题: {post_data.get('title', '无标题')}")
  132. print(f" 发布时间: {publish_time}")
  133. print(f" 图片数: {len(post_data.get('images', []))}")
  134. return historical_posts
  135. def load_test_data(directory):
  136. """
  137. 加载测试数据
  138. Args:
  139. directory: 帖子目录名(如"阿里多多酱"或"G88818")
  140. """
  141. test_data_path = Path(__file__).parent / directory / "待解构帖子.json"
  142. with open(test_data_path, "r", encoding="utf-8") as f:
  143. data = json.load(f)
  144. return data
  145. def convert_to_workflow_input(raw_data, historical_posts=None):
  146. """
  147. 将原始数据转换为工作流输入格式
  148. Args:
  149. raw_data: 原始帖子数据
  150. historical_posts: 历史帖子列表(可选)
  151. """
  152. images = raw_data.get("images", [])
  153. input_data = {
  154. "multimedia_content": {
  155. "images": images,
  156. "video": raw_data.get("video", {}),
  157. "text": {
  158. "title": raw_data.get("title", ""),
  159. "body": raw_data.get("body_text", ""),
  160. "hashtags": ""
  161. }
  162. },
  163. "comments": raw_data.get("comments", []),
  164. "creator_info": {
  165. "nickname": raw_data.get("channel_account_name", ""),
  166. "account_id": raw_data.get("channel_account_id", "")
  167. }
  168. }
  169. # 如果有历史帖子,添加到输入数据中
  170. if historical_posts:
  171. input_data["historical_posts"] = historical_posts
  172. return input_data
  173. def main():
  174. """主函数"""
  175. # 解析命令行参数
  176. parser = argparse.ArgumentParser(description='运行单个帖子的What解构工作流')
  177. parser.add_argument('directory', type=str, help='帖子目录名(如"阿里多多酱"或"G88818")')
  178. args = parser.parse_args()
  179. directory = args.directory
  180. print("=" * 80)
  181. print(f"开始测试 What 解构工作流(带历史帖子)- 目录: {directory}")
  182. print("=" * 80)
  183. # 1. 加载测试数据(目标帖子)
  184. print("\n[1] 加载测试数据(目标帖子)...")
  185. try:
  186. raw_data = load_test_data(directory)
  187. target_timestamp = raw_data.get('publish_timestamp')
  188. target_post_id = raw_data.get('channel_content_id')
  189. target_publish_time = raw_data.get('publish_time', '未知时间')
  190. print(f"✅ 成功加载测试数据")
  191. print(f" - 标题: {raw_data.get('title')}")
  192. print(f" - 帖子ID: {target_post_id}")
  193. print(f" - 发布时间: {target_publish_time}")
  194. print(f" - 图片数: {len(raw_data.get('images', []))}")
  195. print(f" - 点赞数: {raw_data.get('like_count')}")
  196. print(f" - 评论数: {raw_data.get('comment_count')}")
  197. except Exception as e:
  198. print(f"❌ 加载测试数据失败: {e}")
  199. return
  200. # 2. 加载历史帖子(比目标帖子早的帖子,排除目标帖子本身)
  201. print("\n[2] 加载历史帖子...")
  202. history_dir = Path(__file__).parent / directory / "作者历史帖子"
  203. historical_posts = load_historical_posts(
  204. history_dir,
  205. target_timestamp=target_timestamp,
  206. target_post_id=target_post_id,
  207. max_count=15
  208. )
  209. if historical_posts:
  210. print(f"✅ 成功加载 {len(historical_posts)} 篇历史帖子")
  211. else:
  212. print(f"⚠️ 未加载到历史帖子,将使用常规分析模式")
  213. # 3. 转换数据格式
  214. print("\n[3] 转换数据格式...")
  215. try:
  216. input_data = convert_to_workflow_input(raw_data, historical_posts)
  217. print(f"✅ 数据格式转换成功")
  218. print(f" - 话题标签: {input_data['multimedia_content']['text']['hashtags']}")
  219. print(f" - 历史帖子数: {len(input_data.get('historical_posts', []))}")
  220. except Exception as e:
  221. print(f"❌ 数据格式转换失败: {e}")
  222. return
  223. # 4. 初始化工作流
  224. print("\n[4] 初始化工作流...")
  225. try:
  226. workflow = WhatDeconstructionWorkflow(
  227. model_provider="google_genai",
  228. max_depth=10
  229. )
  230. print(f"✅ 工作流初始化成功")
  231. except Exception as e:
  232. print(f"❌ 工作流初始化失败: {e}")
  233. import traceback
  234. traceback.print_exc()
  235. return
  236. # 5. 执行工作流
  237. print("\n[5] 执行工作流...")
  238. print(" 注意:这可能需要几分钟时间...")
  239. try:
  240. result = workflow.invoke(input_data)
  241. print(f"✅ 工作流执行成功")
  242. except Exception as e:
  243. print(f"❌ 工作流执行失败: {e}")
  244. import traceback
  245. traceback.print_exc()
  246. return
  247. # 6. 保存结果
  248. print("\n[6] 保存结果...")
  249. try:
  250. # 生成带时间戳的文件名
  251. timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
  252. output_filename = f"result_{timestamp}.json"
  253. output_path = Path(__file__).parent / directory / "output" / output_filename
  254. output_path.parent.mkdir(parents=True, exist_ok=True)
  255. with open(output_path, "w", encoding="utf-8") as f:
  256. json.dump(result, f, ensure_ascii=False, indent=2)
  257. print(f"✅ 结果已保存到: {output_path}")
  258. print(f" 文件名: {output_filename}")
  259. except Exception as e:
  260. print(f"❌ 保存结果失败: {e}")
  261. return
  262. # 7. 生成HTML可视化
  263. # print("\n[7] 生成HTML可视化...")
  264. # try:
  265. # visualize_script = Path(__file__).parent / "visualize_result.py"
  266. # if visualize_script.exists():
  267. # import subprocess
  268. # result_viz = subprocess.run(
  269. # [sys.executable, str(visualize_script), str(output_path)],
  270. # capture_output=True,
  271. # text=True
  272. # )
  273. # if result_viz.returncode == 0:
  274. # print(f"✅ HTML可视化生成成功")
  275. # # 查找生成的HTML文件
  276. # html_file = output_path.parent / f"{output_path.stem}_visualization.html"
  277. # if html_file.exists():
  278. # print(f" 可视化文件: {html_file}")
  279. # else:
  280. # print(f"⚠️ HTML可视化生成失败: {result_viz.stderr}")
  281. # else:
  282. # print(f"⚠️ 未找到可视化脚本: {visualize_script}")
  283. # except Exception as e:
  284. # print(f"⚠️ 生成HTML可视化失败: {e}")
  285. # 8. 显示结果摘要
  286. print("\n" + "=" * 80)
  287. print("结果摘要")
  288. print("=" * 80)
  289. if result:
  290. three_points = result.get("三点解构", {})
  291. inspiration_data = three_points.get("灵感点", {})
  292. keypoints_data = three_points.get("关键点", {})
  293. comments = result.get("评论分析", {}).get("解构维度", [])
  294. print(f"\n三点解构:")
  295. print(f" - 灵感点数量: {inspiration_data.get('total_count', 0)}")
  296. print(f" - 灵感点分析模式: {inspiration_data.get('analysis_mode', '未知')}")
  297. print(f" - 目的点数量: 1")
  298. print(f" - 关键点数量: {keypoints_data.get('total_count', 0)}")
  299. # 显示灵感点详情
  300. if inspiration_data.get('points'):
  301. print(f"\n灵感点列表:")
  302. for idx, point in enumerate(inspiration_data['points'], 1):
  303. print(f" {idx}. {point.get('灵感点', '')}")
  304. print(f"\n评论分析:")
  305. print(f" - 解构维度数: {len(comments)}")
  306. topic_understanding = result.get("选题理解", {})
  307. if topic_understanding:
  308. topic_theme = topic_understanding.get("topic_theme", "")
  309. print(f"\n选题理解:")
  310. print(f" - 选题主题: {topic_theme}")
  311. print("\n" + "=" * 80)
  312. print("测试完成!")
  313. print("=" * 80)
  314. if __name__ == "__main__":
  315. main()