run_step3_from_analysis.py 9.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322
  1. """
  2. 基于灵感匹配分析结果,批量执行 Step3 生成新灵感
  3. 从灵感匹配分析.json中筛选符合条件的灵感(step1 score 在指定范围内),
  4. 然后对每个灵感的 top1 匹配结果执行 step3,生成新的灵感点
  5. """
  6. import os
  7. import sys
  8. import json
  9. import asyncio
  10. import argparse
  11. from agents import trace
  12. from lib.my_trace import set_trace_smith as set_trace
  13. from lib.data_loader import load_persona_data
  14. import step3_generate_inspirations
  15. def filter_inspirations_by_score(
  16. analysis_file: str,
  17. min_score: float = 0.5,
  18. max_score: float = 0.8
  19. ) -> list:
  20. """从分析文件中筛选符合条件的灵感
  21. Args:
  22. analysis_file: 灵感匹配分析.json 文件路径
  23. min_score: step1 score 最小值(含)
  24. max_score: step1 score 最大值(含)
  25. Returns:
  26. 筛选后的灵感列表
  27. """
  28. with open(analysis_file, 'r', encoding='utf-8') as f:
  29. analysis_data = json.load(f)
  30. results = analysis_data.get("排序结果", [])
  31. filtered = []
  32. for item in results:
  33. step1_score = item["step1"]["匹配结果"].get("score", 0)
  34. if min_score <= step1_score <= max_score:
  35. filtered.append(item["灵感"])
  36. return filtered
  37. async def run_step3_for_inspiration(
  38. persona_dir: str,
  39. inspiration: str,
  40. persona_data: dict,
  41. force: bool = False
  42. ) -> dict:
  43. """为单个灵感执行 step3
  44. Args:
  45. persona_dir: 人设目录
  46. inspiration: 灵感名称
  47. persona_data: 人设数据
  48. force: 是否强制重新执行
  49. Returns:
  50. 执行结果字典
  51. """
  52. print(f"\n{'=' * 80}")
  53. print(f"处理灵感: {inspiration}")
  54. print(f"{'=' * 80}\n")
  55. # 查找 step1 结果文件
  56. model_name = "google/gemini-2.5-pro"
  57. step1_file = step3_generate_inspirations.find_step1_file(
  58. persona_dir, inspiration, model_name
  59. )
  60. # 读取 step1 结果
  61. with open(step1_file, 'r', encoding='utf-8') as f:
  62. step1_data = json.load(f)
  63. step1_results = step1_data.get("匹配结果列表", [])
  64. if not step1_results:
  65. print("❌ step1 结果为空")
  66. return {
  67. "灵感": inspiration,
  68. "status": "step1_empty",
  69. "output_file": None
  70. }
  71. # 获取 top1
  72. step1_top1 = step1_results[0]
  73. # 构建输出文件路径
  74. output_dir = os.path.join(persona_dir, "how", "灵感点", inspiration)
  75. model_name_short = model_name.replace("google/", "").replace("/", "_")
  76. step1_filename = os.path.basename(step1_file)
  77. step1_basename = os.path.splitext(step1_filename)[0]
  78. scope_prefix = step1_basename.split("_")[0]
  79. output_filename = f"{scope_prefix}_step3_top1_生成灵感_{model_name_short}.json"
  80. output_file = os.path.join(output_dir, output_filename)
  81. # 检查文件是否已存在
  82. if not force and os.path.exists(output_file):
  83. print(f"✓ 输出文件已存在,跳过: {output_file}")
  84. return {
  85. "灵感": inspiration,
  86. "status": "skipped",
  87. "output_file": output_file
  88. }
  89. # 创建独立的 trace
  90. current_time, log_url = set_trace()
  91. try:
  92. with trace(f"Step3: {inspiration}"):
  93. # 执行 step3
  94. output = await step3_generate_inspirations.process_step3_generate_inspirations(
  95. step1_top1=step1_top1,
  96. persona_data=persona_data,
  97. current_time=current_time,
  98. log_url=log_url
  99. )
  100. # 添加元数据
  101. output["元数据"]["step1_匹配索引"] = 1
  102. # 保存结果
  103. os.makedirs(output_dir, exist_ok=True)
  104. with open(output_file, 'w', encoding='utf-8') as f:
  105. json.dump(output, f, ensure_ascii=False, indent=2)
  106. # 输出预览
  107. inspirations = output.get("灵感点列表", [])
  108. print(f"✓ 生成了 {len(inspirations)} 个灵感点")
  109. if log_url:
  110. print(f" Trace: {log_url}")
  111. return {
  112. "灵感": inspiration,
  113. "status": "success",
  114. "output_file": output_file,
  115. "生成数量": len(inspirations)
  116. }
  117. except Exception as e:
  118. print(f"❌ 执行失败: {e}")
  119. return {
  120. "灵感": inspiration,
  121. "status": "error",
  122. "output_file": None,
  123. "error": str(e)
  124. }
  125. async def main():
  126. """主函数"""
  127. parser = argparse.ArgumentParser(
  128. description="基于灵感匹配分析结果,批量执行 Step3",
  129. formatter_class=argparse.RawDescriptionHelpFormatter,
  130. epilog="""
  131. 使用示例:
  132. # 测试:只处理第1个符合条件的灵感
  133. python run_step3_from_analysis.py --count 1
  134. # 使用默认参数(step1 score 在 [0.5, 0.8] 区间)
  135. python run_step3_from_analysis.py
  136. # 指定 score 范围
  137. python run_step3_from_analysis.py --min-score 0.6 --max-score 0.9
  138. # 强制重新执行,处理前3个
  139. python run_step3_from_analysis.py --force --count 3
  140. # 指定人设目录
  141. python run_step3_from_analysis.py --dir data/阿里多多酱/out/人设_1110
  142. """
  143. )
  144. parser.add_argument(
  145. "--dir",
  146. default="data/阿里多多酱/out/人设_1110",
  147. help="人设目录路径 (默认: data/阿里多多酱/out/人设_1110)"
  148. )
  149. parser.add_argument(
  150. "--analysis-file",
  151. default=None,
  152. help="灵感匹配分析文件路径 (默认: {dir}/how/灵感匹配分析.json)"
  153. )
  154. parser.add_argument(
  155. "--min-score",
  156. type=float,
  157. default=0.5,
  158. help="step1 score 最小值(含)(默认: 0.5)"
  159. )
  160. parser.add_argument(
  161. "--max-score",
  162. type=float,
  163. default=0.8,
  164. help="step1 score 最大值(含)(默认: 0.8)"
  165. )
  166. parser.add_argument(
  167. "--force",
  168. action="store_true",
  169. help="强制重新执行,覆盖已存在的文件"
  170. )
  171. parser.add_argument(
  172. "--count",
  173. type=int,
  174. default=1,
  175. help="处理的灵感数量限制(默认: 1)"
  176. )
  177. args = parser.parse_args()
  178. persona_dir = args.dir
  179. min_score = args.min_score
  180. max_score = args.max_score
  181. force = args.force
  182. count_limit = args.count
  183. # 确定分析文件路径
  184. if args.analysis_file:
  185. analysis_file = args.analysis_file
  186. else:
  187. analysis_file = os.path.join(persona_dir, "how", "灵感匹配分析.json")
  188. print(f"{'=' * 80}")
  189. print(f"基于灵感匹配分析,批量执行 Step3")
  190. print(f"{'=' * 80}")
  191. print(f"人设目录: {persona_dir}")
  192. print(f"分析文件: {analysis_file}")
  193. print(f"Score 范围: [{min_score}, {max_score}]")
  194. if count_limit:
  195. print(f"数量限制: 处理前 {count_limit} 个")
  196. if force:
  197. print(f"强制模式: 重新执行所有步骤")
  198. print()
  199. # 检查分析文件是否存在
  200. if not os.path.exists(analysis_file):
  201. print(f"❌ 分析文件不存在: {analysis_file}")
  202. print(f"请先运行 analyze_inspiration_results.py 生成分析文件")
  203. sys.exit(1)
  204. # 筛选灵感
  205. filtered_inspirations = filter_inspirations_by_score(
  206. analysis_file, min_score, max_score
  207. )
  208. if not filtered_inspirations:
  209. print(f"❌ 没有找到符合条件的灵感(step1 score 在 [{min_score}, {max_score}] 范围内)")
  210. sys.exit(0)
  211. # 应用数量限制
  212. if count_limit and count_limit < len(filtered_inspirations):
  213. filtered_inspirations = filtered_inspirations[:count_limit]
  214. print(f"找到 {len(filtered_inspirations)} 个符合条件的灵感(已应用数量限制):\n")
  215. else:
  216. print(f"找到 {len(filtered_inspirations)} 个符合条件的灵感:\n")
  217. for i, insp in enumerate(filtered_inspirations, 1):
  218. print(f" {i}. {insp}")
  219. print()
  220. # 加载人设数据(只需要加载一次)
  221. persona_data = load_persona_data(persona_dir)
  222. # 批量执行 step3
  223. results = []
  224. for i, inspiration in enumerate(filtered_inspirations, 1):
  225. print(f"\n{'#' * 80}")
  226. print(f"处理第 {i}/{len(filtered_inspirations)} 个")
  227. print(f"{'#' * 80}")
  228. result = await run_step3_for_inspiration(
  229. persona_dir=persona_dir,
  230. inspiration=inspiration,
  231. persona_data=persona_data,
  232. force=force
  233. )
  234. results.append(result)
  235. # 输出最终汇总
  236. print(f"\n{'=' * 80}")
  237. print(f"批量处理完成")
  238. print(f"{'=' * 80}\n")
  239. success_count = sum(1 for r in results if r["status"] == "success")
  240. skipped_count = sum(1 for r in results if r["status"] == "skipped")
  241. error_count = sum(1 for r in results if r["status"] == "error")
  242. print(f"统计:")
  243. print(f" 总数: {len(results)}")
  244. print(f" 成功: {success_count}")
  245. print(f" 跳过: {skipped_count}")
  246. print(f" 失败: {error_count}")
  247. print(f"\n详细结果:")
  248. for i, result in enumerate(results, 1):
  249. status_icon = {
  250. "success": "✓",
  251. "skipped": "○",
  252. "error": "✗",
  253. "step1_empty": "⚠"
  254. }.get(result["status"], "?")
  255. status_text = {
  256. "success": f"成功,生成 {result.get('生成数量', 0)} 个",
  257. "skipped": "已存在",
  258. "error": f"失败: {result.get('error', '')}",
  259. "step1_empty": "step1 结果为空"
  260. }.get(result["status"], result["status"])
  261. print(f" {status_icon} [{i}] {result['灵感']} - {status_text}")
  262. if __name__ == "__main__":
  263. asyncio.run(main())