topic_aware_chunking.py 6.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200
  1. """
  2. 主题感知分块
  3. """
  4. from __future__ import annotations
  5. from typing import List, Dict, Any
  6. import numpy as np
  7. from applications.api import get_basic_embedding
  8. from applications.config import DEFAULT_MODEL, Chunk
  9. from applications.utils.nlp import SplitTextIntoSentences, num_tokens, BoundaryDetector
  10. class TopicAwareChunker(BoundaryDetector, SplitTextIntoSentences):
  11. INIT_STATUS = 0
  12. PROCESSING_STATUS = 1
  13. FINISHED_STATUS = 2
  14. FAILED_STATUS = 3
  15. def __init__(self, doc_id: str):
  16. super().__init__()
  17. self.doc_id = doc_id
  18. @staticmethod
  19. async def _encode_batch(texts: List[str]) -> np.ndarray:
  20. embs = []
  21. for t in texts:
  22. e = await get_basic_embedding(t, model=DEFAULT_MODEL)
  23. embs.append(np.array(e, dtype=np.float32))
  24. return np.stack(embs)
  25. async def _raw_chunk(self, text: str, dont_chunk: bool) -> Dict[str, Any]:
  26. # sentence_list = self.jieba_sent_tokenize(text)
  27. if dont_chunk:
  28. return {
  29. "sentence_list": [text],
  30. "boundaries": [],
  31. "embeddings": await self._encode_batch([text]),
  32. }
  33. sentence_list = self.lang_chain_tokenize(text)
  34. if not sentence_list:
  35. return {}
  36. sentences_embeddings = await self._encode_batch(sentence_list)
  37. boundaries = self.detect_boundaries(sentence_list, sentences_embeddings)
  38. return {
  39. "sentence_list": sentence_list,
  40. "boundaries": boundaries,
  41. "embeddings": sentences_embeddings,
  42. }
  43. async def _book_chunk(self, sentence_list: List[str]) -> Dict[str, Any]:
  44. sentences_embeddings = await self._encode_batch(sentence_list)
  45. boundaries = self.detect_boundaries(sentence_list, sentences_embeddings)
  46. return {
  47. "sentence_list": sentence_list,
  48. "boundaries": boundaries,
  49. "embeddings": sentences_embeddings,
  50. }
  51. class TopicAwarePackerV1(TopicAwareChunker):
  52. def _pack_v1(
  53. self,
  54. sentence_list: List[str],
  55. boundaries: List[int],
  56. text_type: int,
  57. dataset_id: int,
  58. ) -> List[Chunk]:
  59. boundary_set = set(boundaries)
  60. chunks: List[Chunk] = []
  61. start = 0
  62. n = len(sentence_list)
  63. chunk_id = 0
  64. while start < n:
  65. end = start
  66. sent_count = 0
  67. while end < n and sent_count < self.max_sent_per_chunk:
  68. cur_tokens = num_tokens(" ".join(sentence_list[start : end + 1]))
  69. sent_count += 1
  70. if cur_tokens >= self.target_tokens:
  71. cut = end
  72. for b in range(end, start - 1, -1):
  73. if b in boundary_set:
  74. cut = b
  75. break
  76. if cut - start + 1 >= self.min_sent_per_chunk:
  77. end = cut
  78. break
  79. end += 1
  80. text = " ".join(sentence_list[start : end + 1]).strip()
  81. tokens = num_tokens(text)
  82. chunk_id += 1
  83. chunk = Chunk(
  84. doc_id=self.doc_id,
  85. chunk_id=chunk_id,
  86. text=text,
  87. tokens=tokens,
  88. text_type=text_type,
  89. dataset_id=dataset_id,
  90. )
  91. chunks.append(chunk)
  92. start = end + 1
  93. return chunks
  94. async def chunk(self, text: str, text_type: int, dataset_id: int) -> List[Chunk]:
  95. raw_info = await self._raw_chunk(text)
  96. if not raw_info:
  97. return []
  98. return self._pack_v1(
  99. sentence_list=raw_info["sentence_list"],
  100. boundaries=raw_info["boundaries"],
  101. text_type=text_type,
  102. dataset_id=dataset_id,
  103. )
  104. class TopicAwarePackerV2(TopicAwareChunker):
  105. def _pack_v2(
  106. self,
  107. sentence_list: List[str],
  108. boundaries: List[int],
  109. embeddings: np.ndarray,
  110. text_type: int,
  111. dataset_id: int,
  112. ) -> List[Chunk]:
  113. segments = []
  114. seg_embs = []
  115. last_idx = 0
  116. for b in boundaries + [len(sentence_list) - 1]:
  117. seg = sentence_list[last_idx : b + 1]
  118. seg_emb = np.mean(embeddings[last_idx : b + 1], axis=0)
  119. if seg:
  120. segments.append(seg)
  121. seg_embs.append(seg_emb)
  122. last_idx = b + 1
  123. final_segments = []
  124. for seg in segments:
  125. tokens = num_tokens("".join(seg))
  126. if tokens > self.max_tokens and len(seg) > 1:
  127. mid = len(seg) // 2
  128. final_segments.append(seg[:mid])
  129. final_segments.append(seg[mid:])
  130. else:
  131. final_segments.append(seg)
  132. chunks = []
  133. for index, seg in enumerate(final_segments, 1):
  134. text = "".join(seg)
  135. tokens = num_tokens(text)
  136. # 如果 token 过短,则暂时不用
  137. status = 2 if tokens < self.min_tokens else 1
  138. chunks.append(
  139. Chunk(
  140. doc_id=self.doc_id,
  141. dataset_id=dataset_id,
  142. text=text,
  143. chunk_id=index,
  144. tokens=tokens,
  145. text_type=text_type,
  146. status=status,
  147. )
  148. )
  149. return chunks
  150. async def chunk(
  151. self, text: str, text_type: int, dataset_id: int, dont_chunk: bool
  152. ) -> List[Chunk]:
  153. raw_info = await self._raw_chunk(text, dont_chunk)
  154. if not raw_info:
  155. return []
  156. return self._pack_v2(
  157. sentence_list=raw_info["sentence_list"],
  158. boundaries=raw_info["boundaries"],
  159. embeddings=raw_info["embeddings"],
  160. text_type=text_type,
  161. dataset_id=dataset_id,
  162. )
  163. async def chunk_books(self, sentence_list: List[str], text_type: int, dataset_id: int) -> List[Chunk]:
  164. raw_info = await self._book_chunk(sentence_list=sentence_list)
  165. if not raw_info:
  166. return []
  167. return self._pack_v2(
  168. sentence_list=raw_info["sentence_list"],
  169. boundaries=raw_info["boundaries"],
  170. embeddings=raw_info["embeddings"],
  171. text_type=text_type,
  172. dataset_id=dataset_id,
  173. )