imputation_utils.py 1.7 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152
  1. import ruptures as rpt
  2. import numpy as np
  3. import pandas as pd
  4. from sklearn.preprocessing import StandardScaler
  5. from typing import List, Tuple
  6. def find_2d_data_bkps(X: List[Tuple[int, int]]) -> List[int]:
  7. X_clean = [point if point is not None else (np.nan, np.nan) for point in X]
  8. X = np.array(X_clean, dtype=float)
  9. X = pd.DataFrame(X).interpolate("linear").bfill().ffill().to_numpy()
  10. X_std = StandardScaler().fit_transform(X)
  11. algo = rpt.KernelCPD(kernel="rbf", jump=1).fit(X_std)
  12. bkps = algo.predict(pen=10)
  13. return bkps[:-1]
  14. def get_interval_average_bbox(
  15. bboxes: List[Tuple[int, int, int, int] | None], bkps: List[int]
  16. ) -> List[Tuple[int, int, int, int]]:
  17. average_bboxes = []
  18. for left, right in zip(bkps[:-1], bkps[1:]):
  19. bboxes_interval = bboxes[left:right]
  20. valid_bboxes = [bbox for bbox in bboxes_interval if bbox is not None]
  21. if len(valid_bboxes) > 0:
  22. average_bbox = np.mean(valid_bboxes, axis=0)
  23. average_bboxes.append(tuple(map(int, average_bbox)))
  24. else:
  25. average_bboxes.append(None)
  26. return average_bboxes
  27. def find_idxs_interval(idxs: List[int], bkps: List[int]) -> List[int]:
  28. def _find_idx_interval(_idx: int) -> int:
  29. left = 0
  30. right = len(bkps) - 2
  31. while left <= right:
  32. mid = (left + right) // 2
  33. if bkps[mid] <= _idx < bkps[mid + 1]:
  34. return mid
  35. elif _idx < bkps[mid]:
  36. right = mid - 1
  37. else:
  38. left = mid + 1
  39. return min(max(left, 0), len(bkps) - 2)
  40. intervals = []
  41. for idx in idxs:
  42. interval_idx = _find_idx_interval(idx)
  43. intervals.append(interval_idx)
  44. return intervals