topic_aware_chunking.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452
  1. """
  2. 主题感知分块
  3. """
  4. from __future__ import annotations
  5. import re, uuid, math
  6. from dataclasses import dataclass, field, asdict
  7. from typing import List, Dict, Any, Tuple, Optional
  8. import optuna
  9. import numpy as np
  10. from sentence_transformers import SentenceTransformer, util
  11. from applications.utils import SplitTextIntoSentences, detect_language, num_tokens
  12. # ---------- Utilities ----------
  13. def simple_sent_tokenize(text: str) -> List[str]:
  14. text = re.sub(r"\n{2,}", "\n", text)
  15. parts = re.split(r"([。!?!?;;]+)\s*|\n+", text)
  16. sents, buf = [], ""
  17. for p in parts:
  18. if p is None:
  19. continue
  20. if re.match(r"[。!?!?;;]+", p or ""):
  21. buf += p or ""
  22. if buf.strip():
  23. sents.append(buf.strip())
  24. buf = ""
  25. elif p.strip() == "":
  26. if buf.strip():
  27. sents.append(buf.strip())
  28. buf = ""
  29. else:
  30. buf += p or ""
  31. if buf.strip():
  32. sents.append(buf.strip())
  33. merged = []
  34. for s in sents:
  35. if merged and (len(s) < 10 or len(merged[-1]) < 10):
  36. merged[-1] += " " + s
  37. else:
  38. merged.append(s)
  39. return [s for s in merged if s.strip()]
  40. def approx_tokens(text: str) -> int:
  41. """Cheap token estimator (≈4 chars/token for zh, ≈0.75 words/token for en)."""
  42. # This is a heuristic; replace with tiktoken if desired.
  43. cjk = re.findall(r"[\u4e00-\u9fff]", text)
  44. others = re.sub(r"[\u4e00-\u9fff]", " ", text).split()
  45. return max(1, int(len(cjk) / 2.5 + len(others) / 0.75))
  46. # ---------- Knowledge Graph Stub ----------
  47. class KGClassifier:
  48. """
  49. Hierarchical topic classifier using embedding prototypes per node.
  50. Replace `nodes` with your KG; each node keeps a centroid embedding.
  51. """
  52. def __init__(self, model: SentenceTransformer, kg_spec: Dict[str, Any]):
  53. """
  54. kg_spec example:
  55. {
  56. "root": {
  57. "name": "root",
  58. "children": [
  59. {"name": "Computer Science", "children":[
  60. {"name":"NLP", "children":[{"name":"RAG", "children":[]}]}]},
  61. {"name": "Finance", "children":[{"name":"AP/AR", "children":[]}]}]}
  62. }
  63. """
  64. self.model = model
  65. self.root = kg_spec["root"]
  66. self._embed_cache = {} # name -> vector
  67. def build_centroid(node):
  68. name = node["name"]
  69. if name not in self._embed_cache:
  70. self._embed_cache[name] = self.model.encode(
  71. name, normalize_embeddings=True
  72. )
  73. for ch in node.get("children", []):
  74. build_centroid(ch)
  75. build_centroid(self.root)
  76. def classify(self, text_emb: np.ndarray, topk: int = 3) -> Tuple[List[str], float]:
  77. """
  78. Return (topic_path, purity). Purity is soft max margin across levels.
  79. """
  80. path, purities = [], []
  81. node = self.root
  82. while True:
  83. # score current node children
  84. children = node.get("children", [])
  85. if not children:
  86. break
  87. scores = []
  88. for ch in children:
  89. vec = self._embed_cache[ch["name"]]
  90. scores.append((ch, float(util.cos_sim(text_emb, vec).item())))
  91. scores.sort(key=lambda x: x[1], reverse=True)
  92. best, second = scores[0], (scores[1] if len(scores) > 1 else (None, -1.0))
  93. path.append(best[0]["name"])
  94. margin = max(0.0, (best[1] - max(second[1], -1.0)))
  95. purities.append(1 / (1 + math.exp(-5 * margin))) # squash to (0,1)
  96. node = best[0]
  97. purity = float(np.mean(purities)) if purities else 1.0
  98. return path, purity
  99. # ---------- Core Chunker ----------
  100. @dataclass
  101. class Chunk:
  102. id: str
  103. text: str
  104. tokens: int
  105. topics: List[str] = field(default_factory=list)
  106. topic_purity: float = 1.0
  107. meta: Dict[str, Any] = field(default_factory=dict)
  108. @dataclass
  109. class ChunkerConfig:
  110. model_name: str = "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2"
  111. target_tokens: int = 80
  112. max_tokens: int = 80
  113. overlap_ratio: float = 0.12
  114. boundary_threshold: float = 0.50 # similarity drop boundary (lower -> more cuts)
  115. min_sent_per_chunk: int = 1
  116. max_sent_per_chunk: int = 8
  117. enable_adaptive_boundary: bool = True
  118. enable_kg: bool = True
  119. topic_purity_floor: float = 0.65
  120. kg_topk: int = 3
  121. class TopicAwareChunker:
  122. def __init__(self, cfg: ChunkerConfig, kg_spec: Optional[Dict[str, Any]] = None):
  123. self.cfg = cfg
  124. self.model = SentenceTransformer(
  125. cfg.model_name, device="cpu"
  126. ) # set gpu if available
  127. self.model.max_seq_length = 512
  128. self.kg = (
  129. KGClassifier(self.model, kg_spec) if (cfg.enable_kg and kg_spec) else None
  130. )
  131. # ---------- Public API ----------
  132. def chunk(self, text: str) -> List[Chunk]:
  133. sents = simple_sent_tokenize(text)
  134. if not sents:
  135. return []
  136. sent_embs = self.model.encode(sents, normalize_embeddings=True)
  137. boundaries = self._detect_boundaries(sents, sent_embs)
  138. raw_chunks = self._pack_by_boundaries(sents, sent_embs, boundaries)
  139. final_chunks = self._classify_and_refine(raw_chunks)
  140. return final_chunks
  141. # ---------- Boundary detection ----------
  142. def _detect_boundaries(self, sents: List[str], embs: np.ndarray) -> List[int]:
  143. sims = util.cos_sim(embs[:-1], embs[1:]).cpu().numpy().reshape(-1)
  144. cut_scores = 1 - sims # higher means more likely boundary
  145. # use np.ptp instead of ndarray.ptp (NumPy 2.0 compatibility)
  146. rng = np.ptp(cut_scores) if np.ptp(cut_scores) > 0 else 1e-6
  147. cut_scores = (cut_scores - cut_scores.min()) / (rng + 1e-6)
  148. boundaries = []
  149. for i, score in enumerate(cut_scores):
  150. # 对应的是句对 (i, i+1),这里可以检查 sents[i] 或 sents[i+1]
  151. sent_to_check = sents[i] if i < len(sents) else sents[-1]
  152. # 防御性写法,避免越界
  153. snippet = sent_to_check[-20:] if sent_to_check else ""
  154. turn = (
  155. 0.1
  156. if re.search(
  157. r"(因此|但是|综上|然而|另一方面|In conclusion|However|Therefore)",
  158. snippet,
  159. )
  160. else 0.0
  161. )
  162. fig = (
  163. 0.1
  164. if re.search(
  165. r"(见下图|如表|表\s*\d+|图\s*\d+|Figure|Table)", sent_to_check
  166. )
  167. else 0.0
  168. )
  169. adj_score = score + turn + fig
  170. if adj_score >= self.cfg.boundary_threshold:
  171. boundaries.append(i)
  172. return boundaries
  173. # ---------- Packing ----------
  174. def _pack_by_boundaries(
  175. self, sents: List[str], embs: np.ndarray, boundaries: List[int]
  176. ) -> List[Chunk]:
  177. """Greedy pack around boundaries to meet target length & sentence counts."""
  178. boundary_set = set(boundaries)
  179. chunks: List[Chunk] = []
  180. start = 0
  181. n = len(sents)
  182. while start < n:
  183. end = start
  184. cur_tokens = 0
  185. sent_count = 0
  186. last_boundary = start - 1
  187. while end < n and sent_count < self.cfg.max_sent_per_chunk:
  188. cur_tokens = approx_tokens(" ".join(sents[start : end + 1]))
  189. sent_count += 1
  190. if cur_tokens >= self.cfg.target_tokens:
  191. # try to cut at nearest boundary to 'end'
  192. cut = end
  193. # search backward to nearest boundary within window
  194. for b in range(end, start - 1, -1):
  195. if b in boundary_set:
  196. cut = b
  197. break
  198. # avoid too small chunks
  199. if cut - start + 1 >= self.cfg.min_sent_per_chunk:
  200. end = cut
  201. break
  202. end += 1
  203. # finalize chunk
  204. text = " ".join(sents[start : end + 1]).strip()
  205. tokens = approx_tokens(text)
  206. chunk = Chunk(id=str(uuid.uuid4()), text=text, tokens=tokens)
  207. chunks.append(chunk)
  208. # soft overlap (append tail sentences of current as head of next)
  209. if self.cfg.overlap_ratio > 0 and end + 1 < n:
  210. overlap_tokens = int(tokens * self.cfg.overlap_ratio)
  211. # approximate by sentences
  212. overlap_sents = []
  213. t = 0
  214. for s in reversed(sents[start : end + 1]):
  215. t += approx_tokens(s)
  216. overlap_sents.append(s)
  217. if t >= overlap_tokens:
  218. break
  219. # prepend to next start by reducing start index backward (not altering original sents)
  220. start = end + 1
  221. return chunks
  222. # ---------- KG classify & refine ----------
  223. def _classify_and_refine(self, chunks: List[Chunk]) -> List[Chunk]:
  224. if not self.kg:
  225. return chunks
  226. refined: List[Chunk] = []
  227. for ch in chunks:
  228. emb = self.model.encode(ch.text, normalize_embeddings=True)
  229. topics, purity = self.kg.classify(emb, topk=self.cfg.kg_topk)
  230. ch.topics, ch.topic_purity = topics, purity
  231. # If purity is low, try a secondary split inside the chunk
  232. if purity < self.cfg.topic_purity_floor:
  233. sub = self._refine_chunk_by_topic(ch)
  234. refined.extend(sub)
  235. else:
  236. refined.append(ch)
  237. return refined
  238. def _refine_chunk_by_topic(self, chunk: Chunk) -> List[Chunk]:
  239. """Second-pass split inside a low-purity chunk."""
  240. sents = simple_sent_tokenize(chunk.text)
  241. if len(sents) <= self.cfg.min_sent_per_chunk * 2:
  242. return [chunk]
  243. embs = self.model.encode(sents, normalize_embeddings=True)
  244. # force more boundaries by lowering threshold a bit
  245. orig = self.cfg.boundary_threshold
  246. try:
  247. self.cfg.boundary_threshold = max(0.3, orig - 0.1)
  248. boundaries = self._detect_boundaries(sents, embs)
  249. sub_chunks = self._pack_by_boundaries(sents, embs, boundaries)
  250. # inherit topics again
  251. final = []
  252. for ch in sub_chunks:
  253. emb = self.model.encode(ch.text, normalize_embeddings=True)
  254. topics, purity = self.kg.classify(emb, topk=self.cfg.kg_topk)
  255. ch.topics, ch.topic_purity = topics, purity
  256. final.append(ch)
  257. return final
  258. finally:
  259. self.cfg.boundary_threshold = orig
  260. # ---------- Auto-tuning (unsupervised objective) ----------
  261. class UnsupervisedEvaluator:
  262. """
  263. Build a score: higher is better.
  264. - Intra-chunk coherence (avg similarity of neighboring sentences)
  265. - Inter-chunk separation (low similarity of chunk medoids to neighbors)
  266. - Length penalty (deviation from target_tokens)
  267. - Topic purity reward (if KG is enabled)
  268. """
  269. def __init__(
  270. self, model: SentenceTransformer, target_tokens: int, kg_weight: float = 0.5
  271. ):
  272. self.model = model
  273. self.target = target_tokens
  274. self.kg_weight = kg_weight
  275. def score(self, chunks: List[Chunk], kg_present: bool = True) -> float:
  276. if not chunks:
  277. return -1e6
  278. # Intra coherence: reward high
  279. intra = []
  280. for ch in chunks:
  281. sents = simple_sent_tokenize(ch.text)
  282. if len(sents) < 2:
  283. continue
  284. embs = self.model.encode(sents, normalize_embeddings=True)
  285. sims = util.cos_sim(embs[:-1], embs[1:]).cpu().numpy().reshape(-1)
  286. intra.append(float(np.mean(sims)))
  287. intra_score = float(np.mean(intra)) if intra else 0.0
  288. # Inter separation: penalize adjacent chunk similarity
  289. if len(chunks) > 1:
  290. reps = self.model.encode(
  291. [c.text for c in chunks], normalize_embeddings=True
  292. )
  293. adj = []
  294. for i in range(len(chunks) - 1):
  295. adj.append(float(util.cos_sim(reps[i], reps[i + 1]).item()))
  296. inter_penalty = float(np.mean(adj))
  297. else:
  298. inter_penalty = 0.0
  299. # Length penalty
  300. dev = [abs(c.tokens - self.target) / max(1, self.target) for c in chunks]
  301. len_penalty = float(np.mean(dev))
  302. # Topic purity
  303. if kg_present:
  304. pur = [c.topic_purity for c in chunks]
  305. purity = float(np.mean(pur))
  306. else:
  307. purity = 0.0
  308. # Final score
  309. return (
  310. intra_score
  311. - 0.6 * inter_penalty
  312. - 0.4 * len_penalty
  313. + self.kg_weight * purity
  314. )
  315. def auto_tune_params(
  316. raw_texts: List[str],
  317. kg_spec: Optional[Dict[str, Any]] = None,
  318. n_trials: int = 20,
  319. seed: int = 42,
  320. ) -> ChunkerConfig:
  321. """Bayesian-like search with Optuna to find a good config on your corpus."""
  322. def objective(trial: optuna.Trial):
  323. cfg = ChunkerConfig(
  324. target_tokens=trial.suggest_int("target_tokens", 30, 400, step=10),
  325. max_tokens=trial.suggest_int("max_tokens", 30, 520, step=10),
  326. overlap_ratio=trial.suggest_float("overlap_ratio", 0.05, 0.25, step=0.05),
  327. boundary_threshold=trial.suggest_float(
  328. "boundary_threshold", 0.45, 0.75, step=0.05
  329. ),
  330. min_sent_per_chunk=trial.suggest_int("min_sent_per_chunk", 2, 4),
  331. max_sent_per_chunk=trial.suggest_int("max_sent_per_chunk", 8, 16),
  332. enable_adaptive_boundary=True,
  333. enable_kg=(kg_spec is not None),
  334. topic_purity_floor=trial.suggest_float(
  335. "topic_purity_floor", 0.55, 0.8, step=0.05
  336. ),
  337. )
  338. chunker = TopicAwareChunker(cfg, kg_spec=kg_spec)
  339. evaluator = UnsupervisedEvaluator(
  340. chunker.model, cfg.target_tokens, kg_weight=0.5 if kg_spec else 0.0
  341. )
  342. # Evaluate across a small sample
  343. scores = []
  344. for t in raw_texts:
  345. chunks = chunker.chunk(t)
  346. s = evaluator.score(chunks, kg_present=(kg_spec is not None))
  347. scores.append(s)
  348. return float(np.mean(scores))
  349. sampler = optuna.samplers.TPESampler(seed=seed)
  350. study = optuna.create_study(direction="maximize", sampler=sampler)
  351. study.optimize(objective, n_trials=n_trials, show_progress_bar=False)
  352. best_params = study.best_params
  353. return ChunkerConfig(
  354. target_tokens=best_params["target_tokens"],
  355. max_tokens=best_params["max_tokens"],
  356. overlap_ratio=best_params["overlap_ratio"],
  357. boundary_threshold=best_params["boundary_threshold"],
  358. min_sent_per_chunk=best_params["min_sent_per_chunk"],
  359. max_sent_per_chunk=best_params["max_sent_per_chunk"],
  360. enable_adaptive_boundary=True,
  361. enable_kg=(kg_spec is not None),
  362. topic_purity_floor=best_params["topic_purity_floor"],
  363. )
  364. # ---------- Example usage ----------
  365. if __name__ == "__main__":
  366. sample_text = """
  367. RAG(Retrieval-Augmented Generation)是一种增强生成的技术。
  368. 在复杂知识问答中,RAG 通过检索相关文档片段来改善答案质量。
  369. 然而,分块策略会显著影响检索召回与可引用性。
  370. 因此,我们提出一种主题感知的分块方法,结合 Transformer 边界探测与知识图谱层次分类。
  371. 然后,我们讲一个新的主题,篮球
  372. 这个也就是罚球动作。一般原地动作分为两种。
  373. 第一种原地投篮动作是先下蹲,做好投篮的发力前上举动作,然后竖直向上伸直身体,右臂顺势在身体向上的过程中竖直向上将球向上投出。这种原地投篮的好处是,发力轻松,可以借助身体向上竖直的这个力度的趋势,帮助投篮发力,会让投篮的力气减少很多。尤其是在比赛后半程体力不好的时候,依然可以做到很高的命中略。这种投篮的要领是:主动的竖直向上的意识。我们以前就经常强调竖直起跳和竖直的概念,但是,同样看起来是竖直,但是用出来的效果却很不同,这主要就是技巧的关系了。这个技巧的精髓就在于“主动意识”。在你练习这种投篮的时候,每一次,都要在下蹲以后,明确的在脑子里想着,要竖直向上发力。双腿要竖直向上用力,整个身体也是这样,而且,最为重要的是,你一定要在练习的时候每次都要主动的去想,然后刻意的去竖直向上。这样,长久下去,养成习惯,你的这种投篮才会稳定。这里我们要顺便强调之前的一篇文章,就是录像纠错法,我们这里之所以一再强调要主动意识的竖直上起,就是因为,在录像上,未必能看得出来这个问题。也就是说,你的录像虽然看起来你是竖直起跳的,但是你没有一个主动的也就是刻意的竖直起跳的意识的话,这个球也不是竖直起跳。另外,相反的,如果你在视频上看到自己不是竖直起跳,但是实际上这个球是你使用了竖直起跳的主动意识来发力的。那么,尽管看起来不是很竖直,却依然可以很稳定。也就是说,眼睛会欺骗你,一定要注重你的意识。
  374. """
  375. kg_spec = {
  376. "root": {
  377. "name": "root",
  378. "children": [
  379. {
  380. "name": "Computer Science",
  381. "children": [
  382. {"name": "NLP", "children": [{"name": "RAG", "children": []}]}
  383. ],
  384. },
  385. {"name": "Finance", "children": [{"name": "AP/AR", "children": []}]},
  386. {
  387. "name": "体育",
  388. "children": [
  389. {"name": "篮球", "children": [{"name": "投篮", "children": []}]}
  390. ],
  391. },
  392. ],
  393. }
  394. }
  395. cfg = auto_tune_params([sample_text], kg_spec=kg_spec, n_trials=10, seed=42)
  396. chunker = TopicAwareChunker(cfg, kg_spec=kg_spec)
  397. chunks = chunker.chunk(sample_text)
  398. for i, ch in enumerate(chunks, 1):
  399. print(f"\n== Chunk {i} ==")
  400. print("Tokens:", ch.tokens)
  401. print("Topics:", " / ".join(ch.topics), "Purity:", round(ch.topic_purity, 3))
  402. print(ch.text)