run_step3_from_folder.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376
  1. """
  2. 从灵感点文件夹直接读取 step1 结果,排序后批量执行 Step3
  3. 直接扫描 how/灵感点/ 目录下的所有灵感,读取其 step1 结果文件,
  4. 根据 score 排序并筛选,然后执行 step3 生成新灵感点
  5. """
  6. import os
  7. import sys
  8. import json
  9. import asyncio
  10. import argparse
  11. from pathlib import Path
  12. from typing import List, Tuple
  13. from agents import trace
  14. from lib.my_trace import set_trace_smith as set_trace
  15. from lib.data_loader import load_persona_data
  16. import step3_generate_inspirations
  17. def scan_inspirations_from_folder(
  18. inspirations_dir: str,
  19. model_name: str = "google/gemini-2.5-pro"
  20. ) -> List[Tuple[str, dict]]:
  21. """扫描灵感点文件夹,读取所有 step1 结果
  22. Args:
  23. inspirations_dir: 灵感点目录路径 (如 data/.../how/灵感点)
  24. model_name: 模型名称
  25. Returns:
  26. [(灵感名称, step1数据), ...] 列表
  27. """
  28. results = []
  29. model_name_short = model_name.replace("google/", "").replace("/", "_")
  30. # 遍历所有灵感子目录
  31. for inspiration_dir in Path(inspirations_dir).iterdir():
  32. if not inspiration_dir.is_dir():
  33. continue
  34. inspiration_name = inspiration_dir.name
  35. # 查找 step1 文件
  36. step1_pattern = f"*_step1_*_{model_name_short}.json"
  37. step1_files = list(inspiration_dir.glob(step1_pattern))
  38. if not step1_files:
  39. print(f"⚠️ 跳过 {inspiration_name}: 未找到 step1 文件")
  40. continue
  41. # 读取 step1 文件
  42. step1_file = step1_files[0]
  43. try:
  44. with open(step1_file, 'r', encoding='utf-8') as f:
  45. step1_data = json.load(f)
  46. results.append((inspiration_name, step1_data))
  47. except Exception as e:
  48. print(f"⚠️ 读取失败 {inspiration_name}: {e}")
  49. continue
  50. return results
  51. def filter_and_sort_inspirations(
  52. inspirations_data: List[Tuple[str, dict]],
  53. min_score: float = 0.5,
  54. max_score: float = 0.8
  55. ) -> List[Tuple[str, float, dict]]:
  56. """筛选并排序灵感
  57. Args:
  58. inspirations_data: [(灵感名称, step1数据), ...]
  59. min_score: 最小 score
  60. max_score: 最大 score
  61. Returns:
  62. [(灵感名称, score, step1数据), ...] 按 score 降序排列
  63. """
  64. filtered = []
  65. for inspiration_name, step1_data in inspirations_data:
  66. # 获取匹配结果列表
  67. match_results = step1_data.get("匹配结果列表", [])
  68. if not match_results:
  69. continue
  70. # 取 top1 的 score
  71. top1_result = match_results[0]
  72. match_result = top1_result.get("匹配结果", {})
  73. score = match_result.get("score", 0)
  74. # 筛选
  75. if min_score <= score <= max_score:
  76. filtered.append((inspiration_name, score, step1_data))
  77. # 按 score 降序排序
  78. filtered.sort(key=lambda x: x[1], reverse=True)
  79. return filtered
  80. async def run_step3_for_inspiration(
  81. persona_dir: str,
  82. inspiration: str,
  83. step1_data: dict,
  84. persona_data: dict,
  85. force: bool = False
  86. ) -> dict:
  87. """为单个灵感执行 step3
  88. Args:
  89. persona_dir: 人设目录
  90. inspiration: 灵感名称
  91. step1_data: step1 完整数据
  92. persona_data: 人设数据
  93. force: 是否强制重新执行
  94. Returns:
  95. 执行结果字典
  96. """
  97. print(f"\n{'=' * 80}")
  98. print(f"处理灵感: {inspiration}")
  99. print(f"{'=' * 80}\n")
  100. # 获取 step1 结果
  101. step1_results = step1_data.get("匹配结果列表", [])
  102. if not step1_results:
  103. print("❌ step1 结果为空")
  104. return {
  105. "灵感": inspiration,
  106. "status": "step1_empty",
  107. "output_file": None
  108. }
  109. # 获取 top1
  110. step1_top1 = step1_results[0]
  111. # 构建输出文件路径
  112. model_name = "google/gemini-2.5-pro"
  113. output_dir = os.path.join(persona_dir, "how", "灵感点", inspiration)
  114. model_name_short = model_name.replace("google/", "").replace("/", "_")
  115. # 根据 step1 文件名确定前缀
  116. step1_file_name = f"all_step1_灵感人设匹配_{model_name_short}.json"
  117. scope_prefix = "all"
  118. output_filename = f"{scope_prefix}_step3_top1_生成灵感_{model_name_short}.json"
  119. output_file = os.path.join(output_dir, output_filename)
  120. # 检查文件是否已存在
  121. if not force and os.path.exists(output_file):
  122. print(f"✓ 输出文件已存在,跳过: {output_file}")
  123. return {
  124. "灵感": inspiration,
  125. "status": "skipped",
  126. "output_file": output_file
  127. }
  128. # 创建独立的 trace
  129. current_time, log_url = set_trace()
  130. try:
  131. with trace(f"Step3: {inspiration}"):
  132. # 执行 step3
  133. output = await step3_generate_inspirations.process_step3_generate_inspirations(
  134. step1_top1=step1_top1,
  135. persona_data=persona_data,
  136. current_time=current_time,
  137. log_url=log_url
  138. )
  139. # 添加元数据
  140. output["元数据"]["step1_匹配索引"] = 1
  141. # 保存结果
  142. os.makedirs(output_dir, exist_ok=True)
  143. with open(output_file, 'w', encoding='utf-8') as f:
  144. json.dump(output, f, ensure_ascii=False, indent=2)
  145. # 输出预览
  146. inspirations = output.get("灵感点列表", [])
  147. print(f"✓ 生成了 {len(inspirations)} 个灵感点")
  148. if log_url:
  149. print(f" Trace: {log_url}")
  150. return {
  151. "灵感": inspiration,
  152. "status": "success",
  153. "output_file": output_file,
  154. "生成数量": len(inspirations)
  155. }
  156. except Exception as e:
  157. print(f"❌ 执行失败: {e}")
  158. return {
  159. "灵感": inspiration,
  160. "status": "error",
  161. "output_file": None,
  162. "error": str(e)
  163. }
  164. async def main():
  165. """主函数"""
  166. parser = argparse.ArgumentParser(
  167. description="从灵感点文件夹直接读取 step1,排序后批量执行 Step3",
  168. formatter_class=argparse.RawDescriptionHelpFormatter,
  169. epilog="""
  170. 使用示例:
  171. # 测试:只处理第1个符合条件的灵感
  172. python run_step3_from_folder.py --count 1
  173. # 使用默认参数(step1 score 在 [0.5, 0.8] 区间)
  174. python run_step3_from_folder.py
  175. # 指定 score 范围
  176. python run_step3_from_folder.py --min-score 0.6 --max-score 0.9
  177. # 强制重新执行,处理前3个
  178. python run_step3_from_folder.py --force --count 3
  179. # 指定人设目录
  180. python run_step3_from_folder.py --dir data/阿里多多酱/out/人设_1110
  181. """
  182. )
  183. parser.add_argument(
  184. "--dir",
  185. default="data/阿里多多酱/out/人设_1110",
  186. help="人设目录路径 (默认: data/阿里多多酱/out/人设_1110)"
  187. )
  188. parser.add_argument(
  189. "--min-score",
  190. type=float,
  191. default=0.5,
  192. help="step1 score 最小值(含)(默认: 0.5)"
  193. )
  194. parser.add_argument(
  195. "--max-score",
  196. type=float,
  197. default=0.8,
  198. help="step1 score 最大值(含)(默认: 0.8)"
  199. )
  200. parser.add_argument(
  201. "--force",
  202. action="store_true",
  203. help="强制重新执行,覆盖已存在的文件"
  204. )
  205. parser.add_argument(
  206. "--count",
  207. type=int,
  208. default=None,
  209. help="处理的灵感数量限制(默认: 无限制)"
  210. )
  211. args = parser.parse_args()
  212. persona_dir = args.dir
  213. min_score = args.min_score
  214. max_score = args.max_score
  215. force = args.force
  216. count_limit = args.count
  217. # 确定灵感点目录
  218. inspirations_dir = os.path.join(persona_dir, "how", "灵感点")
  219. print(f"{'=' * 80}")
  220. print(f"从灵感点文件夹批量执行 Step3")
  221. print(f"{'=' * 80}")
  222. print(f"人设目录: {persona_dir}")
  223. print(f"灵感点目录: {inspirations_dir}")
  224. print(f"Score 范围: [{min_score}, {max_score}]")
  225. if count_limit:
  226. print(f"数量限制: 处理前 {count_limit} 个")
  227. if force:
  228. print(f"强制模式: 重新执行所有步骤")
  229. print()
  230. # 检查目录是否存在
  231. if not os.path.exists(inspirations_dir):
  232. print(f"❌ 灵感点目录不存在: {inspirations_dir}")
  233. sys.exit(1)
  234. # 扫描灵感点文件夹
  235. print("正在扫描灵感点文件夹...")
  236. inspirations_data = scan_inspirations_from_folder(inspirations_dir)
  237. print(f"找到 {len(inspirations_data)} 个灵感(有 step1 结果)\n")
  238. if not inspirations_data:
  239. print("❌ 没有找到任何有效的 step1 结果")
  240. sys.exit(0)
  241. # 筛选并排序
  242. print("正在筛选和排序...")
  243. filtered_sorted = filter_and_sort_inspirations(
  244. inspirations_data, min_score, max_score
  245. )
  246. if not filtered_sorted:
  247. print(f"❌ 没有找到符合条件的灵感(step1 score 在 [{min_score}, {max_score}] 范围内)")
  248. sys.exit(0)
  249. # 应用数量限制
  250. if count_limit and count_limit < len(filtered_sorted):
  251. filtered_sorted = filtered_sorted[:count_limit]
  252. print(f"找到 {len(filtered_sorted)} 个符合条件的灵感(已应用数量限制):\n")
  253. else:
  254. print(f"找到 {len(filtered_sorted)} 个符合条件的灵感:\n")
  255. # 打印列表
  256. for i, (insp_name, score, _) in enumerate(filtered_sorted, 1):
  257. print(f" {i}. {insp_name} (score: {score:.2f})")
  258. print()
  259. # 加载人设数据(只需要加载一次)
  260. persona_data = load_persona_data(persona_dir)
  261. # 批量执行 step3
  262. results = []
  263. for i, (inspiration_name, score, step1_data) in enumerate(filtered_sorted, 1):
  264. print(f"\n{'#' * 80}")
  265. print(f"处理第 {i}/{len(filtered_sorted)} 个 (score: {score:.2f})")
  266. print(f"{'#' * 80}")
  267. result = await run_step3_for_inspiration(
  268. persona_dir=persona_dir,
  269. inspiration=inspiration_name,
  270. step1_data=step1_data,
  271. persona_data=persona_data,
  272. force=force
  273. )
  274. results.append(result)
  275. # 输出最终汇总
  276. print(f"\n{'=' * 80}")
  277. print(f"批量处理完成")
  278. print(f"{'=' * 80}\n")
  279. success_count = sum(1 for r in results if r["status"] == "success")
  280. skipped_count = sum(1 for r in results if r["status"] == "skipped")
  281. error_count = sum(1 for r in results if r["status"] == "error")
  282. print(f"统计:")
  283. print(f" 总数: {len(results)}")
  284. print(f" 成功: {success_count}")
  285. print(f" 跳过: {skipped_count}")
  286. print(f" 失败: {error_count}")
  287. print(f"\n详细结果:")
  288. for i, result in enumerate(results, 1):
  289. status_icon = {
  290. "success": "✓",
  291. "skipped": "○",
  292. "error": "✗",
  293. "step1_empty": "⚠"
  294. }.get(result["status"], "?")
  295. status_text = {
  296. "success": f"成功,生成 {result.get('生成数量', 0)} 个",
  297. "skipped": "已存在",
  298. "error": f"失败: {result.get('error', '')}",
  299. "step1_empty": "step1 结果为空"
  300. }.get(result["status"], result["status"])
  301. print(f" {status_icon} [{i}] {result['灵感']} - {status_text}")
  302. if __name__ == "__main__":
  303. asyncio.run(main())