run_pipeline.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388
  1. #!/usr/bin/env python3
  2. """
  3. Pipeline Runner:批量执行完整分析 + SFT 数据生成流程
  4. 功能:
  5. - 自动按 window_size 切分小说,串行调用 step1_analyze.py
  6. - 每个窗口的分析完成后自动传给下一窗口(保持人物/线索连贯)
  7. - 并行调用 step2_build_sft.py 生成三类 SFT 数据
  8. - 所有窗口完成后合并 JSONL 到 merged/ 目录
  9. - **支持断点续跑**:已存在的输出文件自动跳过,直接从中断处继续
  10. 用法:
  11. cd examples/analyze_story/sft
  12. python run_pipeline.py --novel ../input/大奉打更人.txt
  13. # 指定输出目录(默认在 sft/ 下以文件名命名)
  14. python run_pipeline.py --novel ../input/大奉打更人.txt --output-dir runs/大奉/
  15. # 跳过某个任务,调整并发数
  16. python run_pipeline.py --novel ../input/大奉打更人.txt --skip-task 3 --concurrency 8
  17. # 只重新跑 step2(分析已完成的情况下)
  18. python run_pipeline.py --novel ../input/大奉打更人.txt --only-step 2
  19. # 强制重新跑(忽略已有文件)
  20. python run_pipeline.py --novel ../input/大奉打更人.txt --force
  21. 输出结构:
  22. {output_dir}/
  23. analysis/
  24. w0.json ← 第一个窗口分析
  25. w1.json ← 第二个窗口分析(如有)
  26. ...
  27. sft_raw/
  28. w0/
  29. task1_structure_planning.jsonl
  30. task2_scene_continuation.jsonl
  31. task3_shuang_injection.jsonl
  32. stats.json
  33. w1/
  34. ...
  35. merged/ ← 所有窗口合并后的最终数据
  36. task1_structure_planning.jsonl
  37. task2_scene_continuation.jsonl
  38. task3_shuang_injection.jsonl
  39. stats.json ← 汇总统计
  40. pipeline.log ← 运行日志(追加写入)
  41. """
  42. import sys
  43. import json
  44. import math
  45. import argparse
  46. import subprocess
  47. import datetime
  48. from pathlib import Path
  49. SCRIPT_DIR = Path(__file__).parent
  50. STEP1 = SCRIPT_DIR / "step1_analyze.py"
  51. STEP2 = SCRIPT_DIR / "step2_build_sft.py"
  52. SFT_TASKS = [
  53. "task1_structure_planning.jsonl",
  54. "task2_scene_continuation.jsonl",
  55. "task3_shuang_injection.jsonl",
  56. ]
  57. # ──────────────────────────────────────────────────────────────
  58. # 工具
  59. # ──────────────────────────────────────────────────────────────
  60. def load_text_size(path: str) -> int:
  61. """粗略估算文件字符数(不完整解码,用字节数 / 1.5 估算中文字符数)"""
  62. for enc in ["utf-8", "gbk", "gb2312", "gb18030"]:
  63. try:
  64. return len(Path(path).read_text(encoding=enc))
  65. except UnicodeDecodeError:
  66. continue
  67. raise ValueError(f"无法解码文件: {path}")
  68. def count_jsonl_lines(path: Path) -> int:
  69. if not path.exists():
  70. return 0
  71. return sum(1 for line in path.read_text(encoding="utf-8").splitlines() if line.strip())
  72. class Logger:
  73. def __init__(self, log_path: Path):
  74. self.log_path = log_path
  75. log_path.parent.mkdir(parents=True, exist_ok=True)
  76. def log(self, msg: str):
  77. ts = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
  78. line = f"[{ts}] {msg}"
  79. print(line)
  80. with open(self.log_path, "a", encoding="utf-8") as f:
  81. f.write(line + "\n")
  82. def run_cmd(cmd: List[str], logger: Logger) -> bool:
  83. """执行子进程,实时打印输出,返回是否成功"""
  84. logger.log(f"运行: {' '.join(str(c) for c in cmd)}")
  85. proc = subprocess.Popen(
  86. [str(c) for c in cmd],
  87. stdout=subprocess.PIPE,
  88. stderr=subprocess.STDOUT,
  89. text=True,
  90. encoding="utf-8",
  91. )
  92. for line in proc.stdout:
  93. print(line, end="", flush=True)
  94. proc.wait()
  95. if proc.returncode != 0:
  96. logger.log(f"失败(返回码 {proc.returncode})")
  97. return False
  98. return True
  99. # ──────────────────────────────────────────────────────────────
  100. # Step 1:逐窗口分析
  101. # ──────────────────────────────────────────────────────────────
  102. def run_step1_all(
  103. novel: str,
  104. analysis_dir: Path,
  105. n_windows: int,
  106. window_size: int,
  107. model: str,
  108. force: bool,
  109. logger: Logger,
  110. only_step: Optional[int],
  111. ) -> List[Path]:
  112. """串行分析所有窗口,返回成功生成的 analysis 文件路径列表"""
  113. if only_step == 2:
  114. # 只跑 step2,直接读已有的分析文件
  115. files = sorted(analysis_dir.glob("w*.json"))
  116. logger.log(f"[Step1 跳过] 使用已有分析文件 {len(files)} 个")
  117. return files
  118. analysis_dir.mkdir(parents=True, exist_ok=True)
  119. analysis_files: List[Path] = []
  120. prev_analysis: Optional[Path] = None
  121. for i in range(n_windows):
  122. out = analysis_dir / f"w{i}.json"
  123. if out.exists() and not force:
  124. logger.log(f"[Step1 w{i}] 已存在,跳过 → {out}")
  125. analysis_files.append(out)
  126. prev_analysis = out
  127. continue
  128. logger.log(f"[Step1 w{i}/{n_windows-1}] 开始分析")
  129. cmd = [
  130. sys.executable, STEP1,
  131. "--novel", novel,
  132. "--window-index", str(i),
  133. "--window-size", str(window_size),
  134. "--output", str(out),
  135. "--model", model,
  136. ]
  137. if prev_analysis:
  138. cmd += ["--prev-analysis", str(prev_analysis)]
  139. ok = run_cmd(cmd, logger)
  140. if not ok:
  141. logger.log(f"[Step1 w{i}] 失败,跳过后续窗口")
  142. break
  143. analysis_files.append(out)
  144. prev_analysis = out
  145. return analysis_files
  146. # ──────────────────────────────────────────────────────────────
  147. # Step 2:为每个分析文件生成 SFT 数据
  148. # ──────────────────────────────────────────────────────────────
  149. def run_step2_all(
  150. novel: str,
  151. analysis_files: List[Path],
  152. sft_raw_dir: Path,
  153. context_chars: int,
  154. concurrency: int,
  155. skip_tasks: List[int],
  156. model: str,
  157. force: bool,
  158. logger: Logger,
  159. only_step: Optional[int],
  160. ) -> List[Path]:
  161. """为每个 analysis 文件生成 SFT 数据,返回成功的 sft 子目录列表"""
  162. if only_step == 1:
  163. logger.log("[Step2 跳过] --only-step 1")
  164. return []
  165. sft_dirs: List[Path] = []
  166. for analysis_path in analysis_files:
  167. window_name = analysis_path.stem # e.g. "w0"
  168. sft_dir = sft_raw_dir / window_name
  169. done_flag = sft_dir / "stats.json"
  170. if done_flag.exists() and not force:
  171. logger.log(f"[Step2 {window_name}] 已存在,跳过 → {sft_dir}")
  172. sft_dirs.append(sft_dir)
  173. continue
  174. logger.log(f"[Step2 {window_name}] 开始生成 SFT 数据")
  175. cmd = [
  176. sys.executable, STEP2,
  177. "--analysis", str(analysis_path),
  178. "--novel", novel,
  179. "--output-dir", str(sft_dir),
  180. "--context-chars", str(context_chars),
  181. "--concurrency", str(concurrency),
  182. "--model", model,
  183. ]
  184. for t in skip_tasks:
  185. cmd += ["--skip-task", str(t)]
  186. ok = run_cmd(cmd, logger)
  187. if ok:
  188. sft_dirs.append(sft_dir)
  189. else:
  190. logger.log(f"[Step2 {window_name}] 失败,继续处理其他窗口")
  191. return sft_dirs
  192. # ──────────────────────────────────────────────────────────────
  193. # 合并
  194. # ──────────────────────────────────────────────────────────────
  195. def merge_jsonl(sft_dirs: List[Path], merged_dir: Path, logger: Logger):
  196. """合并所有窗口的 JSONL 文件到 merged/ 目录"""
  197. if not sft_dirs:
  198. logger.log("[Merge] 无 SFT 数据可合并")
  199. return
  200. merged_dir.mkdir(parents=True, exist_ok=True)
  201. total_stats: Dict[str, int] = {}
  202. for task_file in SFT_TASKS:
  203. out_path = merged_dir / task_file
  204. count = 0
  205. with open(out_path, "w", encoding="utf-8") as out_f:
  206. for sft_dir in sft_dirs:
  207. src = sft_dir / task_file
  208. if src.exists():
  209. text = src.read_text(encoding="utf-8")
  210. lines = [l for l in text.splitlines() if l.strip()]
  211. for line in lines:
  212. out_f.write(line + "\n")
  213. count += len(lines)
  214. total_stats[task_file] = count
  215. logger.log(f"[Merge] {task_file}: {count} 条")
  216. # 汇总统计
  217. stats_path = merged_dir / "stats.json"
  218. total = sum(total_stats.values())
  219. stats = {
  220. "total_samples": total,
  221. "by_task": total_stats,
  222. "windows": len(sft_dirs),
  223. "merged_at": datetime.datetime.now().isoformat(),
  224. }
  225. stats_path.write_text(json.dumps(stats, ensure_ascii=False, indent=2), encoding="utf-8")
  226. logger.log(f"[Merge] 完成,总计 {total} 条样本 → {merged_dir}")
  227. # 打印汇总表
  228. print(f"\n{'='*50}")
  229. print("合并结果汇总")
  230. print(f"{'='*50}")
  231. for task_file, count in total_stats.items():
  232. name = task_file.replace(".jsonl", "")
  233. print(f" {name:<40} {count:>6} 条")
  234. print(f" {'总计':<40} {total:>6} 条")
  235. print(f"{'='*50}\n")
  236. # ──────────────────────────────────────────────────────────────
  237. # 主入口
  238. # ──────────────────────────────────────────────────────────────
  239. def main():
  240. parser = argparse.ArgumentParser(
  241. description="Pipeline Runner:批量分析小说并生成 SFT 训练数据",
  242. formatter_class=argparse.RawDescriptionHelpFormatter,
  243. )
  244. parser.add_argument("--novel", required=True, help="小说 txt 文件路径")
  245. parser.add_argument(
  246. "--output-dir", default=None,
  247. help="输出根目录(默认:sft/ 目录下以文件名命名,如 runs/大奉打更人/)",
  248. )
  249. parser.add_argument(
  250. "--window-size", type=int, default=500_000,
  251. help="每个分析窗口的字符数(默认 500000)",
  252. )
  253. parser.add_argument("--model", default="qwen-plus", help="使用的模型名称")
  254. parser.add_argument(
  255. "--context-chars", type=int, default=800,
  256. help="Step2 中提取上文的字符数(默认 800)",
  257. )
  258. parser.add_argument(
  259. "--concurrency", type=int, default=5,
  260. help="Step2 并发 LLM 调用数(默认 5)",
  261. )
  262. parser.add_argument(
  263. "--skip-task", type=int, action="append", default=[],
  264. metavar="N", help="跳过 Step2 的某个任务(1/2/3),可多次指定",
  265. )
  266. parser.add_argument(
  267. "--only-step", type=int, choices=[1, 2], default=None,
  268. help="只执行某个步骤(1=只分析,2=只生成SFT,需要analysis已存在)",
  269. )
  270. parser.add_argument(
  271. "--force", action="store_true",
  272. help="强制重新运行,忽略已有输出文件",
  273. )
  274. args = parser.parse_args()
  275. novel_path = Path(args.novel).resolve()
  276. if not novel_path.exists():
  277. print(f"错误:文件不存在 {novel_path}")
  278. sys.exit(1)
  279. # 输出目录
  280. if args.output_dir:
  281. output_dir = Path(args.output_dir).resolve()
  282. else:
  283. output_dir = SCRIPT_DIR / "runs" / novel_path.stem
  284. analysis_dir = output_dir / "analysis"
  285. sft_raw_dir = output_dir / "sft_raw"
  286. merged_dir = output_dir / "merged"
  287. log_path = output_dir / "pipeline.log"
  288. output_dir.mkdir(parents=True, exist_ok=True)
  289. logger = Logger(log_path)
  290. # 计算窗口数
  291. total_chars = load_text_size(str(novel_path))
  292. n_windows = math.ceil(total_chars / args.window_size)
  293. logger.log(f"{'='*60}")
  294. logger.log(f"小说:{novel_path.name} ({total_chars:,} 字符)")
  295. logger.log(f"窗口:{n_windows} 个(每窗口 {args.window_size:,} 字符)")
  296. logger.log(f"输出目录:{output_dir}")
  297. logger.log(f"模型:{args.model} 并发:{args.concurrency}")
  298. logger.log(f"跳过任务:{args.skip_task or '无'} 只执行步骤:{args.only_step or '全部'}")
  299. logger.log(f"强制重跑:{'是' if args.force else '否(已有文件将跳过)'}")
  300. logger.log(f"{'='*60}")
  301. # Step 1
  302. analysis_files = run_step1_all(
  303. str(novel_path), analysis_dir, n_windows,
  304. args.window_size, args.model, args.force, logger, args.only_step,
  305. )
  306. if not analysis_files:
  307. logger.log("没有可用的分析文件,退出。")
  308. sys.exit(1)
  309. # Step 2
  310. sft_dirs = run_step2_all(
  311. str(novel_path), analysis_files, sft_raw_dir,
  312. args.context_chars, args.concurrency, args.skip_task,
  313. args.model, args.force, logger, args.only_step,
  314. )
  315. # Merge
  316. if args.only_step != 1:
  317. merge_jsonl(sft_dirs, merged_dir, logger)
  318. logger.log("Pipeline 完成。")
  319. print(f"\n日志文件:{log_path}")
  320. print(f"最终数据:{merged_dir}")
  321. if __name__ == "__main__":
  322. main()