topic_aware_chunking.py 5.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160
  1. """
  2. 主题感知分块
  3. """
  4. from __future__ import annotations
  5. from typing import List, Dict, Any
  6. import numpy as np
  7. from applications.api import get_basic_embedding
  8. from applications.config import DEFAULT_MODEL, Chunk
  9. from applications.utils.nlp import SplitTextIntoSentences, num_tokens, BoundaryDetector
  10. class TopicAwareChunker(BoundaryDetector, SplitTextIntoSentences):
  11. INIT_STATUS = 0
  12. PROCESSING_STATUS = 1
  13. FINISHED_STATUS = 2
  14. FAILED_STATUS = 3
  15. def __init__(self, doc_id: str):
  16. super().__init__()
  17. self.doc_id = doc_id
  18. @staticmethod
  19. async def _encode_batch(texts: List[str]) -> np.ndarray:
  20. embs = []
  21. for t in texts:
  22. e = await get_basic_embedding(t, model=DEFAULT_MODEL)
  23. embs.append(np.array(e, dtype=np.float32))
  24. return np.stack(embs)
  25. async def _raw_chunk(self, text: str) -> Dict[str, Any]:
  26. sentence_list = self.jieba_sent_tokenize(text)
  27. if not sentence_list:
  28. return {}
  29. sentences_embeddings = await self._encode_batch(sentence_list)
  30. boundaries = self.detect_boundaries(sentence_list, sentences_embeddings)
  31. return {
  32. "sentence_list": sentence_list,
  33. "boundaries": boundaries,
  34. "embeddings": sentences_embeddings,
  35. }
  36. class TopicAwarePackerV1(TopicAwareChunker):
  37. def _pack_v1(self, sentence_list: List[str], boundaries: List[int], text_type: int, dataset_id: int) -> List[Chunk]:
  38. boundary_set = set(boundaries)
  39. chunks: List[Chunk] = []
  40. start = 0
  41. n = len(sentence_list)
  42. chunk_id = 0
  43. while start < n:
  44. end = start
  45. sent_count = 0
  46. while end < n and sent_count < self.max_sent_per_chunk:
  47. cur_tokens = num_tokens(" ".join(sentence_list[start : end + 1]))
  48. sent_count += 1
  49. if cur_tokens >= self.target_tokens:
  50. cut = end
  51. for b in range(end, start - 1, -1):
  52. if b in boundary_set:
  53. cut = b
  54. break
  55. if cut - start + 1 >= self.min_sent_per_chunk:
  56. end = cut
  57. break
  58. end += 1
  59. text = " ".join(sentence_list[start : end + 1]).strip()
  60. tokens = num_tokens(text)
  61. chunk_id += 1
  62. chunk = Chunk(
  63. doc_id=self.doc_id,
  64. chunk_id=chunk_id,
  65. text=text,
  66. tokens=tokens,
  67. text_type=text_type,
  68. dataset_id=dataset_id,
  69. )
  70. chunks.append(chunk)
  71. start = end + 1
  72. return chunks
  73. async def chunk(self, text: str, text_type: int, dataset_id: int) -> List[Chunk]:
  74. raw_info = await self._raw_chunk(text)
  75. if not raw_info:
  76. return []
  77. return self._pack_v1(
  78. sentence_list=raw_info["sentence_list"],
  79. boundaries=raw_info["boundaries"],
  80. text_type=text_type,
  81. dataset_id=dataset_id,
  82. )
  83. class TopicAwarePackerV2(TopicAwareChunker):
  84. def _pack_v2(
  85. self, sentence_list: List[str], boundaries: List[int], embeddings: np.ndarray, text_type: int, dataset_id: int
  86. ) -> List[Chunk]:
  87. segments = []
  88. seg_embs = []
  89. last_idx = 0
  90. for b in boundaries + [len(sentence_list) - 1]:
  91. seg = sentence_list[last_idx:b + 1]
  92. seg_emb = np.mean(embeddings[last_idx:b + 1], axis=0)
  93. if seg:
  94. segments.append(seg)
  95. seg_embs.append(seg_emb)
  96. last_idx = b + 1
  97. final_segments = []
  98. for seg in segments:
  99. tokens = num_tokens("".join(seg))
  100. if tokens > self.max_tokens and len(seg) > 1:
  101. mid = len(seg) // 2
  102. final_segments.append(seg[:mid])
  103. final_segments.append(seg[mid:])
  104. else:
  105. final_segments.append(seg)
  106. chunks = []
  107. for index, seg in enumerate(final_segments, 1):
  108. text = "".join(seg)
  109. tokens = num_tokens(text)
  110. # 如果 token 过短,则暂时不用
  111. status = 2 if tokens < self.min_tokens else 1
  112. chunks.append(
  113. Chunk(
  114. doc_id=self.doc_id,
  115. dataset_id=dataset_id,
  116. text=text,
  117. chunk_id=index,
  118. tokens=num_tokens(text),
  119. text_type=text_type,
  120. status=status
  121. )
  122. )
  123. return chunks
  124. async def chunk(self, text: str, text_type: int, dataset_id: int) -> List[Chunk]:
  125. raw_info = await self._raw_chunk(text)
  126. if not raw_info:
  127. return []
  128. return self._pack_v2(
  129. sentence_list=raw_info["sentence_list"],
  130. boundaries=raw_info["boundaries"],
  131. embeddings=raw_info["embeddings"],
  132. text_type=text_type,
  133. dataset_id=dataset_id,
  134. )