text.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421
  1. import json
  2. import random
  3. import re
  4. from dataclasses import dataclass
  5. from itertools import chain
  6. from pathlib import Path
  7. from random import Random
  8. from typing import Optional, Union
  9. import numpy as np
  10. import pyarrow.parquet as pq
  11. import torch
  12. import torch.nn.functional as F
  13. from datasets.download.streaming_download_manager import xopen
  14. from huggingface_hub import HfApi
  15. from lightning import LightningDataModule
  16. from torch.distributed import get_rank, get_world_size, is_initialized
  17. from torch.utils.data import DataLoader, IterableDataset, get_worker_info
  18. from transformers import AutoTokenizer
  19. from fish_speech.datasets.protos.text_data_pb2 import Semantics
  20. from fish_speech.datasets.protos.text_data_stream import read_pb_stream
  21. from fish_speech.text.symbols import pad as pad_symbol
  22. from fish_speech.text.symbols import pu_symbols
  23. from fish_speech.utils import RankedLogger
  24. from fish_speech.utils.braceexpand import braceexpand
  25. log = RankedLogger(__name__, rank_zero_only=True)
  26. def split_by_rank_worker(files):
  27. # We need to know the total number of devices
  28. # to split the data properly
  29. total_devices = 1
  30. if is_initialized():
  31. total_devices = get_world_size()
  32. worker_info = get_worker_info()
  33. if worker_info is not None:
  34. total_devices *= worker_info.num_workers
  35. if len(files) < total_devices:
  36. # Repeat the files N times to match the number of devices
  37. files = files * (total_devices // len(files) + 1)
  38. # DDP
  39. if is_initialized():
  40. files = files[get_rank() :: get_world_size()]
  41. # Split by worker
  42. if worker_info is not None:
  43. files = files[worker_info.id :: worker_info.num_workers]
  44. return files
  45. class StreamTextDataset(IterableDataset):
  46. def __init__(
  47. self,
  48. files: Optional[Union[list[str], str]] = None,
  49. prefix: Optional[str] = None,
  50. seed: int = 42,
  51. parquet_batch_size: int = 10000,
  52. repo: str = "uonlp/CulturaX",
  53. ):
  54. super().__init__()
  55. self.seed = seed
  56. self.parquet_batch_size = parquet_batch_size
  57. self.repo = repo
  58. if files is None and prefix is None:
  59. raise ValueError("Either files or prefix must be specified")
  60. if prefix is not None:
  61. files = HfApi().list_repo_files(repo, repo_type="dataset")
  62. files = [
  63. f for f in files if f.startswith(prefix) and f.endswith(".parquet")
  64. ]
  65. log.info(f"Found {len(files)} files in {repo} with prefix {prefix}")
  66. else:
  67. if isinstance(files, str):
  68. files = [files]
  69. files = list(chain.from_iterable(map(braceexpand, files)))
  70. log.info(f"Expanded {len(files)} files in {repo}")
  71. # Get sharded files
  72. self.files = sorted(files)
  73. Random(seed).shuffle(self.files)
  74. def __iter__(self):
  75. files = split_by_rank_worker(self.files)
  76. random.shuffle(files)
  77. for filename in files:
  78. try:
  79. yield from self.parse_data(filename)
  80. except Exception as e:
  81. log.exception(f"Failed to parse {filename}: {e}")
  82. def parse_data(self, filename: str):
  83. for data in self.parse_data_internal(filename):
  84. text = data["text"]
  85. expression = re.compile(r"\[INST\] (.*) \[/INST\] (.*) </s>")
  86. match = expression.match(text)
  87. if match is None:
  88. continue
  89. text = match.group(1)
  90. semantic = match.group(2)
  91. # Convert semantic to ids
  92. expression = re.compile(r"<semantic_(\d+)>")
  93. # 0 and 1 are reserved for <s> and </s>
  94. semantic = [0] + [int(i) + 2 for i in expression.findall(semantic)] + [1]
  95. yield {"text": text, "semantic": [semantic]}
  96. def parse_data_internal(self, filename: str):
  97. url = f"https://huggingface.co/datasets/{self.repo}/resolve/main/{filename}"
  98. with xopen(url, mode="rb") as stream:
  99. parquet_file = pq.ParquetFile(stream)
  100. for batch in parquet_file.iter_batches(
  101. batch_size=self.parquet_batch_size, columns=["text"]
  102. ):
  103. # In-batch shuffling
  104. texts = [{"text": text.as_py()} for text in batch["text"]]
  105. random.shuffle(texts)
  106. yield from texts
  107. class AutoAugTextDataset(IterableDataset):
  108. """
  109. Auto Augment Dataset by Speaker
  110. 1. Random concatenate multiple sentences from the same speaker to form a longer sentence
  111. 2. Automatically normalize the text
  112. 3. Mix text and phones
  113. """
  114. def __init__(
  115. self,
  116. files: list[str],
  117. seed: int = 42,
  118. phones_prob: float = 0.3,
  119. max_length: int = 1024,
  120. tokenizer: AutoTokenizer = None,
  121. split: Optional[str] = None,
  122. ):
  123. super().__init__()
  124. self.files = files
  125. self.seed = seed
  126. self.phones_prob = phones_prob
  127. self.max_length = max_length
  128. self.tokenizer = tokenizer
  129. # Read all lines, and group by speaker
  130. self.groups = []
  131. count = 0
  132. for filename in self.files:
  133. with open(filename, "rb") as f:
  134. for text_data in read_pb_stream(f):
  135. self.groups.append(text_data)
  136. count += 1
  137. if count % 10000 == 0:
  138. log.info(f"Read {count} groups of text data")
  139. # Shuffle the lines
  140. Random(seed).shuffle(self.groups)
  141. if split == "train":
  142. self.groups = self.groups[:-500]
  143. elif split == "val":
  144. self.groups = self.groups[-500:]
  145. def __iter__(self):
  146. groups = split_by_rank_worker(self.groups)
  147. random.shuffle(groups)
  148. for group in groups:
  149. x = self.augment(group)
  150. if x is not None:
  151. yield x
  152. def tokenize_sentence(self, sentence: str, phones: list[str], mode: str = "sample"):
  153. if (
  154. mode == "sample" and (random.random() < self.phones_prob)
  155. ) or mode == "phones":
  156. sentence = " ".join(
  157. [
  158. (f"<p:{i}>" if i not in pu_symbols and i != pad_symbol else i)
  159. for i in phones
  160. ]
  161. )
  162. tokens = self.tokenizer.encode(
  163. f"{sentence}",
  164. max_length=10**6,
  165. add_special_tokens=False,
  166. truncation=False,
  167. )
  168. return sentence, len(tokens)
  169. def augment(self, group):
  170. # 50% to pure text or pure phones
  171. # mode = "sample"
  172. # if random.random() < 0.5:
  173. # mode = random.choice(["text", "phones"])
  174. mode = "phones"
  175. # Random sample based on speaker using a truncated normal distribution
  176. a = torch.tensor([0], dtype=torch.float32)
  177. torch.nn.init.trunc_normal_(
  178. a,
  179. mean=self.max_length // 2,
  180. std=self.max_length // 4,
  181. a=10,
  182. b=self.max_length,
  183. )
  184. remaining_tokens = a.long().item() - 4
  185. final_text, final_semantic = [], []
  186. # Shuffle unique lines
  187. idxs = list(range(len(group.sentences)))
  188. random.shuffle(idxs)
  189. if len(idxs) == 0:
  190. # Invalid group
  191. return None
  192. while remaining_tokens > 0 and len(idxs) > 0:
  193. sentence = group.sentences[idxs.pop()]
  194. text, length = self.tokenize_sentence(
  195. sentence.text, sentence.phones, mode=mode
  196. )
  197. remaining_tokens -= length + len(sentence.semantics[0].values)
  198. final_text.append(text)
  199. final_semantic.append(sentence.semantics)
  200. final_text = "[INST] " + "<pad>".join(final_text) + " [/INST]"
  201. encoded = self.tokenizer.encode(
  202. final_text,
  203. max_length=self.max_length,
  204. add_special_tokens=False,
  205. truncation=False,
  206. )
  207. semantic_length = sum([len(i[0].values) for i in final_semantic])
  208. # Pack the tokens and semantics (add <s> and </s> to semantic tokens)
  209. tokens = (
  210. [self.tokenizer.bos_token_id]
  211. + encoded
  212. + [self.tokenizer.pad_token_id] * semantic_length
  213. + [self.tokenizer.eos_token_id]
  214. )
  215. codes = [[0] * (len(encoded) + 1) for _ in range(len(final_semantic[0]))]
  216. for segment in final_semantic:
  217. for book_idx, book in enumerate(segment):
  218. for j in book.values:
  219. codes[book_idx].append(int(j) + 2)
  220. for book in codes:
  221. book.append(1)
  222. tokens = [tokens] + codes
  223. tokens = torch.tensor(tokens, dtype=torch.long)
  224. labels = tokens.clone()
  225. labels[1:, : len(encoded) + 1] = -100 # Mask out the <s> tokens for semantic
  226. return {
  227. "tokens": tokens[:, :-1],
  228. "labels": labels[:, 1:],
  229. }
  230. @dataclass
  231. class TextDataCollator:
  232. tokenizer: AutoTokenizer
  233. max_length: int = 1024
  234. def __call__(self, examples):
  235. tokens, attention_masks, labels = [], [], []
  236. for example in examples:
  237. _tokens = example["tokens"][:, : self.max_length]
  238. _labels = example["labels"][:, : self.max_length]
  239. _attention_mask = torch.ones((self.max_length,), dtype=torch.bool)
  240. _attention_mask[: _tokens.size(1)] = False
  241. assert _tokens.size(1) == _labels.size(
  242. 1
  243. ), f"{_tokens.size(1)} != {_labels.size(1)}"
  244. if _tokens.size(1) < self.max_length:
  245. _tokens = F.pad(
  246. _tokens,
  247. (0, self.max_length - _tokens.size(1)),
  248. value=self.tokenizer.eos_token_id,
  249. )
  250. _labels = F.pad(
  251. _labels, (0, self.max_length - _labels.size(1)), value=-100
  252. )
  253. tokens.append(_tokens)
  254. attention_masks.append(_attention_mask)
  255. labels.append(_labels)
  256. tokens = torch.stack(tokens, dim=0)
  257. attention_masks = torch.stack(attention_masks, dim=0)
  258. labels = torch.stack(labels, dim=0)
  259. return {
  260. "inputs": tokens,
  261. "attention_masks": attention_masks,
  262. "labels": labels,
  263. }
  264. class InterleaveDataset(IterableDataset):
  265. def __init__(
  266. self,
  267. datasets: list[IterableDataset],
  268. probabilities: list[float],
  269. seed: int = 42,
  270. ):
  271. super().__init__()
  272. self.datasets = datasets
  273. self.probabilities = probabilities
  274. self.seed = seed
  275. def __iter__(self):
  276. rng = np.random.default_rng(self.seed)
  277. dataset_iterators = [iter(dataset) for dataset in self.datasets]
  278. while True:
  279. # Random choice one
  280. dataset_idx = rng.choice(len(self.datasets), p=self.probabilities)
  281. dataset_iterator = dataset_iterators[dataset_idx]
  282. try:
  283. yield next(dataset_iterator)
  284. except StopIteration:
  285. # Exhausted, create a new iterator
  286. dataset_iterators[dataset_idx] = iter(self.datasets[dataset_idx])
  287. yield next(dataset_iterators[dataset_idx])
  288. class TextDataModule(LightningDataModule):
  289. def __init__(
  290. self,
  291. train_dataset: Union[StreamTextDataset, AutoAugTextDataset, InterleaveDataset],
  292. val_dataset: Union[StreamTextDataset, AutoAugTextDataset, InterleaveDataset],
  293. batch_size: int = 32,
  294. tokenizer: AutoTokenizer = None,
  295. max_length: int = 1024,
  296. num_workers: int = 4,
  297. ):
  298. super().__init__()
  299. self.train_dataset = train_dataset
  300. self.val_dataset = val_dataset
  301. self.batch_size = batch_size
  302. self.tokenizer = tokenizer
  303. self.max_length = max_length
  304. self.num_workers = num_workers
  305. def train_dataloader(self):
  306. return DataLoader(
  307. self.train_dataset,
  308. batch_size=self.batch_size,
  309. collate_fn=TextDataCollator(self.tokenizer, self.max_length),
  310. num_workers=self.num_workers,
  311. )
  312. def val_dataloader(self):
  313. return DataLoader(
  314. self.val_dataset,
  315. batch_size=self.batch_size,
  316. collate_fn=TextDataCollator(self.tokenizer, self.max_length),
  317. num_workers=self.num_workers,
  318. )
  319. if __name__ == "__main__":
  320. import json
  321. # data/Genshin/English/Aabid/vo_KVCOP001_1907808_aabid_01.lab
  322. # all_files = [i for i in Path("data/Genshin/English").rglob("*.lab")]
  323. # with open("test.jsonl", "w") as f:
  324. # for i in all_files:
  325. # wav_file = i.with_suffix(".wav")
  326. # duration = float(Path(wav_file).stat().st_size) / 2 / 44100
  327. # eta_tokens = duration * 25
  328. # fake_tokens = [random.randint(0, 2048) for _ in range(int(eta_tokens))]
  329. # f.write(json.dumps({"text": Path(i).read_text(), "speaker": i.parent.name, "semantic": fake_tokens}) + "\n")
  330. ds = AutoAugTextDataset(
  331. files=["data/quantized-dataset-1205.protos"],
  332. tokenizer=AutoTokenizer.from_pretrained("fishaudio/speech-lm-v1"),
  333. )
  334. dm = TextDataModule(
  335. train_dataset=ds,
  336. val_dataset=ds,
  337. tokenizer=ds.tokenizer,
  338. batch_size=16,
  339. max_length=1024,
  340. num_workers=0,
  341. )
  342. for batch in dm.train_dataloader():
  343. print(batch)