text.py 22 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664
  1. import random
  2. from dataclasses import dataclass
  3. from itertools import chain
  4. from pathlib import Path
  5. from random import Random
  6. from typing import Optional, Union
  7. import grpc
  8. import numpy as np
  9. import pyarrow.parquet as pq
  10. import torch
  11. import torch.nn.functional as F
  12. from datasets.download.streaming_download_manager import xopen
  13. from huggingface_hub import HfApi
  14. from lightning import LightningDataModule
  15. from torch.distributed import get_rank, get_world_size, is_initialized
  16. from torch.utils.data import DataLoader, IterableDataset, get_worker_info
  17. from transformers import AutoTokenizer
  18. from fish_speech.datasets.protos.text_data_pb2 import SampledData
  19. from fish_speech.datasets.protos.text_data_stream import read_pb_stream
  20. from fish_speech.text.clean import clean_text
  21. from fish_speech.utils import RankedLogger
  22. from fish_speech.utils.braceexpand import braceexpand
  23. log = RankedLogger(__name__, rank_zero_only=True)
  24. CODEBOOK_PAD_TOKEN_ID = 0
  25. CODEBOOK_EOS_TOKEN_ID = 1
  26. def split_by_rank_worker(files):
  27. # We need to know the total number of devices
  28. # to split the data properly
  29. total_devices = 1
  30. if is_initialized():
  31. total_devices = get_world_size()
  32. worker_info = get_worker_info()
  33. if worker_info is not None:
  34. total_devices *= worker_info.num_workers
  35. if len(files) < total_devices:
  36. # Repeat the files N times to match the number of devices
  37. files = files * (total_devices // len(files) + 1)
  38. # DDP
  39. if is_initialized():
  40. files = files[get_rank() :: get_world_size()]
  41. # Split by worker
  42. if worker_info is not None:
  43. files = files[worker_info.id :: worker_info.num_workers]
  44. return files
  45. class StreamTextDataset(IterableDataset):
  46. def __init__(
  47. self,
  48. files: Optional[Union[list[str], str]] = None,
  49. prefix: Optional[str] = None,
  50. seed: int = 42,
  51. parquet_batch_size: int = 10000,
  52. repo: str = "uonlp/CulturaX",
  53. max_length: int = 1024,
  54. tokenizer: AutoTokenizer = None,
  55. ):
  56. super().__init__()
  57. self.seed = seed
  58. self.parquet_batch_size = parquet_batch_size
  59. self.repo = repo
  60. self.max_length = max_length
  61. self.tokenizer = tokenizer
  62. if files is None and prefix is None:
  63. raise ValueError("Either files or prefix must be specified")
  64. if prefix is not None:
  65. files = HfApi().list_repo_files(repo, repo_type="dataset")
  66. files = [
  67. f for f in files if f.startswith(prefix) and f.endswith(".parquet")
  68. ]
  69. log.info(f"Found {len(files)} files in {repo} with prefix {prefix}")
  70. else:
  71. if isinstance(files, str):
  72. files = [files]
  73. files = list(chain.from_iterable(map(braceexpand, files)))
  74. log.info(f"Expanded {len(files)} files in {repo}")
  75. # Get sharded files
  76. self.files = sorted(files)
  77. Random(seed).shuffle(self.files)
  78. def __iter__(self):
  79. files = split_by_rank_worker(self.files)
  80. random.shuffle(files)
  81. for filename in files:
  82. try:
  83. yield from self.parse_data(filename)
  84. except Exception as e:
  85. log.exception(f"Failed to parse {filename}: {e}")
  86. def parse_data(self, filename: str):
  87. for data in self.parse_data_internal(filename):
  88. text = data["text"]
  89. # encode
  90. tokens = self.tokenizer.encode(
  91. text,
  92. add_special_tokens=False,
  93. truncation=False,
  94. max_length=10**6,
  95. )
  96. # Random choice self.max_length
  97. if len(tokens) > self.max_length:
  98. start = random.randint(0, len(tokens) - self.max_length)
  99. tokens = tokens[start : start + self.max_length - 1]
  100. tokens = (
  101. [self.tokenizer.bos_token_id] + tokens + [self.tokenizer.eos_token_id]
  102. )
  103. # Pad dims
  104. placeholder_multi_codebook = torch.zeros((4, len(tokens)), dtype=torch.long)
  105. tokens = torch.concat(
  106. [
  107. torch.tensor([tokens], dtype=torch.long),
  108. placeholder_multi_codebook,
  109. ],
  110. dim=0,
  111. )
  112. labels = tokens.clone()
  113. tokens = tokens[:, :-1]
  114. labels = labels[:, 1:]
  115. labels[1:] = -100 # remove all placeholders
  116. yield {"tokens": tokens, "labels": labels}
  117. def parse_data_internal(self, filename: str):
  118. url = f"https://huggingface.co/datasets/{self.repo}/resolve/main/{filename}"
  119. with xopen(url, mode="rb") as stream:
  120. parquet_file = pq.ParquetFile(stream)
  121. for batch in parquet_file.iter_batches(
  122. batch_size=self.parquet_batch_size, columns=["text"]
  123. ):
  124. # In-batch shuffling
  125. texts = [{"text": text.as_py()} for text in batch["text"]]
  126. random.shuffle(texts)
  127. yield from texts
  128. class AutoAugTextDataset(IterableDataset):
  129. """
  130. Auto Augment Dataset by Speaker
  131. 1. Random concatenate multiple sentences from the same speaker to form a longer sentence
  132. 2. Automatically normalize the text
  133. 3. Mix text and phones
  134. For interactive mode, we use the following format (multiple sequences):
  135. <s> [INST] [SPK: speaker] text [/INST] ... [INST] text [/INST] </s>
  136. For non-interactive mode, we use the following format (one long sequence):
  137. <s> [INST] text [/INST] ... </s>
  138. """
  139. def __init__(
  140. self,
  141. proto_files: list[str],
  142. seed: int = 42,
  143. phones_prob: float = 0.3,
  144. interactive_prob: float = 0.5,
  145. max_length: int = 1024,
  146. tokenizer: AutoTokenizer = None,
  147. use_speaker: bool = True,
  148. causual: bool = True,
  149. use_negative_samples: bool = False,
  150. num_codebooks: Optional[int] = None,
  151. ):
  152. """
  153. Args:
  154. proto_files: proto buf files if using local data
  155. seed: random seed
  156. interactive_prob: probability to use interactive mode
  157. max_length: max length of the text
  158. tokenizer: tokenizer
  159. use_speaker: include speaker information in the prompt
  160. causual: use causual sampling when using local data, disable will lead to random sampling
  161. use_negative_samples: generate negative samples
  162. num_codebooks: number of codebooks, if None, it will be automatically detected
  163. """
  164. super().__init__()
  165. assert 0 <= interactive_prob <= 1, "interactive_prob must be in [0, 1]"
  166. self.seed = seed
  167. self.phones_prob = phones_prob
  168. self.max_length = max_length
  169. self.tokenizer = tokenizer
  170. self.interactive_prob = interactive_prob
  171. self.use_speaker = use_speaker
  172. self.proto_files = proto_files
  173. self.causual = causual
  174. self.use_negative_samples = use_negative_samples
  175. self.num_codebooks = num_codebooks
  176. self.semantic_token_id = self.tokenizer.convert_tokens_to_ids("<|semantic|>")
  177. self.groups = None
  178. def init_mock_data_server(self):
  179. if self.groups is not None:
  180. return
  181. # Expand the proto files
  182. expanded_proto_files = []
  183. for filename in self.proto_files:
  184. for i in braceexpand(filename):
  185. i = Path(i)
  186. if i.is_file():
  187. expanded_proto_files.append(i)
  188. elif i.is_dir():
  189. expanded_proto_files.extend(i.rglob("*.proto"))
  190. expanded_proto_files.extend(i.rglob("*.protos"))
  191. else:
  192. raise ValueError(f"{i} is not a file or directory")
  193. expanded_proto_files = sorted(expanded_proto_files)
  194. Random(self.seed).shuffle(expanded_proto_files)
  195. self.groups = []
  196. shard_proto_files = split_by_rank_worker(expanded_proto_files)
  197. log.info(
  198. f"Reading {len(shard_proto_files)} / {len(expanded_proto_files)} files"
  199. )
  200. count = 0
  201. for filename in shard_proto_files:
  202. with open(filename, "rb") as f:
  203. for text_data in read_pb_stream(f):
  204. self.groups.append(text_data)
  205. count += 1
  206. log.info(f"Read total {count} groups of data")
  207. # Shuffle the lines
  208. Random(self.seed).shuffle(self.groups)
  209. def __iter__(self):
  210. while True:
  211. yield self.augment()
  212. def tokenize_sentence(self, sentence: str):
  213. sentence = clean_text(sentence)
  214. tokens = self.tokenizer.encode(
  215. f"{sentence}",
  216. max_length=10**6,
  217. add_special_tokens=False,
  218. truncation=False,
  219. )
  220. return sentence, len(tokens)
  221. def sample_data(self):
  222. if self.groups is None:
  223. self.init_mock_data_server()
  224. # Shuffle unique lines, estimate that each sample is at least 20 tokens
  225. num_samples = self.max_length // 20
  226. # choice group based on their number of samples
  227. group = random.choices(
  228. self.groups, weights=[len(i.sentences) for i in self.groups], k=1
  229. )[0]
  230. if self.causual:
  231. # Sample in order
  232. if num_samples >= len(group.sentences):
  233. samples = group.sentences
  234. else:
  235. begin = random.randint(0, len(group.sentences) - num_samples)
  236. samples = group.sentences[begin : begin + num_samples]
  237. else:
  238. samples = random.choices(
  239. group.sentences, k=min(num_samples, len(group.sentences))
  240. )
  241. return SampledData(
  242. source=group.source,
  243. name=group.name,
  244. samples=samples,
  245. )
  246. def augment(self):
  247. # Random sample based on speaker using a truncated normal distribution
  248. a = torch.tensor([0], dtype=torch.float32)
  249. torch.nn.init.trunc_normal_(
  250. a,
  251. mean=self.max_length // 2,
  252. std=self.max_length // 4,
  253. a=10,
  254. b=self.max_length,
  255. )
  256. remaining_tokens = a.long().item() - 4
  257. final_text, final_semantic = [], []
  258. response = self.sample_data()
  259. if len(response.samples) == 0:
  260. # Invalid group
  261. return None
  262. samples = list(response.samples)
  263. idx = 0
  264. use_interactive = random.random() < self.interactive_prob
  265. all_tokens, all_labels = [], []
  266. while remaining_tokens > 0 and len(samples) > 0:
  267. sentence = samples.pop()
  268. text, length = self.tokenize_sentence(sentence.text)
  269. remaining_tokens -= length + len(sentence.semantics[0].values)
  270. if use_interactive is False:
  271. final_text.append(text)
  272. final_semantic.append(sentence.semantics)
  273. else:
  274. # For interactive mode, we only apply speaker for the first sentence
  275. # [INST] [SPK: speaker] text [/INST] ... [INST] text [/INST]
  276. tokens, labels = self.pack_sentences(
  277. sentences=[text],
  278. semantics=[sentence.semantics],
  279. speaker=response.name if (self.use_speaker and idx == 0) else None,
  280. add_bos=idx == 0,
  281. )
  282. all_tokens.append(tokens)
  283. all_labels.append(labels)
  284. idx += 1
  285. if use_interactive is False:
  286. tokens, labels = self.pack_sentences(
  287. final_text,
  288. semantics=final_semantic,
  289. speaker=response.name if self.use_speaker else None,
  290. add_bos=True,
  291. )
  292. all_tokens.append(tokens)
  293. all_labels.append(labels)
  294. tokens = torch.cat(all_tokens, dim=1)
  295. labels = torch.cat(all_labels, dim=1)
  296. # Verify that the length is correct
  297. assert tokens.size(1) == labels.size(1), f"{tokens.size(1)} != {labels.size(1)}"
  298. # Verify bos token
  299. assert tokens[0, 0] == self.tokenizer.bos_token_id
  300. data = {"tokens": tokens, "labels": labels}
  301. if self.use_negative_samples:
  302. negative_samples = self.generate_negative_samples(all_tokens, all_labels)
  303. data.update(negative_samples)
  304. return data
  305. def generate_negative_samples(self, all_tokens, all_labels):
  306. new_tokens, new_labels = [], []
  307. for tokens, labels in zip(all_tokens, all_labels):
  308. # If all codebooks are not -100, we find where it starts
  309. start = torch.where(labels[1:].sum(0) != -100 * (labels.size(0) - 1))[0][0]
  310. assert (labels[1:, start:] != -100).all() # This shouldn't happen
  311. mode = random.choice(["repeat", "lost", "noise"])
  312. begin = random.randint(start, labels.size(1) - 1)
  313. end = random.randint(begin, labels.size(1) - 1)
  314. if mode == "repeat":
  315. tokens = torch.cat(
  316. [
  317. tokens[:, :begin],
  318. tokens[:, begin:end],
  319. tokens[:, begin:end],
  320. tokens[:, end:],
  321. ],
  322. dim=1,
  323. )
  324. labels = torch.cat(
  325. [
  326. labels[:, :begin],
  327. labels[:, begin:end],
  328. labels[:, begin:end],
  329. labels[:, end:],
  330. ],
  331. dim=1,
  332. )
  333. elif mode == "lost":
  334. tokens = torch.cat([tokens[:, :begin], tokens[:, end:]], dim=1)
  335. labels = torch.cat([labels[:, :begin], labels[:, end:]], dim=1)
  336. elif mode == "noise":
  337. middle_tokens, middle_labels = (
  338. tokens[:, begin:end],
  339. labels[:, begin:end],
  340. )
  341. random_order0 = torch.randperm(middle_tokens.size(1))
  342. random_order1 = torch.randperm(middle_tokens.size(1))
  343. middle_tokens = middle_tokens[:, random_order0]
  344. middle_labels = middle_labels[:, random_order1]
  345. tokens = torch.cat(
  346. [tokens[:, :begin], middle_tokens, tokens[:, end:]], dim=1
  347. )
  348. labels = torch.cat(
  349. [labels[:, :begin], middle_labels, labels[:, end:]], dim=1
  350. )
  351. new_tokens.append(tokens)
  352. new_labels.append(labels)
  353. tokens = torch.cat(new_tokens, dim=1)
  354. labels = torch.cat(new_labels, dim=1)
  355. # Verify that the length is correct
  356. assert tokens.size(1) == labels.size(1), f"{tokens.size(1)} != {labels.size(1)}"
  357. return {"negative_tokens": tokens, "negative_labels": labels}
  358. def pack_sentences(
  359. self,
  360. sentences: list[str],
  361. semantics=list,
  362. speaker: Optional[str] = None,
  363. add_bos: bool = True,
  364. ):
  365. if speaker is not None:
  366. sentences = [f"[SPK: {speaker}]"] + sentences
  367. final_text = "<|im_start|>user<|im_sep|>" + " ".join(sentences) + "<|im_end|>"
  368. final_text = final_text + "<|im_start|>assistant<|im_sep|>"
  369. encoded = self.tokenizer.encode(
  370. final_text,
  371. add_special_tokens=False,
  372. truncation=False,
  373. max_length=10**6,
  374. )
  375. semantic_length = sum([len(i[0].values) for i in semantics])
  376. prompt_length = len(encoded)
  377. num_codebooks = (
  378. len(semantics[0]) if self.num_codebooks is None else self.num_codebooks
  379. )
  380. bos_bias = 1 if add_bos else 0
  381. # Pack the tokens and semantics (add <s> and </s> to semantic tokens)
  382. tokens = (
  383. encoded
  384. + [self.semantic_token_id] * semantic_length
  385. + self.tokenizer.convert_tokens_to_ids(
  386. ["<|im_end|>", "<|end_of_sequence|>"]
  387. )
  388. )
  389. if add_bos:
  390. tokens = [self.tokenizer.bos_token_id] + tokens
  391. # Codebook bos/padding: 0, eos: 1
  392. codes = [
  393. [CODEBOOK_PAD_TOKEN_ID] * (prompt_length + bos_bias)
  394. for _ in range(num_codebooks)
  395. ]
  396. for segment in semantics:
  397. for book_idx, book in zip(range(num_codebooks), segment):
  398. for j in book.values:
  399. codes[book_idx].append(int(j) + 2)
  400. for book in codes:
  401. book.extend([CODEBOOK_EOS_TOKEN_ID] * 2)
  402. tokens = [tokens] + codes
  403. tokens = torch.tensor(tokens, dtype=torch.long)
  404. labels = tokens.clone()
  405. # Mask out the <s> tokens for semantic, predict semantic tokens only
  406. # Since we don't mask out the input tokens, the language modeling still works
  407. labels[1:, : (prompt_length + bos_bias)] = -100
  408. tokens = tokens[:, :-1]
  409. labels = labels[:, 1:]
  410. # Verify the padding is correct, and the last token is eos
  411. assert add_bos is False or tokens[0, 0] == self.tokenizer.bos_token_id
  412. assert (tokens[1:, : prompt_length + bos_bias] == CODEBOOK_PAD_TOKEN_ID).all()
  413. assert labels[0, -1] == self.tokenizer.eos_token_id
  414. assert (labels[1:, -2:] == CODEBOOK_EOS_TOKEN_ID).all()
  415. return tokens, labels
  416. @dataclass
  417. class TextDataCollator:
  418. tokenizer: AutoTokenizer
  419. max_length: int = 1024
  420. def __call__(self, examples):
  421. if "negative_tokens" in examples:
  422. positive_examples = []
  423. negative_examples = []
  424. for i in examples:
  425. positive_examples.append(
  426. {
  427. "tokens": i["tokens"],
  428. "labels": i["labels"],
  429. }
  430. )
  431. negative_examples.append(
  432. {
  433. "tokens": i["negative_tokens"],
  434. "labels": i["negative_labels"],
  435. }
  436. )
  437. examples = positive_examples + negative_examples
  438. return self.batchify(examples)
  439. def batchify(self, examples, tokens_key="tokens", labels_key="labels"):
  440. tokens, attention_masks, labels = [], [], []
  441. # Calculate the max length
  442. max_tokens_length = 0
  443. for example in examples:
  444. max_tokens_length = max(max_tokens_length, example[tokens_key].size(1))
  445. max_tokens_length = min(max_tokens_length, self.max_length)
  446. for example in examples:
  447. _tokens = example[tokens_key][:, :max_tokens_length]
  448. _labels = example[labels_key][:, :max_tokens_length]
  449. _attention_mask = torch.ones((max_tokens_length,), dtype=torch.bool)
  450. tokens_length = _tokens.size(1)
  451. _attention_mask[:tokens_length] = False
  452. assert tokens_length == _labels.size(
  453. 1
  454. ), f"{tokens_length} != {_labels.size(1)}"
  455. if tokens_length < max_tokens_length:
  456. _tokens = F.pad(
  457. _tokens,
  458. (0, max_tokens_length - tokens_length),
  459. value=self.tokenizer.eos_token_id,
  460. )
  461. _tokens[1:, tokens_length:] = CODEBOOK_PAD_TOKEN_ID
  462. _labels = F.pad(
  463. _labels, (0, max_tokens_length - _labels.size(1)), value=-100
  464. )
  465. tokens.append(_tokens)
  466. attention_masks.append(_attention_mask)
  467. labels.append(_labels)
  468. tokens = torch.stack(tokens, dim=0)
  469. attention_masks = torch.stack(attention_masks, dim=0)
  470. labels = torch.stack(labels, dim=0)
  471. return {
  472. "inputs": tokens,
  473. "attention_masks": attention_masks,
  474. "labels": labels,
  475. }
  476. class InterleaveDataset(IterableDataset):
  477. def __init__(
  478. self,
  479. datasets: list[IterableDataset],
  480. probabilities: list[float],
  481. seed: int = 42,
  482. ):
  483. super().__init__()
  484. self.datasets = datasets
  485. self.probabilities = probabilities
  486. self.seed = seed
  487. def __iter__(self):
  488. rng = np.random.default_rng(self.seed)
  489. dataset_iterators = [iter(dataset) for dataset in self.datasets]
  490. while True:
  491. # Random choice one
  492. dataset_idx = rng.choice(len(self.datasets), p=self.probabilities)
  493. dataset_iterator = dataset_iterators[dataset_idx]
  494. try:
  495. yield next(dataset_iterator)
  496. except StopIteration:
  497. # Exhausted, create a new iterator
  498. dataset_iterators[dataset_idx] = iter(self.datasets[dataset_idx])
  499. yield next(dataset_iterators[dataset_idx])
  500. class TextDataModule(LightningDataModule):
  501. def __init__(
  502. self,
  503. train_dataset: Union[StreamTextDataset, AutoAugTextDataset, InterleaveDataset],
  504. val_dataset: Union[StreamTextDataset, AutoAugTextDataset, InterleaveDataset],
  505. batch_size: int = 32,
  506. tokenizer: AutoTokenizer = None,
  507. max_length: int = 1024,
  508. num_workers: int = 4,
  509. ):
  510. super().__init__()
  511. self.train_dataset = train_dataset
  512. self.val_dataset = val_dataset
  513. self.batch_size = batch_size
  514. self.tokenizer = tokenizer
  515. self.max_length = max_length
  516. self.num_workers = num_workers
  517. def train_dataloader(self):
  518. return DataLoader(
  519. self.train_dataset,
  520. batch_size=self.batch_size,
  521. collate_fn=TextDataCollator(self.tokenizer, self.max_length),
  522. num_workers=self.num_workers,
  523. )
  524. def val_dataloader(self):
  525. return DataLoader(
  526. self.val_dataset,
  527. batch_size=self.batch_size,
  528. collate_fn=TextDataCollator(self.tokenizer, self.max_length),
  529. num_workers=self.num_workers,
  530. )
  531. if __name__ == "__main__":
  532. from tqdm import tqdm
  533. ds = AutoAugTextDataset(
  534. ["data/protos/test"],
  535. tokenizer=AutoTokenizer.from_pretrained("fishaudio/fish-speech-1"),
  536. use_speaker=False,
  537. interactive_prob=1.0,
  538. use_negative_samples=False,
  539. )
  540. for i in ds:
  541. print(ds.tokenizer.decode(i["tokens"][0], skip_special_tokens=False))
  542. # i["labels"][0][i["labels"][0] == -100] = 0
  543. # print(ds.tokenizer.decode(i["labels"][0], skip_special_tokens=False))
  544. break