""" 主题感知分块 """ from __future__ import annotations from typing import List, Dict, Any import numpy as np from applications.api import get_basic_embedding from applications.config import DEFAULT_MODEL, Chunk from applications.utils.nlp import SplitTextIntoSentences, num_tokens, BoundaryDetector class TopicAwareChunker(BoundaryDetector, SplitTextIntoSentences): INIT_STATUS = 0 PROCESSING_STATUS = 1 FINISHED_STATUS = 2 FAILED_STATUS = 3 def __init__(self, doc_id: str): super().__init__() self.doc_id = doc_id @staticmethod async def _encode_batch(texts: List[str]) -> np.ndarray: embs = [] for t in texts: e = await get_basic_embedding(t, model=DEFAULT_MODEL) embs.append(np.array(e, dtype=np.float32)) return np.stack(embs) async def _raw_chunk(self, text: str) -> Dict[str, Any]: sentence_list = self.jieba_sent_tokenize(text) if not sentence_list: return {} sentences_embeddings = await self._encode_batch(sentence_list) boundaries = self.detect_boundaries(sentence_list, sentences_embeddings) return { "sentence_list": sentence_list, "boundaries": boundaries, "embeddings": sentences_embeddings, } class TopicAwarePackerV1(TopicAwareChunker): def _pack_v1(self, sentence_list: List[str], boundaries: List[int], text_type: int, dataset_id: int) -> List[Chunk]: boundary_set = set(boundaries) chunks: List[Chunk] = [] start = 0 n = len(sentence_list) chunk_id = 0 while start < n: end = start sent_count = 0 while end < n and sent_count < self.max_sent_per_chunk: cur_tokens = num_tokens(" ".join(sentence_list[start : end + 1])) sent_count += 1 if cur_tokens >= self.target_tokens: cut = end for b in range(end, start - 1, -1): if b in boundary_set: cut = b break if cut - start + 1 >= self.min_sent_per_chunk: end = cut break end += 1 text = " ".join(sentence_list[start : end + 1]).strip() tokens = num_tokens(text) chunk_id += 1 chunk = Chunk( doc_id=self.doc_id, chunk_id=chunk_id, text=text, tokens=tokens, text_type=text_type, dataset_id=dataset_id, ) chunks.append(chunk) start = end + 1 return chunks async def chunk(self, text: str, text_type: int, dataset_id: int) -> List[Chunk]: raw_info = await self._raw_chunk(text) if not raw_info: return [] return self._pack_v1( sentence_list=raw_info["sentence_list"], boundaries=raw_info["boundaries"], text_type=text_type, dataset_id=dataset_id, ) class TopicAwarePackerV2(TopicAwareChunker): def _pack_v2( self, sentence_list: List[str], boundaries: List[int], embeddings: np.ndarray, text_type: int, dataset_id: int ) -> List[Chunk]: segments = [] seg_embs = [] last_idx = 0 for b in boundaries + [len(sentence_list) - 1]: seg = sentence_list[last_idx:b + 1] seg_emb = np.mean(embeddings[last_idx:b + 1], axis=0) if seg: segments.append(seg) seg_embs.append(seg_emb) last_idx = b + 1 final_segments = [] for seg in segments: tokens = num_tokens("".join(seg)) if tokens > self.max_tokens and len(seg) > 1: mid = len(seg) // 2 final_segments.append(seg[:mid]) final_segments.append(seg[mid:]) else: final_segments.append(seg) chunks = [] for index, seg in enumerate(final_segments, 1): text = "".join(seg) tokens = num_tokens(text) # 如果 token 过短,则暂时不用 status = 2 if tokens < self.min_tokens else 1 chunks.append( Chunk( doc_id=self.doc_id, dataset_id=dataset_id, text=text, chunk_id=index, tokens=num_tokens(text), text_type=text_type, status=status ) ) return chunks async def chunk(self, text: str, text_type: int, dataset_id: int) -> List[Chunk]: raw_info = await self._raw_chunk(text) if not raw_info: return [] return self._pack_v2( sentence_list=raw_info["sentence_list"], boundaries=raw_info["boundaries"], embeddings=raw_info["embeddings"], text_type=text_type, dataset_id=dataset_id, )