run_batch.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504
  1. """
  2. 批量测试脚本:处理作者历史帖子目录下的所有帖子(并发版本,带历史帖子)
  3. 功能:
  4. 1. 加载最新的20个帖子(按照publish_timestamp从新到旧排序)
  5. 2. 为每个帖子加载历史帖子(比当前帖子早的最近15篇)
  6. 3. 并发处理所有帖子
  7. """
  8. import json
  9. import sys
  10. import os
  11. import argparse
  12. from pathlib import Path
  13. from datetime import datetime
  14. from concurrent.futures import ThreadPoolExecutor, as_completed
  15. import threading
  16. # 添加项目根目录到路径
  17. project_root = Path(__file__).parent.parent
  18. sys.path.insert(0, str(project_root))
  19. # 手动加载.env文件
  20. def load_env_file(env_path):
  21. """手动加载.env文件"""
  22. if not env_path.exists():
  23. return False
  24. with open(env_path, 'r') as f:
  25. for line in f:
  26. line = line.strip()
  27. # 跳过注释和空行
  28. if not line or line.startswith('#'):
  29. continue
  30. # 解析KEY=VALUE
  31. if '=' in line:
  32. key, value = line.split('=', 1)
  33. os.environ[key.strip()] = value.strip()
  34. return True
  35. env_path = project_root / ".env"
  36. if load_env_file(env_path):
  37. print(f"✅ 已加载环境变量从: {env_path}")
  38. # 验证API密钥
  39. api_key = os.environ.get("GEMINI_API_KEY", "")
  40. if api_key:
  41. print(f" GEMINI_API_KEY: {api_key[:10]}...")
  42. else:
  43. print(f"⚠️ 未找到.env文件: {env_path}")
  44. from src.workflows.what_deconstruction_workflow import WhatDeconstructionWorkflow
  45. from src.utils.logger import get_logger
  46. logger = get_logger(__name__)
  47. # 线程安全的输出锁
  48. print_lock = threading.Lock()
  49. # 线程安全的计数器
  50. class ThreadSafeCounter:
  51. def __init__(self):
  52. self._lock = threading.Lock()
  53. self._success = 0
  54. self._fail = 0
  55. def increment_success(self):
  56. with self._lock:
  57. self._success += 1
  58. def increment_fail(self):
  59. with self._lock:
  60. self._fail += 1
  61. @property
  62. def success(self):
  63. with self._lock:
  64. return self._success
  65. @property
  66. def fail(self):
  67. with self._lock:
  68. return self._fail
  69. def safe_print(*args, **kwargs):
  70. """线程安全的打印函数"""
  71. with print_lock:
  72. print(*args, **kwargs)
  73. def load_historical_posts(history_dir, target_timestamp=None, target_post_id=None, max_count=15):
  74. """
  75. 加载历史帖子(根据publish_timestamp从新到旧排序)
  76. 选择比目标帖子早发布,并且是最近发布的帖子,排除目标帖子本身
  77. Args:
  78. history_dir: 历史帖子目录
  79. target_timestamp: 目标帖子的发布时间戳(可选)
  80. target_post_id: 目标帖子的ID(用于过滤重复,可选)
  81. max_count: 最多加载的帖子数量
  82. Returns:
  83. list: 历史帖子列表(从新到旧排序)
  84. """
  85. history_path = Path(history_dir)
  86. if not history_path.exists():
  87. safe_print(f"⚠️ 历史帖子目录不存在: {history_path}")
  88. return []
  89. # 获取所有JSON文件
  90. json_files = list(history_path.glob("*.json"))
  91. if not json_files:
  92. safe_print(f"⚠️ 未找到历史帖子文件")
  93. return []
  94. # 读取所有帖子并提取publish_timestamp
  95. posts_with_timestamp = []
  96. for file_path in json_files:
  97. try:
  98. with open(file_path, 'r', encoding='utf-8') as f:
  99. post_data = json.load(f)
  100. # 获取发布时间戳,如果不存在则使用0
  101. timestamp = post_data.get("publish_timestamp", 0)
  102. post_id = post_data.get("channel_content_id", "")
  103. posts_with_timestamp.append({
  104. "file_path": file_path,
  105. "post_data": post_data,
  106. "timestamp": timestamp,
  107. "post_id": post_id
  108. })
  109. except Exception as e:
  110. safe_print(f" ⚠️ 读取文件失败 {file_path.name}: {e}")
  111. continue
  112. if not posts_with_timestamp:
  113. safe_print(f"⚠️ 没有成功读取到任何帖子")
  114. return []
  115. # 过滤掉目标帖子本身
  116. if target_post_id is not None:
  117. posts_with_timestamp = [
  118. post for post in posts_with_timestamp
  119. if post["post_id"] != target_post_id
  120. ]
  121. # 如果提供了目标时间戳,只保留比目标帖子早的帖子
  122. if target_timestamp is not None:
  123. posts_with_timestamp = [
  124. post for post in posts_with_timestamp
  125. if post["timestamp"] < target_timestamp
  126. ]
  127. if not posts_with_timestamp:
  128. return []
  129. # 按照publish_timestamp排序(从新到旧)
  130. posts_with_timestamp.sort(key=lambda x: x["timestamp"], reverse=True)
  131. # 选择最近的N篇(从新到旧)
  132. selected_posts = posts_with_timestamp[:max_count] if len(posts_with_timestamp) > max_count else posts_with_timestamp
  133. historical_posts = []
  134. for post_info in selected_posts:
  135. post_data = post_info["post_data"]
  136. # 转换为需要的格式
  137. historical_post = {
  138. "text": {
  139. "title": post_data.get("title", ""),
  140. "body": post_data.get("body_text", ""),
  141. "hashtags": ""
  142. },
  143. "images": post_data.get("images", [])
  144. }
  145. historical_posts.append(historical_post)
  146. return historical_posts
  147. def load_post_files(directory, max_count=20):
  148. """
  149. 加载作者历史帖子目录下的所有JSON文件,按照publish_timestamp从新到旧排序,取最新的max_count个
  150. Args:
  151. directory: 帖子目录
  152. max_count: 最多加载的帖子数量(默认20)
  153. Returns:
  154. list: 帖子文件路径列表(按时间从新到旧排序)
  155. """
  156. post_dir = Path(directory)
  157. if not post_dir.exists():
  158. raise FileNotFoundError(f"目录不存在: {post_dir}")
  159. # 获取所有JSON文件
  160. json_files = list(post_dir.glob("*.json"))
  161. if not json_files:
  162. raise FileNotFoundError(f"目录中没有找到JSON文件: {post_dir}")
  163. # 读取所有帖子并提取publish_timestamp
  164. posts_with_timestamp = []
  165. for file_path in json_files:
  166. try:
  167. with open(file_path, 'r', encoding='utf-8') as f:
  168. post_data = json.load(f)
  169. # 获取发布时间戳,如果不存在则使用0
  170. timestamp = post_data.get("publish_timestamp", 0)
  171. posts_with_timestamp.append({
  172. "file_path": file_path,
  173. "timestamp": timestamp,
  174. "post_data": post_data
  175. })
  176. except Exception as e:
  177. print(f"⚠️ 读取文件失败 {file_path.name}: {e}")
  178. continue
  179. if not posts_with_timestamp:
  180. raise FileNotFoundError(f"没有成功读取到任何帖子")
  181. # 按照publish_timestamp排序(从新到旧)
  182. posts_with_timestamp.sort(key=lambda x: x["timestamp"], reverse=True)
  183. # 取最新的max_count个
  184. selected_posts = posts_with_timestamp[:max_count]
  185. print(f"📊 按时间排序并选择最新 {len(selected_posts)} 个帖子:")
  186. for idx, post_info in enumerate(selected_posts, 1):
  187. post_data = post_info["post_data"]
  188. publish_time = post_data.get("publish_time", "未知时间")
  189. title = post_data.get("title", "无标题")
  190. print(f" {idx}. {post_info['file_path'].name}")
  191. print(f" 标题: {title}")
  192. print(f" 发布时间: {publish_time}")
  193. return [post_info["file_path"] for post_info in selected_posts]
  194. def convert_to_workflow_input(raw_data, historical_posts=None):
  195. """
  196. 将原始数据转换为工作流输入格式
  197. Args:
  198. raw_data: 原始帖子数据
  199. historical_posts: 历史帖子列表(可选)
  200. """
  201. # 转换为工作流需要的格式
  202. images = raw_data.get("images", [])
  203. input_data = {
  204. "multimedia_content": {
  205. "images": images,
  206. "video": raw_data.get("video", {}),
  207. "text": {
  208. "title": raw_data.get("title", ""),
  209. "body": raw_data.get("body_text", ""),
  210. "hashtags": ""
  211. }
  212. },
  213. "comments": raw_data.get("comments", []), # 包含评论数据
  214. "creator_info": {
  215. "nickname": raw_data.get("channel_account_name", ""),
  216. "account_id": raw_data.get("channel_account_id", "")
  217. }
  218. }
  219. # 如果有历史帖子,添加到输入数据中
  220. if historical_posts:
  221. input_data["historical_posts"] = historical_posts
  222. return input_data
  223. def process_single_post(post_file, posts_dir, output_dir, counter, total_count, current_index):
  224. """处理单个帖子文件(线程安全版本,带历史帖子)"""
  225. post_name = post_file.stem # 获取文件名(不含扩展名)
  226. thread_id = threading.current_thread().name
  227. safe_print(f"\n{'='*80}")
  228. safe_print(f"[线程:{thread_id}] 处理帖子: {post_name}")
  229. safe_print(f"进度: [{current_index}/{total_count}]")
  230. safe_print(f"{'='*80}")
  231. # 1. 加载帖子数据
  232. safe_print(f"\n[线程:{thread_id}][1] 加载帖子数据...")
  233. try:
  234. with open(post_file, "r", encoding="utf-8") as f:
  235. raw_data = json.load(f)
  236. target_timestamp = raw_data.get('publish_timestamp')
  237. target_post_id = raw_data.get('channel_content_id')
  238. safe_print(f"✅ [{thread_id}] 成功加载帖子数据")
  239. safe_print(f" - 标题: {raw_data.get('title', 'N/A')}")
  240. safe_print(f" - 帖子ID: {target_post_id}")
  241. safe_print(f" - 发布时间: {raw_data.get('publish_time', '未知时间')}")
  242. safe_print(f" - 图片数: {len(raw_data.get('images', []))}")
  243. safe_print(f" - 点赞数: {raw_data.get('like_count', 0)}")
  244. safe_print(f" - 评论数: {raw_data.get('comment_count', 0)}")
  245. except Exception as e:
  246. safe_print(f"❌ [{thread_id}] 加载帖子数据失败: {e}")
  247. counter.increment_fail()
  248. return False
  249. # 2. 加载历史帖子
  250. safe_print(f"\n[线程:{thread_id}][2] 加载历史帖子...")
  251. historical_posts = load_historical_posts(
  252. posts_dir,
  253. target_timestamp=target_timestamp,
  254. target_post_id=target_post_id,
  255. max_count=15
  256. )
  257. if historical_posts:
  258. safe_print(f"✅ [{thread_id}] 成功加载 {len(historical_posts)} 篇历史帖子")
  259. else:
  260. safe_print(f"⚠️ [{thread_id}] 未加载到历史帖子,将使用常规分析模式")
  261. # 3. 转换数据格式
  262. safe_print(f"\n[线程:{thread_id}][3] 转换数据格式...")
  263. try:
  264. input_data = convert_to_workflow_input(raw_data, historical_posts)
  265. safe_print(f"✅ [{thread_id}] 数据格式转换成功")
  266. safe_print(f" - 历史帖子数: {len(input_data.get('historical_posts', []))}")
  267. except Exception as e:
  268. safe_print(f"❌ [{thread_id}] 数据格式转换失败: {e}")
  269. counter.increment_fail()
  270. return False
  271. # 4. 初始化工作流(每个线程创建独立实例,确保线程安全)
  272. safe_print(f"\n[线程:{thread_id}][4] 初始化工作流...")
  273. try:
  274. workflow = WhatDeconstructionWorkflow(
  275. model_provider="google_genai",
  276. max_depth=10
  277. )
  278. safe_print(f"✅ [{thread_id}] 工作流初始化成功")
  279. except Exception as e:
  280. safe_print(f"❌ [{thread_id}] 工作流初始化失败: {e}")
  281. counter.increment_fail()
  282. return False
  283. # 5. 执行工作流
  284. safe_print(f"\n[线程:{thread_id}][5] 执行工作流...")
  285. safe_print(f" 注意:这可能需要几分钟时间...")
  286. try:
  287. result = workflow.invoke(input_data)
  288. safe_print(f"✅ [{thread_id}] 工作流执行成功")
  289. except Exception as e:
  290. safe_print(f"❌ [{thread_id}] 工作流执行失败: {e}")
  291. import traceback
  292. safe_print(traceback.format_exc())
  293. counter.increment_fail()
  294. return False
  295. # 6. 保存结果
  296. safe_print(f"\n[线程:{thread_id}][6] 保存结果...")
  297. try:
  298. # 使用帖子文件名作为前缀
  299. timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
  300. output_filename = f"{post_name}_with_history_{timestamp}.json"
  301. output_path = output_dir / output_filename
  302. with open(output_path, "w", encoding="utf-8") as f:
  303. json.dump(result, f, ensure_ascii=False, indent=2)
  304. safe_print(f"✅ [{thread_id}] 结果已保存到: {output_path}")
  305. except Exception as e:
  306. safe_print(f"❌ [{thread_id}] 保存结果失败: {e}")
  307. counter.increment_fail()
  308. return False
  309. # 7. 显示结果摘要
  310. safe_print(f"\n{'='*80}")
  311. safe_print(f"结果摘要 - {post_name} [线程:{thread_id}]")
  312. safe_print(f"{'='*80}")
  313. if result:
  314. three_points = result.get("三点解构", {})
  315. inspiration_data = three_points.get("灵感点", {})
  316. keypoints_data = three_points.get("关键点", {})
  317. comments = result.get("评论分析", {}).get("解构维度", [])
  318. safe_print(f"\n三点解构:")
  319. safe_print(f" - 灵感点数量: {inspiration_data.get('total_count', 0)}")
  320. safe_print(f" - 灵感点分析模式: {inspiration_data.get('analysis_mode', '未知')}")
  321. safe_print(f" - 目的点数量: 1")
  322. safe_print(f" - 关键点数量: {keypoints_data.get('total_count', 0)}")
  323. safe_print(f"\n评论分析:")
  324. safe_print(f" - 解构维度数: {len(comments)}")
  325. topic_understanding = result.get("选题理解", {})
  326. if topic_understanding:
  327. topic_theme = topic_understanding.get("topic_theme", "")
  328. safe_print(f"\n选题理解:")
  329. safe_print(f" - 选题主题: {topic_theme}")
  330. counter.increment_success()
  331. return True
  332. def main():
  333. """主函数(并发版本,带历史帖子)"""
  334. # 解析命令行参数
  335. parser = argparse.ArgumentParser(description='批量处理帖子的What解构工作流')
  336. parser.add_argument('directory', type=str, help='帖子目录名(如"阿里多多酱"或"G88818")')
  337. args = parser.parse_args()
  338. directory = args.directory
  339. print("=" * 80)
  340. print(f"开始批量处理作者历史帖子(并发模式,带历史帖子)- 目录: {directory}")
  341. print("=" * 80)
  342. # 配置
  343. posts_dir = Path(__file__).parent / directory / "作者历史帖子"
  344. output_dir = Path(__file__).parent / directory / "output"
  345. output_dir.mkdir(parents=True, exist_ok=True)
  346. # 并发配置:设置最大线程数(建议根据CPU核心数和API限制调整)
  347. MAX_WORKERS = 4 # 可以根据需要调整,建议不超过5
  348. # 处理帖子数量限制
  349. MAX_POSTS = 20 # 只处理最新的20个帖子
  350. # 1. 加载所有帖子文件(按时间从新到旧排序,取最新20个)
  351. print(f"\n[1] 扫描帖子文件...")
  352. try:
  353. post_files = load_post_files(posts_dir, max_count=MAX_POSTS)
  354. print(f"✅ 选择 {len(post_files)} 个最新帖子进行处理")
  355. except Exception as e:
  356. print(f"❌ 扫描帖子文件失败: {e}")
  357. return
  358. # 2. 初始化线程安全计数器
  359. print(f"\n[2] 初始化并发处理...")
  360. print(f" - 最大并发线程数: {MAX_WORKERS}")
  361. counter = ThreadSafeCounter()
  362. # 3. 使用线程池并发处理所有帖子
  363. print(f"\n[3] 开始并发处理...")
  364. print(f" 注意:多个线程会同时处理不同的帖子,每个帖子都会加载对应的历史帖子")
  365. print("=" * 80)
  366. start_time = datetime.now()
  367. # 使用ThreadPoolExecutor进行并发处理
  368. with ThreadPoolExecutor(max_workers=MAX_WORKERS, thread_name_prefix="Worker") as executor:
  369. # 提交所有任务
  370. future_to_post = {
  371. executor.submit(
  372. process_single_post,
  373. post_file,
  374. posts_dir, # 传入帖子目录,用于加载历史帖子
  375. output_dir,
  376. counter,
  377. len(post_files),
  378. i
  379. ): (post_file, i)
  380. for i, post_file in enumerate(post_files, 1)
  381. }
  382. # 等待所有任务完成
  383. for future in as_completed(future_to_post):
  384. post_file, index = future_to_post[future]
  385. try:
  386. result = future.result()
  387. if not result:
  388. safe_print(f"⚠️ 处理帖子失败: {post_file.name}")
  389. except Exception as e:
  390. safe_print(f"❌ 处理帖子时发生异常: {post_file.name}")
  391. safe_print(f" 错误: {e}")
  392. import traceback
  393. safe_print(traceback.format_exc())
  394. counter.increment_fail()
  395. end_time = datetime.now()
  396. duration = (end_time - start_time).total_seconds()
  397. # 4. 总结
  398. print("\n" + "=" * 80)
  399. print("批量处理完成")
  400. print("=" * 80)
  401. print(f"\n总计: {len(post_files)} 个帖子")
  402. print(f"成功: {counter.success} 个")
  403. print(f"失败: {counter.fail} 个")
  404. print(f"耗时: {duration:.2f} 秒")
  405. print(f"平均每个帖子: {duration/len(post_files):.2f} 秒")
  406. print(f"\n结果保存目录: {output_dir}")
  407. print("=" * 80)
  408. if __name__ == "__main__":
  409. main()