Browse Source

Optimize webui & better concat repeat dataloader

Lengyue 1 năm trước cách đây
mục cha
commit
3ddf3d975c
2 tập tin đã thay đổi với 37 bổ sung2 xóa
  1. 37 1
      fish_speech/datasets/concat_repeat.py
  2. 0 1
      tools/webui.py

+ 37 - 1
fish_speech/datasets/concat_repeat.py

@@ -1,4 +1,5 @@
 import bisect
+import random
 from typing import Iterable
 
 from torch.utils.data import Dataset, IterableDataset
@@ -32,7 +33,7 @@ class ConcatRepeatDataset(Dataset):
         for d in self.datasets:
             assert not isinstance(
                 d, IterableDataset
-            ), "ConcatDataset does not support IterableDataset"
+            ), "ConcatRepeatDataset does not support IterableDataset"
 
         self.cumulative_sizes = self.cumsum(self.datasets, self.repeats)
 
@@ -50,3 +51,38 @@ class ConcatRepeatDataset(Dataset):
         dataset = self.datasets[dataset_idx]
 
         return dataset[sample_idx % len(dataset)]
+
+
+class ConcatWeightedIterableDataset(IterableDataset):
+    datasets: list[IterableDataset]
+    weights: list[float]
+
+    def __init__(self, datasets: Iterable[IterableDataset], weights: list[float]):
+        super().__init__()
+
+        total_weight = sum(weights)
+        self.weights = [w / total_weight for w in weights]
+        self.datasets = list(datasets)
+
+        assert len(self.datasets) > 0, "datasets should not be an empty iterable"
+        assert len(self.datasets) == len(
+            weights
+        ), "datasets and repeats should have the same length"
+
+        for d in self.datasets:
+            assert isinstance(
+                d, IterableDataset
+            ), "ConcatRepeatIterableDataset only supports IterableDataset"
+
+    def __iter__(self):
+        all_datasets = [iter(dataset) for dataset in self.datasets]
+        ids = list(range(len(self.datasets)))
+
+        while True:
+            chosen_dataset = random.choices(ids, self.weights)[0]
+
+            try:
+                yield next(all_datasets[chosen_dataset])
+            except StopIteration:
+                all_datasets[chosen_dataset] = iter(self.datasets[chosen_dataset])
+                yield next(all_datasets[chosen_dataset])

+ 0 - 1
tools/webui.py

@@ -44,7 +44,6 @@ HEADER_MD = f"""# Fish Speech
 
 TEXTBOX_PLACEHOLDER = i18n("Put your text here.")
 SPACE_IMPORTED = False
-cached_audio = np.zeros((1,))
 
 
 def build_html_error_message(error):