text.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357
  1. import json
  2. import random
  3. from dataclasses import dataclass
  4. from itertools import chain
  5. from pathlib import Path
  6. from random import Random
  7. from typing import Optional, Union
  8. import numpy as np
  9. import pyarrow.parquet as pq
  10. import torch
  11. from datasets.download.streaming_download_manager import xopen
  12. from huggingface_hub import HfApi
  13. from lightning import LightningDataModule
  14. from torch.distributed import get_rank, get_world_size, is_initialized
  15. from torch.utils.data import DataLoader, Dataset, IterableDataset, get_worker_info
  16. from transformers import AutoTokenizer
  17. from fish_speech.text import clean_text, g2p
  18. from fish_speech.utils import RankedLogger
  19. from fish_speech.utils.braceexpand import braceexpand
  20. log = RankedLogger(__name__, rank_zero_only=True)
  21. def split_by_rank_worker(files):
  22. # We need to know the total number of devices
  23. # to split the data properly
  24. total_devices = 1
  25. if is_initialized():
  26. total_devices = get_world_size()
  27. worker_info = get_worker_info()
  28. if worker_info is not None:
  29. total_devices *= worker_info.num_workers
  30. if len(files) < total_devices:
  31. # Repeat the files N times to match the number of devices
  32. files = files * (total_devices // len(files) + 1)
  33. # DDP
  34. if is_initialized():
  35. files = files[get_rank() :: get_world_size()]
  36. # Split by worker
  37. if worker_info is not None:
  38. files = files[worker_info.id :: worker_info.num_workers]
  39. return files
  40. class StreamTextDataset(IterableDataset):
  41. def __init__(
  42. self,
  43. files: Optional[Union[list[str], str]] = None,
  44. prefix: Optional[str] = None,
  45. seed: int = 42,
  46. parquet_batch_size: int = 10000,
  47. repo: str = "uonlp/CulturaX",
  48. ):
  49. super().__init__()
  50. self.seed = seed
  51. self.parquet_batch_size = parquet_batch_size
  52. self.repo = repo
  53. if files is None and prefix is None:
  54. raise ValueError("Either files or prefix must be specified")
  55. if prefix is not None:
  56. files = HfApi().list_repo_files(repo, repo_type="dataset")
  57. files = [
  58. f for f in files if f.startswith(prefix) and f.endswith(".parquet")
  59. ]
  60. log.info(f"Found {len(files)} files in {repo} with prefix {prefix}")
  61. else:
  62. if isinstance(files, str):
  63. files = [files]
  64. files = list(chain.from_iterable(map(braceexpand, files)))
  65. log.info(f"Expanded {len(files)} files in {repo}")
  66. # Get sharded files
  67. self.files = sorted(files)
  68. Random(seed).shuffle(self.files)
  69. def __iter__(self):
  70. files = split_by_rank_worker(self.files)
  71. random.shuffle(files)
  72. for filename in files:
  73. try:
  74. yield from self.parse_data(filename)
  75. except Exception as e:
  76. log.exception(f"Failed to parse {filename}: {e}")
  77. def parse_data(self, filename: str):
  78. url = f"https://huggingface.co/datasets/{self.repo}/resolve/main/{filename}"
  79. with xopen(url, mode="rb") as stream:
  80. parquet_file = pq.ParquetFile(stream)
  81. for batch in parquet_file.iter_batches(
  82. batch_size=self.parquet_batch_size, columns=["text"]
  83. ):
  84. # In-batch shuffling
  85. texts = [{"text": text.as_py()} for text in batch["text"]]
  86. random.shuffle(texts)
  87. yield from texts
  88. # @dataclass
  89. # class DatasetLine:
  90. # text: str
  91. # semantic: str
  92. # speaker: str
  93. class AutoAugTextDataset(IterableDataset):
  94. """
  95. Auto Augment Dataset by Speaker
  96. 1. Random concatenate multiple sentences from the same speaker to form a longer sentence
  97. 2. Automatically normalize the text
  98. 3. Mix text and phones
  99. """
  100. def __init__(
  101. self,
  102. jsonl_files: list[str],
  103. seed: int = 42,
  104. phones_prob: float = 0.5,
  105. max_length: int = 1024,
  106. order: Optional[list[str]] = None,
  107. tokenizer: AutoTokenizer = None,
  108. ):
  109. super().__init__()
  110. self.jsonl_files = jsonl_files
  111. self.seed = seed
  112. self.phones_prob = phones_prob
  113. self.max_length = max_length
  114. self.order = order
  115. self.tokenizer = tokenizer
  116. # Read all lines, and group by speaker
  117. self.speakers = {}
  118. self.lines = []
  119. for filename in self.jsonl_files:
  120. lines = Path(filename).read_text().splitlines()
  121. for json_line in lines:
  122. line = json.loads(json_line)
  123. speaker = line.get("speaker", None)
  124. if speaker not in self.speakers:
  125. self.speakers[speaker] = []
  126. self.lines.append(line)
  127. self.speakers[speaker].append(line)
  128. # Shuffle the lines
  129. Random(seed).shuffle(self.lines)
  130. def __iter__(self):
  131. lines = split_by_rank_worker(self.lines)
  132. random.shuffle(lines)
  133. for line in lines:
  134. yield self.augment(line)
  135. def tokenize_sentence(
  136. self, sentence: str, semantic: list[int], mode: str = "sample"
  137. ):
  138. sentence = clean_text(sentence)
  139. if (
  140. mode == "sample" and (random.random() < self.phones_prob)
  141. ) or mode == "phones":
  142. sentence = " ".join([t for _, t in g2p(sentence, order=self.order)])
  143. semantic = " ".join([f"<semantic_{i}>" for i in semantic])
  144. tokens = self.tokenizer.encode(
  145. f"{sentence} {semantic}", max_length=10**6, add_special_tokens=False
  146. )
  147. return sentence, semantic, len(tokens)
  148. def augment(self, line):
  149. speaker = line.get("speaker", None)
  150. # 20% to pure text or pure phones
  151. mode = "sample"
  152. if random.random() < 0.2:
  153. mode = random.choice(["text", "phones"])
  154. if speaker is None:
  155. a, b, _ = self.tokenize_sentence(line["text"], line["semantic"], mode=mode)
  156. return {"text": f"[INST] {a} [/INST] {b} </s>"}
  157. # Random sample based on speaker using a truncated normal distribution
  158. a = torch.tensor([0], dtype=torch.float32)
  159. torch.nn.init.trunc_normal_(
  160. a,
  161. mean=self.max_length // 2,
  162. std=self.max_length // 4,
  163. a=0,
  164. b=self.max_length,
  165. )
  166. remaining_tokens = a.long().item() - 4
  167. final_text, final_semantic = [], []
  168. # Shuffle unique lines
  169. idxs = list(range(len(self.speakers[speaker])))
  170. random.shuffle(idxs)
  171. while remaining_tokens > 0 and len(idxs) > 0:
  172. line = self.speakers[speaker][idxs.pop()]
  173. text, semantic, length = self.tokenize_sentence(
  174. line["text"], line["semantic"], mode=mode
  175. )
  176. remaining_tokens -= length
  177. final_text.append(text)
  178. final_semantic.append(semantic)
  179. final_text = " ".join(final_text)
  180. final_semantic = " ".join(final_semantic)
  181. return {"text": f"[INST] {final_text} [/INST] {final_semantic} </s>"}
  182. @dataclass
  183. class TextDataCollator:
  184. tokenizer: AutoTokenizer
  185. max_length: int = 512
  186. def __call__(self, examples):
  187. texts = [i["text"] for i in examples]
  188. if self.tokenizer.pad_token is None:
  189. self.tokenizer.pad_token = self.tokenizer.eos_token
  190. data = self.tokenizer(
  191. texts,
  192. truncation=True,
  193. padding=True,
  194. max_length=self.max_length,
  195. return_tensors="pt",
  196. )
  197. data["labels"] = data["input_ids"].clone()
  198. data["labels"][data["attention_mask"] == 0] = -100
  199. return data
  200. class InterleaveDataset(IterableDataset):
  201. def __init__(
  202. self,
  203. datasets: list[IterableDataset],
  204. probabilities: list[float],
  205. seed: int = 42,
  206. ):
  207. super().__init__()
  208. self.datasets = datasets
  209. self.probabilities = probabilities
  210. self.seed = seed
  211. def __iter__(self):
  212. rng = np.random.default_rng(self.seed)
  213. dataset_iterators = [iter(dataset) for dataset in self.datasets]
  214. while True:
  215. # Random choice one
  216. dataset_idx = rng.choice(len(self.datasets), p=self.probabilities)
  217. dataset_iterator = dataset_iterators[dataset_idx]
  218. try:
  219. yield next(dataset_iterator)
  220. except StopIteration:
  221. # Exhausted, create a new iterator
  222. dataset_iterators[dataset_idx] = iter(self.datasets[dataset_idx])
  223. yield next(dataset_iterators[dataset_idx])
  224. class TextDataModule(LightningDataModule):
  225. def __init__(
  226. self,
  227. train_dataset: Union[StreamTextDataset, AutoAugTextDataset, InterleaveDataset],
  228. val_dataset: Union[StreamTextDataset, AutoAugTextDataset, InterleaveDataset],
  229. batch_size: int = 32,
  230. tokenizer: AutoTokenizer = None,
  231. max_length: int = 1024,
  232. num_workers: int = 4,
  233. ):
  234. super().__init__()
  235. self.train_dataset = train_dataset
  236. self.val_dataset = val_dataset
  237. self.batch_size = batch_size
  238. self.tokenizer = tokenizer
  239. self.max_length = max_length
  240. self.num_workers = num_workers
  241. def train_dataloader(self):
  242. return DataLoader(
  243. self.train_dataset,
  244. batch_size=self.batch_size,
  245. collate_fn=TextDataCollator(self.tokenizer, self.max_length),
  246. num_workers=self.num_workers,
  247. )
  248. def val_dataloader(self):
  249. return DataLoader(
  250. self.val_dataset,
  251. batch_size=self.batch_size,
  252. collate_fn=TextDataCollator(self.tokenizer, self.max_length),
  253. num_workers=self.num_workers,
  254. )
  255. if __name__ == "__main__":
  256. import json
  257. # data/Genshin/English/Aabid/vo_KVCOP001_1907808_aabid_01.lab
  258. # all_files = [i for i in Path("data/Genshin/English").rglob("*.lab")]
  259. # with open("test.jsonl", "w") as f:
  260. # for i in all_files:
  261. # wav_file = i.with_suffix(".wav")
  262. # duration = float(Path(wav_file).stat().st_size) / 2 / 44100
  263. # eta_tokens = duration * 25
  264. # fake_tokens = [random.randint(0, 2048) for _ in range(int(eta_tokens))]
  265. # f.write(json.dumps({"text": Path(i).read_text(), "speaker": i.parent.name, "semantic": fake_tokens}) + "\n")
  266. ds = AutoAugTextDataset(
  267. jsonl_files=["test.jsonl"],
  268. order=["en"],
  269. tokenizer=AutoTokenizer.from_pretrained(
  270. "fishaudio/speech-lm-300m", revision="text-pretrain-10k-phones"
  271. ),
  272. )
  273. dm = TextDataModule(
  274. train_dataset=ds,
  275. val_dataset=ds,
  276. tokenizer=ds.tokenizer,
  277. batch_size=2,
  278. max_length=1024,
  279. num_workers=0,
  280. )
  281. for batch in dm.train_dataloader():
  282. print(batch)
  283. break