sampler.py 4.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107
  1. from typing import Iterator, List, Optional
  2. import numpy as np
  3. import torch
  4. from torch.utils.data import BatchSampler, Dataset
  5. class SingleDatasetBatchSampler(BatchSampler):
  6. """
  7. A batch sampler that samples from a single dataset per batch and handles distribution across GPUs.
  8. Args:
  9. datasets (List[Dataset]): List of datasets to sample from
  10. batch_size (int): Global batch size (will be divided across GPUs)
  11. drop_last (bool): Whether to drop the last incomplete batch
  12. generator (Optional[torch.Generator]): Random number generator
  13. """
  14. def __init__(
  15. self,
  16. datasets: List[Dataset],
  17. global_batch_size: int,
  18. drop_last: bool = True,
  19. generator: Optional[torch.Generator] = None,
  20. ):
  21. self.datasets = datasets
  22. self.global_batch_size = global_batch_size
  23. self.drop_last = drop_last
  24. self.generator = generator or torch.Generator()
  25. self.initial_seed = self.generator.initial_seed()
  26. # Calculate dataset sizes and create index mappings
  27. self.dataset_sizes = [len(dataset) for dataset in datasets]
  28. #### get start of each dataset #####
  29. self.cumsum_sizes = np.cumsum([0] + self.dataset_sizes).tolist()
  30. self.total_size = sum(self.dataset_sizes)
  31. # Create shuffled indices for each dataset
  32. self.indices_per_dataset = [
  33. torch.randperm(size, generator=self.generator).tolist() for size in self.dataset_sizes
  34. ]
  35. self.current_positions = [0] * len(datasets)
  36. self.available_datasets = list(range(len(datasets)))
  37. self.max_positions = [(size // self.global_batch_size) * self.global_batch_size for size in self.dataset_sizes]
  38. def __iter__(self) -> Iterator[List[int]]:
  39. # Reset state
  40. self.current_positions = [0] * len(self.datasets)
  41. self.available_datasets = list(range(len(self.datasets)))
  42. self.current_data_lengths = [size for size in self.dataset_sizes] # full length, never shrinks
  43. while self.available_datasets:
  44. # Build probabilities for available datasets only
  45. lengths = [self.current_data_lengths[i] for i in self.available_datasets]
  46. total_length = sum(lengths)
  47. if total_length <= 0:
  48. break # nothing left to sample
  49. probs = torch.tensor(lengths, dtype=torch.float) / total_length
  50. # Pick dataset
  51. dataset_idx_in_available = torch.multinomial(probs, num_samples=1, generator=self.generator).item()
  52. dataset_idx = self.available_datasets[dataset_idx_in_available]
  53. # Fetch batch
  54. dataset_indices = self.indices_per_dataset[dataset_idx]
  55. current_pos = self.current_positions[dataset_idx]
  56. end_pos = current_pos + self.global_batch_size
  57. if end_pos <= self.max_positions[dataset_idx]:
  58. batch_indices = [idx + self.cumsum_sizes[dataset_idx] for idx in dataset_indices[current_pos:end_pos]]
  59. self.current_positions[dataset_idx] = end_pos
  60. self.current_data_lengths[dataset_idx] = self.dataset_sizes[dataset_idx] - end_pos
  61. # Remove if exhausted
  62. if end_pos >= self.max_positions[dataset_idx]:
  63. self.available_datasets.remove(dataset_idx)
  64. yield batch_indices
  65. else:
  66. # Not enough for a full batch
  67. self.available_datasets.remove(dataset_idx)
  68. def set_epoch(self, epoch):
  69. """
  70. Sets the epoch for this sampler.
  71. Args:
  72. epoch (int): Epoch number
  73. """
  74. torch_gen = torch.Generator()
  75. # Set seed based on epoch to ensure different shuffling each epoch
  76. new_seed = self.initial_seed + epoch
  77. torch_gen.manual_seed(new_seed)
  78. self.generator.manual_seed(new_seed)
  79. # Reshuffle indices for each dataset
  80. self.indices_per_dataset = [torch.randperm(size, generator=torch_gen).tolist() for size in self.dataset_sizes]
  81. @property
  82. def batch_size(self) -> int:
  83. return self.global_batch_size
  84. def __len__(self) -> int:
  85. return sum(size // self.global_batch_size for size in self.dataset_sizes)