|
|
@@ -0,0 +1,766 @@
|
|
|
+"""
|
|
|
+长篇叙事 SFT 数据集构建工具
|
|
|
+============================
|
|
|
+
|
|
|
+三个核心任务:
|
|
|
+ Task 1: structure_planning - 给定上文,规划下一个 Scene-Sequel 结构
|
|
|
+ Task 2: scene_continuation - 给定上文+规划,续写正文(CoT + 正文)
|
|
|
+ Task 3: shuang_injection - 给定平淡草稿,注入爽点
|
|
|
+
|
|
|
+用法:
|
|
|
+ # 处理单个文件,生成所有任务的训练数据
|
|
|
+ python build_dataset.py --input ../input_1/大奉打更人.txt --tasks all
|
|
|
+
|
|
|
+ # 只生成续写任务数据
|
|
|
+ python build_dataset.py --input ../input_1/大奉打更人.txt --tasks task2
|
|
|
+
|
|
|
+ # 处理多个文件
|
|
|
+ python build_dataset.py --input ../input_1/ --tasks all --max-samples 50
|
|
|
+
|
|
|
+ # 指定模型(默认 gemini-2.0-flash-001,快速便宜)
|
|
|
+ python build_dataset.py --input ../input_1/大奉打更人.txt --model google/gemini-2.5-flash-preview
|
|
|
+"""
|
|
|
+
|
|
|
+import os
|
|
|
+import sys
|
|
|
+import json
|
|
|
+import re
|
|
|
+import argparse
|
|
|
+import time
|
|
|
+from pathlib import Path
|
|
|
+from typing import Iterator
|
|
|
+
|
|
|
+import requests
|
|
|
+from dotenv import load_dotenv
|
|
|
+
|
|
|
+# ── 路径设置 ──────────────────────────────────────────────────────────────────
|
|
|
+HERE = Path(__file__).parent
|
|
|
+ROOT = HERE.parent.parent.parent
|
|
|
+sys.path.insert(0, str(ROOT))
|
|
|
+load_dotenv(ROOT / ".env")
|
|
|
+
|
|
|
+OPEN_ROUTER_KEY = os.environ.get("OPEN_ROUTER_API_KEY", "")
|
|
|
+if not OPEN_ROUTER_KEY:
|
|
|
+ raise RuntimeError("请在 .env 中设置 OPEN_ROUTER_API_KEY")
|
|
|
+
|
|
|
+OUTPUT_DIR = HERE / "output"
|
|
|
+OUTPUT_DIR.mkdir(exist_ok=True)
|
|
|
+
|
|
|
+# ── LLM 调用(单次,同步)────────────────────────────────────────────────────
|
|
|
+
|
|
|
+def llm_call(
|
|
|
+ messages: list[dict],
|
|
|
+ model: str = "google/gemini-2.0-flash-001",
|
|
|
+ temperature: float = 0.7,
|
|
|
+ max_tokens: int = 4096,
|
|
|
+ retry: int = 3,
|
|
|
+) -> str:
|
|
|
+ """调用 OpenRouter API,返回 assistant 文本。失败自动重试。"""
|
|
|
+ for attempt in range(retry):
|
|
|
+ try:
|
|
|
+ resp = requests.post(
|
|
|
+ "https://openrouter.ai/api/v1/chat/completions",
|
|
|
+ headers={
|
|
|
+ "Authorization": f"Bearer {OPEN_ROUTER_KEY}",
|
|
|
+ "Content-Type": "application/json",
|
|
|
+ "HTTP-Referer": "https://github.com/narrative-sft",
|
|
|
+ },
|
|
|
+ json={
|
|
|
+ "model": model,
|
|
|
+ "messages": messages,
|
|
|
+ "max_tokens": max_tokens,
|
|
|
+ "temperature": temperature,
|
|
|
+ },
|
|
|
+ timeout=120,
|
|
|
+ )
|
|
|
+ resp.raise_for_status()
|
|
|
+ return resp.json()["choices"][0]["message"]["content"]
|
|
|
+ except Exception as e:
|
|
|
+ if attempt < retry - 1:
|
|
|
+ wait = 2 ** attempt
|
|
|
+ print(f" [LLM] 重试 {attempt+1}/{retry},等待 {wait}s... ({e})")
|
|
|
+ time.sleep(wait)
|
|
|
+ else:
|
|
|
+ raise
|
|
|
+
|
|
|
+
|
|
|
+# ── 文本解析:从小说文件中切分场景单元 ───────────────────────────────────────
|
|
|
+
|
|
|
+def detect_encoding(path: Path) -> str:
|
|
|
+ """检测文件编码(GBK 或 UTF-8)。"""
|
|
|
+ for enc in ("utf-8", "gbk", "gb18030"):
|
|
|
+ try:
|
|
|
+ path.read_text(encoding=enc)
|
|
|
+ return enc
|
|
|
+ except UnicodeDecodeError:
|
|
|
+ continue
|
|
|
+ return "utf-8"
|
|
|
+
|
|
|
+
|
|
|
+def load_novel(path: Path) -> str:
|
|
|
+ """读取小说文件,返回纯文本。"""
|
|
|
+ enc = detect_encoding(path)
|
|
|
+ text = path.read_text(encoding=enc, errors="replace")
|
|
|
+ # 去掉版权声明等头部
|
|
|
+ text = re.sub(r"={10,}.*?={10,}", "", text, flags=re.DOTALL)
|
|
|
+ return text.strip()
|
|
|
+
|
|
|
+
|
|
|
+def split_chapters(text: str) -> list[dict]:
|
|
|
+ """
|
|
|
+ 按章节切分文本。
|
|
|
+ 支持格式:
|
|
|
+ - "第N章 标题"(独占一行)
|
|
|
+ - "第N章\n正文"(章节标题后紧跟正文)
|
|
|
+ 返回: [{"chapter": "第N章 标题", "content": "正文", "index": N}, ...]
|
|
|
+ """
|
|
|
+ # 匹配常见章节标题格式(允许行首有全角空格)
|
|
|
+ pattern = re.compile(
|
|
|
+ r"^[ \s]*(第[零一二三四五六七八九十百千\d]+[章节回卷][^\n]{0,40})\s*$",
|
|
|
+ re.MULTILINE,
|
|
|
+ )
|
|
|
+ matches = list(pattern.finditer(text))
|
|
|
+ if not matches:
|
|
|
+ return [{"chapter": "全文", "content": text, "index": 0}]
|
|
|
+
|
|
|
+ chapters = []
|
|
|
+ for i, m in enumerate(matches):
|
|
|
+ start = m.end()
|
|
|
+ end = matches[i + 1].start() if i + 1 < len(matches) else len(text)
|
|
|
+ content = text[start:end].strip()
|
|
|
+ if len(content) > 300:
|
|
|
+ chapters.append({
|
|
|
+ "chapter": m.group(1).strip(),
|
|
|
+ "content": content,
|
|
|
+ "index": i,
|
|
|
+ })
|
|
|
+ return chapters
|
|
|
+
|
|
|
+
|
|
|
+def normalize_paragraphs(text: str) -> list[str]:
|
|
|
+ """
|
|
|
+ 将小说文本规范化为段落列表。
|
|
|
+ 处理:
|
|
|
+ - 单 \\n 分隔的段落(网文常见格式)
|
|
|
+ - 行首全角空格( )缩进
|
|
|
+ - 过滤空行
|
|
|
+ """
|
|
|
+ lines = text.split("\n")
|
|
|
+ paragraphs = []
|
|
|
+ for line in lines:
|
|
|
+ # 去掉行首全角/半角空格
|
|
|
+ line = line.strip().lstrip(" ").strip()
|
|
|
+ if line:
|
|
|
+ paragraphs.append(line)
|
|
|
+ return paragraphs
|
|
|
+
|
|
|
+
|
|
|
+def extract_scene_units(
|
|
|
+ chapter_content: str,
|
|
|
+ min_context_chars: int = 500,
|
|
|
+ target_context_chars: int = 1000,
|
|
|
+ target_continuation_chars: int = 600,
|
|
|
+) -> list[dict]:
|
|
|
+ """
|
|
|
+ 从章节内容中提取 Scene-Sequel 候选单元(滑动窗口)。
|
|
|
+
|
|
|
+ 返回: [{"context": str, "continuation": str,
|
|
|
+ "context_words": int, "continuation_words": int}, ...]
|
|
|
+ """
|
|
|
+ paragraphs = normalize_paragraphs(chapter_content)
|
|
|
+ if len(paragraphs) < 4:
|
|
|
+ return []
|
|
|
+
|
|
|
+ units = []
|
|
|
+ total = len(paragraphs)
|
|
|
+ i = 0
|
|
|
+
|
|
|
+ while i < total - 3:
|
|
|
+ # ── 积累 context ──────────────────────────────────────────────────────
|
|
|
+ ctx_paras, ctx_chars = [], 0
|
|
|
+ j = i
|
|
|
+ while j < total and ctx_chars < target_context_chars:
|
|
|
+ ctx_paras.append(paragraphs[j])
|
|
|
+ ctx_chars += len(paragraphs[j])
|
|
|
+ j += 1
|
|
|
+
|
|
|
+ if ctx_chars < min_context_chars or j >= total:
|
|
|
+ i += max(1, len(ctx_paras) // 2)
|
|
|
+ continue
|
|
|
+
|
|
|
+ # ── 积累 continuation ─────────────────────────────────────────────────
|
|
|
+ cont_paras, cont_chars = [], 0
|
|
|
+ k = j
|
|
|
+ while k < total and cont_chars < target_continuation_chars:
|
|
|
+ cont_paras.append(paragraphs[k])
|
|
|
+ cont_chars += len(paragraphs[k])
|
|
|
+ k += 1
|
|
|
+
|
|
|
+ if cont_chars < 150:
|
|
|
+ i += max(1, len(ctx_paras) // 2)
|
|
|
+ continue
|
|
|
+
|
|
|
+ units.append({
|
|
|
+ "context": "\n ".join(ctx_paras), # 还原网文缩进格式
|
|
|
+ "continuation": "\n ".join(cont_paras),
|
|
|
+ "context_words": ctx_chars,
|
|
|
+ "continuation_words": cont_chars,
|
|
|
+ })
|
|
|
+
|
|
|
+ # 步进:跳过 context 的一半(形成重叠,增加多样性)
|
|
|
+ i += max(1, len(ctx_paras) // 2)
|
|
|
+
|
|
|
+ return units
|
|
|
+
|
|
|
+
|
|
|
+# ── Prompt 模板 ───────────────────────────────────────────────────────────────
|
|
|
+
|
|
|
+SYSTEM_STRUCTURE_PLANNING = """你是一位专业的长篇小说结构分析师,精通以下叙事理论:
|
|
|
+- Scene-Sequel 结构(Dwight V. Swain):Scene = Goal→Conflict→Disaster;Sequel = Reaction→Dilemma→Decision
|
|
|
+- MICE Quotient(Orson Scott Card):Milieu / Idea / Character / Event 四类线程
|
|
|
+- Save the Cat 节拍(Blake Snyder):15个关键节拍
|
|
|
+- 网文爽点理论:打脸、升级、装逼、获得、碾压五类爽点
|
|
|
+
|
|
|
+你的任务是:分析给定的上文,规划下一个 Scene-Sequel 单元的结构。
|
|
|
+输出必须包含:
|
|
|
+1. <think> 标签内的叙事分析(真实的决策推理,不是事后解释)
|
|
|
+2. 结构化 JSON 规划"""
|
|
|
+
|
|
|
+SYSTEM_SCENE_CONTINUATION = """你是一位专业的网文作家,擅长写节奏紧凑、爽点密集的长篇小说。
|
|
|
+你精通 Scene-Sequel 结构,知道如何在续写中:
|
|
|
+- 自然衔接上文的叙事状态
|
|
|
+- 在正确位置植入爽点(铺垫→爆发→反应)
|
|
|
+- 在章节末尾设置钩子
|
|
|
+- 保持与原文一致的文风和节奏
|
|
|
+
|
|
|
+你的任务是:根据上文和结构规划,续写下一段正文。
|
|
|
+输出必须包含:
|
|
|
+1. <think> 标签内的写法决策(真实的创作思考过程)
|
|
|
+2. 续写正文"""
|
|
|
+
|
|
|
+SYSTEM_SHUANG_INJECTION = """你是一位专业的网文编辑,擅长识别和设计爽点。
|
|
|
+你知道爽点的三要素:铺垫(建立期待/对比)→ 爆发(核心爽感)→ 反应(放大效果)。
|
|
|
+
|
|
|
+你的任务是:分析给定的平淡草稿,注入爽点使其升级。
|
|
|
+输出必须包含:
|
|
|
+1. <think> 标签内的爽点设计分析
|
|
|
+2. 注入爽点后的增强版正文
|
|
|
+3. 简要的修改说明"""
|
|
|
+
|
|
|
+
|
|
|
+def make_structure_planning_prompt(
|
|
|
+ title: str,
|
|
|
+ chapter: str,
|
|
|
+ position_pct: float,
|
|
|
+ context: str,
|
|
|
+) -> list[dict]:
|
|
|
+ user_content = f"""## 书名
|
|
|
+{title}
|
|
|
+
|
|
|
+## 当前位置
|
|
|
+{chapter},约 {position_pct:.0%} 处
|
|
|
+
|
|
|
+## 上文(最近约 {len(context)} 字)
|
|
|
+{context}
|
|
|
+
|
|
|
+## 任务
|
|
|
+请分析上文的叙事状态,规划下一个 Scene-Sequel 单元的结构。
|
|
|
+
|
|
|
+**要求**:
|
|
|
+1. 在 <think> 中分析:
|
|
|
+ - 上文最后一个 Scene 的 Goal/Conflict/Disaster 是什么
|
|
|
+ - 上文最后一个 Sequel 的 Reaction/Dilemma/Decision 是什么(如果有)
|
|
|
+ - 当前激活的 MICE 线程(M/I/C/E)及其状态
|
|
|
+ - 当前处于 Save the Cat 的哪个节拍
|
|
|
+ - 下一步应该推进哪个线程,为什么
|
|
|
+2. 输出 JSON 格式的结构规划(严格按照下面的 schema)
|
|
|
+
|
|
|
+**JSON Schema**:
|
|
|
+```json
|
|
|
+{{
|
|
|
+ "scene": {{
|
|
|
+ "goal": "主角在这个场景想要达成什么(具体、可衡量)",
|
|
|
+ "conflict_type": "人物冲突|环境冲突|内心冲突|信息冲突",
|
|
|
+ "conflict_description": "具体的障碍是什么",
|
|
|
+ "disaster": "结果比预期更糟,具体是什么",
|
|
|
+ "pacing": "fast|medium|slow",
|
|
|
+ "dialogue_ratio": 0.6
|
|
|
+ }},
|
|
|
+ "sequel": {{
|
|
|
+ "reaction": "主角的情感反应",
|
|
|
+ "dilemma": "面临的两难选择",
|
|
|
+ "decision": "做出的决定(成为下一个 Scene 的 Goal)"
|
|
|
+ }},
|
|
|
+ "shuang_point": {{
|
|
|
+ "has_shuang": true,
|
|
|
+ "type": "打脸|升级|装逼|获得|碾压",
|
|
|
+ "setup": "铺垫内容",
|
|
|
+ "payoff": "爆发内容",
|
|
|
+ "reaction": "旁观者/对手的反应"
|
|
|
+ }},
|
|
|
+ "hooks": [
|
|
|
+ {{"type": "chapter_end", "content": "章末钩子的具体内容"}}
|
|
|
+ ],
|
|
|
+ "mice_advancement": "E",
|
|
|
+ "estimated_words": 1500
|
|
|
+}}
|
|
|
+```"""
|
|
|
+ return [
|
|
|
+ {"role": "system", "content": SYSTEM_STRUCTURE_PLANNING},
|
|
|
+ {"role": "user", "content": user_content},
|
|
|
+ ]
|
|
|
+
|
|
|
+
|
|
|
+def make_scene_continuation_prompt(
|
|
|
+ context: str,
|
|
|
+ structure_plan: str,
|
|
|
+ target_words: int = 1200,
|
|
|
+) -> list[dict]:
|
|
|
+ user_content = f"""## 上文
|
|
|
+{context}
|
|
|
+
|
|
|
+## 结构规划
|
|
|
+{structure_plan}
|
|
|
+
|
|
|
+## 任务
|
|
|
+请根据上文和结构规划,续写下一段正文(目标约 {target_words} 字)。
|
|
|
+
|
|
|
+**要求**:
|
|
|
+1. 在 <think> 中说明:
|
|
|
+ - 如何衔接上文(直接延续/场景切换/时间跳跃)
|
|
|
+ - 爽点在哪里植入,具体怎么写
|
|
|
+ - 钩子如何设置
|
|
|
+ - 对话设计(谁说什么,潜台词)
|
|
|
+ - 节奏控制(哪里快,哪里慢)
|
|
|
+2. 输出续写正文,风格与上文保持一致
|
|
|
+3. 正文中不要出现任何结构标注或括号说明"""
|
|
|
+ return [
|
|
|
+ {"role": "system", "content": SYSTEM_SCENE_CONTINUATION},
|
|
|
+ {"role": "user", "content": user_content},
|
|
|
+ ]
|
|
|
+
|
|
|
+
|
|
|
+def make_shuang_injection_prompt(
|
|
|
+ draft: str,
|
|
|
+ shuang_type: str = "智商碾压",
|
|
|
+ intensity: str = "high",
|
|
|
+) -> list[dict]:
|
|
|
+ user_content = f"""## 平淡草稿
|
|
|
+{draft}
|
|
|
+
|
|
|
+## 注入要求
|
|
|
+- 爽点类型:{shuang_type}
|
|
|
+- 强度:{intensity}(low=轻微惊讶 / medium=明显震惊 / high=三观崩塌)
|
|
|
+- 不改变核心情节走向,只增强情感冲击力
|
|
|
+
|
|
|
+## 任务
|
|
|
+请注入爽点,输出增强版本。
|
|
|
+
|
|
|
+**要求**:
|
|
|
+1. 在 <think> 中分析:
|
|
|
+ - 草稿的问题:哪里平淡,缺少什么
|
|
|
+ - 爽点设计:铺垫在哪里,爆发在哪里,反应怎么写
|
|
|
+ - 关键改动:具体改了哪些句子,为什么
|
|
|
+2. 输出增强版正文
|
|
|
+3. 在正文后附上简要修改说明(3-5条)"""
|
|
|
+ return [
|
|
|
+ {"role": "system", "content": SYSTEM_SHUANG_INJECTION},
|
|
|
+ {"role": "user", "content": user_content},
|
|
|
+ ]
|
|
|
+
|
|
|
+
|
|
|
+# ── 数据生成函数 ──────────────────────────────────────────────────────────────
|
|
|
+
|
|
|
+def generate_task1_sample(
|
|
|
+ unit: dict,
|
|
|
+ title: str,
|
|
|
+ chapter: str,
|
|
|
+ position_pct: float,
|
|
|
+ model: str,
|
|
|
+) -> dict | None:
|
|
|
+ """生成 Task 1(结构规划)训练样本。"""
|
|
|
+ messages = make_structure_planning_prompt(
|
|
|
+ title=title,
|
|
|
+ chapter=chapter,
|
|
|
+ position_pct=position_pct,
|
|
|
+ context=unit["context"],
|
|
|
+ )
|
|
|
+ try:
|
|
|
+ response = llm_call(messages, model=model, temperature=0.3, max_tokens=2048)
|
|
|
+ except Exception as e:
|
|
|
+ print(f" [Task1] LLM 调用失败: {e}")
|
|
|
+ return None
|
|
|
+
|
|
|
+ # 验证输出包含 <think> 和 JSON
|
|
|
+ if "<think>" not in response or "{" not in response:
|
|
|
+ print(f" [Task1] 输出格式不符,跳过")
|
|
|
+ return None
|
|
|
+
|
|
|
+ return {
|
|
|
+ "messages": messages + [{"role": "assistant", "content": response}],
|
|
|
+ "metadata": {
|
|
|
+ "task_type": "structure_planning",
|
|
|
+ "source_file": title,
|
|
|
+ "chapter": chapter,
|
|
|
+ "position_percent": round(position_pct, 3),
|
|
|
+ "context_words": unit["context_words"],
|
|
|
+ "model": model,
|
|
|
+ },
|
|
|
+ }
|
|
|
+
|
|
|
+
|
|
|
+def generate_task2_sample(
|
|
|
+ unit: dict,
|
|
|
+ title: str,
|
|
|
+ chapter: str,
|
|
|
+ position_pct: float,
|
|
|
+ model: str,
|
|
|
+ use_original_as_output: bool = True,
|
|
|
+) -> dict | None:
|
|
|
+ """
|
|
|
+ 生成 Task 2(场景续写)训练样本。
|
|
|
+
|
|
|
+ use_original_as_output=True:
|
|
|
+ 先用 LLM 生成结构规划,再让 LLM 解释"原著为什么这样写"(CoT),
|
|
|
+ 最终输出 = CoT + 原著续写文本(金标准)。
|
|
|
+
|
|
|
+ use_original_as_output=False:
|
|
|
+ 让 LLM 直接续写,输出 = CoT + LLM 生成文本。
|
|
|
+ """
|
|
|
+ # Step 1: 生成结构规划(用于构建 user prompt)
|
|
|
+ plan_messages = make_structure_planning_prompt(
|
|
|
+ title=title,
|
|
|
+ chapter=chapter,
|
|
|
+ position_pct=position_pct,
|
|
|
+ context=unit["context"],
|
|
|
+ )
|
|
|
+ try:
|
|
|
+ plan_response = llm_call(plan_messages, model=model, temperature=0.3, max_tokens=1500)
|
|
|
+ except Exception as e:
|
|
|
+ print(f" [Task2] 规划生成失败: {e}")
|
|
|
+ return None
|
|
|
+
|
|
|
+ if use_original_as_output:
|
|
|
+ # Step 2a: 让 LLM 解释原著的写法(逆向 CoT)
|
|
|
+ explain_messages = [
|
|
|
+ {"role": "system", "content": SYSTEM_SCENE_CONTINUATION},
|
|
|
+ {"role": "user", "content": f"""## 上文
|
|
|
+{unit["context"]}
|
|
|
+
|
|
|
+## 结构规划(已分析)
|
|
|
+{plan_response}
|
|
|
+
|
|
|
+## 原著续写
|
|
|
+{unit["continuation"]}
|
|
|
+
|
|
|
+## 任务
|
|
|
+请分析:原著作者在续写这段时,做了哪些写法决策?
|
|
|
+请用 <think> 标签写出你的分析,然后直接输出原著续写文本(不要修改)。
|
|
|
+
|
|
|
+格式:
|
|
|
+<think>
|
|
|
+[分析原著的写法决策:如何衔接、爽点设计、钩子设置、节奏控制等]
|
|
|
+</think>
|
|
|
+
|
|
|
+[原著续写文本,原文照抄]"""},
|
|
|
+ ]
|
|
|
+ try:
|
|
|
+ cot_response = llm_call(explain_messages, model=model, temperature=0.3, max_tokens=3000)
|
|
|
+ except Exception as e:
|
|
|
+ print(f" [Task2] CoT 生成失败: {e}")
|
|
|
+ return None
|
|
|
+
|
|
|
+ # 确保输出包含原著文本(简单验证:取原著前50字检查)
|
|
|
+ original_snippet = unit["continuation"][:50].strip()
|
|
|
+ if original_snippet not in cot_response and len(cot_response) < 200:
|
|
|
+ print(f" [Task2] 输出未包含原著文本,跳过")
|
|
|
+ return None
|
|
|
+
|
|
|
+ assistant_content = cot_response
|
|
|
+ else:
|
|
|
+ # Step 2b: 直接续写
|
|
|
+ cont_messages = make_scene_continuation_prompt(
|
|
|
+ context=unit["context"],
|
|
|
+ structure_plan=plan_response,
|
|
|
+ target_words=unit["continuation_words"],
|
|
|
+ )
|
|
|
+ try:
|
|
|
+ assistant_content = llm_call(cont_messages, model=model, temperature=0.7, max_tokens=3000)
|
|
|
+ except Exception as e:
|
|
|
+ print(f" [Task2] 续写生成失败: {e}")
|
|
|
+ return None
|
|
|
+
|
|
|
+ # 构建最终的 user prompt(包含结构规划)
|
|
|
+ final_user_content = f"""## 上文
|
|
|
+{unit["context"]}
|
|
|
+
|
|
|
+## 结构规划
|
|
|
+{plan_response}
|
|
|
+
|
|
|
+## 任务
|
|
|
+请根据上文和结构规划,续写下一段正文(目标约 {unit["continuation_words"]} 字)。
|
|
|
+在 <think> 中说明写法决策,然后输出续写正文。"""
|
|
|
+
|
|
|
+ return {
|
|
|
+ "messages": [
|
|
|
+ {"role": "system", "content": SYSTEM_SCENE_CONTINUATION},
|
|
|
+ {"role": "user", "content": final_user_content},
|
|
|
+ {"role": "assistant", "content": assistant_content},
|
|
|
+ ],
|
|
|
+ "metadata": {
|
|
|
+ "task_type": "scene_continuation",
|
|
|
+ "source_file": title,
|
|
|
+ "chapter": chapter,
|
|
|
+ "position_percent": round(position_pct, 3),
|
|
|
+ "context_words": unit["context_words"],
|
|
|
+ "continuation_words": unit["continuation_words"],
|
|
|
+ "use_original": use_original_as_output,
|
|
|
+ "model": model,
|
|
|
+ },
|
|
|
+ }
|
|
|
+
|
|
|
+
|
|
|
+def generate_task3_sample(
|
|
|
+ unit: dict,
|
|
|
+ title: str,
|
|
|
+ chapter: str,
|
|
|
+ model: str,
|
|
|
+) -> dict | None:
|
|
|
+ """
|
|
|
+ 生成 Task 3(爽点注入)训练样本。
|
|
|
+ 策略:先让 LLM 生成"平淡版"(去掉爽点),再注入爽点,对比原著。
|
|
|
+ """
|
|
|
+ # Step 1: 让 LLM 生成平淡版(去掉爽点)
|
|
|
+ flatten_messages = [
|
|
|
+ {"role": "system", "content": "你是一位文字编辑,擅长识别和移除文本中的爽点元素。"},
|
|
|
+ {"role": "user", "content": f"""请将以下文本改写成"平淡版":
|
|
|
+- 去掉所有让读者感到爽快的元素(打脸、碾压、震惊反应等)
|
|
|
+- 保留核心情节和信息
|
|
|
+- 改写后应该是一个"能用但不精彩"的版本
|
|
|
+- 字数可以减少,但不要少于原文的 60%
|
|
|
+
|
|
|
+## 原文
|
|
|
+{unit["continuation"]}
|
|
|
+
|
|
|
+## 要求
|
|
|
+直接输出平淡版文本,不要任何解释。"""},
|
|
|
+ ]
|
|
|
+ try:
|
|
|
+ draft = llm_call(flatten_messages, model=model, temperature=0.3, max_tokens=2000)
|
|
|
+ except Exception as e:
|
|
|
+ print(f" [Task3] 平淡版生成失败: {e}")
|
|
|
+ return None
|
|
|
+
|
|
|
+ if len(draft) < 100:
|
|
|
+ return None
|
|
|
+
|
|
|
+ # Step 2: 注入爽点(目标是还原接近原著的版本)
|
|
|
+ inject_messages = make_shuang_injection_prompt(
|
|
|
+ draft=draft,
|
|
|
+ shuang_type="智商碾压", # 可以根据内容动态判断
|
|
|
+ intensity="high",
|
|
|
+ )
|
|
|
+ try:
|
|
|
+ enhanced = llm_call(inject_messages, model=model, temperature=0.7, max_tokens=3000)
|
|
|
+ except Exception as e:
|
|
|
+ print(f" [Task3] 爽点注入失败: {e}")
|
|
|
+ return None
|
|
|
+
|
|
|
+ if "<think>" not in enhanced:
|
|
|
+ return None
|
|
|
+
|
|
|
+ return {
|
|
|
+ "messages": inject_messages + [{"role": "assistant", "content": enhanced}],
|
|
|
+ "metadata": {
|
|
|
+ "task_type": "shuang_injection",
|
|
|
+ "source_file": title,
|
|
|
+ "chapter": chapter,
|
|
|
+ "original_text": unit["continuation"], # 保存原著用于对比
|
|
|
+ "draft_text": draft,
|
|
|
+ "model": model,
|
|
|
+ },
|
|
|
+ }
|
|
|
+
|
|
|
+
|
|
|
+# ── 主流程 ────────────────────────────────────────────────────────────────────
|
|
|
+
|
|
|
+def process_file(
|
|
|
+ input_path: Path,
|
|
|
+ tasks: list[str],
|
|
|
+ model: str,
|
|
|
+ max_samples: int,
|
|
|
+ output_dir: Path,
|
|
|
+) -> dict[str, list]:
|
|
|
+ """处理单个小说/剧本文件,生成训练数据。"""
|
|
|
+ print(f"\n{'='*60}")
|
|
|
+ print(f"处理文件: {input_path.name}")
|
|
|
+ print(f"{'='*60}")
|
|
|
+
|
|
|
+ title = input_path.stem
|
|
|
+ text = load_novel(input_path)
|
|
|
+ chapters = split_chapters(text)
|
|
|
+ print(f" 检测到 {len(chapters)} 个章节,总字数约 {len(text):,}")
|
|
|
+
|
|
|
+ results: dict[str, list] = {t: [] for t in tasks}
|
|
|
+ sample_count = {t: 0 for t in tasks}
|
|
|
+ total_chapters = len(chapters)
|
|
|
+
|
|
|
+ for ch_idx, chapter in enumerate(chapters):
|
|
|
+ if all(sample_count[t] >= max_samples for t in tasks):
|
|
|
+ break
|
|
|
+
|
|
|
+ position_pct = ch_idx / max(total_chapters - 1, 1)
|
|
|
+ units = extract_scene_units(chapter["content"])
|
|
|
+
|
|
|
+ print(f"\n 章节 [{ch_idx+1}/{total_chapters}] {chapter['chapter'][:30]} "
|
|
|
+ f"({len(units)} 个场景单元)")
|
|
|
+
|
|
|
+ for unit_idx, unit in enumerate(units):
|
|
|
+ if all(sample_count[t] >= max_samples for t in tasks):
|
|
|
+ break
|
|
|
+
|
|
|
+ print(f" 单元 {unit_idx+1}: 上文 {unit['context_words']}字 "
|
|
|
+ f"/ 续写 {unit['continuation_words']}字")
|
|
|
+
|
|
|
+ if "task1" in tasks and sample_count["task1"] < max_samples:
|
|
|
+ print(f" → Task1 结构规划...", end="", flush=True)
|
|
|
+ sample = generate_task1_sample(
|
|
|
+ unit, title, chapter["chapter"], position_pct, model
|
|
|
+ )
|
|
|
+ if sample:
|
|
|
+ results["task1"].append(sample)
|
|
|
+ sample_count["task1"] += 1
|
|
|
+ print(f" ✓ (共 {sample_count['task1']})")
|
|
|
+ else:
|
|
|
+ print(f" ✗")
|
|
|
+
|
|
|
+ if "task2" in tasks and sample_count["task2"] < max_samples:
|
|
|
+ print(f" → Task2 场景续写...", end="", flush=True)
|
|
|
+ sample = generate_task2_sample(
|
|
|
+ unit, title, chapter["chapter"], position_pct, model,
|
|
|
+ use_original_as_output=True,
|
|
|
+ )
|
|
|
+ if sample:
|
|
|
+ results["task2"].append(sample)
|
|
|
+ sample_count["task2"] += 1
|
|
|
+ print(f" ✓ (共 {sample_count['task2']})")
|
|
|
+ else:
|
|
|
+ print(f" ✗")
|
|
|
+
|
|
|
+ if "task3" in tasks and sample_count["task3"] < max_samples:
|
|
|
+ print(f" → Task3 爽点注入...", end="", flush=True)
|
|
|
+ sample = generate_task3_sample(
|
|
|
+ unit, title, chapter["chapter"], model
|
|
|
+ )
|
|
|
+ if sample:
|
|
|
+ results["task3"].append(sample)
|
|
|
+ sample_count["task3"] += 1
|
|
|
+ print(f" ✓ (共 {sample_count['task3']})")
|
|
|
+ else:
|
|
|
+ print(f" ✗")
|
|
|
+
|
|
|
+ # 避免 API 限速
|
|
|
+ time.sleep(0.5)
|
|
|
+
|
|
|
+ return results
|
|
|
+
|
|
|
+
|
|
|
+def save_results(results: dict[str, list], output_dir: Path, prefix: str = ""):
|
|
|
+ """将结果保存为 JSONL 文件。"""
|
|
|
+ saved = {}
|
|
|
+ for task, samples in results.items():
|
|
|
+ if not samples:
|
|
|
+ continue
|
|
|
+ fname = f"{prefix}_{task}.jsonl" if prefix else f"{task}.jsonl"
|
|
|
+ out_path = output_dir / fname
|
|
|
+ with open(out_path, "a", encoding="utf-8") as f:
|
|
|
+ for s in samples:
|
|
|
+ f.write(json.dumps(s, ensure_ascii=False) + "\n")
|
|
|
+ saved[task] = len(samples)
|
|
|
+ print(f" ✓ {task}: {len(samples)} 条 → {out_path}")
|
|
|
+ return saved
|
|
|
+
|
|
|
+
|
|
|
+def main():
|
|
|
+ parser = argparse.ArgumentParser(description="长篇叙事 SFT 数据集构建工具")
|
|
|
+ parser.add_argument(
|
|
|
+ "--input", "-i", required=True,
|
|
|
+ help="输入文件或目录(支持 .txt / .pdf / .docx)",
|
|
|
+ )
|
|
|
+ parser.add_argument(
|
|
|
+ "--tasks", "-t", default="all",
|
|
|
+ help="要生成的任务(all / task1 / task2 / task3 / task1,task2)",
|
|
|
+ )
|
|
|
+ parser.add_argument(
|
|
|
+ "--model", "-m", default="google/gemini-2.0-flash-001",
|
|
|
+ help="使用的模型(默认 google/gemini-2.0-flash-001)",
|
|
|
+ )
|
|
|
+ parser.add_argument(
|
|
|
+ "--max-samples", "-n", type=int, default=10,
|
|
|
+ help="每个任务每个文件最多生成多少条(默认 10)",
|
|
|
+ )
|
|
|
+ parser.add_argument(
|
|
|
+ "--output", "-o", default=str(OUTPUT_DIR),
|
|
|
+ help="输出目录(默认 sft_v2/output/)",
|
|
|
+ )
|
|
|
+ args = parser.parse_args()
|
|
|
+
|
|
|
+ # 解析任务列表
|
|
|
+ if args.tasks == "all":
|
|
|
+ tasks = ["task1", "task2", "task3"]
|
|
|
+ else:
|
|
|
+ tasks = [t.strip() for t in args.tasks.split(",")]
|
|
|
+
|
|
|
+ output_dir = Path(args.output)
|
|
|
+ output_dir.mkdir(parents=True, exist_ok=True)
|
|
|
+
|
|
|
+ # 收集输入文件
|
|
|
+ input_path = Path(args.input)
|
|
|
+ if input_path.is_dir():
|
|
|
+ files = list(input_path.glob("*.txt")) + list(input_path.glob("*.pdf"))
|
|
|
+ elif input_path.is_file():
|
|
|
+ files = [input_path]
|
|
|
+ else:
|
|
|
+ print(f"错误:找不到输入文件或目录: {input_path}")
|
|
|
+ sys.exit(1)
|
|
|
+
|
|
|
+ # 过滤支持的格式(当前只支持 txt)
|
|
|
+ supported = [f for f in files if f.suffix.lower() == ".txt"]
|
|
|
+ if not supported:
|
|
|
+ print(f"错误:没有找到支持的文件(.txt)")
|
|
|
+ sys.exit(1)
|
|
|
+
|
|
|
+ print(f"\n{'='*60}")
|
|
|
+ print(f"长篇叙事 SFT 数据集构建工具")
|
|
|
+ print(f"{'='*60}")
|
|
|
+ print(f" 输入文件: {len(supported)} 个")
|
|
|
+ print(f" 任务类型: {tasks}")
|
|
|
+ print(f" 模型: {args.model}")
|
|
|
+ print(f" 每任务最大样本数: {args.max_samples}")
|
|
|
+ print(f" 输出目录: {output_dir}")
|
|
|
+ print(f"{'='*60}")
|
|
|
+
|
|
|
+ total_saved = {t: 0 for t in tasks}
|
|
|
+
|
|
|
+ for file_path in supported:
|
|
|
+ results = process_file(
|
|
|
+ input_path=file_path,
|
|
|
+ tasks=tasks,
|
|
|
+ model=args.model,
|
|
|
+ max_samples=args.max_samples,
|
|
|
+ output_dir=output_dir,
|
|
|
+ )
|
|
|
+ print(f"\n保存结果...")
|
|
|
+ saved = save_results(results, output_dir, prefix=file_path.stem)
|
|
|
+ for t, n in saved.items():
|
|
|
+ total_saved[t] += n
|
|
|
+
|
|
|
+ print(f"\n{'='*60}")
|
|
|
+ print(f"完成!总计生成:")
|
|
|
+ for t, n in total_saved.items():
|
|
|
+ print(f" {t}: {n} 条")
|
|
|
+ print(f"{'='*60}")
|
|
|
+
|
|
|
+
|
|
|
+if __name__ == "__main__":
|
|
|
+ main()
|