Selaa lähdekoodia

chunk策略优化

luojunhui 2 viikkoa sitten
vanhempi
commit
09123baab5

+ 5 - 5
applications/async_task/chunk_task.py

@@ -3,16 +3,16 @@ from typing import List
 
 from applications.api import get_basic_embedding
 from applications.utils.async_utils import run_tasks_with_asyncio_task_group
-from applications.utils.chunks import TopicAwareChunker, LLMClassifier
+from applications.utils.chunks import LLMClassifier, TopicAwarePackerV2
 from applications.utils.milvus import async_insert_chunk
 from applications.utils.mysql import ContentChunks, Contents
-from applications.config import Chunk, ChunkerConfig, DEFAULT_MODEL
+from applications.config import Chunk, DEFAULT_MODEL
 from applications.config import ELASTIC_SEARCH_INDEX
 
 
-class ChunkEmbeddingTask(TopicAwareChunker):
-    def __init__(self, cfg: ChunkerConfig, doc_id, resource):
-        super().__init__(cfg, doc_id)
+class ChunkEmbeddingTask(TopicAwarePackerV2):
+    def __init__(self, doc_id, resource):
+        super().__init__(doc_id)
         self.chunk_manager = None
         self.content_manager = None
         self.mysql_client = resource.mysql_client

+ 2 - 0
applications/config/base_chunk.py

@@ -25,6 +25,8 @@ class Chunk:
 @dataclass
 class ChunkerConfig:
     target_tokens: int = 256
+    max_tokens: int = 512
+    min_tokens: int = 64
     boundary_threshold: float = 0.8
     min_sent_per_chunk: int = 3
     max_sent_per_chunk: int = 10

+ 3 - 2
applications/utils/chunks/__init__.py

@@ -1,7 +1,8 @@
-from .topic_aware_chunking import TopicAwareChunker
+from .topic_aware_chunking import TopicAwarePackerV1, TopicAwarePackerV2
 from .llm_classifier import LLMClassifier
 
 __all__ = [
-    "TopicAwareChunker",
     "LLMClassifier",
+    "TopicAwarePackerV1",
+    "TopicAwarePackerV2",
 ]

+ 93 - 140
applications/utils/chunks/topic_aware_chunking.py

@@ -4,87 +4,13 @@
 
 from __future__ import annotations
 
-import re
-from typing import List
+from typing import List, Dict, Any
 
 import numpy as np
-from sklearn.preprocessing import minmax_scale
 
 from applications.api import get_basic_embedding
-from applications.config import DEFAULT_MODEL, Chunk, ChunkerConfig
-from applications.utils.nlp import SplitTextIntoSentences, num_tokens
-
-# from .llm_classifier import LLMClassifier
-
-
-# sentence boundary strategy
-class BoundaryDetector:
-    def __init__(self, cfg: ChunkerConfig, debug: bool = False):
-        self.cfg = cfg
-        self.debug = debug
-        # 信号增强因子
-        self.signal_boost_turn = 0.20
-        self.signal_boost_fig = 0.20
-        self.min_gap = 1
-
-    @staticmethod
-    def cosine_sim(u: np.ndarray, v: np.ndarray) -> float:
-        """计算余弦相似度"""
-        return float(np.dot(u, v) / (np.linalg.norm(u) * np.linalg.norm(v) + 1e-8))
-
-    def detect_boundaries(
-        self, sentence_list: List[str], embs: np.ndarray
-    ) -> List[int]:
-        # 1. 相邻句子相似度
-        sims = np.array(
-            [self.cosine_sim(embs[i], embs[i + 1]) for i in range(len(embs) - 1)]
-        )
-        cut_scores = 1 - sims
-
-        # 2. 归一化 cut_scores 到 [0,1]
-        cut_scores = minmax_scale(cut_scores) if len(cut_scores) > 0 else []
-
-        boundaries = []
-        last_boundary = -999
-        for index, base_score in enumerate(cut_scores):
-            sent_to_check = (
-                sentence_list[index]
-                if index < len(sentence_list)
-                else sentence_list[-1]
-            )
-            snippet = sent_to_check[-20:] if sent_to_check else ""
-
-            turn = (
-                self.signal_boost_turn
-                if re.search(
-                    r"(因此|但是|综上|然而|另一方面|In conclusion|However|Therefore)",
-                    snippet,
-                )
-                else 0.0
-            )
-            fig = (
-                self.signal_boost_fig
-                if re.search(
-                    r"(见下图|如表|表\s*\d+|图\s*\d+|Figure|Table)", sent_to_check
-                )
-                else 0.0
-            )
-
-            adj_score = base_score + turn + fig
-
-            if adj_score >= self.cfg.boundary_threshold and (
-                index - last_boundary >= self.min_gap
-            ):
-                boundaries.append(index)
-                last_boundary = index
-
-            # Debug 输出
-            if self.debug:
-                print(
-                    f"[{index}] sim={sims[index]:.3f}, cut={base_score:.3f}, adj={adj_score:.3f}, boundary={index in boundaries}"
-                )
-
-        return boundaries
+from applications.config import DEFAULT_MODEL, Chunk
+from applications.utils.nlp import SplitTextIntoSentences, num_tokens, BoundaryDetector
 
 
 class TopicAwareChunker(BoundaryDetector, SplitTextIntoSentences):
@@ -93,9 +19,8 @@ class TopicAwareChunker(BoundaryDetector, SplitTextIntoSentences):
     FINISHED_STATUS = 2
     FAILED_STATUS = 3
 
-    def __init__(self, cfg: ChunkerConfig, doc_id: str):
-        super().__init__(cfg)
-        # self.classifier = LLMClassifier()
+    def __init__(self, doc_id: str):
+        super().__init__()
         self.doc_id = doc_id
 
     @staticmethod
@@ -106,13 +31,23 @@ class TopicAwareChunker(BoundaryDetector, SplitTextIntoSentences):
             embs.append(np.array(e, dtype=np.float32))
         return np.stack(embs)
 
-    def _pack_by_boundaries(
-        self,
-        sentence_list: List[str],
-        boundaries: List[int],
-        text_type: int,
-        dataset_id: int,
-    ) -> List[Chunk]:
+    async def _raw_chunk(self, text: str) -> Dict[str, Any]:
+        sentence_list = self.jieba_sent_tokenize(text)
+        if not sentence_list:
+            return {}
+
+        sentences_embeddings = await self._encode_batch(sentence_list)
+        boundaries = self.detect_boundaries(sentence_list, sentences_embeddings)
+        return {
+            "sentences_list": sentence_list,
+            "boundaries": boundaries,
+            "embeddings": sentences_embeddings,
+        }
+
+
+class TopicAwarePackerV1(TopicAwareChunker):
+
+    def _pack_v1(self, sentence_list: List[str], boundaries: List[int], text_type: int, dataset_id: int) -> List[Chunk]:
         boundary_set = set(boundaries)
         chunks: List[Chunk] = []
         start = 0
@@ -121,16 +56,16 @@ class TopicAwareChunker(BoundaryDetector, SplitTextIntoSentences):
         while start < n:
             end = start
             sent_count = 0
-            while end < n and sent_count < self.cfg.max_sent_per_chunk:
+            while end < n and sent_count < self.max_sent_per_chunk:
                 cur_tokens = num_tokens(" ".join(sentence_list[start : end + 1]))
                 sent_count += 1
-                if cur_tokens >= self.cfg.target_tokens:
+                if cur_tokens >= self.target_tokens:
                     cut = end
                     for b in range(end, start - 1, -1):
                         if b in boundary_set:
                             cut = b
                             break
-                    if cut - start + 1 >= self.cfg.min_sent_per_chunk:
+                    if cut - start + 1 >= self.min_sent_per_chunk:
                         end = cut
                     break
                 end += 1
@@ -150,58 +85,76 @@ class TopicAwareChunker(BoundaryDetector, SplitTextIntoSentences):
             start = end + 1
         return chunks
 
-    async def _refine_chunk_by_topic(self, chunk: Chunk) -> List[Chunk]:
-        sentence_list = self.jieba_sent_tokenize(chunk.text)
-        if len(sentence_list) <= self.cfg.min_sent_per_chunk * 2:
-            return [chunk]
-
-        embs = await self._encode_batch(sentence_list)
-        orig = self.cfg.boundary_threshold
-        try:
-            self.cfg.boundary_threshold = max(0.3, orig - 0.1)
-            boundaries = self.detect_boundaries(sentence_list, embs)
-            sub_chunks = self._pack_by_boundaries(sentence_list, boundaries)
-
-            final = []
-            for ch in sub_chunks:
-                topics, purity = await self.kg.classify(ch.text, topk=self.cfg.kg_topk)
-                ch.topics, ch.topic_purity = topics, purity
-                final.append(ch)
-            return final
-        finally:
-            self.cfg.boundary_threshold = orig
+    async def chunk(self, text: str, text_type: int, dataset_id: int) -> List[Chunk]:
+        raw_info = await self._raw_chunk(text)
+        if not raw_info:
+            return []
+
+        return self._pack_v1(
+            sentence_list=raw_info["sentence_list"],
+            boundaries=raw_info["boundaries"],
+            text_type=text_type,
+            dataset_id=dataset_id,
+        )
+
+
+class TopicAwarePackerV2(TopicAwareChunker):
+
+    def _pack_v2(
+        self, sentence_list: List[str], boundaries: List[int], embeddings: np.ndarray, text_type: int, dataset_id: int
+    ) -> List[Chunk]:
+        segments = []
+        seg_embs = []
+        last_idx = 0
+        for b in boundaries + [len(sentence_list) - 1]:
+            seg = sentence_list[last_idx:b + 1]
+            seg_emb = np.mean(embeddings[last_idx:b + 1], axis=0)
+            if seg:
+                segments.append(seg)
+                seg_embs.append(seg_emb)
+            last_idx = b + 1
+
+
+        final_segments = []
+        for seg in segments:
+            tokens = num_tokens("".join(seg))
+            if tokens > self.max_tokens and len(seg) > 1:
+                mid = len(seg) // 2
+                final_segments.append(seg[:mid])
+                final_segments.append(seg[mid:])
+            else:
+                final_segments.append(seg)
+
+        chunks = []
+        for index, seg in enumerate(final_segments, 1):
+            text = "".join(seg)
+            tokens = num_tokens(text)
+            # 如果 token 过短,则暂时不用
+            status = 2 if tokens < self.max_tokens else 1
+            chunks.append(
+                Chunk(
+                    doc_id=self.doc_id,
+                    dataset_id=dataset_id,
+                    text=text,
+                    chunk_id=index,
+                    tokens=num_tokens(text),
+                    text_type=text_type,
+                    status=status
+                )
+            )
+
+        return chunks
 
     async def chunk(self, text: str, text_type: int, dataset_id: int) -> List[Chunk]:
-        sentence_list = self.jieba_sent_tokenize(text)
-        if not sentence_list:
+        raw_info = await self._raw_chunk(text)
+        if not raw_info:
             return []
 
-        sentences_embeddings = await self._encode_batch(sentence_list)
-        boundaries = self.detect_boundaries(sentence_list, sentences_embeddings)
-        raw_chunks = self._pack_by_boundaries(
-            sentence_list, boundaries, text_type, dataset_id
+        return self._pack_v2(
+            sentence_list=raw_info["sentence_list"],
+            boundaries=raw_info["boundaries"],
+            embeddings=raw_info["embeddings"],
+            text_type=text_type,
+            dataset_id=dataset_id,
         )
-        return raw_chunks
-
-
-# async def main():
-#     cfg = ChunkerConfig()
-#     sample_text = """
-#         RAG(Retrieval-Augmented Generation)是一种增强生成的技术。
-#         在复杂知识问答中,RAG 通过检索相关文档片段来改善答案质量。
-#         然而,分块策略会显著影响检索召回与可引用性。
-#         因此,我们提出一种主题感知的分块方法,结合 Transformer 边界探测与知识图谱层次分类。
-#         然后,我们讲一个新的主题,篮球
-#         这个也就是罚球动作。一般原地动作分为两种。
-#         第一种原地投篮动作是先下蹲,做好投篮的发力前上举动作,然后竖直向上伸直身体,右臂顺势在身体向上的过程中竖直向上将球向上投出。这种原地投篮的好处是,发力轻松,可以借助身体向上竖直的这个力度的趋势,帮助投篮发力,会让投篮的力气减少很多。尤其是在比赛后半程体力不好的时候,依然可以做到很高的命中略。这种投篮的要领是:主动的竖直向上的意识。我们以前就经常强调竖直起跳和竖直的概念,但是,同样看起来是竖直,但是用出来的效果却很不同,这主要就是技巧的关系了。这个技巧的精髓就在于“主动意识”。在你练习这种投篮的时候,每一次,都要在下蹲以后,明确的在脑子里想着,要竖直向上发力。双腿要竖直向上用力,整个身体也是这样,而且,最为重要的是,你一定要在练习的时候每次都要主动的去想,然后刻意的去竖直向上。这样,长久下去,养成习惯,你的这种投篮才会稳定。这里我们要顺便强调之前的一篇文章,就是录像纠错法,我们这里之所以一再强调要主动意识的竖直上起,就是因为,在录像上,未必能看得出来这个问题。也就是说,你的录像虽然看起来你是竖直起跳的,但是你没有一个主动的也就是刻意的竖直起跳的意识的话,这个球也不是竖直起跳。另外,相反的,如果你在视频上看到自己不是竖直起跳,但是实际上这个球是你使用了竖直起跳的主动意识来发力的。那么,尽管看起来不是很竖直,却依然可以很稳定。也就是说,眼睛会欺骗你,一定要注重你的意识。
-#     """
-#     chunker = TopicAwareChunker(cfg)
-#     chunks = await chunker.chunk(sample_text)
-#
-#     for c in chunks:
-#         print(f"[{c.tokens} tokens] {c.topic} purity={c.topic_purity:.2f}")
-#         print(c.text)
-#
-#
-# if __name__ == "__main__":
-#     asyncio.run(main())
+

+ 8 - 1
applications/utils/nlp/__init__.py

@@ -1,5 +1,12 @@
+from .boundary_detector import BoundaryDetector
 from .cal_tokens import num_tokens
 from .language_detect import detect_language
 from .split_text_into_sentences import SplitTextIntoSentences
 
-__all__ = ["SplitTextIntoSentences", "detect_language", "num_tokens"]
+
+__all__ = [
+    "SplitTextIntoSentences",
+    "detect_language",
+    "num_tokens",
+    "BoundaryDetector",
+]

+ 70 - 0
applications/utils/nlp/boundary_detector.py

@@ -0,0 +1,70 @@
+from __future__ import annotations
+
+import re
+from typing import List
+
+import numpy as np
+from sklearn.preprocessing import minmax_scale
+from applications.config import ChunkerConfig
+
+
+class BoundaryDetector(ChunkerConfig):
+    def __init__(self):
+        self.signal_boost_turn = 0.20
+        self.signal_boost_fig = 0.20
+        self.min_gap = 1
+
+    @staticmethod
+    def cosine_sim(u: np.ndarray, v: np.ndarray) -> float:
+        """计算余弦相似度"""
+        return float(np.dot(u, v) / (np.linalg.norm(u) * np.linalg.norm(v) + 1e-8))
+
+    def turn_signal(self, text: str) -> float:
+        pattern = r"(因此|但是|综上所述?|然而|另一方面|总之|结论是|In conclusion\b|To conclude\b|However\b|Therefore\b|Thus\b|On the other hand\b)"
+        if re.search(pattern, text, flags=re.IGNORECASE):
+            return self.signal_boost_turn
+        return 0.0
+
+    def figure_signal(self, text: str) -> float:
+        pattern = r"(见下图|如下图所示|如表所示|如下表所示|表\s*\d+[::]?|图\s*\d+[::]?|Figure\s*\d+|Table\s*\d+)"
+        if re.search(pattern, text, flags=re.IGNORECASE):
+            return self.signal_boost_fig
+        return 0.0
+
+    def detect_boundaries(
+        self, sentence_list: List[str], embs: np.ndarray, debug: bool = False
+    ) -> List[int]:
+        sims = np.array(
+            [self.cosine_sim(embs[i], embs[i + 1]) for i in range(len(embs) - 1)]
+        )
+        cut_scores = 1 - sims
+        cut_scores = minmax_scale(cut_scores) if len(cut_scores) > 0 else []
+
+        boundaries = []
+        last_boundary = -999
+        for index, base_score in enumerate(cut_scores):
+            sent_to_check = (
+                sentence_list[index]
+                if index < len(sentence_list)
+                else sentence_list[-1]
+            )
+            snippet = sent_to_check[-20:] if sent_to_check else ""
+            adj_score = (
+                base_score
+                + self.turn_signal(snippet)
+                + self.figure_signal(sent_to_check)
+            )
+
+            if adj_score >= self.boundary_threshold and (
+                index - last_boundary >= self.min_gap
+            ):
+                boundaries.append(index)
+                last_boundary = index
+
+            # Debug 输出
+            if debug:
+                print(
+                    f"[{index}] sim={sims[index]:.3f}, cut={base_score:.3f}, adj={adj_score:.3f}, boundary={index in boundaries}"
+                )
+
+        return boundaries

+ 1 - 7
routes/buleprint.py

@@ -65,13 +65,7 @@ async def chunk():
         return jsonify({"error": "error  text"})
     resource = get_resource_manager()
     doc_id = f"doc-{uuid.uuid4()}"
-    chunk_task = ChunkEmbeddingTask(
-        resource.mysql_client,
-        resource.milvus_client,
-        cfg=ChunkerConfig(),
-        doc_id=doc_id,
-        es_pool=resource.es_client,
-    )
+    chunk_task = ChunkEmbeddingTask(doc_id=doc_id, resource=resource)
     doc_id = await chunk_task.deal(body)
     return jsonify({"doc_id": doc_id})