Browse Source

Implement text auto aug

Lengyue 2 years ago
parent
commit
ba1e088246

+ 2 - 2
fish_speech/configs/llama_finetune.yaml

@@ -18,12 +18,12 @@ tokenizer:
 
 
 # Dataset Configuration
 # Dataset Configuration
 train_dataset:
 train_dataset:
-  - _target_: fish_speech.datasets.text.TextDataset
+  - _target_: fish_speech.datasets.text.StreamTextDataset
     repo: fishaudio/cn-hubert-25hz-vq
     repo: fishaudio/cn-hubert-25hz-vq
     prefix: 'data/train'
     prefix: 'data/train'
 
 
 val_dataset:
 val_dataset:
-  _target_: fish_speech.datasets.text.TextDataset
+  _target_: fish_speech.datasets.text.StreamTextDataset
   repo: fishaudio/cn-hubert-25hz-vq
   repo: fishaudio/cn-hubert-25hz-vq
   prefix: 'data/test'
   prefix: 'data/test'
 
 

+ 3 - 3
fish_speech/configs/llama_pretrain.yaml

@@ -28,11 +28,11 @@ tokenizer:
 dataset:
 dataset:
   _target_: fish_speech.datasets.text.InterleaveDataset
   _target_: fish_speech.datasets.text.InterleaveDataset
   datasets:
   datasets:
-    - _target_: fish_speech.datasets.text.TextDataset
+    - _target_: fish_speech.datasets.text.StreamTextDataset
       prefix: 'en/'
       prefix: 'en/'
-    - _target_: fish_speech.datasets.text.TextDataset
+    - _target_: fish_speech.datasets.text.StreamTextDataset
       prefix: 'zh/'
       prefix: 'zh/'
-    - _target_: fish_speech.datasets.text.TextDataset
+    - _target_: fish_speech.datasets.text.StreamTextDataset
       prefix: 'ja/'
       prefix: 'ja/'
   probabilities: [0.4, 0.3, 0.3]
   probabilities: [0.4, 0.3, 0.3]
   seed: 42
   seed: 42

+ 182 - 50
fish_speech/datasets/text.py

@@ -1,26 +1,56 @@
+import json
 import random
 import random
 from dataclasses import dataclass
 from dataclasses import dataclass
 from itertools import chain
 from itertools import chain
+from pathlib import Path
 from random import Random
 from random import Random
 from typing import Optional, Union
 from typing import Optional, Union
 
 
 import numpy as np
 import numpy as np
 import pyarrow.parquet as pq
 import pyarrow.parquet as pq
+import torch
 from datasets.download.streaming_download_manager import xopen
 from datasets.download.streaming_download_manager import xopen
 from huggingface_hub import HfApi
 from huggingface_hub import HfApi
 from lightning import LightningDataModule
 from lightning import LightningDataModule
-from lightning.pytorch.utilities.exceptions import MisconfigurationException
 from torch.distributed import get_rank, get_world_size, is_initialized
 from torch.distributed import get_rank, get_world_size, is_initialized
-from torch.utils.data import DataLoader, IterableDataset, get_worker_info
+from torch.utils.data import DataLoader, Dataset, IterableDataset, get_worker_info
 from transformers import AutoTokenizer
 from transformers import AutoTokenizer
 
 
+from fish_speech.text import clean_text, g2p
 from fish_speech.utils import RankedLogger
 from fish_speech.utils import RankedLogger
 from fish_speech.utils.braceexpand import braceexpand
 from fish_speech.utils.braceexpand import braceexpand
 
 
 log = RankedLogger(__name__, rank_zero_only=True)
 log = RankedLogger(__name__, rank_zero_only=True)
 
 
 
 
-class TextDataset(IterableDataset):
+def split_by_rank_worker(files):
+    # We need to know the total number of devices
+    # to split the data properly
+
+    total_devices = 1
+    if is_initialized():
+        total_devices = get_world_size()
+
+    worker_info = get_worker_info()
+    if worker_info is not None:
+        total_devices *= worker_info.num_workers
+
+    if len(files) < total_devices:
+        # Repeat the files N times to match the number of devices
+        files = files * (total_devices // len(files) + 1)
+
+    # DDP
+    if is_initialized():
+        files = files[get_rank() :: get_world_size()]
+
+    # Split by worker
+    if worker_info is not None:
+        files = files[worker_info.id :: worker_info.num_workers]
+
+    return files
+
+
+class StreamTextDataset(IterableDataset):
     def __init__(
     def __init__(
         self,
         self,
         files: Optional[Union[list[str], str]] = None,
         files: Optional[Union[list[str], str]] = None,
@@ -55,34 +85,8 @@ class TextDataset(IterableDataset):
         self.files = sorted(files)
         self.files = sorted(files)
         Random(seed).shuffle(self.files)
         Random(seed).shuffle(self.files)
 
 
-    def get_data_splits(self, files):
-        # We need to know the total number of devices
-        # to split the data properly
-
-        total_devices = 1
-        if is_initialized():
-            total_devices = get_world_size()
-
-        worker_info = get_worker_info()
-        if worker_info is not None:
-            total_devices *= worker_info.num_workers
-
-        if len(files) < total_devices:
-            # Repeat the files N times to match the number of devices
-            files = files * (total_devices // len(files) + 1)
-
-        # DDP
-        if is_initialized():
-            files = files[get_rank() :: get_world_size()]
-
-        # Split by worker
-        if worker_info is not None:
-            files = files[worker_info.id :: worker_info.num_workers]
-
-        return files
-
     def __iter__(self):
     def __iter__(self):
-        files = self.get_data_splits(self.files)
+        files = split_by_rank_worker(self.files)
         random.shuffle(files)
         random.shuffle(files)
 
 
         for filename in files:
         for filename in files:
@@ -106,6 +110,127 @@ class TextDataset(IterableDataset):
                 yield from texts
                 yield from texts
 
 
 
 
+# @dataclass
+# class DatasetLine:
+#     text: str
+#     semantic: str
+#     speaker: str
+
+
+class AutoAugTextDataset(IterableDataset):
+    """
+    Auto Augment Dataset by Speaker
+
+    1. Random concatenate multiple sentences from the same speaker to form a longer sentence
+    2. Automatically normalize the text
+    3. Mix text and phones
+    """
+
+    def __init__(
+        self,
+        jsonl_files: list[str],
+        seed: int = 42,
+        phones_prob: float = 0.5,
+        max_length: int = 1024,
+        order: Optional[list[str]] = None,
+        tokenizer: AutoTokenizer = None,
+    ):
+        super().__init__()
+
+        self.jsonl_files = jsonl_files
+        self.seed = seed
+        self.phones_prob = phones_prob
+        self.max_length = max_length
+        self.order = order
+        self.tokenizer = tokenizer
+
+        # Read all lines, and group by speaker
+        self.speakers = {}
+        self.lines = []
+
+        for filename in self.jsonl_files:
+            lines = Path(filename).read_text().splitlines()
+            for json_line in lines:
+                line = json.loads(json_line)
+                speaker = line.get("speaker", None)
+
+                if speaker not in self.speakers:
+                    self.speakers[speaker] = []
+
+                self.lines.append(line)
+                self.speakers[speaker].append(line)
+
+        # Shuffle the lines
+        Random(seed).shuffle(self.lines)
+
+    def __iter__(self):
+        lines = split_by_rank_worker(self.lines)
+        random.shuffle(lines)
+
+        for line in lines:
+            yield self.augment(line)
+
+    def tokenize_sentence(
+        self, sentence: str, semantic: list[int], mode: str = "sample"
+    ):
+        sentence = clean_text(sentence)
+
+        if (
+            mode == "sample" and (random.random() < self.phones_prob)
+        ) or mode == "phones":
+            sentence = " ".join([t for _, t in g2p(sentence, order=self.order)])
+
+        semantic = " ".join([f"<semantic_{i}>" for i in semantic])
+
+        tokens = self.tokenizer.encode(
+            f"{sentence} {semantic}", max_length=10**6, add_special_tokens=False
+        )
+        return sentence, semantic, len(tokens)
+
+    def augment(self, line):
+        speaker = line.get("speaker", None)
+
+        # 20% to pure text or pure phones
+        mode = "sample"
+        if random.random() < 0.2:
+            mode = random.choice(["text", "phones"])
+
+        if speaker is None:
+            a, b, _ = self.tokenize_sentence(line["text"], line["semantic"], mode=mode)
+            return {"text": f"[INST] {a} [/INST] {b} </s>"}
+
+        # Random sample based on speaker using a truncated normal distribution
+        a = torch.tensor([0], dtype=torch.float32)
+        torch.nn.init.trunc_normal_(
+            a,
+            mean=self.max_length // 2,
+            std=self.max_length // 4,
+            a=0,
+            b=self.max_length,
+        )
+        remaining_tokens = a.long().item() - 4
+
+        final_text, final_semantic = [], []
+
+        # Shuffle unique lines
+        idxs = list(range(len(self.speakers[speaker])))
+        random.shuffle(idxs)
+
+        while remaining_tokens > 0 and len(idxs) > 0:
+            line = self.speakers[speaker][idxs.pop()]
+            text, semantic, length = self.tokenize_sentence(
+                line["text"], line["semantic"], mode=mode
+            )
+            remaining_tokens -= length
+            final_text.append(text)
+            final_semantic.append(semantic)
+
+        final_text = " ".join(final_text)
+        final_semantic = " ".join(final_semantic)
+
+        return {"text": f"[INST] {final_text} [/INST] {final_semantic} </s>"}
+
+
 @dataclass
 @dataclass
 class TextDataCollator:
 class TextDataCollator:
     tokenizer: AutoTokenizer
     tokenizer: AutoTokenizer
@@ -164,8 +289,8 @@ class InterleaveDataset(IterableDataset):
 class TextDataModule(LightningDataModule):
 class TextDataModule(LightningDataModule):
     def __init__(
     def __init__(
         self,
         self,
-        train_dataset: Union[TextDataset, InterleaveDataset],
-        val_dataset: Union[TextDataset, InterleaveDataset],
+        train_dataset: Union[StreamTextDataset, AutoAugTextDataset, InterleaveDataset],
+        val_dataset: Union[StreamTextDataset, AutoAugTextDataset, InterleaveDataset],
         batch_size: int = 32,
         batch_size: int = 32,
         tokenizer: AutoTokenizer = None,
         tokenizer: AutoTokenizer = None,
         max_length: int = 1024,
         max_length: int = 1024,
@@ -198,26 +323,33 @@ class TextDataModule(LightningDataModule):
 
 
 
 
 if __name__ == "__main__":
 if __name__ == "__main__":
-    dm = TextDataModule(
-        InterleaveDataset(
-            datasets=[
-                TextDataset(
-                    prefix="en/en_part_",
-                ),
-                TextDataset(
-                    prefix="zh/zh_part_",
-                ),
-                TextDataset(
-                    prefix="ja/ja_part_",
-                ),
-            ],
-            probabilities=[0.8, 0.1, 0.1],
-        ),
-        TextDataset(
-            files="ja/ja_part_{00000..00159}",
+    import json
+
+    # data/Genshin/English/Aabid/vo_KVCOP001_1907808_aabid_01.lab
+    # all_files = [i for i in Path("data/Genshin/English").rglob("*.lab")]
+    # with open("test.jsonl", "w") as f:
+    #     for i in all_files:
+    #         wav_file = i.with_suffix(".wav")
+    #         duration = float(Path(wav_file).stat().st_size) / 2 / 44100
+    #         eta_tokens = duration * 25
+    #         fake_tokens = [random.randint(0, 2048) for _ in range(int(eta_tokens))]
+    #         f.write(json.dumps({"text": Path(i).read_text(), "speaker": i.parent.name, "semantic": fake_tokens}) + "\n")
+
+    ds = AutoAugTextDataset(
+        jsonl_files=["test.jsonl"],
+        order=["en"],
+        tokenizer=AutoTokenizer.from_pretrained(
+            "fishaudio/speech-lm-300m", revision="text-pretrain-10k-phones"
         ),
         ),
+    )
+
+    dm = TextDataModule(
+        train_dataset=ds,
+        val_dataset=ds,
+        tokenizer=ds.tokenizer,
         batch_size=2,
         batch_size=2,
-        tokenizer=AutoTokenizer.from_pretrained("bert-base-multilingual-cased"),
+        max_length=1024,
+        num_workers=0,
     )
     )
 
 
     for batch in dm.train_dataloader():
     for batch in dm.train_dataloader():