Ver Fonte

Greatly improved data loading speed

Lengyue há 2 anos atrás
pai
commit
f8cfbb4ac0

+ 4 - 0
dockerfile

@@ -26,6 +26,10 @@ RUN pip3 install --upgrade pip && \
     pip3 install ninja packaging && \
     pip3 install git+https://github.com/Dao-AILab/flash-attention.git
 
+# Setup rust-data-server
+RUN curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y && \
+    cd data_server && cargo build --release
+
 # Project Env
 WORKDIR /exp
 

+ 1 - 4
fish_speech/configs/text2semantic.yaml

@@ -11,6 +11,7 @@ trainer:
   gradient_clip_algorithm: 'norm'
   max_steps: 1_000_000
   precision: bf16-true
+  limit_val_batches: 10
 
 # Dataset Configuration
 tokenizer:
@@ -20,15 +21,11 @@ tokenizer:
 # Dataset Configuration
 train_dataset:
   _target_: fish_speech.datasets.text.AutoAugTextDataset
-  files: [ data/quantized-dataset-1205.protos ]
   tokenizer: ${tokenizer}
-  split: train
 
 val_dataset:
   _target_: fish_speech.datasets.text.AutoAugTextDataset
-  files: [ data/quantized-dataset-1205.protos ]
   tokenizer: ${tokenizer}
-  split: val
 
 data:
   _target_: fish_speech.datasets.text.TextDataModule

+ 0 - 83
fish_speech/datasets/data_server.py

@@ -1,83 +0,0 @@
-import asyncio
-import random
-
-import grpc
-from loguru import logger
-
-from fish_speech.datasets.protos.text_data_pb2 import SampleDataRequest, SampledData
-from fish_speech.datasets.protos.text_data_pb2_grpc import (
-    DataServiceServicer,
-    DataServiceStub,
-    add_DataServiceServicer_to_server,
-)
-from fish_speech.datasets.protos.text_data_stream import read_pb_stream
-
-
-class DataService(DataServiceServicer):
-    def __init__(
-        self,
-        files: list[str],
-    ):
-        super().__init__()
-
-        self.files = files
-
-        # Read all lines, and group by speaker
-        self.groups = []
-        self.weights = []
-
-        count = 0
-        for filename in self.files:
-            with open(filename, "rb") as f:
-                for text_data in read_pb_stream(f):
-                    self.groups.append(text_data)
-                    self.weights.append(len(text_data.sentences))
-                    count += 1
-
-                    if count % 10000 == 0:
-                        logger.info(f"Read {count} groups of text data")
-                        break
-
-    def SampleData(self, request: SampleDataRequest, context):
-        group = random.choices(self.groups, weights=self.weights, k=1)[0]
-        k = min(request.num_samples, len(group.sentences))
-        samples = random.choices(group.sentences, k=k)
-
-        return SampledData(
-            samples=samples,
-        )
-
-
-async def run():
-    async with grpc.aio.insecure_channel("localhost:50051") as channel:
-        stub = DataServiceStub(channel)
-        import time
-
-        from tqdm import tqdm
-
-        start = time.time()
-        for _ in tqdm(range(10000)):
-            await stub.SampleData(SampleDataRequest(num_samples=50))
-
-        print(
-            f"Time taken: {time.time() - start}, {10000 / (time.time() - start)} samples/s"
-        )
-
-
-async def serve():
-    server = grpc.aio.server()
-    add_DataServiceServicer_to_server(
-        DataService(files=["data/quantized-dataset-1205.protos"]), server
-    )
-    listen_addr = "127.0.0.1:50051"
-    server.add_insecure_port(listen_addr)
-    print(f"Starting server on {listen_addr}")
-    await server.start()
-    await server.wait_for_termination()
-
-
-if __name__ == "__main__":
-    # asyncio.run(serve())
-    # Launch 14 workers
-
-    asyncio.run(run())

+ 21 - 53
fish_speech/datasets/text.py

@@ -7,6 +7,7 @@ from pathlib import Path
 from random import Random
 from typing import Optional, Union
 
+import grpc
 import numpy as np
 import pyarrow.parquet as pq
 import torch
@@ -18,8 +19,8 @@ from torch.distributed import get_rank, get_world_size, is_initialized
 from torch.utils.data import DataLoader, IterableDataset, get_worker_info
 from transformers import AutoTokenizer
 
-from fish_speech.datasets.protos.text_data_pb2 import Semantics
-from fish_speech.datasets.protos.text_data_stream import read_pb_stream
+from fish_speech.datasets.protos.text_data_pb2 import SampleDataRequest
+from fish_speech.datasets.protos.text_data_pb2_grpc import DataServiceStub
 from fish_speech.text.symbols import pad as pad_symbol
 from fish_speech.text.symbols import pu_symbols
 from fish_speech.utils import RankedLogger
@@ -145,49 +146,26 @@ class AutoAugTextDataset(IterableDataset):
 
     def __init__(
         self,
-        files: list[str],
+        server: str = "localhost:50051",
         seed: int = 42,
         phones_prob: float = 0.3,
         max_length: int = 1024,
         tokenizer: AutoTokenizer = None,
-        split: Optional[str] = None,
     ):
         super().__init__()
 
-        self.files = files
         self.seed = seed
         self.phones_prob = phones_prob
         self.max_length = max_length
         self.tokenizer = tokenizer
 
         # Read all lines, and group by speaker
-        self.groups = []
-        count = 0
-        for filename in self.files:
-            with open(filename, "rb") as f:
-                for text_data in read_pb_stream(f):
-                    self.groups.append(text_data)
-                    count += 1
-
-                    if count % 10000 == 0:
-                        log.info(f"Read {count} groups of text data")
-
-        # Shuffle the lines
-        Random(seed).shuffle(self.groups)
-
-        if split == "train":
-            self.groups = self.groups[:-500]
-        elif split == "val":
-            self.groups = self.groups[-500:]
+        self.channel = grpc.insecure_channel(server)
+        self.stub = DataServiceStub(self.channel)
 
     def __iter__(self):
-        groups = split_by_rank_worker(self.groups)
-        random.shuffle(groups)
-
-        for group in groups:
-            x = self.augment(group)
-            if x is not None:
-                yield x
+        while True:
+            yield self.augment()
 
     def tokenize_sentence(self, sentence: str, phones: list[str], mode: str = "sample"):
         if (
@@ -208,12 +186,11 @@ class AutoAugTextDataset(IterableDataset):
         )
         return sentence, len(tokens)
 
-    def augment(self, group):
+    def augment(self):
         # 50% to pure text or pure phones
-        # mode = "sample"
-        # if random.random() < 0.5:
-        #     mode = random.choice(["text", "phones"])
-        mode = "phones"
+        mode = "sample"
+        if random.random() < 0.5:
+            mode = random.choice(["text", "phones"])
 
         # Random sample based on speaker using a truncated normal distribution
         a = torch.tensor([0], dtype=torch.float32)
@@ -229,15 +206,15 @@ class AutoAugTextDataset(IterableDataset):
         final_text, final_semantic = [], []
 
         # Shuffle unique lines
-        idxs = list(range(len(group.sentences)))
-        random.shuffle(idxs)
-
-        if len(idxs) == 0:
+        request = SampleDataRequest(num_samples=50)
+        response = self.stub.SampleData(request)
+        if len(response.samples) == 0:
             # Invalid group
             return None
 
-        while remaining_tokens > 0 and len(idxs) > 0:
-            sentence = group.sentences[idxs.pop()]
+        samples = list(response.samples)
+        while remaining_tokens > 0 and len(samples) > 0:
+            sentence = samples.pop()
             text, length = self.tokenize_sentence(
                 sentence.text, sentence.phones, mode=mode
             )
@@ -393,18 +370,9 @@ class TextDataModule(LightningDataModule):
 if __name__ == "__main__":
     import json
 
-    # data/Genshin/English/Aabid/vo_KVCOP001_1907808_aabid_01.lab
-    # all_files = [i for i in Path("data/Genshin/English").rglob("*.lab")]
-    # with open("test.jsonl", "w") as f:
-    #     for i in all_files:
-    #         wav_file = i.with_suffix(".wav")
-    #         duration = float(Path(wav_file).stat().st_size) / 2 / 44100
-    #         eta_tokens = duration * 25
-    #         fake_tokens = [random.randint(0, 2048) for _ in range(int(eta_tokens))]
-    #         f.write(json.dumps({"text": Path(i).read_text(), "speaker": i.parent.name, "semantic": fake_tokens}) + "\n")
+    from tqdm import tqdm
 
     ds = AutoAugTextDataset(
-        files=["data/quantized-dataset-1205.protos"],
         tokenizer=AutoTokenizer.from_pretrained("fishaudio/speech-lm-v1"),
     )
 
@@ -417,5 +385,5 @@ if __name__ == "__main__":
         num_workers=0,
     )
 
-    for batch in dm.train_dataloader():
-        print(batch)
+    for batch in tqdm(dm.train_dataloader()):
+        pass