text.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419
  1. import json
  2. import random
  3. import re
  4. from dataclasses import dataclass
  5. from itertools import chain
  6. from pathlib import Path
  7. from random import Random
  8. from typing import Optional, Union
  9. import numpy as np
  10. import orjson
  11. import pyarrow.parquet as pq
  12. import torch
  13. import torch.nn.functional as F
  14. from datasets.download.streaming_download_manager import xopen
  15. from huggingface_hub import HfApi
  16. from lightning import LightningDataModule
  17. from torch.distributed import get_rank, get_world_size, is_initialized
  18. from torch.utils.data import DataLoader, Dataset, IterableDataset, get_worker_info
  19. from transformers import AutoTokenizer
  20. from fish_speech.text import clean_text, g2p
  21. from fish_speech.utils import RankedLogger
  22. from fish_speech.utils.braceexpand import braceexpand
  23. log = RankedLogger(__name__, rank_zero_only=True)
  24. def split_by_rank_worker(files):
  25. # We need to know the total number of devices
  26. # to split the data properly
  27. total_devices = 1
  28. if is_initialized():
  29. total_devices = get_world_size()
  30. worker_info = get_worker_info()
  31. if worker_info is not None:
  32. total_devices *= worker_info.num_workers
  33. if len(files) < total_devices:
  34. # Repeat the files N times to match the number of devices
  35. files = files * (total_devices // len(files) + 1)
  36. # DDP
  37. if is_initialized():
  38. files = files[get_rank() :: get_world_size()]
  39. # Split by worker
  40. if worker_info is not None:
  41. files = files[worker_info.id :: worker_info.num_workers]
  42. return files
  43. class StreamTextDataset(IterableDataset):
  44. def __init__(
  45. self,
  46. files: Optional[Union[list[str], str]] = None,
  47. prefix: Optional[str] = None,
  48. seed: int = 42,
  49. parquet_batch_size: int = 10000,
  50. repo: str = "uonlp/CulturaX",
  51. ):
  52. super().__init__()
  53. self.seed = seed
  54. self.parquet_batch_size = parquet_batch_size
  55. self.repo = repo
  56. if files is None and prefix is None:
  57. raise ValueError("Either files or prefix must be specified")
  58. if prefix is not None:
  59. files = HfApi().list_repo_files(repo, repo_type="dataset")
  60. files = [
  61. f for f in files if f.startswith(prefix) and f.endswith(".parquet")
  62. ]
  63. log.info(f"Found {len(files)} files in {repo} with prefix {prefix}")
  64. else:
  65. if isinstance(files, str):
  66. files = [files]
  67. files = list(chain.from_iterable(map(braceexpand, files)))
  68. log.info(f"Expanded {len(files)} files in {repo}")
  69. # Get sharded files
  70. self.files = sorted(files)
  71. Random(seed).shuffle(self.files)
  72. def __iter__(self):
  73. files = split_by_rank_worker(self.files)
  74. random.shuffle(files)
  75. for filename in files:
  76. try:
  77. yield from self.parse_data(filename)
  78. except Exception as e:
  79. log.exception(f"Failed to parse {filename}: {e}")
  80. def parse_data(self, filename: str):
  81. for data in self.parse_data_internal(filename):
  82. text = data["text"]
  83. expression = re.compile(r"\[INST\] (.*) \[/INST\] (.*) </s>")
  84. match = expression.match(text)
  85. if match is None:
  86. continue
  87. text = match.group(1)
  88. semantic = match.group(2)
  89. # Convert semantic to ids
  90. expression = re.compile(r"<semantic_(\d+)>")
  91. # 0 and 1 are reserved for <s> and </s>
  92. semantic = [0] + [int(i) + 2 for i in expression.findall(semantic)] + [1]
  93. yield {"text": text, "semantic": [semantic]}
  94. def parse_data_internal(self, filename: str):
  95. url = f"https://huggingface.co/datasets/{self.repo}/resolve/main/{filename}"
  96. with xopen(url, mode="rb") as stream:
  97. parquet_file = pq.ParquetFile(stream)
  98. for batch in parquet_file.iter_batches(
  99. batch_size=self.parquet_batch_size, columns=["text"]
  100. ):
  101. # In-batch shuffling
  102. texts = [{"text": text.as_py()} for text in batch["text"]]
  103. random.shuffle(texts)
  104. yield from texts
  105. # @dataclass
  106. # class DatasetLine:
  107. # text: str
  108. # semantic: str
  109. # speaker: str
  110. class AutoAugTextDataset(IterableDataset):
  111. """
  112. Auto Augment Dataset by Speaker
  113. 1. Random concatenate multiple sentences from the same speaker to form a longer sentence
  114. 2. Automatically normalize the text
  115. 3. Mix text and phones
  116. """
  117. def __init__(
  118. self,
  119. jsonl_files: list[str],
  120. seed: int = 42,
  121. phones_prob: float = 0.5,
  122. max_length: int = 1024,
  123. order: Optional[list[str]] = None,
  124. tokenizer: AutoTokenizer = None,
  125. ):
  126. super().__init__()
  127. self.jsonl_files = jsonl_files
  128. self.seed = seed
  129. self.phones_prob = phones_prob
  130. self.max_length = max_length
  131. self.order = order
  132. self.tokenizer = tokenizer
  133. # Read all lines, and group by speaker
  134. self.groups = []
  135. from tqdm import tqdm
  136. for filename in self.jsonl_files:
  137. with open(filename, "r") as f:
  138. for json_line in tqdm(f):
  139. if json_line.strip() == "":
  140. continue
  141. line = orjson.loads(json_line)
  142. # for i in line["sentences"]:
  143. # # Save memory
  144. # i["semantics"] = np.array(i["semantics"], dtype=np.uint16)
  145. self.groups.append(line)
  146. import sys
  147. print(sys.getsizeof(self.groups) / 1024 / 1024)
  148. # Shuffle the lines
  149. Random(seed).shuffle(self.lines)
  150. def __iter__(self):
  151. lines = split_by_rank_worker(self.lines)
  152. random.shuffle(lines)
  153. for line in lines:
  154. yield self.augment(line)
  155. def tokenize_sentence(
  156. self, sentence: str, semantic: list[int], mode: str = "sample"
  157. ):
  158. sentence = clean_text(sentence)
  159. if (
  160. mode == "sample" and (random.random() < self.phones_prob)
  161. ) or mode == "phones":
  162. sentence = " ".join([t for _, t in g2p(sentence, order=self.order)])
  163. semantic = " ".join([f"<semantic_{i}>" for i in semantic])
  164. tokens = self.tokenizer.encode(
  165. f"{sentence} {semantic}", max_length=10**6, add_special_tokens=False
  166. )
  167. return sentence, semantic, len(tokens)
  168. def augment(self, line):
  169. speaker = line.get("speaker", None)
  170. # 20% to pure text or pure phones
  171. mode = "sample"
  172. if random.random() < 0.2:
  173. mode = random.choice(["text", "phones"])
  174. if speaker is None:
  175. a, b, _ = self.tokenize_sentence(line["text"], line["semantic"], mode=mode)
  176. return {"text": f"[INST] {a} [/INST] {b} </s>"}
  177. # Random sample based on speaker using a truncated normal distribution
  178. a = torch.tensor([0], dtype=torch.float32)
  179. torch.nn.init.trunc_normal_(
  180. a,
  181. mean=self.max_length // 2,
  182. std=self.max_length // 4,
  183. a=0,
  184. b=self.max_length,
  185. )
  186. remaining_tokens = a.long().item() - 4
  187. final_text, final_semantic = [], []
  188. # Shuffle unique lines
  189. idxs = list(range(len(self.speakers[speaker])))
  190. random.shuffle(idxs)
  191. while remaining_tokens > 0 and len(idxs) > 0:
  192. line = self.speakers[speaker][idxs.pop()]
  193. text, semantic, length = self.tokenize_sentence(
  194. line["text"], line["semantic"], mode=mode
  195. )
  196. remaining_tokens -= length
  197. final_text.append(text)
  198. final_semantic.append(semantic)
  199. final_text = " ".join(final_text)
  200. final_semantic = " ".join(final_semantic)
  201. return {"text": f"[INST] {final_text} [/INST] {final_semantic} </s>"}
  202. @dataclass
  203. class TextDataCollator:
  204. tokenizer: AutoTokenizer
  205. max_length: int = 512
  206. def __call__(self, examples):
  207. texts = [i["text"] for i in examples]
  208. if self.tokenizer.pad_token is None:
  209. self.tokenizer.pad_token = self.tokenizer.eos_token
  210. encoded_texts = self.tokenizer(
  211. texts,
  212. truncation=True,
  213. padding=True,
  214. max_length=self.max_length,
  215. return_tensors="pt",
  216. pad_to_multiple_of=8,
  217. )
  218. semantic = [i["semantic"] for i in examples]
  219. max_semantic_length = max([len(i[0]) for i in semantic])
  220. # Make xformers happy
  221. if (max_semantic_length - 1) % 8 != 0:
  222. max_semantic_length += 8 - (max_semantic_length - 1) % 8
  223. if max_semantic_length > self.max_length + 1:
  224. max_semantic_length = self.max_length + 1
  225. codes, codes_mask = [], []
  226. for i in semantic:
  227. t = torch.tensor(i)
  228. if t.shape[-1] >= max_semantic_length:
  229. t = t[..., :max_semantic_length]
  230. codes.append(
  231. F.pad(
  232. t,
  233. (0, max_semantic_length - t.shape[-1]),
  234. value=1,
  235. )
  236. )
  237. mask = torch.zeros(max_semantic_length, dtype=torch.long)
  238. mask[t.shape[-1] :] = 1
  239. codes_mask.append(mask.bool())
  240. codes = torch.stack(codes)
  241. codes_mask = torch.stack(codes_mask)
  242. data = {
  243. "inputs": encoded_texts["input_ids"],
  244. "input_mask": encoded_texts["attention_mask"] == 0,
  245. "codes": codes,
  246. "codes_mask": codes_mask,
  247. }
  248. return data
  249. class InterleaveDataset(IterableDataset):
  250. def __init__(
  251. self,
  252. datasets: list[IterableDataset],
  253. probabilities: list[float],
  254. seed: int = 42,
  255. ):
  256. super().__init__()
  257. self.datasets = datasets
  258. self.probabilities = probabilities
  259. self.seed = seed
  260. def __iter__(self):
  261. rng = np.random.default_rng(self.seed)
  262. dataset_iterators = [iter(dataset) for dataset in self.datasets]
  263. while True:
  264. # Random choice one
  265. dataset_idx = rng.choice(len(self.datasets), p=self.probabilities)
  266. dataset_iterator = dataset_iterators[dataset_idx]
  267. try:
  268. yield next(dataset_iterator)
  269. except StopIteration:
  270. # Exhausted, create a new iterator
  271. dataset_iterators[dataset_idx] = iter(self.datasets[dataset_idx])
  272. yield next(dataset_iterators[dataset_idx])
  273. class TextDataModule(LightningDataModule):
  274. def __init__(
  275. self,
  276. train_dataset: Union[StreamTextDataset, AutoAugTextDataset, InterleaveDataset],
  277. val_dataset: Union[StreamTextDataset, AutoAugTextDataset, InterleaveDataset],
  278. batch_size: int = 32,
  279. tokenizer: AutoTokenizer = None,
  280. max_length: int = 1024,
  281. num_workers: int = 4,
  282. ):
  283. super().__init__()
  284. self.train_dataset = train_dataset
  285. self.val_dataset = val_dataset
  286. self.batch_size = batch_size
  287. self.tokenizer = tokenizer
  288. self.max_length = max_length
  289. self.num_workers = num_workers
  290. def train_dataloader(self):
  291. return DataLoader(
  292. self.train_dataset,
  293. batch_size=self.batch_size,
  294. collate_fn=TextDataCollator(self.tokenizer, self.max_length),
  295. num_workers=self.num_workers,
  296. )
  297. def val_dataloader(self):
  298. return DataLoader(
  299. self.val_dataset,
  300. batch_size=self.batch_size,
  301. collate_fn=TextDataCollator(self.tokenizer, self.max_length),
  302. num_workers=self.num_workers,
  303. )
  304. if __name__ == "__main__":
  305. import json
  306. # data/Genshin/English/Aabid/vo_KVCOP001_1907808_aabid_01.lab
  307. # all_files = [i for i in Path("data/Genshin/English").rglob("*.lab")]
  308. # with open("test.jsonl", "w") as f:
  309. # for i in all_files:
  310. # wav_file = i.with_suffix(".wav")
  311. # duration = float(Path(wav_file).stat().st_size) / 2 / 44100
  312. # eta_tokens = duration * 25
  313. # fake_tokens = [random.randint(0, 2048) for _ in range(int(eta_tokens))]
  314. # f.write(json.dumps({"text": Path(i).read_text(), "speaker": i.parent.name, "semantic": fake_tokens}) + "\n")
  315. ds = AutoAugTextDataset(
  316. jsonl_files=["data/quantized-dataset-1205.json"],
  317. order=["en"],
  318. tokenizer=AutoTokenizer.from_pretrained(
  319. "fishaudio/speech-lm-300m", revision="text-pretrain-10k-phones"
  320. ),
  321. )
  322. dm = TextDataModule(
  323. train_dataset=ds,
  324. val_dataset=ds,
  325. tokenizer=ds.tokenizer,
  326. batch_size=2,
  327. max_length=1024,
  328. num_workers=0,
  329. )
  330. for batch in dm.train_dataloader():
  331. print(batch)
  332. break