text.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415
  1. import json
  2. import random
  3. import re
  4. from dataclasses import dataclass
  5. from itertools import chain
  6. from pathlib import Path
  7. from random import Random
  8. from typing import Optional, Union
  9. import numpy as np
  10. import pyarrow.parquet as pq
  11. import torch
  12. import torch.nn.functional as F
  13. from datasets.download.streaming_download_manager import xopen
  14. from huggingface_hub import HfApi
  15. from lightning import LightningDataModule
  16. from torch.distributed import get_rank, get_world_size, is_initialized
  17. from torch.utils.data import DataLoader, Dataset, IterableDataset, get_worker_info
  18. from transformers import AutoTokenizer
  19. from fish_speech.text import clean_text, g2p
  20. from fish_speech.utils import RankedLogger
  21. from fish_speech.utils.braceexpand import braceexpand
  22. log = RankedLogger(__name__, rank_zero_only=True)
  23. def split_by_rank_worker(files):
  24. # We need to know the total number of devices
  25. # to split the data properly
  26. total_devices = 1
  27. if is_initialized():
  28. total_devices = get_world_size()
  29. worker_info = get_worker_info()
  30. if worker_info is not None:
  31. total_devices *= worker_info.num_workers
  32. if len(files) < total_devices:
  33. # Repeat the files N times to match the number of devices
  34. files = files * (total_devices // len(files) + 1)
  35. # DDP
  36. if is_initialized():
  37. files = files[get_rank() :: get_world_size()]
  38. # Split by worker
  39. if worker_info is not None:
  40. files = files[worker_info.id :: worker_info.num_workers]
  41. return files
  42. class StreamTextDataset(IterableDataset):
  43. def __init__(
  44. self,
  45. files: Optional[Union[list[str], str]] = None,
  46. prefix: Optional[str] = None,
  47. seed: int = 42,
  48. parquet_batch_size: int = 10000,
  49. repo: str = "uonlp/CulturaX",
  50. ):
  51. super().__init__()
  52. self.seed = seed
  53. self.parquet_batch_size = parquet_batch_size
  54. self.repo = repo
  55. if files is None and prefix is None:
  56. raise ValueError("Either files or prefix must be specified")
  57. if prefix is not None:
  58. files = HfApi().list_repo_files(repo, repo_type="dataset")
  59. files = [
  60. f for f in files if f.startswith(prefix) and f.endswith(".parquet")
  61. ]
  62. log.info(f"Found {len(files)} files in {repo} with prefix {prefix}")
  63. else:
  64. if isinstance(files, str):
  65. files = [files]
  66. files = list(chain.from_iterable(map(braceexpand, files)))
  67. log.info(f"Expanded {len(files)} files in {repo}")
  68. # Get sharded files
  69. self.files = sorted(files)
  70. Random(seed).shuffle(self.files)
  71. def __iter__(self):
  72. files = split_by_rank_worker(self.files)
  73. random.shuffle(files)
  74. for filename in files:
  75. try:
  76. yield from self.parse_data(filename)
  77. except Exception as e:
  78. log.exception(f"Failed to parse {filename}: {e}")
  79. def parse_data(self, filename: str):
  80. for data in self.parse_data_internal(filename):
  81. text = data["text"]
  82. expression = re.compile(r"\[INST\] (.*) \[/INST\] (.*) </s>")
  83. match = expression.match(text)
  84. if match is None:
  85. continue
  86. text = match.group(1)
  87. semantic = match.group(2)
  88. # Convert semantic to ids
  89. expression = re.compile(r"<semantic_(\d+)>")
  90. # 0 and 1 are reserved for <s> and </s>
  91. semantic = [0] + [int(i) + 2 for i in expression.findall(semantic)] + [1]
  92. yield {"text": text, "semantic": [semantic]}
  93. def parse_data_internal(self, filename: str):
  94. url = f"https://huggingface.co/datasets/{self.repo}/resolve/main/{filename}"
  95. with xopen(url, mode="rb") as stream:
  96. parquet_file = pq.ParquetFile(stream)
  97. for batch in parquet_file.iter_batches(
  98. batch_size=self.parquet_batch_size, columns=["text"]
  99. ):
  100. # In-batch shuffling
  101. texts = [{"text": text.as_py()} for text in batch["text"]]
  102. random.shuffle(texts)
  103. yield from texts
  104. # @dataclass
  105. # class DatasetLine:
  106. # text: str
  107. # semantic: str
  108. # speaker: str
  109. class AutoAugTextDataset(IterableDataset):
  110. """
  111. Auto Augment Dataset by Speaker
  112. 1. Random concatenate multiple sentences from the same speaker to form a longer sentence
  113. 2. Automatically normalize the text
  114. 3. Mix text and phones
  115. """
  116. def __init__(
  117. self,
  118. jsonl_files: list[str],
  119. seed: int = 42,
  120. phones_prob: float = 0.5,
  121. max_length: int = 1024,
  122. order: Optional[list[str]] = None,
  123. tokenizer: AutoTokenizer = None,
  124. ):
  125. super().__init__()
  126. self.jsonl_files = jsonl_files
  127. self.seed = seed
  128. self.phones_prob = phones_prob
  129. self.max_length = max_length
  130. self.order = order
  131. self.tokenizer = tokenizer
  132. # Read all lines, and group by speaker
  133. self.speakers = {}
  134. self.lines = []
  135. for filename in self.jsonl_files:
  136. lines = Path(filename).read_text().splitlines()
  137. for json_line in lines:
  138. line = json.loads(json_line)
  139. speaker = line.get("speaker", None)
  140. if speaker not in self.speakers:
  141. self.speakers[speaker] = []
  142. self.lines.append(line)
  143. self.speakers[speaker].append(line)
  144. # Shuffle the lines
  145. Random(seed).shuffle(self.lines)
  146. def __iter__(self):
  147. lines = split_by_rank_worker(self.lines)
  148. random.shuffle(lines)
  149. for line in lines:
  150. yield self.augment(line)
  151. def tokenize_sentence(
  152. self, sentence: str, semantic: list[int], mode: str = "sample"
  153. ):
  154. sentence = clean_text(sentence)
  155. if (
  156. mode == "sample" and (random.random() < self.phones_prob)
  157. ) or mode == "phones":
  158. sentence = " ".join([t for _, t in g2p(sentence, order=self.order)])
  159. semantic = " ".join([f"<semantic_{i}>" for i in semantic])
  160. tokens = self.tokenizer.encode(
  161. f"{sentence} {semantic}", max_length=10**6, add_special_tokens=False
  162. )
  163. return sentence, semantic, len(tokens)
  164. def augment(self, line):
  165. speaker = line.get("speaker", None)
  166. # 20% to pure text or pure phones
  167. mode = "sample"
  168. if random.random() < 0.2:
  169. mode = random.choice(["text", "phones"])
  170. if speaker is None:
  171. a, b, _ = self.tokenize_sentence(line["text"], line["semantic"], mode=mode)
  172. return {"text": f"[INST] {a} [/INST] {b} </s>"}
  173. # Random sample based on speaker using a truncated normal distribution
  174. a = torch.tensor([0], dtype=torch.float32)
  175. torch.nn.init.trunc_normal_(
  176. a,
  177. mean=self.max_length // 2,
  178. std=self.max_length // 4,
  179. a=0,
  180. b=self.max_length,
  181. )
  182. remaining_tokens = a.long().item() - 4
  183. final_text, final_semantic = [], []
  184. # Shuffle unique lines
  185. idxs = list(range(len(self.speakers[speaker])))
  186. random.shuffle(idxs)
  187. while remaining_tokens > 0 and len(idxs) > 0:
  188. line = self.speakers[speaker][idxs.pop()]
  189. text, semantic, length = self.tokenize_sentence(
  190. line["text"], line["semantic"], mode=mode
  191. )
  192. remaining_tokens -= length
  193. final_text.append(text)
  194. final_semantic.append(semantic)
  195. final_text = " ".join(final_text)
  196. final_semantic = " ".join(final_semantic)
  197. return {"text": f"[INST] {final_text} [/INST] {final_semantic} </s>"}
  198. @dataclass
  199. class TextDataCollator:
  200. tokenizer: AutoTokenizer
  201. max_length: int = 512
  202. def __call__(self, examples):
  203. texts = [i["text"] for i in examples]
  204. if self.tokenizer.pad_token is None:
  205. self.tokenizer.pad_token = self.tokenizer.eos_token
  206. encoded_texts = self.tokenizer(
  207. texts,
  208. truncation=True,
  209. padding=True,
  210. max_length=self.max_length,
  211. return_tensors="pt",
  212. pad_to_multiple_of=8,
  213. )
  214. semantic = [i["semantic"] for i in examples]
  215. max_semantic_length = max([len(i[0]) for i in semantic])
  216. # Make xformers happy
  217. if (max_semantic_length - 1) % 8 != 0:
  218. max_semantic_length += 8 - (max_semantic_length - 1) % 8
  219. if max_semantic_length > self.max_length + 1:
  220. max_semantic_length = self.max_length + 1
  221. codes, codes_mask = [], []
  222. for i in semantic:
  223. t = torch.tensor(i)
  224. if t.shape[-1] >= max_semantic_length:
  225. t = t[..., :max_semantic_length]
  226. codes.append(
  227. F.pad(
  228. t,
  229. (0, max_semantic_length - t.shape[-1]),
  230. value=1,
  231. )
  232. )
  233. mask = torch.zeros(max_semantic_length, dtype=torch.long)
  234. mask[t.shape[-1] :] = 1
  235. codes_mask.append(mask.bool())
  236. codes = torch.stack(codes)
  237. codes_mask = torch.stack(codes_mask)
  238. data = {
  239. "inputs": encoded_texts["input_ids"],
  240. "input_mask": encoded_texts["attention_mask"] == 0,
  241. "codes": codes,
  242. "codes_mask": codes_mask,
  243. }
  244. return data
  245. class InterleaveDataset(IterableDataset):
  246. def __init__(
  247. self,
  248. datasets: list[IterableDataset],
  249. probabilities: list[float],
  250. seed: int = 42,
  251. ):
  252. super().__init__()
  253. self.datasets = datasets
  254. self.probabilities = probabilities
  255. self.seed = seed
  256. def __iter__(self):
  257. rng = np.random.default_rng(self.seed)
  258. dataset_iterators = [iter(dataset) for dataset in self.datasets]
  259. while True:
  260. # Random choice one
  261. dataset_idx = rng.choice(len(self.datasets), p=self.probabilities)
  262. dataset_iterator = dataset_iterators[dataset_idx]
  263. try:
  264. yield next(dataset_iterator)
  265. except StopIteration:
  266. # Exhausted, create a new iterator
  267. dataset_iterators[dataset_idx] = iter(self.datasets[dataset_idx])
  268. yield next(dataset_iterators[dataset_idx])
  269. class TextDataModule(LightningDataModule):
  270. def __init__(
  271. self,
  272. train_dataset: Union[StreamTextDataset, AutoAugTextDataset, InterleaveDataset],
  273. val_dataset: Union[StreamTextDataset, AutoAugTextDataset, InterleaveDataset],
  274. batch_size: int = 32,
  275. tokenizer: AutoTokenizer = None,
  276. max_length: int = 1024,
  277. num_workers: int = 4,
  278. ):
  279. super().__init__()
  280. self.train_dataset = train_dataset
  281. self.val_dataset = val_dataset
  282. self.batch_size = batch_size
  283. self.tokenizer = tokenizer
  284. self.max_length = max_length
  285. self.num_workers = num_workers
  286. def train_dataloader(self):
  287. return DataLoader(
  288. self.train_dataset,
  289. batch_size=self.batch_size,
  290. collate_fn=TextDataCollator(self.tokenizer, self.max_length),
  291. num_workers=self.num_workers,
  292. )
  293. def val_dataloader(self):
  294. return DataLoader(
  295. self.val_dataset,
  296. batch_size=self.batch_size,
  297. collate_fn=TextDataCollator(self.tokenizer, self.max_length),
  298. num_workers=self.num_workers,
  299. )
  300. if __name__ == "__main__":
  301. import json
  302. # data/Genshin/English/Aabid/vo_KVCOP001_1907808_aabid_01.lab
  303. # all_files = [i for i in Path("data/Genshin/English").rglob("*.lab")]
  304. # with open("test.jsonl", "w") as f:
  305. # for i in all_files:
  306. # wav_file = i.with_suffix(".wav")
  307. # duration = float(Path(wav_file).stat().st_size) / 2 / 44100
  308. # eta_tokens = duration * 25
  309. # fake_tokens = [random.randint(0, 2048) for _ in range(int(eta_tokens))]
  310. # f.write(json.dumps({"text": Path(i).read_text(), "speaker": i.parent.name, "semantic": fake_tokens}) + "\n")
  311. ds = AutoAugTextDataset(
  312. jsonl_files=["test.jsonl"],
  313. order=["en"],
  314. tokenizer=AutoTokenizer.from_pretrained(
  315. "fishaudio/speech-lm-300m", revision="text-pretrain-10k-phones"
  316. ),
  317. )
  318. dm = TextDataModule(
  319. train_dataset=ds,
  320. val_dataset=ds,
  321. tokenizer=ds.tokenizer,
  322. batch_size=2,
  323. max_length=1024,
  324. num_workers=0,
  325. )
  326. for batch in dm.train_dataloader():
  327. print(batch)
  328. break