|
|
@@ -1,26 +1,56 @@
|
|
|
+import json
|
|
|
import random
|
|
|
from dataclasses import dataclass
|
|
|
from itertools import chain
|
|
|
+from pathlib import Path
|
|
|
from random import Random
|
|
|
from typing import Optional, Union
|
|
|
|
|
|
import numpy as np
|
|
|
import pyarrow.parquet as pq
|
|
|
+import torch
|
|
|
from datasets.download.streaming_download_manager import xopen
|
|
|
from huggingface_hub import HfApi
|
|
|
from lightning import LightningDataModule
|
|
|
-from lightning.pytorch.utilities.exceptions import MisconfigurationException
|
|
|
from torch.distributed import get_rank, get_world_size, is_initialized
|
|
|
-from torch.utils.data import DataLoader, IterableDataset, get_worker_info
|
|
|
+from torch.utils.data import DataLoader, Dataset, IterableDataset, get_worker_info
|
|
|
from transformers import AutoTokenizer
|
|
|
|
|
|
+from fish_speech.text import clean_text, g2p
|
|
|
from fish_speech.utils import RankedLogger
|
|
|
from fish_speech.utils.braceexpand import braceexpand
|
|
|
|
|
|
log = RankedLogger(__name__, rank_zero_only=True)
|
|
|
|
|
|
|
|
|
-class TextDataset(IterableDataset):
|
|
|
+def split_by_rank_worker(files):
|
|
|
+ # We need to know the total number of devices
|
|
|
+ # to split the data properly
|
|
|
+
|
|
|
+ total_devices = 1
|
|
|
+ if is_initialized():
|
|
|
+ total_devices = get_world_size()
|
|
|
+
|
|
|
+ worker_info = get_worker_info()
|
|
|
+ if worker_info is not None:
|
|
|
+ total_devices *= worker_info.num_workers
|
|
|
+
|
|
|
+ if len(files) < total_devices:
|
|
|
+ # Repeat the files N times to match the number of devices
|
|
|
+ files = files * (total_devices // len(files) + 1)
|
|
|
+
|
|
|
+ # DDP
|
|
|
+ if is_initialized():
|
|
|
+ files = files[get_rank() :: get_world_size()]
|
|
|
+
|
|
|
+ # Split by worker
|
|
|
+ if worker_info is not None:
|
|
|
+ files = files[worker_info.id :: worker_info.num_workers]
|
|
|
+
|
|
|
+ return files
|
|
|
+
|
|
|
+
|
|
|
+class StreamTextDataset(IterableDataset):
|
|
|
def __init__(
|
|
|
self,
|
|
|
files: Optional[Union[list[str], str]] = None,
|
|
|
@@ -55,34 +85,8 @@ class TextDataset(IterableDataset):
|
|
|
self.files = sorted(files)
|
|
|
Random(seed).shuffle(self.files)
|
|
|
|
|
|
- def get_data_splits(self, files):
|
|
|
- # We need to know the total number of devices
|
|
|
- # to split the data properly
|
|
|
-
|
|
|
- total_devices = 1
|
|
|
- if is_initialized():
|
|
|
- total_devices = get_world_size()
|
|
|
-
|
|
|
- worker_info = get_worker_info()
|
|
|
- if worker_info is not None:
|
|
|
- total_devices *= worker_info.num_workers
|
|
|
-
|
|
|
- if len(files) < total_devices:
|
|
|
- # Repeat the files N times to match the number of devices
|
|
|
- files = files * (total_devices // len(files) + 1)
|
|
|
-
|
|
|
- # DDP
|
|
|
- if is_initialized():
|
|
|
- files = files[get_rank() :: get_world_size()]
|
|
|
-
|
|
|
- # Split by worker
|
|
|
- if worker_info is not None:
|
|
|
- files = files[worker_info.id :: worker_info.num_workers]
|
|
|
-
|
|
|
- return files
|
|
|
-
|
|
|
def __iter__(self):
|
|
|
- files = self.get_data_splits(self.files)
|
|
|
+ files = split_by_rank_worker(self.files)
|
|
|
random.shuffle(files)
|
|
|
|
|
|
for filename in files:
|
|
|
@@ -106,6 +110,127 @@ class TextDataset(IterableDataset):
|
|
|
yield from texts
|
|
|
|
|
|
|
|
|
+# @dataclass
|
|
|
+# class DatasetLine:
|
|
|
+# text: str
|
|
|
+# semantic: str
|
|
|
+# speaker: str
|
|
|
+
|
|
|
+
|
|
|
+class AutoAugTextDataset(IterableDataset):
|
|
|
+ """
|
|
|
+ Auto Augment Dataset by Speaker
|
|
|
+
|
|
|
+ 1. Random concatenate multiple sentences from the same speaker to form a longer sentence
|
|
|
+ 2. Automatically normalize the text
|
|
|
+ 3. Mix text and phones
|
|
|
+ """
|
|
|
+
|
|
|
+ def __init__(
|
|
|
+ self,
|
|
|
+ jsonl_files: list[str],
|
|
|
+ seed: int = 42,
|
|
|
+ phones_prob: float = 0.5,
|
|
|
+ max_length: int = 1024,
|
|
|
+ order: Optional[list[str]] = None,
|
|
|
+ tokenizer: AutoTokenizer = None,
|
|
|
+ ):
|
|
|
+ super().__init__()
|
|
|
+
|
|
|
+ self.jsonl_files = jsonl_files
|
|
|
+ self.seed = seed
|
|
|
+ self.phones_prob = phones_prob
|
|
|
+ self.max_length = max_length
|
|
|
+ self.order = order
|
|
|
+ self.tokenizer = tokenizer
|
|
|
+
|
|
|
+ # Read all lines, and group by speaker
|
|
|
+ self.speakers = {}
|
|
|
+ self.lines = []
|
|
|
+
|
|
|
+ for filename in self.jsonl_files:
|
|
|
+ lines = Path(filename).read_text().splitlines()
|
|
|
+ for json_line in lines:
|
|
|
+ line = json.loads(json_line)
|
|
|
+ speaker = line.get("speaker", None)
|
|
|
+
|
|
|
+ if speaker not in self.speakers:
|
|
|
+ self.speakers[speaker] = []
|
|
|
+
|
|
|
+ self.lines.append(line)
|
|
|
+ self.speakers[speaker].append(line)
|
|
|
+
|
|
|
+ # Shuffle the lines
|
|
|
+ Random(seed).shuffle(self.lines)
|
|
|
+
|
|
|
+ def __iter__(self):
|
|
|
+ lines = split_by_rank_worker(self.lines)
|
|
|
+ random.shuffle(lines)
|
|
|
+
|
|
|
+ for line in lines:
|
|
|
+ yield self.augment(line)
|
|
|
+
|
|
|
+ def tokenize_sentence(
|
|
|
+ self, sentence: str, semantic: list[int], mode: str = "sample"
|
|
|
+ ):
|
|
|
+ sentence = clean_text(sentence)
|
|
|
+
|
|
|
+ if (
|
|
|
+ mode == "sample" and (random.random() < self.phones_prob)
|
|
|
+ ) or mode == "phones":
|
|
|
+ sentence = " ".join([t for _, t in g2p(sentence, order=self.order)])
|
|
|
+
|
|
|
+ semantic = " ".join([f"<semantic_{i}>" for i in semantic])
|
|
|
+
|
|
|
+ tokens = self.tokenizer.encode(
|
|
|
+ f"{sentence} {semantic}", max_length=10**6, add_special_tokens=False
|
|
|
+ )
|
|
|
+ return sentence, semantic, len(tokens)
|
|
|
+
|
|
|
+ def augment(self, line):
|
|
|
+ speaker = line.get("speaker", None)
|
|
|
+
|
|
|
+ # 20% to pure text or pure phones
|
|
|
+ mode = "sample"
|
|
|
+ if random.random() < 0.2:
|
|
|
+ mode = random.choice(["text", "phones"])
|
|
|
+
|
|
|
+ if speaker is None:
|
|
|
+ a, b, _ = self.tokenize_sentence(line["text"], line["semantic"], mode=mode)
|
|
|
+ return {"text": f"[INST] {a} [/INST] {b} </s>"}
|
|
|
+
|
|
|
+ # Random sample based on speaker using a truncated normal distribution
|
|
|
+ a = torch.tensor([0], dtype=torch.float32)
|
|
|
+ torch.nn.init.trunc_normal_(
|
|
|
+ a,
|
|
|
+ mean=self.max_length // 2,
|
|
|
+ std=self.max_length // 4,
|
|
|
+ a=0,
|
|
|
+ b=self.max_length,
|
|
|
+ )
|
|
|
+ remaining_tokens = a.long().item() - 4
|
|
|
+
|
|
|
+ final_text, final_semantic = [], []
|
|
|
+
|
|
|
+ # Shuffle unique lines
|
|
|
+ idxs = list(range(len(self.speakers[speaker])))
|
|
|
+ random.shuffle(idxs)
|
|
|
+
|
|
|
+ while remaining_tokens > 0 and len(idxs) > 0:
|
|
|
+ line = self.speakers[speaker][idxs.pop()]
|
|
|
+ text, semantic, length = self.tokenize_sentence(
|
|
|
+ line["text"], line["semantic"], mode=mode
|
|
|
+ )
|
|
|
+ remaining_tokens -= length
|
|
|
+ final_text.append(text)
|
|
|
+ final_semantic.append(semantic)
|
|
|
+
|
|
|
+ final_text = " ".join(final_text)
|
|
|
+ final_semantic = " ".join(final_semantic)
|
|
|
+
|
|
|
+ return {"text": f"[INST] {final_text} [/INST] {final_semantic} </s>"}
|
|
|
+
|
|
|
+
|
|
|
@dataclass
|
|
|
class TextDataCollator:
|
|
|
tokenizer: AutoTokenizer
|
|
|
@@ -164,8 +289,8 @@ class InterleaveDataset(IterableDataset):
|
|
|
class TextDataModule(LightningDataModule):
|
|
|
def __init__(
|
|
|
self,
|
|
|
- train_dataset: Union[TextDataset, InterleaveDataset],
|
|
|
- val_dataset: Union[TextDataset, InterleaveDataset],
|
|
|
+ train_dataset: Union[StreamTextDataset, AutoAugTextDataset, InterleaveDataset],
|
|
|
+ val_dataset: Union[StreamTextDataset, AutoAugTextDataset, InterleaveDataset],
|
|
|
batch_size: int = 32,
|
|
|
tokenizer: AutoTokenizer = None,
|
|
|
max_length: int = 1024,
|
|
|
@@ -198,26 +323,33 @@ class TextDataModule(LightningDataModule):
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
- dm = TextDataModule(
|
|
|
- InterleaveDataset(
|
|
|
- datasets=[
|
|
|
- TextDataset(
|
|
|
- prefix="en/en_part_",
|
|
|
- ),
|
|
|
- TextDataset(
|
|
|
- prefix="zh/zh_part_",
|
|
|
- ),
|
|
|
- TextDataset(
|
|
|
- prefix="ja/ja_part_",
|
|
|
- ),
|
|
|
- ],
|
|
|
- probabilities=[0.8, 0.1, 0.1],
|
|
|
- ),
|
|
|
- TextDataset(
|
|
|
- files="ja/ja_part_{00000..00159}",
|
|
|
+ import json
|
|
|
+
|
|
|
+ # data/Genshin/English/Aabid/vo_KVCOP001_1907808_aabid_01.lab
|
|
|
+ # all_files = [i for i in Path("data/Genshin/English").rglob("*.lab")]
|
|
|
+ # with open("test.jsonl", "w") as f:
|
|
|
+ # for i in all_files:
|
|
|
+ # wav_file = i.with_suffix(".wav")
|
|
|
+ # duration = float(Path(wav_file).stat().st_size) / 2 / 44100
|
|
|
+ # eta_tokens = duration * 25
|
|
|
+ # fake_tokens = [random.randint(0, 2048) for _ in range(int(eta_tokens))]
|
|
|
+ # f.write(json.dumps({"text": Path(i).read_text(), "speaker": i.parent.name, "semantic": fake_tokens}) + "\n")
|
|
|
+
|
|
|
+ ds = AutoAugTextDataset(
|
|
|
+ jsonl_files=["test.jsonl"],
|
|
|
+ order=["en"],
|
|
|
+ tokenizer=AutoTokenizer.from_pretrained(
|
|
|
+ "fishaudio/speech-lm-300m", revision="text-pretrain-10k-phones"
|
|
|
),
|
|
|
+ )
|
|
|
+
|
|
|
+ dm = TextDataModule(
|
|
|
+ train_dataset=ds,
|
|
|
+ val_dataset=ds,
|
|
|
+ tokenizer=ds.tokenizer,
|
|
|
batch_size=2,
|
|
|
- tokenizer=AutoTokenizer.from_pretrained("bert-base-multilingual-cased"),
|
|
|
+ max_length=1024,
|
|
|
+ num_workers=0,
|
|
|
)
|
|
|
|
|
|
for batch in dm.train_dataloader():
|