소스 검색

Support load data without data server.

Lengyue 2 년 전
부모
커밋
bf67e01fc4
1개의 변경된 파일74개의 추가작업 그리고 8개의 파일을 삭제
  1. 74 8
      fish_speech/datasets/text.py

+ 74 - 8
fish_speech/datasets/text.py

@@ -16,8 +16,9 @@ from torch.distributed import get_rank, get_world_size, is_initialized
 from torch.utils.data import DataLoader, IterableDataset, get_worker_info
 from transformers import AutoTokenizer
 
-from fish_speech.datasets.protos.text_data_pb2 import SampleDataRequest
+from fish_speech.datasets.protos.text_data_pb2 import SampleDataRequest, SampledData
 from fish_speech.datasets.protos.text_data_pb2_grpc import DataServiceStub
+from fish_speech.datasets.protos.text_data_stream import read_pb_stream
 from fish_speech.text.parser import clean_text
 from fish_speech.text.symbols import pad as pad_symbol
 from fish_speech.text.symbols import pu_symbols
@@ -192,6 +193,9 @@ class AutoAugTextDataset(IterableDataset):
         max_length: int = 1024,
         tokenizer: AutoTokenizer = None,
         use_speaker: bool = True,
+        use_data_server: bool = True,
+        proto_files: str = "data",
+        causual: bool = True,
     ):
         """
         Args:
@@ -202,6 +206,10 @@ class AutoAugTextDataset(IterableDataset):
             interactive_prob: probability to use interactive mode
             max_length: max length of the text
             tokenizer: tokenizer
+            use_speaker: include speaker information in the prompt
+            use_data_server: use data server or local data
+            proto_files: proto buf files if using local data
+            causual: use causual sampling when using local data, disable will lead to random sampling
         """
 
         super().__init__()
@@ -213,10 +221,32 @@ class AutoAugTextDataset(IterableDataset):
         self.repetition_prob = repetition_prob
         self.interactive_prob = interactive_prob
         self.use_speaker = use_speaker
+        self.use_data_server = use_data_server
+        self.proto_files = proto_files
+        self.causual = causual
 
-        # Read all lines, and group by speaker
-        self.channel = grpc.insecure_channel(server)
-        self.stub = DataServiceStub(self.channel)
+        if use_data_server is True:
+            self.channel = grpc.insecure_channel(server)
+            self.stub = DataServiceStub(self.channel)
+        else:
+            self.init_mock_data_server()
+
+    def init_mock_data_server(self):
+        self.groups = []
+        count = 0
+        for filename in self.proto_files:
+            with open(filename, "rb") as f:
+                for text_data in read_pb_stream(f):
+                    self.groups.append(text_data)
+                    count += 1
+
+                    if count % 1000 == 0:
+                        log.info(f"Read {count} groups of data")
+
+        log.info(f"Read total {count} groups of data")
+
+        # Shuffle the lines
+        Random(self.seed).shuffle(self.groups)
 
     def __iter__(self):
         while True:
@@ -243,6 +273,37 @@ class AutoAugTextDataset(IterableDataset):
         )
         return sentence, len(tokens)
 
+    def sample_data(self):
+        # Shuffle unique lines, estimate that each sample is at least 20 tokens
+        num_samples = self.max_length // 20
+
+        if self.use_data_server:
+            request = SampleDataRequest(num_samples=num_samples)
+            return self.stub.SampleData(request)
+
+        # choice group based on their number of samples
+        group = random.choices(
+            self.groups, weights=[len(i.sentences) for i in self.groups], k=1
+        )[0]
+
+        if self.causual:
+            # Sample in order
+            if num_samples >= len(group.sentences):
+                samples = group.sentences
+            else:
+                begin = random.randint(0, len(group.sentences) - num_samples)
+                samples = group.sentences[begin : begin + num_samples]
+        else:
+            samples = random.choices(
+                group.sentences, k=min(num_samples, len(group.sentences))
+            )
+
+        return SampledData(
+            source=group.source,
+            name=group.name,
+            samples=samples,
+        )
+
     def augment(self):
         # 50% to pure text or pure phones
         mode = "sample"
@@ -261,10 +322,7 @@ class AutoAugTextDataset(IterableDataset):
         remaining_tokens = a.long().item() - 4
 
         final_text, final_semantic = [], []
-
-        # Shuffle unique lines, estimate that each sample is at least 20 tokens
-        request = SampleDataRequest(num_samples=self.max_length // 20)
-        response = self.stub.SampleData(request)
+        response = self.sample_data()
         if len(response.samples) == 0:
             # Invalid group
             return None
@@ -500,10 +558,18 @@ class TextDataModule(LightningDataModule):
 if __name__ == "__main__":
     from tqdm import tqdm
 
+    # ds = AutoAugTextDataset(
+    #     tokenizer=AutoTokenizer.from_pretrained("fishaudio/speech-lm-v1"),
+    #     use_speaker=True,
+    #     interactive_prob=1.0,
+    # )
+
     ds = AutoAugTextDataset(
         tokenizer=AutoTokenizer.from_pretrained("fishaudio/speech-lm-v1"),
         use_speaker=True,
         interactive_prob=1.0,
+        use_data_server=False,
+        proto_files=["data/wenet-speech.protos"],
     )
 
     dm = TextDataModule(