text.py 23 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687
  1. import random
  2. from dataclasses import dataclass
  3. from itertools import chain
  4. from pathlib import Path
  5. from random import Random
  6. from typing import Optional, Union
  7. import numpy as np
  8. import pyarrow.parquet as pq
  9. import torch
  10. import torch.nn.functional as F
  11. from datasets.download.streaming_download_manager import xopen
  12. from huggingface_hub import HfApi
  13. from lightning import LightningDataModule
  14. from torch.distributed import get_rank, get_world_size, is_initialized
  15. from torch.utils.data import DataLoader, IterableDataset, get_worker_info
  16. from transformers import AutoTokenizer
  17. from fish_speech.datasets.protos.text_data_pb2 import SampledData
  18. from fish_speech.datasets.protos.text_data_stream import read_pb_stream
  19. from fish_speech.text.clean import clean_text
  20. from fish_speech.utils import RankedLogger
  21. from fish_speech.utils.braceexpand import braceexpand
  22. log = RankedLogger(__name__, rank_zero_only=True)
  23. CODEBOOK_PAD_TOKEN_ID = 0
  24. CODEBOOK_EOS_TOKEN_ID = 1
  25. SKIP_TEXT_STRING = "<|skip_text|>"
  26. def split_by_rank_worker(files):
  27. # We need to know the total number of devices
  28. # to split the data properly
  29. total_devices = 1
  30. if is_initialized():
  31. total_devices = get_world_size()
  32. worker_info = get_worker_info()
  33. if worker_info is not None:
  34. total_devices *= worker_info.num_workers
  35. if len(files) < total_devices:
  36. # Repeat the files N times to match the number of devices
  37. files = files * (total_devices // len(files) + 1)
  38. # DDP
  39. if is_initialized():
  40. files = files[get_rank() :: get_world_size()]
  41. # Split by worker
  42. if worker_info is not None:
  43. files = files[worker_info.id :: worker_info.num_workers]
  44. return files
  45. class StreamTextDataset(IterableDataset):
  46. def __init__(
  47. self,
  48. files: Optional[Union[list[str], str]] = None,
  49. prefix: Optional[str] = None,
  50. seed: int = 42,
  51. parquet_batch_size: int = 10000,
  52. repo: str = "uonlp/CulturaX",
  53. max_length: int = 1024,
  54. tokenizer: AutoTokenizer = None,
  55. ):
  56. super().__init__()
  57. self.seed = seed
  58. self.parquet_batch_size = parquet_batch_size
  59. self.repo = repo
  60. self.max_length = max_length
  61. self.tokenizer = tokenizer
  62. if files is None and prefix is None:
  63. raise ValueError("Either files or prefix must be specified")
  64. if prefix is not None:
  65. files = HfApi().list_repo_files(repo, repo_type="dataset")
  66. files = [
  67. f for f in files if f.startswith(prefix) and f.endswith(".parquet")
  68. ]
  69. log.info(f"Found {len(files)} files in {repo} with prefix {prefix}")
  70. else:
  71. if isinstance(files, str):
  72. files = [files]
  73. files = list(chain.from_iterable(map(braceexpand, files)))
  74. log.info(f"Expanded {len(files)} files in {repo}")
  75. # Get sharded files
  76. self.files = sorted(files)
  77. Random(seed).shuffle(self.files)
  78. def __iter__(self):
  79. files = split_by_rank_worker(self.files)
  80. random.shuffle(files)
  81. for filename in files:
  82. try:
  83. yield from self.parse_data(filename)
  84. except Exception as e:
  85. log.exception(f"Failed to parse {filename}: {e}")
  86. def parse_data(self, filename: str):
  87. for data in self.parse_data_internal(filename):
  88. text = data["text"]
  89. # encode
  90. tokens = self.tokenizer.encode(
  91. text,
  92. add_special_tokens=False,
  93. truncation=False,
  94. max_length=10**6,
  95. )
  96. # Random choice self.max_length
  97. if len(tokens) > self.max_length:
  98. start = random.randint(0, len(tokens) - self.max_length)
  99. tokens = tokens[start : start + self.max_length - 1]
  100. tokens = (
  101. [self.tokenizer.bos_token_id] + tokens + [self.tokenizer.eos_token_id]
  102. )
  103. # Pad dims
  104. placeholder_multi_codebook = torch.zeros((4, len(tokens)), dtype=torch.long)
  105. tokens = torch.concat(
  106. [
  107. torch.tensor([tokens], dtype=torch.long),
  108. placeholder_multi_codebook,
  109. ],
  110. dim=0,
  111. )
  112. labels = tokens.clone()
  113. tokens = tokens[:, :-1]
  114. labels = labels[:, 1:]
  115. labels[1:] = -100 # remove all placeholders
  116. yield {"tokens": tokens, "labels": labels}
  117. def parse_data_internal(self, filename: str):
  118. url = f"https://huggingface.co/datasets/{self.repo}/resolve/main/{filename}"
  119. with xopen(url, mode="rb") as stream:
  120. parquet_file = pq.ParquetFile(stream)
  121. for batch in parquet_file.iter_batches(
  122. batch_size=self.parquet_batch_size, columns=["text"]
  123. ):
  124. # In-batch shuffling
  125. texts = [{"text": text.as_py()} for text in batch["text"]]
  126. random.shuffle(texts)
  127. yield from texts
  128. class AutoAugTextDataset(IterableDataset):
  129. """
  130. Auto Augment Dataset by Speaker
  131. 1. Random concatenate multiple sentences from the same speaker to form a longer sentence
  132. 2. Automatically normalize the text
  133. For interactive mode, we use the following format (multiple sequences):
  134. <s> [INST] [SPK: speaker] text [/INST] ... [INST] text [/INST] </s>
  135. For non-interactive mode, we use the following format (one long sequence):
  136. <s> [INST] text [/INST] ... </s>
  137. """
  138. def __init__(
  139. self,
  140. proto_files: list[str],
  141. seed: int = 42,
  142. interactive_prob: float = 0.5,
  143. max_length: int = 1024,
  144. tokenizer: AutoTokenizer = None,
  145. use_speaker: bool | float = True,
  146. causual: bool = True,
  147. use_negative_samples: bool = False,
  148. num_codebooks: Optional[int] = None,
  149. skip_text_prob: float = 0.0,
  150. ):
  151. """
  152. Args:
  153. proto_files: proto buf files if using local data
  154. seed: random seed
  155. interactive_prob: probability to use interactive mode
  156. max_length: max length of the text
  157. tokenizer: tokenizer
  158. use_speaker: include speaker information in the prompt
  159. causual: use causual sampling when using local data, disable will lead to random sampling
  160. use_negative_samples: generate negative samples
  161. num_codebooks: number of codebooks, if None, it will be automatically detected
  162. skip_text_prob: probability to skip the text (audio only), this only applies to interactive mode
  163. """
  164. super().__init__()
  165. assert 0 <= interactive_prob <= 1, "interactive_prob must be in [0, 1]"
  166. self.seed = seed
  167. self.max_length = max_length
  168. self.tokenizer = tokenizer
  169. self.interactive_prob = interactive_prob
  170. self.use_speaker = use_speaker
  171. self.proto_files = proto_files
  172. self.causual = causual
  173. self.use_negative_samples = use_negative_samples
  174. self.num_codebooks = num_codebooks
  175. self.skip_text_prob = skip_text_prob
  176. self.semantic_token_id = self.tokenizer.convert_tokens_to_ids("<|semantic|>")
  177. self.groups = None
  178. def init_mock_data_server(self):
  179. if self.groups is not None:
  180. return
  181. # Expand the proto files
  182. expanded_proto_files = []
  183. for filename in self.proto_files:
  184. for i in braceexpand(filename):
  185. i = Path(i)
  186. if i.is_file():
  187. expanded_proto_files.append(i)
  188. elif i.is_dir():
  189. expanded_proto_files.extend(i.rglob("*.proto"))
  190. expanded_proto_files.extend(i.rglob("*.protos"))
  191. else:
  192. raise ValueError(f"{i} is not a file or directory")
  193. expanded_proto_files = sorted(expanded_proto_files)
  194. Random(self.seed).shuffle(expanded_proto_files)
  195. self.groups = []
  196. shard_proto_files = split_by_rank_worker(expanded_proto_files)
  197. log.info(
  198. f"Reading {len(shard_proto_files)} / {len(expanded_proto_files)} files"
  199. )
  200. count = 0
  201. for filename in shard_proto_files:
  202. with open(filename, "rb") as f:
  203. for text_data in read_pb_stream(f):
  204. self.groups.append(text_data)
  205. count += 1
  206. log.info(f"Read total {count} groups of data")
  207. # Shuffle the lines
  208. Random(self.seed).shuffle(self.groups)
  209. self.group_weights = [len(i.sentences) for i in self.groups]
  210. def __iter__(self):
  211. while True:
  212. yield self.augment()
  213. def tokenize_sentence(self, sentence: str):
  214. sentence = clean_text(sentence)
  215. tokens = self.tokenizer.encode(
  216. f"{sentence}",
  217. max_length=10**6,
  218. add_special_tokens=False,
  219. truncation=False,
  220. )
  221. return sentence, len(tokens)
  222. def sample_data(self):
  223. if self.groups is None:
  224. self.init_mock_data_server()
  225. # Shuffle unique lines, estimate that each sample is at least 20 tokens
  226. num_samples = self.max_length // 20
  227. # choice group based on their number of samples
  228. group = random.choices(self.groups, weights=self.group_weights, k=1)[0]
  229. if self.causual:
  230. # Sample in order
  231. if num_samples >= len(group.sentences):
  232. samples = group.sentences
  233. else:
  234. begin = random.randint(0, len(group.sentences) - num_samples)
  235. samples = group.sentences[begin : begin + num_samples]
  236. else:
  237. samples = random.choices(
  238. group.sentences, k=min(num_samples, len(group.sentences))
  239. )
  240. return SampledData(
  241. source=group.source,
  242. name=group.name,
  243. samples=samples,
  244. )
  245. def augment(self):
  246. final_text, final_semantic = [], []
  247. response = self.sample_data()
  248. if len(response.samples) == 0:
  249. # Invalid group
  250. return None
  251. samples = list(response.samples)
  252. idx = 0
  253. use_interactive = random.random() < self.interactive_prob
  254. if use_interactive is False:
  255. # Random sample based on speaker using a truncated normal distribution
  256. a = torch.tensor([0], dtype=torch.float32)
  257. torch.nn.init.trunc_normal_(
  258. a,
  259. mean=self.max_length // 2,
  260. std=self.max_length // 4,
  261. a=10,
  262. b=self.max_length,
  263. )
  264. remaining_tokens = a.long().item() - 4
  265. else:
  266. remaining_tokens = self.max_length
  267. # Use speaker
  268. if isinstance(self.use_speaker, float):
  269. use_speaker = random.random() < self.use_speaker
  270. else:
  271. use_speaker = self.use_speaker
  272. all_tokens, all_labels = [], []
  273. while remaining_tokens > 0 and len(samples) > 0:
  274. sentence = samples.pop(0)
  275. text = random.choice(sentence.texts)
  276. text, length = self.tokenize_sentence(text)
  277. remaining_tokens -= length + len(sentence.semantics[0].values)
  278. if use_interactive is False:
  279. final_text.append(text)
  280. final_semantic.append(sentence.semantics)
  281. else:
  282. # For interactive mode, we only apply speaker for the first sentence
  283. # [INST] [SPK: speaker] text [/INST] ... [INST] text [/INST]
  284. tokens, labels = self.pack_sentences(
  285. sentences=[text],
  286. semantics=[sentence.semantics],
  287. speaker=response.name if use_speaker else None,
  288. add_bos=idx == 0,
  289. skip_text=random.random() < self.skip_text_prob,
  290. )
  291. all_tokens.append(tokens)
  292. all_labels.append(labels)
  293. idx += 1
  294. if use_interactive is False:
  295. tokens, labels = self.pack_sentences(
  296. final_text,
  297. semantics=final_semantic,
  298. speaker=response.name if use_speaker else None,
  299. add_bos=True,
  300. )
  301. all_tokens.append(tokens)
  302. all_labels.append(labels)
  303. tokens = torch.cat(all_tokens, dim=1)
  304. labels = torch.cat(all_labels, dim=1)
  305. # Verify that the length is correct
  306. assert tokens.size(1) == labels.size(1), f"{tokens.size(1)} != {labels.size(1)}"
  307. # Verify bos token
  308. assert tokens[0, 0] == self.tokenizer.bos_token_id
  309. data = {"tokens": tokens, "labels": labels}
  310. if self.use_negative_samples:
  311. negative_samples = self.generate_negative_samples(all_tokens, all_labels)
  312. data.update(negative_samples)
  313. return data
  314. def generate_negative_samples(self, all_tokens, all_labels):
  315. new_tokens, new_labels = [], []
  316. for tokens, labels in zip(all_tokens, all_labels):
  317. # If all codebooks are not -100, we find where it starts
  318. start = torch.where(labels[1:].sum(0) != -100 * (labels.size(0) - 1))[0][0]
  319. assert (labels[1:, start:] != -100).all() # This shouldn't happen
  320. mode = random.choice(["repeat", "lost", "noise"])
  321. begin = random.randint(start, labels.size(1) - 1)
  322. end = random.randint(begin, labels.size(1) - 1)
  323. if mode == "repeat":
  324. tokens = torch.cat(
  325. [
  326. tokens[:, :begin],
  327. tokens[:, begin:end],
  328. tokens[:, begin:end],
  329. tokens[:, end:],
  330. ],
  331. dim=1,
  332. )
  333. labels = torch.cat(
  334. [
  335. labels[:, :begin],
  336. labels[:, begin:end],
  337. labels[:, begin:end],
  338. labels[:, end:],
  339. ],
  340. dim=1,
  341. )
  342. elif mode == "lost":
  343. tokens = torch.cat([tokens[:, :begin], tokens[:, end:]], dim=1)
  344. labels = torch.cat([labels[:, :begin], labels[:, end:]], dim=1)
  345. elif mode == "noise":
  346. middle_tokens, middle_labels = (
  347. tokens[:, begin:end],
  348. labels[:, begin:end],
  349. )
  350. random_order0 = torch.randperm(middle_tokens.size(1))
  351. random_order1 = torch.randperm(middle_tokens.size(1))
  352. middle_tokens = middle_tokens[:, random_order0]
  353. middle_labels = middle_labels[:, random_order1]
  354. tokens = torch.cat(
  355. [tokens[:, :begin], middle_tokens, tokens[:, end:]], dim=1
  356. )
  357. labels = torch.cat(
  358. [labels[:, :begin], middle_labels, labels[:, end:]], dim=1
  359. )
  360. new_tokens.append(tokens)
  361. new_labels.append(labels)
  362. tokens = torch.cat(new_tokens, dim=1)
  363. labels = torch.cat(new_labels, dim=1)
  364. # Verify that the length is correct
  365. assert tokens.size(1) == labels.size(1), f"{tokens.size(1)} != {labels.size(1)}"
  366. return {"negative_tokens": tokens, "negative_labels": labels}
  367. def pack_sentences(
  368. self,
  369. sentences: list[str],
  370. semantics: list,
  371. speaker: Optional[str] = None,
  372. add_bos: bool = True,
  373. skip_text: bool = False,
  374. ):
  375. if speaker is None:
  376. speaker = "assistant"
  377. cated_sentences = " ".join(sentences)
  378. if skip_text:
  379. cated_sentences = SKIP_TEXT_STRING
  380. final_text = "<|im_start|>user<|im_sep|>" + cated_sentences + "<|im_end|>"
  381. final_text = final_text + f"<|im_start|>{speaker}<|im_sep|>"
  382. encoded = self.tokenizer.encode(
  383. final_text,
  384. add_special_tokens=False,
  385. truncation=False,
  386. max_length=10**6,
  387. )
  388. semantic_length = sum([len(i[0].values) for i in semantics])
  389. prompt_length = len(encoded)
  390. num_codebooks = (
  391. len(semantics[0]) if self.num_codebooks is None else self.num_codebooks
  392. )
  393. bos_bias = 1 if add_bos else 0
  394. # Pack the tokens and semantics (add <s> and </s> to semantic tokens)
  395. tokens = (
  396. encoded
  397. + [self.semantic_token_id] * semantic_length
  398. + self.tokenizer.convert_tokens_to_ids(
  399. ["<|im_end|>", "<|end_of_sequence|>"]
  400. )
  401. )
  402. if add_bos:
  403. tokens = [self.tokenizer.bos_token_id] + tokens
  404. # Codebook bos/padding: 0, eos: 1
  405. codes = [
  406. [CODEBOOK_PAD_TOKEN_ID] * (prompt_length + bos_bias)
  407. for _ in range(num_codebooks)
  408. ]
  409. for segment in semantics:
  410. for book_idx, book in zip(range(num_codebooks), segment):
  411. for j in book.values:
  412. codes[book_idx].append(int(j) + 2)
  413. for book in codes:
  414. book.extend([CODEBOOK_EOS_TOKEN_ID] * 2)
  415. tokens = [tokens] + codes
  416. tokens = torch.tensor(tokens, dtype=torch.long)
  417. labels = tokens.clone()
  418. if skip_text:
  419. # If text is not provided, the sentence is used for condition only, all labels are -100
  420. torch.fill_(labels, -100)
  421. return tokens, labels
  422. # Mask out the <s> tokens for semantic, predict semantic tokens only
  423. # Since we don't mask out the input tokens, the language modeling still works
  424. labels[1:, : (prompt_length + bos_bias)] = -100
  425. tokens = tokens[:, :-1]
  426. labels = labels[:, 1:]
  427. # Verify the padding is correct, and the last token is eos
  428. assert add_bos is False or tokens[0, 0] == self.tokenizer.bos_token_id
  429. assert (tokens[1:, : prompt_length + bos_bias] == CODEBOOK_PAD_TOKEN_ID).all()
  430. assert labels[0, -1] == self.tokenizer.eos_token_id
  431. assert (labels[1:, -2:] == CODEBOOK_EOS_TOKEN_ID).all()
  432. return tokens, labels
  433. @dataclass
  434. class TextDataCollator:
  435. tokenizer: AutoTokenizer
  436. max_length: int = 1024
  437. def __call__(self, examples):
  438. if "negative_tokens" in examples:
  439. positive_examples = []
  440. negative_examples = []
  441. for i in examples:
  442. positive_examples.append(
  443. {
  444. "tokens": i["tokens"],
  445. "labels": i["labels"],
  446. }
  447. )
  448. negative_examples.append(
  449. {
  450. "tokens": i["negative_tokens"],
  451. "labels": i["negative_labels"],
  452. }
  453. )
  454. examples = positive_examples + negative_examples
  455. return self.batchify(examples)
  456. def batchify(self, examples, tokens_key="tokens", labels_key="labels"):
  457. tokens, attention_masks, labels = [], [], []
  458. # Calculate the max length
  459. max_tokens_length = 0
  460. for example in examples:
  461. max_tokens_length = max(max_tokens_length, example[tokens_key].size(1))
  462. max_tokens_length = min(max_tokens_length, self.max_length)
  463. for example in examples:
  464. _tokens = example[tokens_key][:, :max_tokens_length]
  465. _labels = example[labels_key][:, :max_tokens_length]
  466. _attention_mask = torch.ones((max_tokens_length,), dtype=torch.bool)
  467. tokens_length = _tokens.size(1)
  468. _attention_mask[:tokens_length] = False
  469. assert tokens_length == _labels.size(
  470. 1
  471. ), f"{tokens_length} != {_labels.size(1)}"
  472. if tokens_length < max_tokens_length:
  473. _tokens = F.pad(
  474. _tokens,
  475. (0, max_tokens_length - tokens_length),
  476. value=self.tokenizer.eos_token_id,
  477. )
  478. _tokens[1:, tokens_length:] = CODEBOOK_PAD_TOKEN_ID
  479. _labels = F.pad(
  480. _labels, (0, max_tokens_length - _labels.size(1)), value=-100
  481. )
  482. tokens.append(_tokens)
  483. attention_masks.append(_attention_mask)
  484. labels.append(_labels)
  485. tokens = torch.stack(tokens, dim=0)
  486. attention_masks = torch.stack(attention_masks, dim=0)
  487. labels = torch.stack(labels, dim=0)
  488. return {
  489. "inputs": tokens,
  490. "attention_masks": attention_masks,
  491. "labels": labels,
  492. }
  493. class InterleaveDataset(IterableDataset):
  494. def __init__(
  495. self,
  496. datasets: list[IterableDataset],
  497. probabilities: list[float],
  498. seed: int = 42,
  499. ):
  500. super().__init__()
  501. self.datasets = datasets
  502. self.probabilities = probabilities
  503. self.seed = seed
  504. def __iter__(self):
  505. rng = np.random.default_rng(self.seed)
  506. dataset_iterators = [iter(dataset) for dataset in self.datasets]
  507. while True:
  508. # Random choice one
  509. dataset_idx = rng.choice(len(self.datasets), p=self.probabilities)
  510. dataset_iterator = dataset_iterators[dataset_idx]
  511. try:
  512. yield next(dataset_iterator)
  513. except StopIteration:
  514. # Exhausted, create a new iterator
  515. dataset_iterators[dataset_idx] = iter(self.datasets[dataset_idx])
  516. yield next(dataset_iterators[dataset_idx])
  517. class TextDataModule(LightningDataModule):
  518. def __init__(
  519. self,
  520. train_dataset: Union[StreamTextDataset, AutoAugTextDataset, InterleaveDataset],
  521. val_dataset: Union[StreamTextDataset, AutoAugTextDataset, InterleaveDataset],
  522. batch_size: int = 32,
  523. tokenizer: AutoTokenizer = None,
  524. max_length: int = 1024,
  525. num_workers: int = 4,
  526. ):
  527. super().__init__()
  528. self.train_dataset = train_dataset
  529. self.val_dataset = val_dataset
  530. self.batch_size = batch_size
  531. self.tokenizer = tokenizer
  532. self.max_length = max_length
  533. self.num_workers = num_workers
  534. def train_dataloader(self):
  535. return DataLoader(
  536. self.train_dataset,
  537. batch_size=self.batch_size,
  538. collate_fn=TextDataCollator(self.tokenizer, self.max_length),
  539. num_workers=self.num_workers,
  540. persistent_workers=True,
  541. )
  542. def val_dataloader(self):
  543. return DataLoader(
  544. self.val_dataset,
  545. batch_size=self.batch_size,
  546. collate_fn=TextDataCollator(self.tokenizer, self.max_length),
  547. num_workers=self.num_workers,
  548. persistent_workers=True,
  549. )
  550. if __name__ == "__main__":
  551. from tqdm import tqdm
  552. ds = AutoAugTextDataset(
  553. ["data/protos"],
  554. tokenizer=AutoTokenizer.from_pretrained("fishaudio/fish-speech-1"),
  555. use_speaker=False,
  556. interactive_prob=1.0,
  557. use_negative_samples=False,
  558. skip_text_prob=0.5,
  559. )
  560. for i in ds:
  561. print(ds.tokenizer.decode(i["tokens"][0], skip_special_tokens=False))
  562. # i["labels"][0][i["labels"][0] == -100] = 0
  563. # print(ds.tokenizer.decode(i["labels"][0], skip_special_tokens=False))
  564. break