boundary_detector.py 6.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175
  1. from __future__ import annotations
  2. import re
  3. from typing import List
  4. import numpy as np
  5. from sklearn.preprocessing import minmax_scale
  6. from applications.config import ChunkerConfig
  7. class BoundaryDetector(ChunkerConfig):
  8. def __init__(self):
  9. self.signal_boost_turn = 0.20
  10. self.signal_boost_fig = 0.20
  11. self.min_gap = 1
  12. @staticmethod
  13. def cosine_sim(u: np.ndarray, v: np.ndarray) -> float:
  14. """计算余弦相似度"""
  15. return float(np.dot(u, v) / (np.linalg.norm(u) * np.linalg.norm(v) + 1e-8))
  16. def turn_signal(self, text: str) -> float:
  17. pattern = r"(因此|但是|综上所述?|然而|另一方面|总之|结论是|In conclusion\b|To conclude\b|However\b|Therefore\b|Thus\b|On the other hand\b)"
  18. if re.search(pattern, text, flags=re.IGNORECASE):
  19. return self.signal_boost_turn
  20. return 0.0
  21. def figure_signal(self, text: str) -> float:
  22. pattern = r"(见下图|如下图所示|如表所示|如下表所示|表\s*\d+[::]?|图\s*\d+[::]?|Figure\s*\d+|Table\s*\d+)"
  23. if re.search(pattern, text, flags=re.IGNORECASE):
  24. return self.signal_boost_fig
  25. return 0.0
  26. def detect_boundaries(
  27. self, sentence_list: List[str], embs: np.ndarray, debug: bool = False
  28. ) -> List[int]:
  29. sims = np.array(
  30. [self.cosine_sim(embs[i], embs[i + 1]) for i in range(len(embs) - 1)]
  31. )
  32. cut_scores = 1 - sims
  33. cut_scores = minmax_scale(cut_scores) if len(cut_scores) > 0 else []
  34. boundaries = []
  35. last_boundary = -999
  36. for index, base_score in enumerate(cut_scores):
  37. sent_to_check = (
  38. sentence_list[index]
  39. if index < len(sentence_list)
  40. else sentence_list[-1]
  41. )
  42. snippet = sent_to_check[-20:] if sent_to_check else ""
  43. adj_score = (
  44. base_score
  45. + self.turn_signal(snippet)
  46. + self.figure_signal(sent_to_check)
  47. )
  48. if adj_score >= self.boundary_threshold and (
  49. index - last_boundary >= self.min_gap
  50. ):
  51. boundaries.append(index)
  52. last_boundary = index
  53. # Debug 输出
  54. if debug:
  55. print(
  56. f"[{index}] sim={sims[index]:.3f}, cut={base_score:.3f}, adj={adj_score:.3f}, boundary={index in boundaries}"
  57. )
  58. return boundaries
  59. def detect_boundaries_v2(
  60. self, sentence_list: List[str], embs: np.ndarray, debug: bool = False
  61. ) -> List[int]:
  62. """
  63. 约束:相邻 boundary(含开头到第一个 boundary)之间的句子数 ∈ [3, 10]
  64. boundary 的含义:作为“段落末句”的索引(与 pack 时的 b 含义一致)
  65. """
  66. n = len(sentence_list)
  67. if n <= 1 or embs is None or len(embs) != n:
  68. return []
  69. # --- 基础打分 ---
  70. sims = np.array([self.cosine_sim(embs[i], embs[i + 1]) for i in range(n - 1)])
  71. cut_scores = 1 - sims
  72. cut_scores = minmax_scale(cut_scores) if len(cut_scores) > 0 else np.array([])
  73. # 组合信号:内容转折/图片编号等
  74. adj_scores = np.zeros_like(cut_scores)
  75. for i in range(len(cut_scores)):
  76. sent_to_check = sentence_list[i] if i < n else sentence_list[-1]
  77. snippet = sent_to_check[-20:] if sent_to_check else ""
  78. adj_scores[i] = (
  79. cut_scores[i]
  80. + self.turn_signal(snippet)
  81. + self.figure_signal(sent_to_check)
  82. )
  83. # --- 3-10 句强约束切分 ---
  84. MIN_SIZE = self.min_sent_per_chunk
  85. MAX_SIZE = self.max_sent_per_chunk
  86. thr = getattr(self, "boundary_threshold", 0.5)
  87. boundaries: List[int] = []
  88. last_boundary = -1 # 作为上一个“段末句”的索引(开头前为 -1)
  89. best_idx = None # 记录当前窗口内(已达 MIN_SIZE)的最高分切点
  90. best_score = -1e9
  91. for i in range(n - 1): # i 表示把 i 作为“段末句”的候选
  92. seg_len = i - last_boundary # 若切在 i,本段包含的句数 = i - last_boundary
  93. # 更新当前窗口最佳候选(仅在达到最低长度后才可记为候选)
  94. if seg_len >= MIN_SIZE:
  95. if adj_scores[i] > best_score:
  96. best_score = float(adj_scores[i])
  97. best_idx = i
  98. cut_now = False
  99. cut_at = None
  100. if seg_len < MIN_SIZE:
  101. # 不足 3 句,绝不切
  102. pass
  103. elif adj_scores[i] >= thr and seg_len <= MAX_SIZE:
  104. # 在 [3,10] 区间且过阈值,直接切
  105. cut_now = True
  106. cut_at = i
  107. elif seg_len == MAX_SIZE:
  108. # 已到 10 句必须切:优先用窗口内最高分位置
  109. cut_now = True
  110. cut_at = best_idx if best_idx is not None else i
  111. if cut_now:
  112. boundaries.append(cut_at)
  113. last_boundary = cut_at
  114. best_idx = None
  115. best_score = -1e9
  116. if debug:
  117. print(
  118. f"[{i}] sim={sims[i]:.3f}, cut={cut_scores[i]:.3f}, "
  119. f"adj={adj_scores[i]:.3f}, len={seg_len}, "
  120. f"cut={'Y@' + str(cut_at) if cut_now else 'N'}"
  121. )
  122. # --- 收尾:避免最后一段 < 3 句 ---
  123. # pack 时会额外补上末尾 n-1 作为最终 boundary,因此尾段长度为 (n-1 - last_boundary)
  124. tail_len = (n - 1) - last_boundary
  125. if tail_len < MIN_SIZE and boundaries:
  126. # 需要把“最后一个 boundary”往前/后微调到一个可行区间内
  127. prev_last = boundaries[-2] if len(boundaries) >= 2 else -1
  128. # 新的最后切点需满足:
  129. # 1) 前一段长度在 [3,10] => j ∈ [prev_last+3, prev_last+10]
  130. # 2) 尾段长度在 [3,10] => j ∈ [n-1-10, n-1-3]
  131. lower = max(prev_last + MIN_SIZE, (n - 1) - MAX_SIZE)
  132. upper = min(prev_last + MAX_SIZE, (n - 1) - MIN_SIZE)
  133. if lower <= upper:
  134. # 在允许区间里找 adj_scores 最高的位置
  135. window = adj_scores[lower : upper + 1]
  136. j = int(np.argmax(window)) + lower
  137. if j != boundaries[-1]:
  138. boundaries[-1] = j
  139. if debug:
  140. print(
  141. f"[fix-tail] move last boundary -> {j}, tail_len={n - 1 - j}"
  142. )
  143. else:
  144. # 没有可行区间:退化为合并尾段(删掉最后一个 boundary)
  145. dropped = boundaries.pop()
  146. if debug:
  147. print(f"[fix-tail] drop last boundary {dropped} to avoid tiny tail")
  148. return boundaries