step2_build_sft.py 36 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048
  1. #!/usr/bin/env python3
  2. """
  3. 步骤2:从分析结果生成三类 SFT 训练数据
  4. 三个任务(参考 00_task_definition.md):
  5. Task 1 - 结构规划(Structure Planning)
  6. 输入:故事状态(MICE线程、last disaster/decision、位置)+ 上文
  7. 输出:<think>叙事状态分析 + 续写决策</think> + 结构规划 JSON
  8. 目标:让模型学会规划下一个 Scene-Sequel 单元的结构
  9. Task 2 - 场景续写(Scene Continuation)
  10. 输入:上文 + 结构规划(Task 1 的输出)
  11. 输出:<think>上文理解 + 写法决策</think> + 续写正文
  12. 目标:让模型学会根据规划生成高质量正文
  13. Task 3 - 爽点注入(Shuang Point Injection)
  14. 输入:平淡草稿 + 爽点类型 + 强度要求
  15. 输出:<think>草稿分析 + 爽点设计</think> + 增强版正文 + 修改说明
  16. 目标:让模型学会识别并注入爽点
  17. 用法:
  18. python step2_build_sft.py \\
  19. --analysis analysis_w0.json \\
  20. --novel input/大奉打更人.txt \\
  21. --output-dir sft/dafeng/ \\
  22. [--context-chars 800] \\
  23. [--skip-task 3] \\
  24. [--concurrency 5] \\
  25. [--model qwen-plus]
  26. 输出文件:
  27. sft/dafeng/task1_structure_planning.jsonl
  28. sft/dafeng/task2_scene_continuation.jsonl
  29. sft/dafeng/task3_shuang_injection.jsonl
  30. sft/dafeng/stats.json
  31. """
  32. import os
  33. import re
  34. import json
  35. import asyncio
  36. import argparse
  37. from copy import deepcopy
  38. from pathlib import Path
  39. from openai import AsyncOpenAI, BadRequestError, RateLimitError, APIError
  40. from typing import Optional, List, Set
  41. from dotenv import load_dotenv
  42. load_dotenv()
  43. load_dotenv(Path(__file__).parent.parent / ".env") # 项目根目录 .env
  44. client = AsyncOpenAI(
  45. api_key=os.getenv("ALI_API_KEY"),
  46. base_url=os.getenv(
  47. "ALI_BASE_URL", "https://dashscope.aliyuncs.com/compatible-mode/v1"
  48. ),
  49. )
  50. # ──────────────────────────────────────────────────────────────
  51. # 基础工具
  52. # ──────────────────────────────────────────────────────────────
  53. class ContentFilterError(Exception):
  54. """内容审查不通过,跳过该条样本,不重试"""
  55. def load_text(path: str) -> str:
  56. for enc in ["utf-8", "gbk", "gb2312", "gb18030"]:
  57. try:
  58. return Path(path).read_text(encoding=enc)
  59. except UnicodeDecodeError:
  60. continue
  61. raise ValueError(f"无法解码文件: {path}")
  62. async def llm_call(
  63. messages: list,
  64. model: str,
  65. temperature: float = 0.6,
  66. max_tokens: int = 4096,
  67. max_retries: int = 3,
  68. ) -> str:
  69. delay = 5.0
  70. for attempt in range(1, max_retries + 2): # +1 for the final attempt
  71. try:
  72. resp = await client.chat.completions.create(
  73. model=model,
  74. messages=messages,
  75. temperature=temperature,
  76. max_tokens=max_tokens,
  77. )
  78. return resp.choices[0].message.content
  79. except BadRequestError as e:
  80. err_code = getattr(e, "code", "") or ""
  81. # 阿里云内容审查:data_inspection_failed / content_filter 等
  82. if "data_inspection_failed" in str(e) or "content_filter" in err_code:
  83. raise ContentFilterError(f"内容审查不通过: {e}") from e
  84. raise # 其他 400 错误直接抛出
  85. except (RateLimitError, APIError) as e:
  86. if attempt > max_retries:
  87. raise
  88. print(f" [重试 {attempt}/{max_retries}] {type(e).__name__}: {e},{delay:.0f}s 后重试...")
  89. await asyncio.sleep(delay)
  90. delay = min(delay * 2, 60)
  91. def extract_json_block(text: str) -> dict:
  92. m = re.search(r"```json\s*(.*?)\s*```", text, re.DOTALL)
  93. json_str = m.group(1) if m else text.strip()
  94. try:
  95. return json.loads(json_str)
  96. except json.JSONDecodeError:
  97. json_str = re.sub(r",\s*([}\]])", r"\1", json_str)
  98. return json.loads(json_str)
  99. def write_jsonl(samples: List[dict], path: Path) -> None:
  100. path.parent.mkdir(parents=True, exist_ok=True)
  101. with open(path, "w", encoding="utf-8") as f:
  102. for s in samples:
  103. # 去掉内部 _* 字段再写入
  104. out = {k: v for k, v in s.items() if not k.startswith("_")}
  105. f.write(json.dumps(out, ensure_ascii=False) + "\n")
  106. print(f" 写入 {len(samples)} 条 → {path}")
  107. # ──────────────────────────────────────────────────────────────
  108. # 故事状态累积
  109. # ──────────────────────────────────────────────────────────────
  110. def build_state_snapshot(analysis: dict, beat_index: int) -> dict:
  111. """
  112. 返回 beat_index 之前的故事状态快照。
  113. 额外字段(比单纯状态更丰富):
  114. - plot_line_events: {线索名 -> [事件描述列表]}
  115. - recent_beats: 最近 5 个 beat 的简要记录
  116. """
  117. state = {
  118. "plot_lines": deepcopy(analysis.get("outline", {}).get("plot_lines", [])),
  119. "characters": deepcopy(analysis.get("characters", [])),
  120. "plot_line_events": {}, # name -> [str]
  121. "recent_beats": [],
  122. }
  123. for b in analysis.get("beats", [])[:beat_index]:
  124. changes = b.get("state_changes", {})
  125. # 更新线索状态 + 记录事件历史
  126. for pl in changes.get("plot_lines", []):
  127. matched = False
  128. for line in state["plot_lines"]:
  129. if line["name"] == pl["name"]:
  130. line["status"] = pl["new_state"]
  131. matched = True
  132. break
  133. if not matched:
  134. state["plot_lines"].append(
  135. {"name": pl["name"], "status": pl["new_state"],
  136. "mice_type": "?", "description": pl.get("new_state", "")}
  137. )
  138. event = f"{pl.get('old_state', '?')} → {pl['new_state']}"
  139. state["plot_line_events"].setdefault(pl["name"], []).append(event)
  140. # 更新人物近期变化
  141. for ch in changes.get("characters", []):
  142. for char in state["characters"]:
  143. if char["name"] == ch["name"]:
  144. char.setdefault("recent_changes", []).append(ch["change"])
  145. char["recent_changes"] = char["recent_changes"][-3:]
  146. break
  147. # 记录近期节拍(保留最近 5 个)
  148. state["recent_beats"].append({
  149. "id": b.get("id", ""),
  150. "type": b["type"],
  151. "summary": b.get("summary", ""),
  152. "outcome": b.get("disaster", "") if b["type"] == "scene" else b.get("decision", ""),
  153. })
  154. state["recent_beats"] = state["recent_beats"][-5:]
  155. return state
  156. def get_last_disaster_decision(beats: List[dict], before_index: int) -> tuple:
  157. """返回 beat_index 之前最后一个 scene 的 disaster 和 最后一个 sequel 的 decision"""
  158. last_disaster = "无(故事开局)"
  159. last_decision = "无(故事开局)"
  160. for b in beats[:before_index]:
  161. if b["type"] == "scene":
  162. last_disaster = b.get("disaster", "")
  163. elif b["type"] == "sequel":
  164. last_decision = b.get("decision", "")
  165. return last_disaster, last_decision
  166. def format_story_notes(
  167. analysis: dict,
  168. state: dict,
  169. last_disaster: str,
  170. last_decision: str,
  171. ) -> str:
  172. """
  173. 生成故事笔记(约 2000-4000 字符)。
  174. 包含 core_question/next_steps(线索)、speaking_style/current_state(人物)、writing_insights(窗口级)。
  175. """
  176. parts = []
  177. # 1. 主线摘要
  178. main_plot = analysis.get("outline", {}).get("main_plot", "")
  179. if main_plot:
  180. parts.append(f"**主线**:{main_plot}")
  181. # 2. 活跃剧情线索(含 core_question, next_steps, 历史事件)
  182. active = [pl for pl in state["plot_lines"]
  183. if pl.get("status") not in ["已解决", "已关闭"]]
  184. resolved = [pl for pl in state["plot_lines"]
  185. if pl.get("status") in ["已解决", "已关闭"]]
  186. if active:
  187. lines = ["**活跃线索**:"]
  188. for pl in active:
  189. mice = pl.get("mice_type", "?")
  190. events = state.get("plot_line_events", {}).get(pl["name"], [])
  191. ev_str = f"(进展:{';'.join(events[-3:])})" if events else ""
  192. cq = pl.get("core_question", "")
  193. ns = pl.get("next_steps", "")
  194. extra = ""
  195. if cq:
  196. extra += f" 核心问:{cq}"
  197. if ns:
  198. extra += f" 待推进:{ns}"
  199. lines.append(
  200. f"- [{mice}] {pl['name']}({pl['status']}):"
  201. f"{pl.get('description', '')}{ev_str}{extra}"
  202. )
  203. if resolved:
  204. lines.append(f"- 已结:{'、'.join(p['name'] for p in resolved)}")
  205. parts.append("\n".join(lines))
  206. # 3. 人物状态(含 speaking_style, current_state, 性格, 关系, 近期变化)
  207. if state["characters"]:
  208. lines = ["**人物**:"]
  209. for c in state["characters"]:
  210. segs = [f"{c['name']}({c.get('role', '?')})目标:{c.get('goal', '')}"]
  211. traits = c.get("traits", [])
  212. if traits:
  213. segs.append(f"性格:{'、'.join(traits)}")
  214. style = c.get("speaking_style", [])
  215. if style:
  216. style_str = ",".join(style) if isinstance(style, list) else str(style)
  217. segs.append(f"说话风格:{style_str}")
  218. cur_state = c.get("current_state", "")
  219. if cur_state:
  220. segs.append(f"当前处境:{cur_state}")
  221. rels = c.get("relationships", {})
  222. if rels:
  223. rel_items = [f"{k}→{v}" for k, v in list(rels.items())[:4]]
  224. segs.append(f"关系:{';'.join(rel_items)}")
  225. recent = c.get("recent_changes", [])
  226. if recent:
  227. segs.append(f"近期:{';'.join(recent)}")
  228. lines.append("- " + "。".join(segs))
  229. parts.append("\n".join(lines))
  230. # 4. 近期节拍
  231. recent_beats = state.get("recent_beats", [])
  232. if recent_beats:
  233. lines = ["**近期节拍**:"]
  234. for b in recent_beats:
  235. tag = "场景" if b["type"] == "scene" else "后续"
  236. outcome_label = "结局" if b["type"] == "scene" else "决定"
  237. outcome = f" → {outcome_label}:{b['outcome']}" if b.get("outcome") else ""
  238. lines.append(f"- [{b['id']}·{tag}] {b['summary']}{outcome}")
  239. parts.append("\n".join(lines))
  240. # 5. 写作亮点(窗口级,来自 step1 提取的 writing_insights)
  241. wi = analysis.get("writing_insights", {})
  242. if wi:
  243. wi_lines = []
  244. for item in wi.get("techniques", []):
  245. wi_lines.append(f"- 技巧:{item}")
  246. for item in wi.get("shuang_designs", []):
  247. wi_lines.append(f"- 爽点设计:{item}")
  248. for item in wi.get("pacing", []):
  249. wi_lines.append(f"- 节奏:{item}")
  250. if wi_lines:
  251. parts.append("**写作亮点**:\n" + "\n".join(wi_lines))
  252. # 6. 悬而未决
  253. parts.append(
  254. f"**待解决**:上一场景结局:{last_disaster};上一个决定:{last_decision}"
  255. )
  256. return "\n\n".join(parts)
  257. def calc_position_percent(beat: dict, total_chars: int) -> float:
  258. return round(beat.get("position_start", 0) / max(total_chars, 1) * 100, 1)
  259. # ──────────────────────────────────────────────────────────────
  260. # Task 1:结构规划(Structure Planning)
  261. # ──────────────────────────────────────────────────────────────
  262. TASK1_SYSTEM = """\
  263. 你是资深网文作者,擅长基于故事笔记规划场景。
  264. ## 核心能力
  265. 1. **分析笔记**:理解当前故事状态、活跃线索、人物动态
  266. 2. **规划场景**:基于笔记设计下一个场景的结构
  267. 3. **更新笔记**:记录场景对故事状态的改变
  268. ## 工作流程
  269. 1. 仔细阅读故事笔记(当前状态、活跃线索、待办事项)
  270. 2. 在 `<think>` 中展示你的思考过程(800-1500字)
  271. 3. 输出场景规划(JSON 格式)
  272. 4. 输出笔记更新(Markdown 格式)
  273. ---
  274. ## Think 要求
  275. 在 `<think>` 标签中,展示你真实的创作思维过程。**不要求固定格式**,但需要包含以下核心要素:
  276. 必须包含的要素:
  277. 1. **笔记分析**:当前故事进行到哪里?哪些线索在推进?主要角色的目标、冲突、关系状态;笔记中标记的待推进事项和风险点
  278. 2. **方案推演**:至少考虑 2-3 种不同的场景设计方案;对比各方案的优缺点;说明为什么选择某个方案
  279. 3. **笔记更新计划**:这个场景会推进哪些线索?哪些人物状态会变化?需要新增或完成哪些待推进事项?
  280. 鼓励的思维方式:
  281. - **跳跃联想**:从笔记的某个细节突然想到类似案例
  282. - **自我质疑**:推翻之前的想法,重新思考
  283. - **细节推敲**:对某个对话、动作、道具的反复打磨
  284. - **灵感闪现**:突然意识到某个巧妙的设计
  285. - **风险预警**:发现可能的逻辑漏洞或人设崩塌
  286. 不要求固定章节标题(如【笔记分析】【方案推演】),不需要按固定顺序展开,可以有口语化、跳跃、修正。
  287. ---
  288. ## 输出格式
  289. ### 1. 场景规划(JSON)
  290. ```json
  291. {
  292. "scene_type": "scene | sequel",
  293. "goal": "角色目标",
  294. "conflict_type": "冲突类型",
  295. "conflict_description": "...",
  296. "disaster": "场景结尾的灾难/转折(scene 类型必填)",
  297. "sequel": {"reaction": "...", "dilemma": "...", "decision": "..."},
  298. "pacing": "fast|medium|slow",
  299. "dialogue_ratio": 0.4,
  300. "shuang_point": {
  301. "has_shuang": true,
  302. "type": "打脸|升级|装逼|获得|碾压",
  303. "mechanism": "实现机制"
  304. },
  305. "hooks": ["悬念1", "悬念2"],
  306. "mice_threads": {
  307. "推进": ["线索名"],
  308. "开启": ["新线索名"],
  309. "解决": ["已完成线索名"]
  310. },
  311. "estimated_words": 2000
  312. }
  313. ```
  314. ### 2. 笔记更新(Markdown)
  315. ```markdown
  316. ## 笔记更新
  317. ### 剧情线索变化
  318. - [线索名]:[旧状态] → [新状态]
  319. - [新线索]:开启([简短描述])
  320. ### 人物状态变化
  321. - [角色名]:[变化描述]
  322. ### 待推进更新
  323. - [✓] [已完成事项]
  324. - [ ] [新增事项](紧急/重要)
  325. ### 新增写作亮点(可选)
  326. - [技巧/桥段]:[描述]
  327. ```
  328. """
  329. TASK1_USER_TMPL = """\
  330. ## 故事笔记
  331. - 书名:{title}
  332. - 当前位置:第 {chapter} 章,约 {position_pct}% 处
  333. {story_notes}
  334. ---
  335. ## 上文(最近 {context_chars} 字)
  336. {context_text}
  337. ## 任务
  338. 请基于故事笔记和上文,完成以下任务:
  339. 1. 分析当前故事状态(在 `<think>` 中展示你的思考过程)
  340. 2. 规划下一个场景的结构(JSON 格式)
  341. 3. 输出笔记更新(Markdown 格式)"""
  342. TASK1_COT_GEN_TMPL = """\
  343. ## 故事笔记
  344. - 书名:{title}
  345. - 当前位置:第 {chapter} 章,约 {position_pct}% 处
  346. {story_notes}
  347. ---
  348. ## 上文(最近 {context_chars} 字)
  349. {context_text}
  350. ## 参考信息(该节拍的实际内容摘要,仅用于帮你构建 CoT,禁止直接引用)
  351. 类型:{beat_type}
  352. 摘要:{beat_summary}
  353. 核心要素:{beat_core}
  354. 爽点:{shuang_info}
  355. ---
  356. 请以"事前规划"的视角展示你真实的创作思维过程(分析笔记状态、推演至少 2-3 个方案并对比优缺点、规划笔记更新),然后输出规划 JSON 和笔记更新。
  357. <think>
  358. [自由思考过程]
  359. </think>
  360. ```json
  361. {{
  362. "scene_type": "scene | sequel",
  363. "goal": "...",
  364. "conflict_type": "人物冲突|环境冲突|内心冲突|信息冲突",
  365. "conflict_description": "...",
  366. "disaster": "...",
  367. "sequel": {{"reaction": "...", "dilemma": "...", "decision": "..."}},
  368. "pacing": "fast|medium|slow",
  369. "dialogue_ratio": 0.4,
  370. "shuang_point": {{
  371. "has_shuang": true,
  372. "type": "打脸|升级|装逼|获得|碾压",
  373. "mechanism": "..."
  374. }},
  375. "hooks": [
  376. {{"type": "chapter_end|mid_chapter", "content": "..."}}
  377. ],
  378. "mice_threads": {{
  379. "推进": ["线索名"],
  380. "开启": ["新线索名"],
  381. "解决": ["已完成线索名"]
  382. }},
  383. "estimated_words": 2000
  384. }}
  385. ```
  386. ```markdown
  387. ## 笔记更新
  388. ### 剧情线索变化
  389. - [线索名]:[旧状态] → [新状态]
  390. ### 人物状态变化
  391. - [角色名]:[变化描述]
  392. ### 待推进更新
  393. - [✓] [已完成]
  394. - [ ] [新增](紧急/重要)
  395. ### 新增写作亮点(可选)
  396. - [技巧]:[描述]
  397. ```"""
  398. def _beat_core_str(beat: dict) -> str:
  399. if beat["type"] == "scene":
  400. return (
  401. f"goal={beat.get('goal', '')} "
  402. f"conflict={beat.get('conflict_description', '')} "
  403. f"disaster={beat.get('disaster', '')}"
  404. )
  405. return (
  406. f"reaction={beat.get('reaction', '')} "
  407. f"dilemma={beat.get('dilemma', '')} "
  408. f"decision={beat.get('decision', '')}"
  409. )
  410. def _shuang_str(beat: dict) -> str:
  411. sp = beat.get("shuang_point", {})
  412. if not sp.get("has_shuang"):
  413. return "无"
  414. return f"{sp.get('type', '')}({sp.get('intensity', '')}):{sp.get('description', '')}"
  415. async def gen_task1_sample(
  416. i: int,
  417. beat: dict,
  418. analysis: dict,
  419. novel_text: str,
  420. context_chars: int,
  421. model: str,
  422. sem: asyncio.Semaphore,
  423. ) -> Optional[dict]:
  424. async with sem:
  425. meta = analysis.get("_meta", {})
  426. title = meta.get("novel_title", "未知")
  427. total_chars = meta.get("total_chars", len(novel_text))
  428. beats = analysis.get("beats", [])
  429. state = build_state_snapshot(analysis, i)
  430. last_disaster, last_decision = get_last_disaster_decision(beats, i)
  431. chapter = beat.get("chapter_start", "?")
  432. position_pct = calc_position_percent(beat, total_chars)
  433. ctx_start = max(0, beat["position_start"] - context_chars)
  434. context_text = novel_text[ctx_start: beat["position_start"]].strip()
  435. story_notes = format_story_notes(analysis, state, last_disaster, last_decision)
  436. shared_kwargs = dict(
  437. title=title,
  438. chapter=chapter,
  439. position_pct=position_pct,
  440. story_notes=story_notes,
  441. context_chars=context_chars,
  442. context_text=context_text,
  443. )
  444. # 生成 CoT + 规划 JSON
  445. cot_prompt = TASK1_COT_GEN_TMPL.format(
  446. beat_type=beat["type"],
  447. beat_summary=beat.get("summary", ""),
  448. beat_core=_beat_core_str(beat),
  449. shuang_info=_shuang_str(beat),
  450. **shared_kwargs,
  451. )
  452. messages = [
  453. {"role": "system", "content": TASK1_SYSTEM},
  454. {"role": "user", "content": cot_prompt},
  455. ]
  456. try:
  457. assistant_content = await llm_call(messages, model=model)
  458. except ContentFilterError as e:
  459. print(f" [Task1] beat {i+1} 内容审查拦截,跳过:{e}")
  460. return None
  461. except Exception as e:
  462. print(f" [Task1] beat {i+1} LLM 调用失败:{e}")
  463. return None
  464. # 训练样本:用户只看到 story_state + context,不知道 beat 实际内容
  465. user_content = TASK1_USER_TMPL.format(**shared_kwargs)
  466. return {
  467. "messages": [
  468. {"role": "system", "content": TASK1_SYSTEM},
  469. {"role": "user", "content": user_content},
  470. {"role": "assistant", "content": assistant_content},
  471. ],
  472. "metadata": {
  473. "task_type": "structure_planning",
  474. "source_file": meta.get("novel_title", ""),
  475. "chapter": f"第{chapter}章",
  476. "position_percent": position_pct,
  477. "mice_thread": beat.get("mice_thread", ""),
  478. "beat_id": beat.get("id", ""),
  479. "beat_type": beat["type"],
  480. "word_count": beat["position_end"] - beat["position_start"],
  481. },
  482. }
  483. # ──────────────────────────────────────────────────────────────
  484. # Task 2:场景续写(Scene Continuation)
  485. # ──────────────────────────────────────────────────────────────
  486. TASK2_SYSTEM = (
  487. "你是一位专业的网文作家,擅长写爽文、悬疑和情感类长篇小说,"
  488. "能够根据结构规划生成节奏流畅、爽点鲜明的正文。"
  489. )
  490. TASK2_USER_TMPL = """\
  491. ## 故事笔记(概要)
  492. - 书名:{title},当前位置约 {position_pct}% 处
  493. {story_notes_brief}
  494. ---
  495. ## 上文
  496. {context_text}
  497. ## 结构规划
  498. {structure_plan}
  499. ## 任务
  500. 请续写下一段(约 {target_words} 字),风格与上文保持一致。"""
  501. TASK2_COT_GEN_TMPL = """\
  502. ## 故事笔记(概要)
  503. - 书名:{title},当前位置约 {position_pct}% 处
  504. {story_notes_brief}
  505. ---
  506. ## 上文
  507. {context_text}
  508. ## 结构规划
  509. {structure_plan}
  510. ## 参考信息(该节拍的实际续写内容,仅用于帮你构建 CoT,禁止逐句引用)
  511. {beat_text_hint}
  512. ---
  513. 请以"事前决策"的视角自由写出写作思考过程(上文衔接方式、爽点植入、人物动机、对话设计等,无需固定段落),然后直接输出实际续写内容。
  514. <think>
  515. [自由思考过程]
  516. </think>
  517. {actual_text}"""
  518. async def gen_task2_sample(
  519. i: int,
  520. beat: dict,
  521. analysis: dict,
  522. novel_text: str,
  523. task1_samples: list,
  524. context_chars: int,
  525. model: str,
  526. sem: asyncio.Semaphore,
  527. ) -> Optional[dict]:
  528. async with sem:
  529. meta = analysis.get("_meta", {})
  530. title = meta.get("novel_title", "未知")
  531. total_chars = meta.get("total_chars", len(novel_text))
  532. beats = analysis.get("beats", [])
  533. state = build_state_snapshot(analysis, i)
  534. last_disaster, last_decision = get_last_disaster_decision(beats, i)
  535. position_pct = calc_position_percent(beat, total_chars)
  536. ctx_start = max(0, beat["position_start"] - context_chars)
  537. context_text = novel_text[ctx_start: beat["position_start"]].strip()
  538. beat_text = novel_text[beat["position_start"]: beat["position_end"]].strip()
  539. if not beat_text:
  540. return None
  541. # Task2 使用精简版笔记:只含活跃线索和人物,不含近期节拍(上文已涵盖)
  542. story_notes_brief = format_story_notes(analysis, state, last_disaster, last_decision)
  543. # 从 Task1 样本中提取结构规划(assistant 输出部分)
  544. structure_plan = ""
  545. if i < len(task1_samples) and task1_samples[i]:
  546. for msg in task1_samples[i]["messages"]:
  547. if msg["role"] == "assistant":
  548. structure_plan = msg["content"]
  549. break
  550. if not structure_plan:
  551. structure_plan = f"(Task1 未生成,beat 摘要:{beat.get('summary', '')})"
  552. target_words = max(500, (beat["position_end"] - beat["position_start"]) // 2)
  553. # 只给 LLM 前 300 字作为 hint,避免泄露太多
  554. beat_hint = beat_text[:300] + "..." if len(beat_text) > 300 else beat_text
  555. cot_prompt = TASK2_COT_GEN_TMPL.format(
  556. title=title,
  557. position_pct=position_pct,
  558. story_notes_brief=story_notes_brief,
  559. context_text=context_text,
  560. structure_plan=structure_plan,
  561. beat_text_hint=beat_hint,
  562. actual_text=beat_text,
  563. )
  564. messages = [
  565. {"role": "system", "content": TASK2_SYSTEM},
  566. {"role": "user", "content": cot_prompt},
  567. ]
  568. try:
  569. cot_part = await llm_call(messages, model=model)
  570. except ContentFilterError as e:
  571. print(f" [Task2] beat {i+1} 内容审查拦截,跳过:{e}")
  572. return None
  573. except Exception as e:
  574. print(f" [Task2] beat {i+1} LLM 调用失败:{e}")
  575. return None
  576. # 确保输出格式:<think>...</think>\n\n{实际正文}
  577. if "<think>" in cot_part and beat_text not in cot_part:
  578. # LLM 只生成了 CoT,拼接实际文本
  579. think_end = cot_part.find("</think>")
  580. if think_end != -1:
  581. think_block = cot_part[: think_end + len("</think>")]
  582. assistant_content = f"{think_block}\n\n{beat_text}"
  583. else:
  584. assistant_content = f"{cot_part}\n\n{beat_text}"
  585. else:
  586. assistant_content = cot_part
  587. user_content = TASK2_USER_TMPL.format(
  588. title=title,
  589. position_pct=position_pct,
  590. story_notes_brief=story_notes_brief,
  591. context_text=context_text,
  592. structure_plan=structure_plan,
  593. target_words=target_words,
  594. )
  595. return {
  596. "messages": [
  597. {"role": "system", "content": TASK2_SYSTEM},
  598. {"role": "user", "content": user_content},
  599. {"role": "assistant", "content": assistant_content},
  600. ],
  601. "metadata": {
  602. "task_type": "scene_continuation",
  603. "source_file": meta.get("novel_title", ""),
  604. "chapter": f"第{beat.get('chapter_start', '?')}章",
  605. "position_percent": calc_position_percent(beat, total_chars),
  606. "mice_thread": beat.get("mice_thread", ""),
  607. "beat_id": beat.get("id", ""),
  608. "beat_type": beat["type"],
  609. "word_count": len(beat_text),
  610. },
  611. }
  612. # ──────────────────────────────────────────────────────────────
  613. # Task 3:爽点注入(Shuang Point Injection)
  614. # ──────────────────────────────────────────────────────────────
  615. TASK3_SYSTEM = (
  616. "你是一位专业的网文编辑,擅长识别和设计爽点(打脸、升级、装逼、获得、碾压),"
  617. "能在不改变核心情节的前提下大幅提升情感冲击力。"
  618. )
  619. TASK3_GEN_TMPL = """\
  620. ## 故事背景(用于理解爽点来源)
  621. {story_notes_brief}
  622. ---
  623. ## 原文(包含爽点的增强版)
  624. {beat_text}
  625. ---
  626. ## 任务
  627. 1. 判断这段文字是否包含明显爽点(打脸/升级/装逼/获得/碾压)
  628. 2. 如果有,生成去掉爽点后的"平淡草稿"(保留核心情节事件,但去掉爽感设计)
  629. 3. 以编辑视角,写出重新注入爽点的完整思考过程(CoT)和修改说明
  630. 注意:CoT 应分析人物性格/关系如何使这个爽点成立,以及与当前剧情线索的联动
  631. **输出格式(严格 JSON)**:
  632. ```json
  633. {{
  634. "has_shuang": true,
  635. "shuang_type": "打脸|升级|装逼|获得|碾压",
  636. "intensity": "low|medium|high",
  637. "flat_draft": "去掉爽点后的平淡版本(完整文字)",
  638. "cot": "<think>\\n[自由分析草稿问题和注入方案,结合人物特质和线索背景]\\n</think>",
  639. "modification_notes": "注入位置:...\\n爽点类型:...\\n关键改动:..."
  640. }}
  641. ```
  642. 如果不包含明显爽点,输出:`{{"has_shuang": false}}`"""
  643. TASK3_USER_TMPL = """\
  644. ## 故事背景
  645. {story_notes_brief}
  646. ---
  647. ## 平淡草稿
  648. {flat_draft}
  649. ## 要求
  650. - 爽点类型:{shuang_type}
  651. - 强度:{intensity}(low=轻微强化 | medium=明显提升 | high=大幅改写)
  652. - 不改变核心情节,只增强情感冲击力
  653. - 结合人物性格特质和当前剧情线索设计爽感
  654. ## 任务
  655. 请注入爽点,输出增强版本。"""
  656. async def gen_task3_sample(
  657. i: int,
  658. beat: dict,
  659. analysis: dict,
  660. novel_text: str,
  661. model: str,
  662. sem: asyncio.Semaphore,
  663. ) -> Optional[dict]:
  664. # 只处理有爽点的 beat
  665. sp = beat.get("shuang_point", {})
  666. if not sp.get("has_shuang"):
  667. return None
  668. async with sem:
  669. meta = analysis.get("_meta", {})
  670. total_chars = meta.get("total_chars", len(novel_text))
  671. beats = analysis.get("beats", [])
  672. state = build_state_snapshot(analysis, i)
  673. last_disaster, last_decision = get_last_disaster_decision(beats, i)
  674. story_notes_brief = format_story_notes(analysis, state, last_disaster, last_decision)
  675. beat_text = novel_text[beat["position_start"]: beat["position_end"]].strip()
  676. if len(beat_text) < 200:
  677. return None
  678. # 生成平淡草稿 + CoT
  679. gen_prompt = TASK3_GEN_TMPL.format(
  680. story_notes_brief=story_notes_brief,
  681. beat_text=beat_text,
  682. )
  683. messages = [
  684. {"role": "system", "content": TASK3_SYSTEM},
  685. {"role": "user", "content": gen_prompt},
  686. ]
  687. try:
  688. raw = await llm_call(messages, model=model)
  689. except ContentFilterError as e:
  690. print(f" [Task3] beat {i+1} 内容审查拦截,跳过:{e}")
  691. return None
  692. except Exception as e:
  693. print(f" [Task3] beat {i+1} LLM 调用失败:{e}")
  694. return None
  695. try:
  696. result = extract_json_block(raw)
  697. except Exception:
  698. # 保存原始响应供排查
  699. debug_path = Path(f"/tmp/task3_beat{i+1}_debug.txt")
  700. debug_path.write_text(raw, encoding="utf-8")
  701. print(f" [Task3] beat {i+1} JSON 解析失败,原始响应已保存至 {debug_path},跳过")
  702. return None
  703. if not result.get("has_shuang"):
  704. return None
  705. flat_draft = result.get("flat_draft", "")
  706. cot = result.get("cot", "")
  707. modification_notes = result.get("modification_notes", "")
  708. shuang_type = result.get("shuang_type", sp.get("type", ""))
  709. intensity = result.get("intensity", sp.get("intensity", "medium"))
  710. if not flat_draft or not cot:
  711. return None
  712. # 训练样本
  713. user_content = TASK3_USER_TMPL.format(
  714. story_notes_brief=story_notes_brief,
  715. flat_draft=flat_draft,
  716. shuang_type=shuang_type,
  717. intensity=intensity,
  718. )
  719. # 输出:CoT + 增强版(原文)+ 修改说明
  720. assistant_content = (
  721. f"{cot}\n\n"
  722. f"{beat_text}\n\n"
  723. f"---\n**修改说明**:\n{modification_notes}"
  724. )
  725. return {
  726. "messages": [
  727. {"role": "system", "content": TASK3_SYSTEM},
  728. {"role": "user", "content": user_content},
  729. {"role": "assistant", "content": assistant_content},
  730. ],
  731. "metadata": {
  732. "task_type": "shuang_injection",
  733. "source_file": meta.get("novel_title", ""),
  734. "chapter": f"第{beat.get('chapter_start', '?')}章",
  735. "position_percent": calc_position_percent(beat, total_chars),
  736. "shuang_type": shuang_type,
  737. "intensity": intensity,
  738. "beat_id": beat.get("id", ""),
  739. "word_count": len(beat_text),
  740. },
  741. }
  742. # ──────────────────────────────────────────────────────────────
  743. # 主流程
  744. # ──────────────────────────────────────────────────────────────
  745. async def build_all(
  746. analysis_path: str,
  747. novel_path: str,
  748. output_dir: str,
  749. context_chars: int,
  750. skip_tasks: Set[int],
  751. model: str,
  752. concurrency: int,
  753. max_beats: Optional[int] = None,
  754. ):
  755. with open(analysis_path, encoding="utf-8") as f:
  756. analysis = json.load(f)
  757. novel_text = load_text(novel_path)
  758. beats = analysis.get("beats", [])
  759. if max_beats is not None:
  760. beats = beats[:max_beats]
  761. analysis = dict(analysis, beats=beats) # 局部视图,不修改文件
  762. out = Path(output_dir)
  763. sem = asyncio.Semaphore(concurrency)
  764. print(f"\n分析文件:{analysis_path}")
  765. print(f"节拍数:{len(beats)}")
  766. print(f"输出目录:{out}")
  767. print(f"并发数:{concurrency}\n")
  768. stats = {}
  769. # ── Task 1 ──────────────────────────────────
  770. task1_samples: List[Optional[dict]] = [None] * len(beats)
  771. if 1 not in skip_tasks:
  772. print("[Task 1] 结构规划(Structure Planning)...")
  773. tasks = [
  774. gen_task1_sample(i, b, analysis, novel_text, context_chars, model, sem)
  775. for i, b in enumerate(beats)
  776. ]
  777. results = await asyncio.gather(*tasks)
  778. task1_samples = list(results)
  779. valid = [s for s in task1_samples if s]
  780. write_jsonl(valid, out / "task1_structure_planning.jsonl")
  781. stats["task1"] = {"total": len(beats), "valid": len(valid)}
  782. print(f" Task1 完成:{len(valid)}/{len(beats)} 条有效\n")
  783. # ── Task 2 ──────────────────────────────────
  784. if 2 not in skip_tasks:
  785. print("[Task 2] 场景续写(Scene Continuation)...")
  786. tasks = [
  787. gen_task2_sample(
  788. i, b, analysis, novel_text, task1_samples, context_chars, model, sem
  789. )
  790. for i, b in enumerate(beats)
  791. ]
  792. results = await asyncio.gather(*tasks)
  793. valid = [s for s in results if s]
  794. write_jsonl(valid, out / "task2_scene_continuation.jsonl")
  795. stats["task2"] = {"total": len(beats), "valid": len(valid)}
  796. print(f" Task2 完成:{len(valid)}/{len(beats)} 条有效\n")
  797. # ── Task 3 ──────────────────────────────────
  798. if 3 not in skip_tasks:
  799. shuang_beats = [b for b in beats if b.get("shuang_point", {}).get("has_shuang")]
  800. print(f"[Task 3] 爽点注入(Shuang Point Injection)... (共 {len(shuang_beats)} 个有爽点的 beat)")
  801. tasks = [
  802. gen_task3_sample(i, b, analysis, novel_text, model, sem)
  803. for i, b in enumerate(beats)
  804. ]
  805. results = await asyncio.gather(*tasks)
  806. valid = [s for s in results if s]
  807. write_jsonl(valid, out / "task3_shuang_injection.jsonl")
  808. stats["task3"] = {
  809. "total": len(shuang_beats),
  810. "valid": len(valid),
  811. }
  812. print(f" Task3 完成:{len(valid)}/{len(shuang_beats)} 条有效\n")
  813. # ── 统计 ──────────────────────────────────
  814. stats_path = out / "stats.json"
  815. stats_path.write_text(json.dumps(stats, ensure_ascii=False, indent=2), encoding="utf-8")
  816. print(f"统计信息 → {stats_path}")
  817. total_valid = sum(v.get("valid", 0) for v in stats.values())
  818. print(f"\n全部完成。总有效样本数:{total_valid}")
  819. def main():
  820. parser = argparse.ArgumentParser(description="步骤2:生成三类 SFT 训练数据")
  821. parser.add_argument("--analysis", required=True, help="step1 输出的 analysis JSON")
  822. parser.add_argument("--novel", required=True, help="小说 txt 文件路径")
  823. parser.add_argument("--output-dir", required=True, help="输出目录")
  824. parser.add_argument(
  825. "--context-chars", type=int, default=800,
  826. help="Task1/2 的上文字符数(默认 800)",
  827. )
  828. parser.add_argument(
  829. "--skip-task", type=int, action="append", default=[],
  830. metavar="N", help="跳过某个任务(1/2/3),可多次指定",
  831. )
  832. parser.add_argument(
  833. "--concurrency", type=int, default=5,
  834. help="并发 LLM 调用数(默认 5)",
  835. )
  836. parser.add_argument("--model", default="qwen-plus", help="使用的模型名称")
  837. parser.add_argument(
  838. "--max-beats", type=int, default=None,
  839. help="只处理前 N 个 beat(用于试运行验证)",
  840. )
  841. args = parser.parse_args()
  842. asyncio.run(
  843. build_all(
  844. args.analysis,
  845. args.novel,
  846. args.output_dir,
  847. args.context_chars,
  848. set(args.skip_task),
  849. args.model,
  850. args.concurrency,
  851. args.max_beats,
  852. )
  853. )
  854. if __name__ == "__main__":
  855. main()