|
|
@@ -8,7 +8,6 @@ from random import Random
|
|
|
from typing import Optional, Union
|
|
|
|
|
|
import numpy as np
|
|
|
-import orjson
|
|
|
import pyarrow.parquet as pq
|
|
|
import torch
|
|
|
import torch.nn.functional as F
|
|
|
@@ -16,10 +15,13 @@ from datasets.download.streaming_download_manager import xopen
|
|
|
from huggingface_hub import HfApi
|
|
|
from lightning import LightningDataModule
|
|
|
from torch.distributed import get_rank, get_world_size, is_initialized
|
|
|
-from torch.utils.data import DataLoader, Dataset, IterableDataset, get_worker_info
|
|
|
+from torch.utils.data import DataLoader, IterableDataset, get_worker_info
|
|
|
from transformers import AutoTokenizer
|
|
|
|
|
|
-from fish_speech.text import clean_text, g2p
|
|
|
+from fish_speech.datasets.protos.text_data_pb2 import Semantics
|
|
|
+from fish_speech.datasets.protos.text_data_stream import read_pb_stream
|
|
|
+from fish_speech.text.symbols import pad as pad_symbol
|
|
|
+from fish_speech.text.symbols import pu_symbols
|
|
|
from fish_speech.utils import RankedLogger
|
|
|
from fish_speech.utils.braceexpand import braceexpand
|
|
|
|
|
|
@@ -132,13 +134,6 @@ class StreamTextDataset(IterableDataset):
|
|
|
yield from texts
|
|
|
|
|
|
|
|
|
-# @dataclass
|
|
|
-# class DatasetLine:
|
|
|
-# text: str
|
|
|
-# semantic: str
|
|
|
-# speaker: str
|
|
|
-
|
|
|
-
|
|
|
class AutoAugTextDataset(IterableDataset):
|
|
|
"""
|
|
|
Auto Augment Dataset by Speaker
|
|
|
@@ -150,87 +145,79 @@ class AutoAugTextDataset(IterableDataset):
|
|
|
|
|
|
def __init__(
|
|
|
self,
|
|
|
- jsonl_files: list[str],
|
|
|
+ files: list[str],
|
|
|
seed: int = 42,
|
|
|
- phones_prob: float = 0.5,
|
|
|
+ phones_prob: float = 0.3,
|
|
|
max_length: int = 1024,
|
|
|
- order: Optional[list[str]] = None,
|
|
|
tokenizer: AutoTokenizer = None,
|
|
|
+ split: Optional[str] = None,
|
|
|
):
|
|
|
super().__init__()
|
|
|
|
|
|
- self.jsonl_files = jsonl_files
|
|
|
+ self.files = files
|
|
|
self.seed = seed
|
|
|
self.phones_prob = phones_prob
|
|
|
self.max_length = max_length
|
|
|
- self.order = order
|
|
|
self.tokenizer = tokenizer
|
|
|
|
|
|
# Read all lines, and group by speaker
|
|
|
self.groups = []
|
|
|
- from tqdm import tqdm
|
|
|
-
|
|
|
- for filename in self.jsonl_files:
|
|
|
- with open(filename, "r") as f:
|
|
|
- for json_line in tqdm(f):
|
|
|
- if json_line.strip() == "":
|
|
|
- continue
|
|
|
-
|
|
|
- line = orjson.loads(json_line)
|
|
|
- # for i in line["sentences"]:
|
|
|
- # # Save memory
|
|
|
- # i["semantics"] = np.array(i["semantics"], dtype=np.uint16)
|
|
|
- self.groups.append(line)
|
|
|
+ count = 0
|
|
|
+ for filename in self.files:
|
|
|
+ with open(filename, "rb") as f:
|
|
|
+ for text_data in read_pb_stream(f):
|
|
|
+ self.groups.append(text_data)
|
|
|
+ count += 1
|
|
|
|
|
|
- import sys
|
|
|
+ if count % 10000 == 0:
|
|
|
+ log.info(f"Read {count} groups of text data")
|
|
|
|
|
|
- print(sys.getsizeof(self.groups) / 1024 / 1024)
|
|
|
# Shuffle the lines
|
|
|
- Random(seed).shuffle(self.lines)
|
|
|
+ Random(seed).shuffle(self.groups)
|
|
|
|
|
|
- def __iter__(self):
|
|
|
- lines = split_by_rank_worker(self.lines)
|
|
|
- random.shuffle(lines)
|
|
|
+ if split == "train":
|
|
|
+ self.groups = self.groups[:-500]
|
|
|
+ elif split == "val":
|
|
|
+ self.groups = self.groups[-500:]
|
|
|
|
|
|
- for line in lines:
|
|
|
- yield self.augment(line)
|
|
|
+ def __iter__(self):
|
|
|
+ groups = split_by_rank_worker(self.groups)
|
|
|
+ random.shuffle(groups)
|
|
|
|
|
|
- def tokenize_sentence(
|
|
|
- self, sentence: str, semantic: list[int], mode: str = "sample"
|
|
|
- ):
|
|
|
- sentence = clean_text(sentence)
|
|
|
+ for group in groups:
|
|
|
+ x = self.augment(group)
|
|
|
+ if x is not None:
|
|
|
+ yield x
|
|
|
|
|
|
+ def tokenize_sentence(self, sentence: str, phones: list[str], mode: str = "sample"):
|
|
|
if (
|
|
|
mode == "sample" and (random.random() < self.phones_prob)
|
|
|
) or mode == "phones":
|
|
|
- sentence = " ".join([t for _, t in g2p(sentence, order=self.order)])
|
|
|
-
|
|
|
- semantic = " ".join([f"<semantic_{i}>" for i in semantic])
|
|
|
+ sentence = " ".join(
|
|
|
+ [
|
|
|
+ (f"<p:{i}>" if i not in pu_symbols and i != pad_symbol else i)
|
|
|
+ for i in phones
|
|
|
+ ]
|
|
|
+ )
|
|
|
|
|
|
tokens = self.tokenizer.encode(
|
|
|
- f"{sentence} {semantic}", max_length=10**6, add_special_tokens=False
|
|
|
+ f"{sentence}", max_length=10**6, add_special_tokens=False
|
|
|
)
|
|
|
- return sentence, semantic, len(tokens)
|
|
|
+ return sentence, len(tokens)
|
|
|
|
|
|
- def augment(self, line):
|
|
|
- speaker = line.get("speaker", None)
|
|
|
-
|
|
|
- # 20% to pure text or pure phones
|
|
|
+ def augment(self, group):
|
|
|
+ # 50% to pure text or pure phones
|
|
|
mode = "sample"
|
|
|
- if random.random() < 0.2:
|
|
|
+ if random.random() < 0.5:
|
|
|
mode = random.choice(["text", "phones"])
|
|
|
|
|
|
- if speaker is None:
|
|
|
- a, b, _ = self.tokenize_sentence(line["text"], line["semantic"], mode=mode)
|
|
|
- return {"text": f"[INST] {a} [/INST] {b} </s>"}
|
|
|
-
|
|
|
# Random sample based on speaker using a truncated normal distribution
|
|
|
a = torch.tensor([0], dtype=torch.float32)
|
|
|
torch.nn.init.trunc_normal_(
|
|
|
a,
|
|
|
mean=self.max_length // 2,
|
|
|
std=self.max_length // 4,
|
|
|
- a=0,
|
|
|
+ a=10,
|
|
|
b=self.max_length,
|
|
|
)
|
|
|
remaining_tokens = a.long().item() - 4
|
|
|
@@ -238,85 +225,97 @@ class AutoAugTextDataset(IterableDataset):
|
|
|
final_text, final_semantic = [], []
|
|
|
|
|
|
# Shuffle unique lines
|
|
|
- idxs = list(range(len(self.speakers[speaker])))
|
|
|
+ idxs = list(range(len(group.sentences)))
|
|
|
random.shuffle(idxs)
|
|
|
|
|
|
+ if len(idxs) == 0:
|
|
|
+ # Invalid group
|
|
|
+ return None
|
|
|
+
|
|
|
while remaining_tokens > 0 and len(idxs) > 0:
|
|
|
- line = self.speakers[speaker][idxs.pop()]
|
|
|
- text, semantic, length = self.tokenize_sentence(
|
|
|
- line["text"], line["semantic"], mode=mode
|
|
|
+ sentence = group.sentences[idxs.pop()]
|
|
|
+ text, length = self.tokenize_sentence(
|
|
|
+ sentence.text, sentence.phones, mode=mode
|
|
|
)
|
|
|
- remaining_tokens -= length
|
|
|
+ remaining_tokens -= length + len(sentence.semantics[0].values)
|
|
|
final_text.append(text)
|
|
|
- final_semantic.append(semantic)
|
|
|
+ final_semantic.append(sentence.semantics)
|
|
|
+
|
|
|
+ final_text = "[INST] " + "<pad>".join(final_text) + " [/INST]"
|
|
|
+ encoded = self.tokenizer.encode(
|
|
|
+ final_text, max_length=self.max_length, add_special_tokens=False
|
|
|
+ )
|
|
|
+ semantic_length = sum([len(i[0].values) for i in final_semantic])
|
|
|
+
|
|
|
+ # Pack the tokens and semantics (add <s> and </s> to semantic tokens)
|
|
|
+ tokens = (
|
|
|
+ [self.tokenizer.bos_token_id]
|
|
|
+ + encoded
|
|
|
+ + [self.tokenizer.pad_token_id] * semantic_length
|
|
|
+ + [self.tokenizer.eos_token_id]
|
|
|
+ )
|
|
|
+ codes = [[0] * (len(encoded) + 1) for _ in range(len(final_semantic[0]))]
|
|
|
+ for segment in final_semantic:
|
|
|
+ for book_idx, book in enumerate(segment):
|
|
|
+ for j in book.values:
|
|
|
+ codes[book_idx].append(int(j) + 2)
|
|
|
+
|
|
|
+ for book in codes:
|
|
|
+ book.append(1)
|
|
|
|
|
|
- final_text = " ".join(final_text)
|
|
|
- final_semantic = " ".join(final_semantic)
|
|
|
+ tokens = [tokens] + codes
|
|
|
+ tokens = torch.tensor(tokens, dtype=torch.long)
|
|
|
|
|
|
- return {"text": f"[INST] {final_text} [/INST] {final_semantic} </s>"}
|
|
|
+ labels = tokens.clone()
|
|
|
+ labels[1:, : len(encoded) + 1] = -100 # Mask out the <s> tokens for semantic
|
|
|
+
|
|
|
+ return {
|
|
|
+ "tokens": tokens[:, :-1],
|
|
|
+ "labels": labels[:, 1:],
|
|
|
+ }
|
|
|
|
|
|
|
|
|
@dataclass
|
|
|
class TextDataCollator:
|
|
|
tokenizer: AutoTokenizer
|
|
|
- max_length: int = 512
|
|
|
+ max_length: int = 1024
|
|
|
|
|
|
def __call__(self, examples):
|
|
|
- texts = [i["text"] for i in examples]
|
|
|
-
|
|
|
- if self.tokenizer.pad_token is None:
|
|
|
- self.tokenizer.pad_token = self.tokenizer.eos_token
|
|
|
-
|
|
|
- encoded_texts = self.tokenizer(
|
|
|
- texts,
|
|
|
- truncation=True,
|
|
|
- padding=True,
|
|
|
- max_length=self.max_length,
|
|
|
- return_tensors="pt",
|
|
|
- pad_to_multiple_of=8,
|
|
|
- )
|
|
|
-
|
|
|
- semantic = [i["semantic"] for i in examples]
|
|
|
- max_semantic_length = max([len(i[0]) for i in semantic])
|
|
|
-
|
|
|
- # Make xformers happy
|
|
|
- if (max_semantic_length - 1) % 8 != 0:
|
|
|
- max_semantic_length += 8 - (max_semantic_length - 1) % 8
|
|
|
-
|
|
|
- if max_semantic_length > self.max_length + 1:
|
|
|
- max_semantic_length = self.max_length + 1
|
|
|
-
|
|
|
- codes, codes_mask = [], []
|
|
|
-
|
|
|
- for i in semantic:
|
|
|
- t = torch.tensor(i)
|
|
|
- if t.shape[-1] >= max_semantic_length:
|
|
|
- t = t[..., :max_semantic_length]
|
|
|
-
|
|
|
- codes.append(
|
|
|
- F.pad(
|
|
|
- t,
|
|
|
- (0, max_semantic_length - t.shape[-1]),
|
|
|
- value=1,
|
|
|
+ tokens, attention_masks, labels = [], [], []
|
|
|
+ for example in examples:
|
|
|
+ _tokens = example["tokens"][:, : self.max_length]
|
|
|
+ _labels = example["labels"][:, : self.max_length]
|
|
|
+ _attention_mask = torch.ones((self.max_length,), dtype=torch.bool)
|
|
|
+ _attention_mask[: _tokens.size(1)] = False
|
|
|
+
|
|
|
+ assert _tokens.size(1) == _labels.size(
|
|
|
+ 1
|
|
|
+ ), f"{_tokens.size(1)} != {_labels.size(1)}"
|
|
|
+
|
|
|
+ if _tokens.size(1) < self.max_length:
|
|
|
+ _tokens = F.pad(
|
|
|
+ _tokens,
|
|
|
+ (0, self.max_length - _tokens.size(1)),
|
|
|
+ value=self.tokenizer.eos_token_id,
|
|
|
+ )
|
|
|
+ _labels = F.pad(
|
|
|
+ _labels, (0, self.max_length - _labels.size(1)), value=-100
|
|
|
)
|
|
|
- )
|
|
|
|
|
|
- mask = torch.zeros(max_semantic_length, dtype=torch.long)
|
|
|
- mask[t.shape[-1] :] = 1
|
|
|
- codes_mask.append(mask.bool())
|
|
|
+ tokens.append(_tokens)
|
|
|
+ attention_masks.append(_attention_mask)
|
|
|
+ labels.append(_labels)
|
|
|
|
|
|
- codes = torch.stack(codes)
|
|
|
- codes_mask = torch.stack(codes_mask)
|
|
|
+ tokens = torch.stack(tokens, dim=0)
|
|
|
+ attention_masks = torch.stack(attention_masks, dim=0)
|
|
|
+ labels = torch.stack(labels, dim=0)
|
|
|
|
|
|
- data = {
|
|
|
- "inputs": encoded_texts["input_ids"],
|
|
|
- "input_mask": encoded_texts["attention_mask"] == 0,
|
|
|
- "codes": codes,
|
|
|
- "codes_mask": codes_mask,
|
|
|
+ return {
|
|
|
+ "inputs": tokens,
|
|
|
+ "attention_masks": attention_masks,
|
|
|
+ "labels": labels,
|
|
|
}
|
|
|
|
|
|
- return data
|
|
|
-
|
|
|
|
|
|
class InterleaveDataset(IterableDataset):
|
|
|
def __init__(
|
|
|
@@ -398,22 +397,18 @@ if __name__ == "__main__":
|
|
|
# f.write(json.dumps({"text": Path(i).read_text(), "speaker": i.parent.name, "semantic": fake_tokens}) + "\n")
|
|
|
|
|
|
ds = AutoAugTextDataset(
|
|
|
- jsonl_files=["data/quantized-dataset-1205.json"],
|
|
|
- order=["en"],
|
|
|
- tokenizer=AutoTokenizer.from_pretrained(
|
|
|
- "fishaudio/speech-lm-300m", revision="text-pretrain-10k-phones"
|
|
|
- ),
|
|
|
+ files=["data/quantized-dataset-1205.protos"],
|
|
|
+ tokenizer=AutoTokenizer.from_pretrained("fishaudio/speech-lm-v1"),
|
|
|
)
|
|
|
|
|
|
dm = TextDataModule(
|
|
|
train_dataset=ds,
|
|
|
val_dataset=ds,
|
|
|
tokenizer=ds.tokenizer,
|
|
|
- batch_size=2,
|
|
|
+ batch_size=16,
|
|
|
max_length=1024,
|
|
|
num_workers=0,
|
|
|
)
|
|
|
|
|
|
for batch in dm.train_dataloader():
|
|
|
print(batch)
|
|
|
- break
|