Quellcode durchsuchen

Implement text auto aug

Lengyue vor 2 Jahren
Ursprung
Commit
ba1e088246

+ 2 - 2
fish_speech/configs/llama_finetune.yaml

@@ -18,12 +18,12 @@ tokenizer:
 
 # Dataset Configuration
 train_dataset:
-  - _target_: fish_speech.datasets.text.TextDataset
+  - _target_: fish_speech.datasets.text.StreamTextDataset
     repo: fishaudio/cn-hubert-25hz-vq
     prefix: 'data/train'
 
 val_dataset:
-  _target_: fish_speech.datasets.text.TextDataset
+  _target_: fish_speech.datasets.text.StreamTextDataset
   repo: fishaudio/cn-hubert-25hz-vq
   prefix: 'data/test'
 

+ 3 - 3
fish_speech/configs/llama_pretrain.yaml

@@ -28,11 +28,11 @@ tokenizer:
 dataset:
   _target_: fish_speech.datasets.text.InterleaveDataset
   datasets:
-    - _target_: fish_speech.datasets.text.TextDataset
+    - _target_: fish_speech.datasets.text.StreamTextDataset
       prefix: 'en/'
-    - _target_: fish_speech.datasets.text.TextDataset
+    - _target_: fish_speech.datasets.text.StreamTextDataset
       prefix: 'zh/'
-    - _target_: fish_speech.datasets.text.TextDataset
+    - _target_: fish_speech.datasets.text.StreamTextDataset
       prefix: 'ja/'
   probabilities: [0.4, 0.3, 0.3]
   seed: 42

+ 182 - 50
fish_speech/datasets/text.py

@@ -1,26 +1,56 @@
+import json
 import random
 from dataclasses import dataclass
 from itertools import chain
+from pathlib import Path
 from random import Random
 from typing import Optional, Union
 
 import numpy as np
 import pyarrow.parquet as pq
+import torch
 from datasets.download.streaming_download_manager import xopen
 from huggingface_hub import HfApi
 from lightning import LightningDataModule
-from lightning.pytorch.utilities.exceptions import MisconfigurationException
 from torch.distributed import get_rank, get_world_size, is_initialized
-from torch.utils.data import DataLoader, IterableDataset, get_worker_info
+from torch.utils.data import DataLoader, Dataset, IterableDataset, get_worker_info
 from transformers import AutoTokenizer
 
+from fish_speech.text import clean_text, g2p
 from fish_speech.utils import RankedLogger
 from fish_speech.utils.braceexpand import braceexpand
 
 log = RankedLogger(__name__, rank_zero_only=True)
 
 
-class TextDataset(IterableDataset):
+def split_by_rank_worker(files):
+    # We need to know the total number of devices
+    # to split the data properly
+
+    total_devices = 1
+    if is_initialized():
+        total_devices = get_world_size()
+
+    worker_info = get_worker_info()
+    if worker_info is not None:
+        total_devices *= worker_info.num_workers
+
+    if len(files) < total_devices:
+        # Repeat the files N times to match the number of devices
+        files = files * (total_devices // len(files) + 1)
+
+    # DDP
+    if is_initialized():
+        files = files[get_rank() :: get_world_size()]
+
+    # Split by worker
+    if worker_info is not None:
+        files = files[worker_info.id :: worker_info.num_workers]
+
+    return files
+
+
+class StreamTextDataset(IterableDataset):
     def __init__(
         self,
         files: Optional[Union[list[str], str]] = None,
@@ -55,34 +85,8 @@ class TextDataset(IterableDataset):
         self.files = sorted(files)
         Random(seed).shuffle(self.files)
 
-    def get_data_splits(self, files):
-        # We need to know the total number of devices
-        # to split the data properly
-
-        total_devices = 1
-        if is_initialized():
-            total_devices = get_world_size()
-
-        worker_info = get_worker_info()
-        if worker_info is not None:
-            total_devices *= worker_info.num_workers
-
-        if len(files) < total_devices:
-            # Repeat the files N times to match the number of devices
-            files = files * (total_devices // len(files) + 1)
-
-        # DDP
-        if is_initialized():
-            files = files[get_rank() :: get_world_size()]
-
-        # Split by worker
-        if worker_info is not None:
-            files = files[worker_info.id :: worker_info.num_workers]
-
-        return files
-
     def __iter__(self):
-        files = self.get_data_splits(self.files)
+        files = split_by_rank_worker(self.files)
         random.shuffle(files)
 
         for filename in files:
@@ -106,6 +110,127 @@ class TextDataset(IterableDataset):
                 yield from texts
 
 
+# @dataclass
+# class DatasetLine:
+#     text: str
+#     semantic: str
+#     speaker: str
+
+
+class AutoAugTextDataset(IterableDataset):
+    """
+    Auto Augment Dataset by Speaker
+
+    1. Random concatenate multiple sentences from the same speaker to form a longer sentence
+    2. Automatically normalize the text
+    3. Mix text and phones
+    """
+
+    def __init__(
+        self,
+        jsonl_files: list[str],
+        seed: int = 42,
+        phones_prob: float = 0.5,
+        max_length: int = 1024,
+        order: Optional[list[str]] = None,
+        tokenizer: AutoTokenizer = None,
+    ):
+        super().__init__()
+
+        self.jsonl_files = jsonl_files
+        self.seed = seed
+        self.phones_prob = phones_prob
+        self.max_length = max_length
+        self.order = order
+        self.tokenizer = tokenizer
+
+        # Read all lines, and group by speaker
+        self.speakers = {}
+        self.lines = []
+
+        for filename in self.jsonl_files:
+            lines = Path(filename).read_text().splitlines()
+            for json_line in lines:
+                line = json.loads(json_line)
+                speaker = line.get("speaker", None)
+
+                if speaker not in self.speakers:
+                    self.speakers[speaker] = []
+
+                self.lines.append(line)
+                self.speakers[speaker].append(line)
+
+        # Shuffle the lines
+        Random(seed).shuffle(self.lines)
+
+    def __iter__(self):
+        lines = split_by_rank_worker(self.lines)
+        random.shuffle(lines)
+
+        for line in lines:
+            yield self.augment(line)
+
+    def tokenize_sentence(
+        self, sentence: str, semantic: list[int], mode: str = "sample"
+    ):
+        sentence = clean_text(sentence)
+
+        if (
+            mode == "sample" and (random.random() < self.phones_prob)
+        ) or mode == "phones":
+            sentence = " ".join([t for _, t in g2p(sentence, order=self.order)])
+
+        semantic = " ".join([f"<semantic_{i}>" for i in semantic])
+
+        tokens = self.tokenizer.encode(
+            f"{sentence} {semantic}", max_length=10**6, add_special_tokens=False
+        )
+        return sentence, semantic, len(tokens)
+
+    def augment(self, line):
+        speaker = line.get("speaker", None)
+
+        # 20% to pure text or pure phones
+        mode = "sample"
+        if random.random() < 0.2:
+            mode = random.choice(["text", "phones"])
+
+        if speaker is None:
+            a, b, _ = self.tokenize_sentence(line["text"], line["semantic"], mode=mode)
+            return {"text": f"[INST] {a} [/INST] {b} </s>"}
+
+        # Random sample based on speaker using a truncated normal distribution
+        a = torch.tensor([0], dtype=torch.float32)
+        torch.nn.init.trunc_normal_(
+            a,
+            mean=self.max_length // 2,
+            std=self.max_length // 4,
+            a=0,
+            b=self.max_length,
+        )
+        remaining_tokens = a.long().item() - 4
+
+        final_text, final_semantic = [], []
+
+        # Shuffle unique lines
+        idxs = list(range(len(self.speakers[speaker])))
+        random.shuffle(idxs)
+
+        while remaining_tokens > 0 and len(idxs) > 0:
+            line = self.speakers[speaker][idxs.pop()]
+            text, semantic, length = self.tokenize_sentence(
+                line["text"], line["semantic"], mode=mode
+            )
+            remaining_tokens -= length
+            final_text.append(text)
+            final_semantic.append(semantic)
+
+        final_text = " ".join(final_text)
+        final_semantic = " ".join(final_semantic)
+
+        return {"text": f"[INST] {final_text} [/INST] {final_semantic} </s>"}
+
+
 @dataclass
 class TextDataCollator:
     tokenizer: AutoTokenizer
@@ -164,8 +289,8 @@ class InterleaveDataset(IterableDataset):
 class TextDataModule(LightningDataModule):
     def __init__(
         self,
-        train_dataset: Union[TextDataset, InterleaveDataset],
-        val_dataset: Union[TextDataset, InterleaveDataset],
+        train_dataset: Union[StreamTextDataset, AutoAugTextDataset, InterleaveDataset],
+        val_dataset: Union[StreamTextDataset, AutoAugTextDataset, InterleaveDataset],
         batch_size: int = 32,
         tokenizer: AutoTokenizer = None,
         max_length: int = 1024,
@@ -198,26 +323,33 @@ class TextDataModule(LightningDataModule):
 
 
 if __name__ == "__main__":
-    dm = TextDataModule(
-        InterleaveDataset(
-            datasets=[
-                TextDataset(
-                    prefix="en/en_part_",
-                ),
-                TextDataset(
-                    prefix="zh/zh_part_",
-                ),
-                TextDataset(
-                    prefix="ja/ja_part_",
-                ),
-            ],
-            probabilities=[0.8, 0.1, 0.1],
-        ),
-        TextDataset(
-            files="ja/ja_part_{00000..00159}",
+    import json
+
+    # data/Genshin/English/Aabid/vo_KVCOP001_1907808_aabid_01.lab
+    # all_files = [i for i in Path("data/Genshin/English").rglob("*.lab")]
+    # with open("test.jsonl", "w") as f:
+    #     for i in all_files:
+    #         wav_file = i.with_suffix(".wav")
+    #         duration = float(Path(wav_file).stat().st_size) / 2 / 44100
+    #         eta_tokens = duration * 25
+    #         fake_tokens = [random.randint(0, 2048) for _ in range(int(eta_tokens))]
+    #         f.write(json.dumps({"text": Path(i).read_text(), "speaker": i.parent.name, "semantic": fake_tokens}) + "\n")
+
+    ds = AutoAugTextDataset(
+        jsonl_files=["test.jsonl"],
+        order=["en"],
+        tokenizer=AutoTokenizer.from_pretrained(
+            "fishaudio/speech-lm-300m", revision="text-pretrain-10k-phones"
         ),
+    )
+
+    dm = TextDataModule(
+        train_dataset=ds,
+        val_dataset=ds,
+        tokenizer=ds.tokenizer,
         batch_size=2,
-        tokenizer=AutoTokenizer.from_pretrained("bert-base-multilingual-cased"),
+        max_length=1024,
+        num_workers=0,
     )
 
     for batch in dm.train_dataloader():