123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168 |
- """
- 主题感知分块
- """
- from __future__ import annotations
- from typing import List, Dict, Any
- import numpy as np
- from applications.api import get_basic_embedding
- from applications.config import DEFAULT_MODEL, Chunk
- from applications.utils.nlp import SplitTextIntoSentences, num_tokens, BoundaryDetector
- class TopicAwareChunker(BoundaryDetector, SplitTextIntoSentences):
- INIT_STATUS = 0
- PROCESSING_STATUS = 1
- FINISHED_STATUS = 2
- FAILED_STATUS = 3
- def __init__(self, doc_id: str):
- super().__init__()
- self.doc_id = doc_id
- @staticmethod
- async def _encode_batch(texts: List[str]) -> np.ndarray:
- embs = []
- for t in texts:
- e = await get_basic_embedding(t, model=DEFAULT_MODEL)
- embs.append(np.array(e, dtype=np.float32))
- return np.stack(embs)
- async def _raw_chunk(self, text: str) -> Dict[str, Any]:
- # sentence_list = self.jieba_sent_tokenize(text)
- sentence_list = self.lang_chain_tokenize(text)
- if not sentence_list:
- return {}
- sentences_embeddings = await self._encode_batch(sentence_list)
- boundaries = self.detect_boundaries(sentence_list, sentences_embeddings)
- return {
- "sentence_list": sentence_list,
- "boundaries": boundaries,
- "embeddings": sentences_embeddings,
- }
- class TopicAwarePackerV1(TopicAwareChunker):
- def _pack_v1(
- self,
- sentence_list: List[str],
- boundaries: List[int],
- text_type: int,
- dataset_id: int,
- ) -> List[Chunk]:
- boundary_set = set(boundaries)
- chunks: List[Chunk] = []
- start = 0
- n = len(sentence_list)
- chunk_id = 0
- while start < n:
- end = start
- sent_count = 0
- while end < n and sent_count < self.max_sent_per_chunk:
- cur_tokens = num_tokens(" ".join(sentence_list[start : end + 1]))
- sent_count += 1
- if cur_tokens >= self.target_tokens:
- cut = end
- for b in range(end, start - 1, -1):
- if b in boundary_set:
- cut = b
- break
- if cut - start + 1 >= self.min_sent_per_chunk:
- end = cut
- break
- end += 1
- text = " ".join(sentence_list[start : end + 1]).strip()
- tokens = num_tokens(text)
- chunk_id += 1
- chunk = Chunk(
- doc_id=self.doc_id,
- chunk_id=chunk_id,
- text=text,
- tokens=tokens,
- text_type=text_type,
- dataset_id=dataset_id,
- )
- chunks.append(chunk)
- start = end + 1
- return chunks
- async def chunk(self, text: str, text_type: int, dataset_id: int) -> List[Chunk]:
- raw_info = await self._raw_chunk(text)
- if not raw_info:
- return []
- return self._pack_v1(
- sentence_list=raw_info["sentence_list"],
- boundaries=raw_info["boundaries"],
- text_type=text_type,
- dataset_id=dataset_id,
- )
- class TopicAwarePackerV2(TopicAwareChunker):
- def _pack_v2(
- self,
- sentence_list: List[str],
- boundaries: List[int],
- embeddings: np.ndarray,
- text_type: int,
- dataset_id: int,
- ) -> List[Chunk]:
- segments = []
- seg_embs = []
- last_idx = 0
- for b in boundaries + [len(sentence_list) - 1]:
- seg = sentence_list[last_idx : b + 1]
- seg_emb = np.mean(embeddings[last_idx : b + 1], axis=0)
- if seg:
- segments.append(seg)
- seg_embs.append(seg_emb)
- last_idx = b + 1
- final_segments = []
- for seg in segments:
- tokens = num_tokens("".join(seg))
- if tokens > self.max_tokens and len(seg) > 1:
- mid = len(seg) // 2
- final_segments.append(seg[:mid])
- final_segments.append(seg[mid:])
- else:
- final_segments.append(seg)
- chunks = []
- for index, seg in enumerate(final_segments, 1):
- text = "".join(seg)
- tokens = num_tokens(text)
- # 如果 token 过短,则暂时不用
- status = 2 if tokens < self.min_tokens else 1
- chunks.append(
- Chunk(
- doc_id=self.doc_id,
- dataset_id=dataset_id,
- text=text,
- chunk_id=index,
- tokens=tokens,
- text_type=text_type,
- status=status,
- )
- )
- return chunks
- async def chunk(self, text: str, text_type: int, dataset_id: int) -> List[Chunk]:
- raw_info = await self._raw_chunk(text)
- if not raw_info:
- return []
- return self._pack_v2(
- sentence_list=raw_info["sentence_list"],
- boundaries=raw_info["boundaries"],
- embeddings=raw_info["embeddings"],
- text_type=text_type,
- dataset_id=dataset_id,
- )
|