| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388 |
- #!/usr/bin/env python3
- """
- Pipeline Runner:批量执行完整分析 + SFT 数据生成流程
- 功能:
- - 自动按 window_size 切分小说,串行调用 step1_analyze.py
- - 每个窗口的分析完成后自动传给下一窗口(保持人物/线索连贯)
- - 并行调用 step2_build_sft.py 生成三类 SFT 数据
- - 所有窗口完成后合并 JSONL 到 merged/ 目录
- - **支持断点续跑**:已存在的输出文件自动跳过,直接从中断处继续
- 用法:
- cd examples/analyze_story/sft
- python run_pipeline.py --novel ../input/大奉打更人.txt
- # 指定输出目录(默认在 sft/ 下以文件名命名)
- python run_pipeline.py --novel ../input/大奉打更人.txt --output-dir runs/大奉/
- # 跳过某个任务,调整并发数
- python run_pipeline.py --novel ../input/大奉打更人.txt --skip-task 3 --concurrency 8
- # 只重新跑 step2(分析已完成的情况下)
- python run_pipeline.py --novel ../input/大奉打更人.txt --only-step 2
- # 强制重新跑(忽略已有文件)
- python run_pipeline.py --novel ../input/大奉打更人.txt --force
- 输出结构:
- {output_dir}/
- analysis/
- w0.json ← 第一个窗口分析
- w1.json ← 第二个窗口分析(如有)
- ...
- sft_raw/
- w0/
- task1_structure_planning.jsonl
- task2_scene_continuation.jsonl
- task3_shuang_injection.jsonl
- stats.json
- w1/
- ...
- merged/ ← 所有窗口合并后的最终数据
- task1_structure_planning.jsonl
- task2_scene_continuation.jsonl
- task3_shuang_injection.jsonl
- stats.json ← 汇总统计
- pipeline.log ← 运行日志(追加写入)
- """
- import sys
- import json
- import math
- import argparse
- import subprocess
- import datetime
- from pathlib import Path
- SCRIPT_DIR = Path(__file__).parent
- STEP1 = SCRIPT_DIR / "step1_analyze.py"
- STEP2 = SCRIPT_DIR / "step2_build_sft.py"
- SFT_TASKS = [
- "task1_structure_planning.jsonl",
- "task2_scene_continuation.jsonl",
- "task3_shuang_injection.jsonl",
- ]
- # ──────────────────────────────────────────────────────────────
- # 工具
- # ──────────────────────────────────────────────────────────────
- def load_text_size(path: str) -> int:
- """粗略估算文件字符数(不完整解码,用字节数 / 1.5 估算中文字符数)"""
- for enc in ["utf-8", "gbk", "gb2312", "gb18030"]:
- try:
- return len(Path(path).read_text(encoding=enc))
- except UnicodeDecodeError:
- continue
- raise ValueError(f"无法解码文件: {path}")
- def count_jsonl_lines(path: Path) -> int:
- if not path.exists():
- return 0
- return sum(1 for line in path.read_text(encoding="utf-8").splitlines() if line.strip())
- class Logger:
- def __init__(self, log_path: Path):
- self.log_path = log_path
- log_path.parent.mkdir(parents=True, exist_ok=True)
- def log(self, msg: str):
- ts = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
- line = f"[{ts}] {msg}"
- print(line)
- with open(self.log_path, "a", encoding="utf-8") as f:
- f.write(line + "\n")
- def run_cmd(cmd: List[str], logger: Logger) -> bool:
- """执行子进程,实时打印输出,返回是否成功"""
- logger.log(f"运行: {' '.join(str(c) for c in cmd)}")
- proc = subprocess.Popen(
- [str(c) for c in cmd],
- stdout=subprocess.PIPE,
- stderr=subprocess.STDOUT,
- text=True,
- encoding="utf-8",
- )
- for line in proc.stdout:
- print(line, end="", flush=True)
- proc.wait()
- if proc.returncode != 0:
- logger.log(f"失败(返回码 {proc.returncode})")
- return False
- return True
- # ──────────────────────────────────────────────────────────────
- # Step 1:逐窗口分析
- # ──────────────────────────────────────────────────────────────
- def run_step1_all(
- novel: str,
- analysis_dir: Path,
- n_windows: int,
- window_size: int,
- model: str,
- force: bool,
- logger: Logger,
- only_step: Optional[int],
- ) -> List[Path]:
- """串行分析所有窗口,返回成功生成的 analysis 文件路径列表"""
- if only_step == 2:
- # 只跑 step2,直接读已有的分析文件
- files = sorted(analysis_dir.glob("w*.json"))
- logger.log(f"[Step1 跳过] 使用已有分析文件 {len(files)} 个")
- return files
- analysis_dir.mkdir(parents=True, exist_ok=True)
- analysis_files: List[Path] = []
- prev_analysis: Optional[Path] = None
- for i in range(n_windows):
- out = analysis_dir / f"w{i}.json"
- if out.exists() and not force:
- logger.log(f"[Step1 w{i}] 已存在,跳过 → {out}")
- analysis_files.append(out)
- prev_analysis = out
- continue
- logger.log(f"[Step1 w{i}/{n_windows-1}] 开始分析")
- cmd = [
- sys.executable, STEP1,
- "--novel", novel,
- "--window-index", str(i),
- "--window-size", str(window_size),
- "--output", str(out),
- "--model", model,
- ]
- if prev_analysis:
- cmd += ["--prev-analysis", str(prev_analysis)]
- ok = run_cmd(cmd, logger)
- if not ok:
- logger.log(f"[Step1 w{i}] 失败,跳过后续窗口")
- break
- analysis_files.append(out)
- prev_analysis = out
- return analysis_files
- # ──────────────────────────────────────────────────────────────
- # Step 2:为每个分析文件生成 SFT 数据
- # ──────────────────────────────────────────────────────────────
- def run_step2_all(
- novel: str,
- analysis_files: List[Path],
- sft_raw_dir: Path,
- context_chars: int,
- concurrency: int,
- skip_tasks: List[int],
- model: str,
- force: bool,
- logger: Logger,
- only_step: Optional[int],
- ) -> List[Path]:
- """为每个 analysis 文件生成 SFT 数据,返回成功的 sft 子目录列表"""
- if only_step == 1:
- logger.log("[Step2 跳过] --only-step 1")
- return []
- sft_dirs: List[Path] = []
- for analysis_path in analysis_files:
- window_name = analysis_path.stem # e.g. "w0"
- sft_dir = sft_raw_dir / window_name
- done_flag = sft_dir / "stats.json"
- if done_flag.exists() and not force:
- logger.log(f"[Step2 {window_name}] 已存在,跳过 → {sft_dir}")
- sft_dirs.append(sft_dir)
- continue
- logger.log(f"[Step2 {window_name}] 开始生成 SFT 数据")
- cmd = [
- sys.executable, STEP2,
- "--analysis", str(analysis_path),
- "--novel", novel,
- "--output-dir", str(sft_dir),
- "--context-chars", str(context_chars),
- "--concurrency", str(concurrency),
- "--model", model,
- ]
- for t in skip_tasks:
- cmd += ["--skip-task", str(t)]
- ok = run_cmd(cmd, logger)
- if ok:
- sft_dirs.append(sft_dir)
- else:
- logger.log(f"[Step2 {window_name}] 失败,继续处理其他窗口")
- return sft_dirs
- # ──────────────────────────────────────────────────────────────
- # 合并
- # ──────────────────────────────────────────────────────────────
- def merge_jsonl(sft_dirs: List[Path], merged_dir: Path, logger: Logger):
- """合并所有窗口的 JSONL 文件到 merged/ 目录"""
- if not sft_dirs:
- logger.log("[Merge] 无 SFT 数据可合并")
- return
- merged_dir.mkdir(parents=True, exist_ok=True)
- total_stats: Dict[str, int] = {}
- for task_file in SFT_TASKS:
- out_path = merged_dir / task_file
- count = 0
- with open(out_path, "w", encoding="utf-8") as out_f:
- for sft_dir in sft_dirs:
- src = sft_dir / task_file
- if src.exists():
- text = src.read_text(encoding="utf-8")
- lines = [l for l in text.splitlines() if l.strip()]
- for line in lines:
- out_f.write(line + "\n")
- count += len(lines)
- total_stats[task_file] = count
- logger.log(f"[Merge] {task_file}: {count} 条")
- # 汇总统计
- stats_path = merged_dir / "stats.json"
- total = sum(total_stats.values())
- stats = {
- "total_samples": total,
- "by_task": total_stats,
- "windows": len(sft_dirs),
- "merged_at": datetime.datetime.now().isoformat(),
- }
- stats_path.write_text(json.dumps(stats, ensure_ascii=False, indent=2), encoding="utf-8")
- logger.log(f"[Merge] 完成,总计 {total} 条样本 → {merged_dir}")
- # 打印汇总表
- print(f"\n{'='*50}")
- print("合并结果汇总")
- print(f"{'='*50}")
- for task_file, count in total_stats.items():
- name = task_file.replace(".jsonl", "")
- print(f" {name:<40} {count:>6} 条")
- print(f" {'总计':<40} {total:>6} 条")
- print(f"{'='*50}\n")
- # ──────────────────────────────────────────────────────────────
- # 主入口
- # ──────────────────────────────────────────────────────────────
- def main():
- parser = argparse.ArgumentParser(
- description="Pipeline Runner:批量分析小说并生成 SFT 训练数据",
- formatter_class=argparse.RawDescriptionHelpFormatter,
- )
- parser.add_argument("--novel", required=True, help="小说 txt 文件路径")
- parser.add_argument(
- "--output-dir", default=None,
- help="输出根目录(默认:sft/ 目录下以文件名命名,如 runs/大奉打更人/)",
- )
- parser.add_argument(
- "--window-size", type=int, default=500_000,
- help="每个分析窗口的字符数(默认 500000)",
- )
- parser.add_argument("--model", default="qwen-plus", help="使用的模型名称")
- parser.add_argument(
- "--context-chars", type=int, default=800,
- help="Step2 中提取上文的字符数(默认 800)",
- )
- parser.add_argument(
- "--concurrency", type=int, default=5,
- help="Step2 并发 LLM 调用数(默认 5)",
- )
- parser.add_argument(
- "--skip-task", type=int, action="append", default=[],
- metavar="N", help="跳过 Step2 的某个任务(1/2/3),可多次指定",
- )
- parser.add_argument(
- "--only-step", type=int, choices=[1, 2], default=None,
- help="只执行某个步骤(1=只分析,2=只生成SFT,需要analysis已存在)",
- )
- parser.add_argument(
- "--force", action="store_true",
- help="强制重新运行,忽略已有输出文件",
- )
- args = parser.parse_args()
- novel_path = Path(args.novel).resolve()
- if not novel_path.exists():
- print(f"错误:文件不存在 {novel_path}")
- sys.exit(1)
- # 输出目录
- if args.output_dir:
- output_dir = Path(args.output_dir).resolve()
- else:
- output_dir = SCRIPT_DIR / "runs" / novel_path.stem
- analysis_dir = output_dir / "analysis"
- sft_raw_dir = output_dir / "sft_raw"
- merged_dir = output_dir / "merged"
- log_path = output_dir / "pipeline.log"
- output_dir.mkdir(parents=True, exist_ok=True)
- logger = Logger(log_path)
- # 计算窗口数
- total_chars = load_text_size(str(novel_path))
- n_windows = math.ceil(total_chars / args.window_size)
- logger.log(f"{'='*60}")
- logger.log(f"小说:{novel_path.name} ({total_chars:,} 字符)")
- logger.log(f"窗口:{n_windows} 个(每窗口 {args.window_size:,} 字符)")
- logger.log(f"输出目录:{output_dir}")
- logger.log(f"模型:{args.model} 并发:{args.concurrency}")
- logger.log(f"跳过任务:{args.skip_task or '无'} 只执行步骤:{args.only_step or '全部'}")
- logger.log(f"强制重跑:{'是' if args.force else '否(已有文件将跳过)'}")
- logger.log(f"{'='*60}")
- # Step 1
- analysis_files = run_step1_all(
- str(novel_path), analysis_dir, n_windows,
- args.window_size, args.model, args.force, logger, args.only_step,
- )
- if not analysis_files:
- logger.log("没有可用的分析文件,退出。")
- sys.exit(1)
- # Step 2
- sft_dirs = run_step2_all(
- str(novel_path), analysis_files, sft_raw_dir,
- args.context_chars, args.concurrency, args.skip_task,
- args.model, args.force, logger, args.only_step,
- )
- # Merge
- if args.only_step != 1:
- merge_jsonl(sft_dirs, merged_dir, logger)
- logger.log("Pipeline 完成。")
- print(f"\n日志文件:{log_path}")
- print(f"最终数据:{merged_dir}")
- if __name__ == "__main__":
- main()
|