boundary_detector.py 2.4 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970
  1. from __future__ import annotations
  2. import re
  3. from typing import List
  4. import numpy as np
  5. from sklearn.preprocessing import minmax_scale
  6. from applications.config import ChunkerConfig
  7. class BoundaryDetector(ChunkerConfig):
  8. def __init__(self):
  9. self.signal_boost_turn = 0.20
  10. self.signal_boost_fig = 0.20
  11. self.min_gap = 1
  12. @staticmethod
  13. def cosine_sim(u: np.ndarray, v: np.ndarray) -> float:
  14. """计算余弦相似度"""
  15. return float(np.dot(u, v) / (np.linalg.norm(u) * np.linalg.norm(v) + 1e-8))
  16. def turn_signal(self, text: str) -> float:
  17. pattern = r"(因此|但是|综上所述?|然而|另一方面|总之|结论是|In conclusion\b|To conclude\b|However\b|Therefore\b|Thus\b|On the other hand\b)"
  18. if re.search(pattern, text, flags=re.IGNORECASE):
  19. return self.signal_boost_turn
  20. return 0.0
  21. def figure_signal(self, text: str) -> float:
  22. pattern = r"(见下图|如下图所示|如表所示|如下表所示|表\s*\d+[::]?|图\s*\d+[::]?|Figure\s*\d+|Table\s*\d+)"
  23. if re.search(pattern, text, flags=re.IGNORECASE):
  24. return self.signal_boost_fig
  25. return 0.0
  26. def detect_boundaries(
  27. self, sentence_list: List[str], embs: np.ndarray, debug: bool = False
  28. ) -> List[int]:
  29. sims = np.array(
  30. [self.cosine_sim(embs[i], embs[i + 1]) for i in range(len(embs) - 1)]
  31. )
  32. cut_scores = 1 - sims
  33. cut_scores = minmax_scale(cut_scores) if len(cut_scores) > 0 else []
  34. boundaries = []
  35. last_boundary = -999
  36. for index, base_score in enumerate(cut_scores):
  37. sent_to_check = (
  38. sentence_list[index]
  39. if index < len(sentence_list)
  40. else sentence_list[-1]
  41. )
  42. snippet = sent_to_check[-20:] if sent_to_check else ""
  43. adj_score = (
  44. base_score
  45. + self.turn_signal(snippet)
  46. + self.figure_signal(sent_to_check)
  47. )
  48. if adj_score >= self.boundary_threshold and (
  49. index - last_boundary >= self.min_gap
  50. ):
  51. boundaries.append(index)
  52. last_boundary = index
  53. # Debug 输出
  54. if debug:
  55. print(
  56. f"[{index}] sim={sims[index]:.3f}, cut={base_score:.3f}, adj={adj_score:.3f}, boundary={index in boundaries}"
  57. )
  58. return boundaries