text.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432
  1. import json
  2. import random
  3. import re
  4. from dataclasses import dataclass
  5. from itertools import chain
  6. from pathlib import Path
  7. from random import Random
  8. from typing import Optional, Union
  9. import grpc
  10. import numpy as np
  11. import pyarrow.parquet as pq
  12. import torch
  13. import torch.nn.functional as F
  14. from datasets.download.streaming_download_manager import xopen
  15. from huggingface_hub import HfApi
  16. from lightning import LightningDataModule
  17. from torch.distributed import get_rank, get_world_size, is_initialized
  18. from torch.utils.data import DataLoader, IterableDataset, get_worker_info
  19. from transformers import AutoTokenizer
  20. from fish_speech.datasets.protos.text_data_pb2 import SampleDataRequest
  21. from fish_speech.datasets.protos.text_data_pb2_grpc import DataServiceStub
  22. from fish_speech.text.parser import clean_text
  23. from fish_speech.text.symbols import pad as pad_symbol
  24. from fish_speech.text.symbols import pu_symbols
  25. from fish_speech.utils import RankedLogger
  26. from fish_speech.utils.braceexpand import braceexpand
  27. log = RankedLogger(__name__, rank_zero_only=True)
  28. def split_by_rank_worker(files):
  29. # We need to know the total number of devices
  30. # to split the data properly
  31. total_devices = 1
  32. if is_initialized():
  33. total_devices = get_world_size()
  34. worker_info = get_worker_info()
  35. if worker_info is not None:
  36. total_devices *= worker_info.num_workers
  37. if len(files) < total_devices:
  38. # Repeat the files N times to match the number of devices
  39. files = files * (total_devices // len(files) + 1)
  40. # DDP
  41. if is_initialized():
  42. files = files[get_rank() :: get_world_size()]
  43. # Split by worker
  44. if worker_info is not None:
  45. files = files[worker_info.id :: worker_info.num_workers]
  46. return files
  47. class StreamTextDataset(IterableDataset):
  48. def __init__(
  49. self,
  50. files: Optional[Union[list[str], str]] = None,
  51. prefix: Optional[str] = None,
  52. seed: int = 42,
  53. parquet_batch_size: int = 10000,
  54. repo: str = "uonlp/CulturaX",
  55. ):
  56. super().__init__()
  57. self.seed = seed
  58. self.parquet_batch_size = parquet_batch_size
  59. self.repo = repo
  60. if files is None and prefix is None:
  61. raise ValueError("Either files or prefix must be specified")
  62. if prefix is not None:
  63. files = HfApi().list_repo_files(repo, repo_type="dataset")
  64. files = [
  65. f for f in files if f.startswith(prefix) and f.endswith(".parquet")
  66. ]
  67. log.info(f"Found {len(files)} files in {repo} with prefix {prefix}")
  68. else:
  69. if isinstance(files, str):
  70. files = [files]
  71. files = list(chain.from_iterable(map(braceexpand, files)))
  72. log.info(f"Expanded {len(files)} files in {repo}")
  73. # Get sharded files
  74. self.files = sorted(files)
  75. Random(seed).shuffle(self.files)
  76. def __iter__(self):
  77. files = split_by_rank_worker(self.files)
  78. random.shuffle(files)
  79. for filename in files:
  80. try:
  81. yield from self.parse_data(filename)
  82. except Exception as e:
  83. log.exception(f"Failed to parse {filename}: {e}")
  84. def parse_data(self, filename: str):
  85. for data in self.parse_data_internal(filename):
  86. text = data["text"]
  87. expression = re.compile(r"\[INST\] (.*) \[/INST\] (.*) </s>")
  88. match = expression.match(text)
  89. if match is None:
  90. continue
  91. text = match.group(1)
  92. semantic = match.group(2)
  93. # Convert semantic to ids
  94. expression = re.compile(r"<semantic_(\d+)>")
  95. # 0 and 1 are reserved for <s> and </s>
  96. semantic = [0] + [int(i) + 2 for i in expression.findall(semantic)] + [1]
  97. yield {"text": text, "semantic": [semantic]}
  98. def parse_data_internal(self, filename: str):
  99. url = f"https://huggingface.co/datasets/{self.repo}/resolve/main/{filename}"
  100. with xopen(url, mode="rb") as stream:
  101. parquet_file = pq.ParquetFile(stream)
  102. for batch in parquet_file.iter_batches(
  103. batch_size=self.parquet_batch_size, columns=["text"]
  104. ):
  105. # In-batch shuffling
  106. texts = [{"text": text.as_py()} for text in batch["text"]]
  107. random.shuffle(texts)
  108. yield from texts
  109. class AutoAugTextDataset(IterableDataset):
  110. """
  111. Auto Augment Dataset by Speaker
  112. 1. Random concatenate multiple sentences from the same speaker to form a longer sentence
  113. 2. Automatically normalize the text
  114. 3. Mix text and phones
  115. """
  116. def __init__(
  117. self,
  118. server: str = "localhost:50051",
  119. seed: int = 42,
  120. phones_prob: float = 0.3,
  121. repetition_prob: float = 0.0,
  122. max_length: int = 1024,
  123. tokenizer: AutoTokenizer = None,
  124. ):
  125. """
  126. Args:
  127. server: gRPC server address
  128. seed: random seed
  129. phones_prob: probability to use phones
  130. repetition_prob: probability to repeat the same sentence
  131. max_length: max length of the text
  132. tokenizer: tokenizer
  133. """
  134. super().__init__()
  135. self.seed = seed
  136. self.phones_prob = phones_prob
  137. self.max_length = max_length
  138. self.tokenizer = tokenizer
  139. self.repetition_prob = repetition_prob
  140. # Read all lines, and group by speaker
  141. self.channel = grpc.insecure_channel(server)
  142. self.stub = DataServiceStub(self.channel)
  143. def __iter__(self):
  144. while True:
  145. yield self.augment()
  146. def tokenize_sentence(self, sentence: str, phones: list[str], mode: str = "sample"):
  147. if (
  148. mode == "sample" and (random.random() < self.phones_prob)
  149. ) or mode == "phones":
  150. sentence = " ".join(
  151. [
  152. (f"<p:{i}>" if i not in pu_symbols and i != pad_symbol else i)
  153. for i in phones
  154. ]
  155. )
  156. else:
  157. sentence = clean_text(sentence)
  158. tokens = self.tokenizer.encode(
  159. f"{sentence}",
  160. max_length=10**6,
  161. add_special_tokens=False,
  162. truncation=False,
  163. )
  164. return sentence, len(tokens)
  165. def augment(self):
  166. # 50% to pure text or pure phones
  167. mode = "sample"
  168. if random.random() < 0.5:
  169. mode = random.choice(["text", "phones"])
  170. # Random sample based on speaker using a truncated normal distribution
  171. a = torch.tensor([0], dtype=torch.float32)
  172. torch.nn.init.trunc_normal_(
  173. a,
  174. mean=self.max_length // 2,
  175. std=self.max_length // 4,
  176. a=10,
  177. b=self.max_length,
  178. )
  179. remaining_tokens = a.long().item() - 4
  180. final_text, final_semantic = [], []
  181. # Shuffle unique lines, estimate that each sample is at least 20 tokens
  182. request = SampleDataRequest(num_samples=self.max_length // 20)
  183. response = self.stub.SampleData(request)
  184. if len(response.samples) == 0:
  185. # Invalid group
  186. return None
  187. samples = list(response.samples)
  188. while remaining_tokens > 0 and len(samples) > 0:
  189. if random.random() < self.repetition_prob:
  190. # Repeat the same sentence
  191. sentence = samples[-1]
  192. else:
  193. sentence = samples.pop()
  194. text, length = self.tokenize_sentence(
  195. sentence.text, sentence.phones, mode=mode
  196. )
  197. remaining_tokens -= length + len(sentence.semantics[0].values)
  198. final_text.append(text)
  199. final_semantic.append(sentence.semantics)
  200. final_text = "[INST] " + " ".join(final_text) + " [/INST]"
  201. encoded = self.tokenizer.encode(
  202. final_text,
  203. add_special_tokens=False,
  204. truncation=False,
  205. max_length=10**6,
  206. )
  207. semantic_length = sum([len(i[0].values) for i in final_semantic])
  208. # Single codebook
  209. if len(final_semantic[0]) == 1:
  210. semantic_tokens = [f"<s:{j}>" for i in final_semantic for j in i[0].values]
  211. tokenized = self.tokenizer.encode(
  212. f" ".join(semantic_tokens),
  213. add_special_tokens=False,
  214. truncation=False,
  215. max_length=10**6,
  216. )
  217. # Pack the tokens and semantics (add <s> and </s> to semantic tokens)
  218. tokens = (
  219. [self.tokenizer.bos_token_id]
  220. + encoded
  221. + tokenized
  222. + [self.tokenizer.eos_token_id]
  223. )
  224. tokens = torch.tensor([tokens], dtype=torch.long)
  225. labels = tokens.clone()
  226. labels[0, : len(encoded) + 1] = -100 # Mask out the <s> and query tokens
  227. else:
  228. # Pack the tokens and semantics (add <s> and </s> to semantic tokens)
  229. tokens = (
  230. [self.tokenizer.bos_token_id]
  231. + encoded
  232. + [self.tokenizer.pad_token_id] * semantic_length
  233. + [self.tokenizer.eos_token_id]
  234. )
  235. codes = [[0] * (len(encoded) + 1) for _ in range(len(final_semantic[0]))]
  236. for segment in final_semantic:
  237. for book_idx, book in enumerate(segment):
  238. for j in book.values:
  239. codes[book_idx].append(int(j) + 2)
  240. for book in codes:
  241. book.append(1)
  242. tokens = [tokens] + codes
  243. tokens = torch.tensor(tokens, dtype=torch.long)
  244. labels = tokens.clone()
  245. labels[
  246. 1:, : len(encoded) + 1
  247. ] = -100 # Mask out the <s> tokens for semantic
  248. return {
  249. "tokens": tokens[:, :-1],
  250. "labels": labels[:, 1:],
  251. }
  252. @dataclass
  253. class TextDataCollator:
  254. tokenizer: AutoTokenizer
  255. max_length: int = 1024
  256. def __call__(self, examples):
  257. tokens, attention_masks, labels = [], [], []
  258. for example in examples:
  259. _tokens = example["tokens"][:, : self.max_length]
  260. _labels = example["labels"][:, : self.max_length]
  261. _attention_mask = torch.ones((self.max_length,), dtype=torch.bool)
  262. _attention_mask[: _tokens.size(1)] = False
  263. assert _tokens.size(1) == _labels.size(
  264. 1
  265. ), f"{_tokens.size(1)} != {_labels.size(1)}"
  266. if _tokens.size(1) < self.max_length:
  267. _tokens = F.pad(
  268. _tokens,
  269. (0, self.max_length - _tokens.size(1)),
  270. value=self.tokenizer.eos_token_id,
  271. )
  272. _labels = F.pad(
  273. _labels, (0, self.max_length - _labels.size(1)), value=-100
  274. )
  275. tokens.append(_tokens)
  276. attention_masks.append(_attention_mask)
  277. labels.append(_labels)
  278. tokens = torch.stack(tokens, dim=0)
  279. attention_masks = torch.stack(attention_masks, dim=0)
  280. labels = torch.stack(labels, dim=0)
  281. return {
  282. "inputs": tokens,
  283. "attention_masks": attention_masks,
  284. "labels": labels,
  285. }
  286. class InterleaveDataset(IterableDataset):
  287. def __init__(
  288. self,
  289. datasets: list[IterableDataset],
  290. probabilities: list[float],
  291. seed: int = 42,
  292. ):
  293. super().__init__()
  294. self.datasets = datasets
  295. self.probabilities = probabilities
  296. self.seed = seed
  297. def __iter__(self):
  298. rng = np.random.default_rng(self.seed)
  299. dataset_iterators = [iter(dataset) for dataset in self.datasets]
  300. while True:
  301. # Random choice one
  302. dataset_idx = rng.choice(len(self.datasets), p=self.probabilities)
  303. dataset_iterator = dataset_iterators[dataset_idx]
  304. try:
  305. yield next(dataset_iterator)
  306. except StopIteration:
  307. # Exhausted, create a new iterator
  308. dataset_iterators[dataset_idx] = iter(self.datasets[dataset_idx])
  309. yield next(dataset_iterators[dataset_idx])
  310. class TextDataModule(LightningDataModule):
  311. def __init__(
  312. self,
  313. train_dataset: Union[StreamTextDataset, AutoAugTextDataset, InterleaveDataset],
  314. val_dataset: Union[StreamTextDataset, AutoAugTextDataset, InterleaveDataset],
  315. batch_size: int = 32,
  316. tokenizer: AutoTokenizer = None,
  317. max_length: int = 1024,
  318. num_workers: int = 4,
  319. ):
  320. super().__init__()
  321. self.train_dataset = train_dataset
  322. self.val_dataset = val_dataset
  323. self.batch_size = batch_size
  324. self.tokenizer = tokenizer
  325. self.max_length = max_length
  326. self.num_workers = num_workers
  327. def train_dataloader(self):
  328. return DataLoader(
  329. self.train_dataset,
  330. batch_size=self.batch_size,
  331. collate_fn=TextDataCollator(self.tokenizer, self.max_length),
  332. num_workers=self.num_workers,
  333. )
  334. def val_dataloader(self):
  335. return DataLoader(
  336. self.val_dataset,
  337. batch_size=self.batch_size,
  338. collate_fn=TextDataCollator(self.tokenizer, self.max_length),
  339. num_workers=self.num_workers,
  340. )
  341. if __name__ == "__main__":
  342. import json
  343. from tqdm import tqdm
  344. ds = AutoAugTextDataset(
  345. tokenizer=AutoTokenizer.from_pretrained("fishaudio/speech-lm-v1"),
  346. )
  347. dm = TextDataModule(
  348. train_dataset=ds,
  349. val_dataset=ds,
  350. tokenizer=ds.tokenizer,
  351. batch_size=16,
  352. max_length=1024,
  353. num_workers=0,
  354. )
  355. for batch in tqdm(dm.train_dataloader()):
  356. pass