|
@@ -1,452 +0,0 @@
|
|
|
-"""
|
|
|
-主题感知分块
|
|
|
-"""
|
|
|
-
|
|
|
-from __future__ import annotations
|
|
|
-
|
|
|
-import re, uuid, math
|
|
|
-from dataclasses import dataclass, field, asdict
|
|
|
-from typing import List, Dict, Any, Tuple, Optional
|
|
|
-
|
|
|
-import optuna
|
|
|
-import numpy as np
|
|
|
-
|
|
|
-from sentence_transformers import SentenceTransformer, util
|
|
|
-
|
|
|
-from applications.utils import SplitTextIntoSentences, detect_language, num_tokens
|
|
|
-
|
|
|
-
|
|
|
-# ---------- Utilities ----------
|
|
|
-def simple_sent_tokenize(text: str) -> List[str]:
|
|
|
- text = re.sub(r"\n{2,}", "\n", text)
|
|
|
- parts = re.split(r"([。!?!?;;]+)\s*|\n+", text)
|
|
|
- sents, buf = [], ""
|
|
|
- for p in parts:
|
|
|
- if p is None:
|
|
|
- continue
|
|
|
- if re.match(r"[。!?!?;;]+", p or ""):
|
|
|
- buf += p or ""
|
|
|
- if buf.strip():
|
|
|
- sents.append(buf.strip())
|
|
|
- buf = ""
|
|
|
- elif p.strip() == "":
|
|
|
- if buf.strip():
|
|
|
- sents.append(buf.strip())
|
|
|
- buf = ""
|
|
|
- else:
|
|
|
- buf += p or ""
|
|
|
- if buf.strip():
|
|
|
- sents.append(buf.strip())
|
|
|
-
|
|
|
- merged = []
|
|
|
- for s in sents:
|
|
|
- if merged and (len(s) < 10 or len(merged[-1]) < 10):
|
|
|
- merged[-1] += " " + s
|
|
|
- else:
|
|
|
- merged.append(s)
|
|
|
- return [s for s in merged if s.strip()]
|
|
|
-
|
|
|
-
|
|
|
-def approx_tokens(text: str) -> int:
|
|
|
- """Cheap token estimator (≈4 chars/token for zh, ≈0.75 words/token for en)."""
|
|
|
- # This is a heuristic; replace with tiktoken if desired.
|
|
|
- cjk = re.findall(r"[\u4e00-\u9fff]", text)
|
|
|
- others = re.sub(r"[\u4e00-\u9fff]", " ", text).split()
|
|
|
- return max(1, int(len(cjk) / 2.5 + len(others) / 0.75))
|
|
|
-
|
|
|
-
|
|
|
-# ---------- Knowledge Graph Stub ----------
|
|
|
-class KGClassifier:
|
|
|
- """
|
|
|
- Hierarchical topic classifier using embedding prototypes per node.
|
|
|
- Replace `nodes` with your KG; each node keeps a centroid embedding.
|
|
|
- """
|
|
|
-
|
|
|
- def __init__(self, model: SentenceTransformer, kg_spec: Dict[str, Any]):
|
|
|
- """
|
|
|
- kg_spec example:
|
|
|
- {
|
|
|
- "root": {
|
|
|
- "name": "root",
|
|
|
- "children": [
|
|
|
- {"name": "Computer Science", "children":[
|
|
|
- {"name":"NLP", "children":[{"name":"RAG", "children":[]}]}]},
|
|
|
- {"name": "Finance", "children":[{"name":"AP/AR", "children":[]}]}]}
|
|
|
- }
|
|
|
- """
|
|
|
- self.model = model
|
|
|
- self.root = kg_spec["root"]
|
|
|
- self._embed_cache = {} # name -> vector
|
|
|
-
|
|
|
- def build_centroid(node):
|
|
|
- name = node["name"]
|
|
|
- if name not in self._embed_cache:
|
|
|
- self._embed_cache[name] = self.model.encode(
|
|
|
- name, normalize_embeddings=True
|
|
|
- )
|
|
|
- for ch in node.get("children", []):
|
|
|
- build_centroid(ch)
|
|
|
-
|
|
|
- build_centroid(self.root)
|
|
|
-
|
|
|
- def classify(self, text_emb: np.ndarray, topk: int = 3) -> Tuple[List[str], float]:
|
|
|
- """
|
|
|
- Return (topic_path, purity). Purity is soft max margin across levels.
|
|
|
- """
|
|
|
- path, purities = [], []
|
|
|
- node = self.root
|
|
|
- while True:
|
|
|
- # score current node children
|
|
|
- children = node.get("children", [])
|
|
|
- if not children:
|
|
|
- break
|
|
|
- scores = []
|
|
|
- for ch in children:
|
|
|
- vec = self._embed_cache[ch["name"]]
|
|
|
- scores.append((ch, float(util.cos_sim(text_emb, vec).item())))
|
|
|
- scores.sort(key=lambda x: x[1], reverse=True)
|
|
|
- best, second = scores[0], (scores[1] if len(scores) > 1 else (None, -1.0))
|
|
|
- path.append(best[0]["name"])
|
|
|
- margin = max(0.0, (best[1] - max(second[1], -1.0)))
|
|
|
- purities.append(1 / (1 + math.exp(-5 * margin))) # squash to (0,1)
|
|
|
- node = best[0]
|
|
|
- purity = float(np.mean(purities)) if purities else 1.0
|
|
|
- return path, purity
|
|
|
-
|
|
|
-
|
|
|
-# ---------- Core Chunker ----------
|
|
|
-@dataclass
|
|
|
-class Chunk:
|
|
|
- id: str
|
|
|
- text: str
|
|
|
- tokens: int
|
|
|
- topics: List[str] = field(default_factory=list)
|
|
|
- topic_purity: float = 1.0
|
|
|
- meta: Dict[str, Any] = field(default_factory=dict)
|
|
|
-
|
|
|
-
|
|
|
-@dataclass
|
|
|
-class ChunkerConfig:
|
|
|
- model_name: str = "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2"
|
|
|
- target_tokens: int = 80
|
|
|
- max_tokens: int = 80
|
|
|
- overlap_ratio: float = 0.12
|
|
|
- boundary_threshold: float = 0.50 # similarity drop boundary (lower -> more cuts)
|
|
|
- min_sent_per_chunk: int = 1
|
|
|
- max_sent_per_chunk: int = 8
|
|
|
- enable_adaptive_boundary: bool = True
|
|
|
- enable_kg: bool = True
|
|
|
- topic_purity_floor: float = 0.65
|
|
|
- kg_topk: int = 3
|
|
|
-
|
|
|
-
|
|
|
-class TopicAwareChunker:
|
|
|
- def __init__(self, cfg: ChunkerConfig, kg_spec: Optional[Dict[str, Any]] = None):
|
|
|
- self.cfg = cfg
|
|
|
- self.model = SentenceTransformer(
|
|
|
- cfg.model_name, device="cpu"
|
|
|
- ) # set gpu if available
|
|
|
- self.model.max_seq_length = 512
|
|
|
- self.kg = (
|
|
|
- KGClassifier(self.model, kg_spec) if (cfg.enable_kg and kg_spec) else None
|
|
|
- )
|
|
|
-
|
|
|
- # ---------- Public API ----------
|
|
|
- def chunk(self, text: str) -> List[Chunk]:
|
|
|
- sents = simple_sent_tokenize(text)
|
|
|
- if not sents:
|
|
|
- return []
|
|
|
- sent_embs = self.model.encode(sents, normalize_embeddings=True)
|
|
|
- boundaries = self._detect_boundaries(sents, sent_embs)
|
|
|
- raw_chunks = self._pack_by_boundaries(sents, sent_embs, boundaries)
|
|
|
- final_chunks = self._classify_and_refine(raw_chunks)
|
|
|
- return final_chunks
|
|
|
-
|
|
|
- # ---------- Boundary detection ----------
|
|
|
- def _detect_boundaries(self, sents: List[str], embs: np.ndarray) -> List[int]:
|
|
|
- sims = util.cos_sim(embs[:-1], embs[1:]).cpu().numpy().reshape(-1)
|
|
|
- cut_scores = 1 - sims # higher means more likely boundary
|
|
|
-
|
|
|
- # use np.ptp instead of ndarray.ptp (NumPy 2.0 compatibility)
|
|
|
- rng = np.ptp(cut_scores) if np.ptp(cut_scores) > 0 else 1e-6
|
|
|
- cut_scores = (cut_scores - cut_scores.min()) / (rng + 1e-6)
|
|
|
-
|
|
|
- boundaries = []
|
|
|
- for i, score in enumerate(cut_scores):
|
|
|
- # 对应的是句对 (i, i+1),这里可以检查 sents[i] 或 sents[i+1]
|
|
|
- sent_to_check = sents[i] if i < len(sents) else sents[-1]
|
|
|
- # 防御性写法,避免越界
|
|
|
- snippet = sent_to_check[-20:] if sent_to_check else ""
|
|
|
-
|
|
|
- turn = (
|
|
|
- 0.1
|
|
|
- if re.search(
|
|
|
- r"(因此|但是|综上|然而|另一方面|In conclusion|However|Therefore)",
|
|
|
- snippet,
|
|
|
- )
|
|
|
- else 0.0
|
|
|
- )
|
|
|
- fig = (
|
|
|
- 0.1
|
|
|
- if re.search(
|
|
|
- r"(见下图|如表|表\s*\d+|图\s*\d+|Figure|Table)", sent_to_check
|
|
|
- )
|
|
|
- else 0.0
|
|
|
- )
|
|
|
-
|
|
|
- adj_score = score + turn + fig
|
|
|
- if adj_score >= self.cfg.boundary_threshold:
|
|
|
- boundaries.append(i)
|
|
|
-
|
|
|
- return boundaries
|
|
|
-
|
|
|
- # ---------- Packing ----------
|
|
|
- def _pack_by_boundaries(
|
|
|
- self, sents: List[str], embs: np.ndarray, boundaries: List[int]
|
|
|
- ) -> List[Chunk]:
|
|
|
- """Greedy pack around boundaries to meet target length & sentence counts."""
|
|
|
- boundary_set = set(boundaries)
|
|
|
- chunks: List[Chunk] = []
|
|
|
- start = 0
|
|
|
- n = len(sents)
|
|
|
- while start < n:
|
|
|
- end = start
|
|
|
- cur_tokens = 0
|
|
|
- sent_count = 0
|
|
|
- last_boundary = start - 1
|
|
|
- while end < n and sent_count < self.cfg.max_sent_per_chunk:
|
|
|
- cur_tokens = approx_tokens(" ".join(sents[start : end + 1]))
|
|
|
- sent_count += 1
|
|
|
- if cur_tokens >= self.cfg.target_tokens:
|
|
|
- # try to cut at nearest boundary to 'end'
|
|
|
- cut = end
|
|
|
- # search backward to nearest boundary within window
|
|
|
- for b in range(end, start - 1, -1):
|
|
|
- if b in boundary_set:
|
|
|
- cut = b
|
|
|
- break
|
|
|
- # avoid too small chunks
|
|
|
- if cut - start + 1 >= self.cfg.min_sent_per_chunk:
|
|
|
- end = cut
|
|
|
- break
|
|
|
- end += 1
|
|
|
-
|
|
|
- # finalize chunk
|
|
|
- text = " ".join(sents[start : end + 1]).strip()
|
|
|
- tokens = approx_tokens(text)
|
|
|
- chunk = Chunk(id=str(uuid.uuid4()), text=text, tokens=tokens)
|
|
|
- chunks.append(chunk)
|
|
|
-
|
|
|
- # soft overlap (append tail sentences of current as head of next)
|
|
|
- if self.cfg.overlap_ratio > 0 and end + 1 < n:
|
|
|
- overlap_tokens = int(tokens * self.cfg.overlap_ratio)
|
|
|
- # approximate by sentences
|
|
|
- overlap_sents = []
|
|
|
- t = 0
|
|
|
- for s in reversed(sents[start : end + 1]):
|
|
|
- t += approx_tokens(s)
|
|
|
- overlap_sents.append(s)
|
|
|
- if t >= overlap_tokens:
|
|
|
- break
|
|
|
- # prepend to next start by reducing start index backward (not altering original sents)
|
|
|
- start = end + 1
|
|
|
- return chunks
|
|
|
-
|
|
|
- # ---------- KG classify & refine ----------
|
|
|
- def _classify_and_refine(self, chunks: List[Chunk]) -> List[Chunk]:
|
|
|
- if not self.kg:
|
|
|
- return chunks
|
|
|
- refined: List[Chunk] = []
|
|
|
- for ch in chunks:
|
|
|
- emb = self.model.encode(ch.text, normalize_embeddings=True)
|
|
|
- topics, purity = self.kg.classify(emb, topk=self.cfg.kg_topk)
|
|
|
- ch.topics, ch.topic_purity = topics, purity
|
|
|
- # If purity is low, try a secondary split inside the chunk
|
|
|
- if purity < self.cfg.topic_purity_floor:
|
|
|
- sub = self._refine_chunk_by_topic(ch)
|
|
|
- refined.extend(sub)
|
|
|
- else:
|
|
|
- refined.append(ch)
|
|
|
- return refined
|
|
|
-
|
|
|
- def _refine_chunk_by_topic(self, chunk: Chunk) -> List[Chunk]:
|
|
|
- """Second-pass split inside a low-purity chunk."""
|
|
|
- sents = simple_sent_tokenize(chunk.text)
|
|
|
- if len(sents) <= self.cfg.min_sent_per_chunk * 2:
|
|
|
- return [chunk]
|
|
|
- embs = self.model.encode(sents, normalize_embeddings=True)
|
|
|
- # force more boundaries by lowering threshold a bit
|
|
|
- orig = self.cfg.boundary_threshold
|
|
|
- try:
|
|
|
- self.cfg.boundary_threshold = max(0.3, orig - 0.1)
|
|
|
- boundaries = self._detect_boundaries(sents, embs)
|
|
|
- sub_chunks = self._pack_by_boundaries(sents, embs, boundaries)
|
|
|
- # inherit topics again
|
|
|
- final = []
|
|
|
- for ch in sub_chunks:
|
|
|
- emb = self.model.encode(ch.text, normalize_embeddings=True)
|
|
|
- topics, purity = self.kg.classify(emb, topk=self.cfg.kg_topk)
|
|
|
- ch.topics, ch.topic_purity = topics, purity
|
|
|
- final.append(ch)
|
|
|
- return final
|
|
|
- finally:
|
|
|
- self.cfg.boundary_threshold = orig
|
|
|
-
|
|
|
-
|
|
|
-# ---------- Auto-tuning (unsupervised objective) ----------
|
|
|
-class UnsupervisedEvaluator:
|
|
|
- """
|
|
|
- Build a score: higher is better.
|
|
|
- - Intra-chunk coherence (avg similarity of neighboring sentences)
|
|
|
- - Inter-chunk separation (low similarity of chunk medoids to neighbors)
|
|
|
- - Length penalty (deviation from target_tokens)
|
|
|
- - Topic purity reward (if KG is enabled)
|
|
|
- """
|
|
|
-
|
|
|
- def __init__(
|
|
|
- self, model: SentenceTransformer, target_tokens: int, kg_weight: float = 0.5
|
|
|
- ):
|
|
|
- self.model = model
|
|
|
- self.target = target_tokens
|
|
|
- self.kg_weight = kg_weight
|
|
|
-
|
|
|
- def score(self, chunks: List[Chunk], kg_present: bool = True) -> float:
|
|
|
- if not chunks:
|
|
|
- return -1e6
|
|
|
- # Intra coherence: reward high
|
|
|
- intra = []
|
|
|
- for ch in chunks:
|
|
|
- sents = simple_sent_tokenize(ch.text)
|
|
|
- if len(sents) < 2:
|
|
|
- continue
|
|
|
- embs = self.model.encode(sents, normalize_embeddings=True)
|
|
|
- sims = util.cos_sim(embs[:-1], embs[1:]).cpu().numpy().reshape(-1)
|
|
|
- intra.append(float(np.mean(sims)))
|
|
|
- intra_score = float(np.mean(intra)) if intra else 0.0
|
|
|
-
|
|
|
- # Inter separation: penalize adjacent chunk similarity
|
|
|
- if len(chunks) > 1:
|
|
|
- reps = self.model.encode(
|
|
|
- [c.text for c in chunks], normalize_embeddings=True
|
|
|
- )
|
|
|
- adj = []
|
|
|
- for i in range(len(chunks) - 1):
|
|
|
- adj.append(float(util.cos_sim(reps[i], reps[i + 1]).item()))
|
|
|
- inter_penalty = float(np.mean(adj))
|
|
|
- else:
|
|
|
- inter_penalty = 0.0
|
|
|
-
|
|
|
- # Length penalty
|
|
|
- dev = [abs(c.tokens - self.target) / max(1, self.target) for c in chunks]
|
|
|
- len_penalty = float(np.mean(dev))
|
|
|
-
|
|
|
- # Topic purity
|
|
|
- if kg_present:
|
|
|
- pur = [c.topic_purity for c in chunks]
|
|
|
- purity = float(np.mean(pur))
|
|
|
- else:
|
|
|
- purity = 0.0
|
|
|
-
|
|
|
- # Final score
|
|
|
- return (
|
|
|
- intra_score
|
|
|
- - 0.6 * inter_penalty
|
|
|
- - 0.4 * len_penalty
|
|
|
- + self.kg_weight * purity
|
|
|
- )
|
|
|
-
|
|
|
-
|
|
|
-def auto_tune_params(
|
|
|
- raw_texts: List[str],
|
|
|
- kg_spec: Optional[Dict[str, Any]] = None,
|
|
|
- n_trials: int = 20,
|
|
|
- seed: int = 42,
|
|
|
-) -> ChunkerConfig:
|
|
|
- """Bayesian-like search with Optuna to find a good config on your corpus."""
|
|
|
-
|
|
|
- def objective(trial: optuna.Trial):
|
|
|
- cfg = ChunkerConfig(
|
|
|
- target_tokens=trial.suggest_int("target_tokens", 30, 400, step=10),
|
|
|
- max_tokens=trial.suggest_int("max_tokens", 30, 520, step=10),
|
|
|
- overlap_ratio=trial.suggest_float("overlap_ratio", 0.05, 0.25, step=0.05),
|
|
|
- boundary_threshold=trial.suggest_float(
|
|
|
- "boundary_threshold", 0.45, 0.75, step=0.05
|
|
|
- ),
|
|
|
- min_sent_per_chunk=trial.suggest_int("min_sent_per_chunk", 2, 4),
|
|
|
- max_sent_per_chunk=trial.suggest_int("max_sent_per_chunk", 8, 16),
|
|
|
- enable_adaptive_boundary=True,
|
|
|
- enable_kg=(kg_spec is not None),
|
|
|
- topic_purity_floor=trial.suggest_float(
|
|
|
- "topic_purity_floor", 0.55, 0.8, step=0.05
|
|
|
- ),
|
|
|
- )
|
|
|
- chunker = TopicAwareChunker(cfg, kg_spec=kg_spec)
|
|
|
- evaluator = UnsupervisedEvaluator(
|
|
|
- chunker.model, cfg.target_tokens, kg_weight=0.5 if kg_spec else 0.0
|
|
|
- )
|
|
|
-
|
|
|
- # Evaluate across a small sample
|
|
|
- scores = []
|
|
|
- for t in raw_texts:
|
|
|
- chunks = chunker.chunk(t)
|
|
|
- s = evaluator.score(chunks, kg_present=(kg_spec is not None))
|
|
|
- scores.append(s)
|
|
|
- return float(np.mean(scores))
|
|
|
-
|
|
|
- sampler = optuna.samplers.TPESampler(seed=seed)
|
|
|
- study = optuna.create_study(direction="maximize", sampler=sampler)
|
|
|
- study.optimize(objective, n_trials=n_trials, show_progress_bar=False)
|
|
|
- best_params = study.best_params
|
|
|
-
|
|
|
- return ChunkerConfig(
|
|
|
- target_tokens=best_params["target_tokens"],
|
|
|
- max_tokens=best_params["max_tokens"],
|
|
|
- overlap_ratio=best_params["overlap_ratio"],
|
|
|
- boundary_threshold=best_params["boundary_threshold"],
|
|
|
- min_sent_per_chunk=best_params["min_sent_per_chunk"],
|
|
|
- max_sent_per_chunk=best_params["max_sent_per_chunk"],
|
|
|
- enable_adaptive_boundary=True,
|
|
|
- enable_kg=(kg_spec is not None),
|
|
|
- topic_purity_floor=best_params["topic_purity_floor"],
|
|
|
- )
|
|
|
-
|
|
|
-
|
|
|
-# ---------- Example usage ----------
|
|
|
-if __name__ == "__main__":
|
|
|
- sample_text = """
|
|
|
- RAG(Retrieval-Augmented Generation)是一种增强生成的技术。
|
|
|
- 在复杂知识问答中,RAG 通过检索相关文档片段来改善答案质量。
|
|
|
- 然而,分块策略会显著影响检索召回与可引用性。
|
|
|
- 因此,我们提出一种主题感知的分块方法,结合 Transformer 边界探测与知识图谱层次分类。
|
|
|
- 然后,我们讲一个新的主题,篮球
|
|
|
- 这个也就是罚球动作。一般原地动作分为两种。
|
|
|
- 第一种原地投篮动作是先下蹲,做好投篮的发力前上举动作,然后竖直向上伸直身体,右臂顺势在身体向上的过程中竖直向上将球向上投出。这种原地投篮的好处是,发力轻松,可以借助身体向上竖直的这个力度的趋势,帮助投篮发力,会让投篮的力气减少很多。尤其是在比赛后半程体力不好的时候,依然可以做到很高的命中略。这种投篮的要领是:主动的竖直向上的意识。我们以前就经常强调竖直起跳和竖直的概念,但是,同样看起来是竖直,但是用出来的效果却很不同,这主要就是技巧的关系了。这个技巧的精髓就在于“主动意识”。在你练习这种投篮的时候,每一次,都要在下蹲以后,明确的在脑子里想着,要竖直向上发力。双腿要竖直向上用力,整个身体也是这样,而且,最为重要的是,你一定要在练习的时候每次都要主动的去想,然后刻意的去竖直向上。这样,长久下去,养成习惯,你的这种投篮才会稳定。这里我们要顺便强调之前的一篇文章,就是录像纠错法,我们这里之所以一再强调要主动意识的竖直上起,就是因为,在录像上,未必能看得出来这个问题。也就是说,你的录像虽然看起来你是竖直起跳的,但是你没有一个主动的也就是刻意的竖直起跳的意识的话,这个球也不是竖直起跳。另外,相反的,如果你在视频上看到自己不是竖直起跳,但是实际上这个球是你使用了竖直起跳的主动意识来发力的。那么,尽管看起来不是很竖直,却依然可以很稳定。也就是说,眼睛会欺骗你,一定要注重你的意识。
|
|
|
- """
|
|
|
- kg_spec = {
|
|
|
- "root": {
|
|
|
- "name": "root",
|
|
|
- "children": [
|
|
|
- {
|
|
|
- "name": "Computer Science",
|
|
|
- "children": [
|
|
|
- {"name": "NLP", "children": [{"name": "RAG", "children": []}]}
|
|
|
- ],
|
|
|
- },
|
|
|
- {"name": "Finance", "children": [{"name": "AP/AR", "children": []}]},
|
|
|
- {
|
|
|
- "name": "体育",
|
|
|
- "children": [
|
|
|
- {"name": "篮球", "children": [{"name": "投篮", "children": []}]}
|
|
|
- ],
|
|
|
- },
|
|
|
- ],
|
|
|
- }
|
|
|
- }
|
|
|
- cfg = auto_tune_params([sample_text], kg_spec=kg_spec, n_trials=10, seed=42)
|
|
|
- chunker = TopicAwareChunker(cfg, kg_spec=kg_spec)
|
|
|
- chunks = chunker.chunk(sample_text)
|
|
|
- for i, ch in enumerate(chunks, 1):
|
|
|
- print(f"\n== Chunk {i} ==")
|
|
|
- print("Tokens:", ch.tokens)
|
|
|
- print("Topics:", " / ".join(ch.topics), "Purity:", round(ch.topic_purity, 3))
|
|
|
- print(ch.text)
|