text.py 22 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662
  1. import random
  2. from dataclasses import dataclass
  3. from itertools import chain
  4. from pathlib import Path
  5. from random import Random
  6. from typing import Optional, Union
  7. import grpc
  8. import numpy as np
  9. import pyarrow.parquet as pq
  10. import torch
  11. import torch.nn.functional as F
  12. from datasets.download.streaming_download_manager import xopen
  13. from huggingface_hub import HfApi
  14. from lightning import LightningDataModule
  15. from torch.distributed import get_rank, get_world_size, is_initialized
  16. from torch.utils.data import DataLoader, IterableDataset, get_worker_info
  17. from transformers import AutoTokenizer
  18. from fish_speech.datasets.protos.text_data_pb2 import SampledData
  19. from fish_speech.datasets.protos.text_data_stream import read_pb_stream
  20. from fish_speech.text.clean import clean_text
  21. from fish_speech.utils import RankedLogger
  22. from fish_speech.utils.braceexpand import braceexpand
  23. log = RankedLogger(__name__, rank_zero_only=True)
  24. CODEBOOK_PAD_TOKEN_ID = 0
  25. CODEBOOK_EOS_TOKEN_ID = 1
  26. def split_by_rank_worker(files):
  27. # We need to know the total number of devices
  28. # to split the data properly
  29. total_devices = 1
  30. if is_initialized():
  31. total_devices = get_world_size()
  32. worker_info = get_worker_info()
  33. if worker_info is not None:
  34. total_devices *= worker_info.num_workers
  35. if len(files) < total_devices:
  36. # Repeat the files N times to match the number of devices
  37. files = files * (total_devices // len(files) + 1)
  38. # DDP
  39. if is_initialized():
  40. files = files[get_rank() :: get_world_size()]
  41. # Split by worker
  42. if worker_info is not None:
  43. files = files[worker_info.id :: worker_info.num_workers]
  44. return files
  45. class StreamTextDataset(IterableDataset):
  46. def __init__(
  47. self,
  48. files: Optional[Union[list[str], str]] = None,
  49. prefix: Optional[str] = None,
  50. seed: int = 42,
  51. parquet_batch_size: int = 10000,
  52. repo: str = "uonlp/CulturaX",
  53. max_length: int = 1024,
  54. tokenizer: AutoTokenizer = None,
  55. ):
  56. super().__init__()
  57. self.seed = seed
  58. self.parquet_batch_size = parquet_batch_size
  59. self.repo = repo
  60. self.max_length = max_length
  61. self.tokenizer = tokenizer
  62. if files is None and prefix is None:
  63. raise ValueError("Either files or prefix must be specified")
  64. if prefix is not None:
  65. files = HfApi().list_repo_files(repo, repo_type="dataset")
  66. files = [
  67. f for f in files if f.startswith(prefix) and f.endswith(".parquet")
  68. ]
  69. log.info(f"Found {len(files)} files in {repo} with prefix {prefix}")
  70. else:
  71. if isinstance(files, str):
  72. files = [files]
  73. files = list(chain.from_iterable(map(braceexpand, files)))
  74. log.info(f"Expanded {len(files)} files in {repo}")
  75. # Get sharded files
  76. self.files = sorted(files)
  77. Random(seed).shuffle(self.files)
  78. def __iter__(self):
  79. files = split_by_rank_worker(self.files)
  80. random.shuffle(files)
  81. for filename in files:
  82. try:
  83. yield from self.parse_data(filename)
  84. except Exception as e:
  85. log.exception(f"Failed to parse {filename}: {e}")
  86. def parse_data(self, filename: str):
  87. for data in self.parse_data_internal(filename):
  88. text = data["text"]
  89. # encode
  90. tokens = self.tokenizer.encode(
  91. text,
  92. add_special_tokens=False,
  93. truncation=False,
  94. max_length=10**6,
  95. )
  96. # Random choice self.max_length
  97. if len(tokens) > self.max_length:
  98. start = random.randint(0, len(tokens) - self.max_length)
  99. tokens = tokens[start : start + self.max_length - 1]
  100. tokens = (
  101. [self.tokenizer.bos_token_id] + tokens + [self.tokenizer.eos_token_id]
  102. )
  103. # Pad dims
  104. placeholder_multi_codebook = torch.zeros((4, len(tokens)), dtype=torch.long)
  105. tokens = torch.concat(
  106. [
  107. torch.tensor([tokens], dtype=torch.long),
  108. placeholder_multi_codebook,
  109. ],
  110. dim=0,
  111. )
  112. labels = tokens.clone()
  113. tokens = tokens[:, :-1]
  114. labels = labels[:, 1:]
  115. labels[1:] = -100 # remove all placeholders
  116. yield {"tokens": tokens, "labels": labels}
  117. def parse_data_internal(self, filename: str):
  118. url = f"https://huggingface.co/datasets/{self.repo}/resolve/main/{filename}"
  119. with xopen(url, mode="rb") as stream:
  120. parquet_file = pq.ParquetFile(stream)
  121. for batch in parquet_file.iter_batches(
  122. batch_size=self.parquet_batch_size, columns=["text"]
  123. ):
  124. # In-batch shuffling
  125. texts = [{"text": text.as_py()} for text in batch["text"]]
  126. random.shuffle(texts)
  127. yield from texts
  128. class AutoAugTextDataset(IterableDataset):
  129. """
  130. Auto Augment Dataset by Speaker
  131. 1. Random concatenate multiple sentences from the same speaker to form a longer sentence
  132. 2. Automatically normalize the text
  133. For interactive mode, we use the following format (multiple sequences):
  134. <s> [INST] [SPK: speaker] text [/INST] ... [INST] text [/INST] </s>
  135. For non-interactive mode, we use the following format (one long sequence):
  136. <s> [INST] text [/INST] ... </s>
  137. """
  138. def __init__(
  139. self,
  140. proto_files: list[str],
  141. seed: int = 42,
  142. interactive_prob: float = 0.5,
  143. max_length: int = 1024,
  144. tokenizer: AutoTokenizer = None,
  145. use_speaker: bool = True,
  146. causual: bool = True,
  147. use_negative_samples: bool = False,
  148. num_codebooks: Optional[int] = None,
  149. ):
  150. """
  151. Args:
  152. proto_files: proto buf files if using local data
  153. seed: random seed
  154. interactive_prob: probability to use interactive mode
  155. max_length: max length of the text
  156. tokenizer: tokenizer
  157. use_speaker: include speaker information in the prompt
  158. causual: use causual sampling when using local data, disable will lead to random sampling
  159. use_negative_samples: generate negative samples
  160. num_codebooks: number of codebooks, if None, it will be automatically detected
  161. """
  162. super().__init__()
  163. assert 0 <= interactive_prob <= 1, "interactive_prob must be in [0, 1]"
  164. self.seed = seed
  165. self.max_length = max_length
  166. self.tokenizer = tokenizer
  167. self.interactive_prob = interactive_prob
  168. self.use_speaker = use_speaker
  169. self.proto_files = proto_files
  170. self.causual = causual
  171. self.use_negative_samples = use_negative_samples
  172. self.num_codebooks = num_codebooks
  173. self.semantic_token_id = self.tokenizer.convert_tokens_to_ids("<|semantic|>")
  174. self.groups = None
  175. def init_mock_data_server(self):
  176. if self.groups is not None:
  177. return
  178. # Expand the proto files
  179. expanded_proto_files = []
  180. for filename in self.proto_files:
  181. for i in braceexpand(filename):
  182. i = Path(i)
  183. if i.is_file():
  184. expanded_proto_files.append(i)
  185. elif i.is_dir():
  186. expanded_proto_files.extend(i.rglob("*.proto"))
  187. expanded_proto_files.extend(i.rglob("*.protos"))
  188. else:
  189. raise ValueError(f"{i} is not a file or directory")
  190. expanded_proto_files = sorted(expanded_proto_files)
  191. Random(self.seed).shuffle(expanded_proto_files)
  192. self.groups = []
  193. shard_proto_files = split_by_rank_worker(expanded_proto_files)
  194. log.info(
  195. f"Reading {len(shard_proto_files)} / {len(expanded_proto_files)} files"
  196. )
  197. count = 0
  198. for filename in shard_proto_files:
  199. with open(filename, "rb") as f:
  200. for text_data in read_pb_stream(f):
  201. self.groups.append(text_data)
  202. count += 1
  203. log.info(f"Read total {count} groups of data")
  204. # Shuffle the lines
  205. Random(self.seed).shuffle(self.groups)
  206. def __iter__(self):
  207. while True:
  208. yield self.augment()
  209. def tokenize_sentence(self, sentence: str):
  210. sentence = clean_text(sentence)
  211. tokens = self.tokenizer.encode(
  212. f"{sentence}",
  213. max_length=10**6,
  214. add_special_tokens=False,
  215. truncation=False,
  216. )
  217. return sentence, len(tokens)
  218. def sample_data(self):
  219. if self.groups is None:
  220. self.init_mock_data_server()
  221. # Shuffle unique lines, estimate that each sample is at least 20 tokens
  222. num_samples = self.max_length // 20
  223. # choice group based on their number of samples
  224. group = random.choices(
  225. self.groups, weights=[len(i.sentences) for i in self.groups], k=1
  226. )[0]
  227. if self.causual:
  228. # Sample in order
  229. if num_samples >= len(group.sentences):
  230. samples = group.sentences
  231. else:
  232. begin = random.randint(0, len(group.sentences) - num_samples)
  233. samples = group.sentences[begin : begin + num_samples]
  234. else:
  235. samples = random.choices(
  236. group.sentences, k=min(num_samples, len(group.sentences))
  237. )
  238. return SampledData(
  239. source=group.source,
  240. name=group.name,
  241. samples=samples,
  242. )
  243. def augment(self):
  244. # Random sample based on speaker using a truncated normal distribution
  245. a = torch.tensor([0], dtype=torch.float32)
  246. torch.nn.init.trunc_normal_(
  247. a,
  248. mean=self.max_length // 2,
  249. std=self.max_length // 4,
  250. a=10,
  251. b=self.max_length,
  252. )
  253. remaining_tokens = a.long().item() - 4
  254. final_text, final_semantic = [], []
  255. response = self.sample_data()
  256. if len(response.samples) == 0:
  257. # Invalid group
  258. return None
  259. samples = list(response.samples)
  260. idx = 0
  261. use_interactive = random.random() < self.interactive_prob
  262. all_tokens, all_labels = [], []
  263. while remaining_tokens > 0 and len(samples) > 0:
  264. sentence = samples.pop()
  265. text = random.choice(sentence.texts)
  266. text, length = self.tokenize_sentence(text)
  267. remaining_tokens -= length + len(sentence.semantics[0].values)
  268. if use_interactive is False:
  269. final_text.append(text)
  270. final_semantic.append(sentence.semantics)
  271. else:
  272. # For interactive mode, we only apply speaker for the first sentence
  273. # [INST] [SPK: speaker] text [/INST] ... [INST] text [/INST]
  274. tokens, labels = self.pack_sentences(
  275. sentences=[text],
  276. semantics=[sentence.semantics],
  277. speaker=response.name if (self.use_speaker and idx == 0) else None,
  278. add_bos=idx == 0,
  279. )
  280. all_tokens.append(tokens)
  281. all_labels.append(labels)
  282. idx += 1
  283. if use_interactive is False:
  284. tokens, labels = self.pack_sentences(
  285. final_text,
  286. semantics=final_semantic,
  287. speaker=response.name if self.use_speaker else None,
  288. add_bos=True,
  289. )
  290. all_tokens.append(tokens)
  291. all_labels.append(labels)
  292. tokens = torch.cat(all_tokens, dim=1)
  293. labels = torch.cat(all_labels, dim=1)
  294. # Verify that the length is correct
  295. assert tokens.size(1) == labels.size(1), f"{tokens.size(1)} != {labels.size(1)}"
  296. # Verify bos token
  297. assert tokens[0, 0] == self.tokenizer.bos_token_id
  298. data = {"tokens": tokens, "labels": labels}
  299. if self.use_negative_samples:
  300. negative_samples = self.generate_negative_samples(all_tokens, all_labels)
  301. data.update(negative_samples)
  302. return data
  303. def generate_negative_samples(self, all_tokens, all_labels):
  304. new_tokens, new_labels = [], []
  305. for tokens, labels in zip(all_tokens, all_labels):
  306. # If all codebooks are not -100, we find where it starts
  307. start = torch.where(labels[1:].sum(0) != -100 * (labels.size(0) - 1))[0][0]
  308. assert (labels[1:, start:] != -100).all() # This shouldn't happen
  309. mode = random.choice(["repeat", "lost", "noise"])
  310. begin = random.randint(start, labels.size(1) - 1)
  311. end = random.randint(begin, labels.size(1) - 1)
  312. if mode == "repeat":
  313. tokens = torch.cat(
  314. [
  315. tokens[:, :begin],
  316. tokens[:, begin:end],
  317. tokens[:, begin:end],
  318. tokens[:, end:],
  319. ],
  320. dim=1,
  321. )
  322. labels = torch.cat(
  323. [
  324. labels[:, :begin],
  325. labels[:, begin:end],
  326. labels[:, begin:end],
  327. labels[:, end:],
  328. ],
  329. dim=1,
  330. )
  331. elif mode == "lost":
  332. tokens = torch.cat([tokens[:, :begin], tokens[:, end:]], dim=1)
  333. labels = torch.cat([labels[:, :begin], labels[:, end:]], dim=1)
  334. elif mode == "noise":
  335. middle_tokens, middle_labels = (
  336. tokens[:, begin:end],
  337. labels[:, begin:end],
  338. )
  339. random_order0 = torch.randperm(middle_tokens.size(1))
  340. random_order1 = torch.randperm(middle_tokens.size(1))
  341. middle_tokens = middle_tokens[:, random_order0]
  342. middle_labels = middle_labels[:, random_order1]
  343. tokens = torch.cat(
  344. [tokens[:, :begin], middle_tokens, tokens[:, end:]], dim=1
  345. )
  346. labels = torch.cat(
  347. [labels[:, :begin], middle_labels, labels[:, end:]], dim=1
  348. )
  349. new_tokens.append(tokens)
  350. new_labels.append(labels)
  351. tokens = torch.cat(new_tokens, dim=1)
  352. labels = torch.cat(new_labels, dim=1)
  353. # Verify that the length is correct
  354. assert tokens.size(1) == labels.size(1), f"{tokens.size(1)} != {labels.size(1)}"
  355. return {"negative_tokens": tokens, "negative_labels": labels}
  356. def pack_sentences(
  357. self,
  358. sentences: list[str],
  359. semantics=list,
  360. speaker: Optional[str] = None,
  361. add_bos: bool = True,
  362. ):
  363. if speaker is not None:
  364. sentences = [f"[SPK: {speaker}]"] + sentences
  365. final_text = "<|im_start|>user<|im_sep|>" + " ".join(sentences) + "<|im_end|>"
  366. final_text = final_text + "<|im_start|>assistant<|im_sep|>"
  367. encoded = self.tokenizer.encode(
  368. final_text,
  369. add_special_tokens=False,
  370. truncation=False,
  371. max_length=10**6,
  372. )
  373. semantic_length = sum([len(i[0].values) for i in semantics])
  374. prompt_length = len(encoded)
  375. num_codebooks = (
  376. len(semantics[0]) if self.num_codebooks is None else self.num_codebooks
  377. )
  378. bos_bias = 1 if add_bos else 0
  379. # Pack the tokens and semantics (add <s> and </s> to semantic tokens)
  380. tokens = (
  381. encoded
  382. + [self.semantic_token_id] * semantic_length
  383. + self.tokenizer.convert_tokens_to_ids(
  384. ["<|im_end|>", "<|end_of_sequence|>"]
  385. )
  386. )
  387. if add_bos:
  388. tokens = [self.tokenizer.bos_token_id] + tokens
  389. # Codebook bos/padding: 0, eos: 1
  390. codes = [
  391. [CODEBOOK_PAD_TOKEN_ID] * (prompt_length + bos_bias)
  392. for _ in range(num_codebooks)
  393. ]
  394. for segment in semantics:
  395. for book_idx, book in zip(range(num_codebooks), segment):
  396. for j in book.values:
  397. codes[book_idx].append(int(j) + 2)
  398. for book in codes:
  399. book.extend([CODEBOOK_EOS_TOKEN_ID] * 2)
  400. tokens = [tokens] + codes
  401. tokens = torch.tensor(tokens, dtype=torch.long)
  402. labels = tokens.clone()
  403. # Mask out the <s> tokens for semantic, predict semantic tokens only
  404. # Since we don't mask out the input tokens, the language modeling still works
  405. labels[1:, : (prompt_length + bos_bias)] = -100
  406. tokens = tokens[:, :-1]
  407. labels = labels[:, 1:]
  408. # Verify the padding is correct, and the last token is eos
  409. assert add_bos is False or tokens[0, 0] == self.tokenizer.bos_token_id
  410. assert (tokens[1:, : prompt_length + bos_bias] == CODEBOOK_PAD_TOKEN_ID).all()
  411. assert labels[0, -1] == self.tokenizer.eos_token_id
  412. assert (labels[1:, -2:] == CODEBOOK_EOS_TOKEN_ID).all()
  413. return tokens, labels
  414. @dataclass
  415. class TextDataCollator:
  416. tokenizer: AutoTokenizer
  417. max_length: int = 1024
  418. def __call__(self, examples):
  419. if "negative_tokens" in examples:
  420. positive_examples = []
  421. negative_examples = []
  422. for i in examples:
  423. positive_examples.append(
  424. {
  425. "tokens": i["tokens"],
  426. "labels": i["labels"],
  427. }
  428. )
  429. negative_examples.append(
  430. {
  431. "tokens": i["negative_tokens"],
  432. "labels": i["negative_labels"],
  433. }
  434. )
  435. examples = positive_examples + negative_examples
  436. return self.batchify(examples)
  437. def batchify(self, examples, tokens_key="tokens", labels_key="labels"):
  438. tokens, attention_masks, labels = [], [], []
  439. # Calculate the max length
  440. max_tokens_length = 0
  441. for example in examples:
  442. max_tokens_length = max(max_tokens_length, example[tokens_key].size(1))
  443. max_tokens_length = min(max_tokens_length, self.max_length)
  444. for example in examples:
  445. _tokens = example[tokens_key][:, :max_tokens_length]
  446. _labels = example[labels_key][:, :max_tokens_length]
  447. _attention_mask = torch.ones((max_tokens_length,), dtype=torch.bool)
  448. tokens_length = _tokens.size(1)
  449. _attention_mask[:tokens_length] = False
  450. assert tokens_length == _labels.size(
  451. 1
  452. ), f"{tokens_length} != {_labels.size(1)}"
  453. if tokens_length < max_tokens_length:
  454. _tokens = F.pad(
  455. _tokens,
  456. (0, max_tokens_length - tokens_length),
  457. value=self.tokenizer.eos_token_id,
  458. )
  459. _tokens[1:, tokens_length:] = CODEBOOK_PAD_TOKEN_ID
  460. _labels = F.pad(
  461. _labels, (0, max_tokens_length - _labels.size(1)), value=-100
  462. )
  463. tokens.append(_tokens)
  464. attention_masks.append(_attention_mask)
  465. labels.append(_labels)
  466. tokens = torch.stack(tokens, dim=0)
  467. attention_masks = torch.stack(attention_masks, dim=0)
  468. labels = torch.stack(labels, dim=0)
  469. return {
  470. "inputs": tokens,
  471. "attention_masks": attention_masks,
  472. "labels": labels,
  473. }
  474. class InterleaveDataset(IterableDataset):
  475. def __init__(
  476. self,
  477. datasets: list[IterableDataset],
  478. probabilities: list[float],
  479. seed: int = 42,
  480. ):
  481. super().__init__()
  482. self.datasets = datasets
  483. self.probabilities = probabilities
  484. self.seed = seed
  485. def __iter__(self):
  486. rng = np.random.default_rng(self.seed)
  487. dataset_iterators = [iter(dataset) for dataset in self.datasets]
  488. while True:
  489. # Random choice one
  490. dataset_idx = rng.choice(len(self.datasets), p=self.probabilities)
  491. dataset_iterator = dataset_iterators[dataset_idx]
  492. try:
  493. yield next(dataset_iterator)
  494. except StopIteration:
  495. # Exhausted, create a new iterator
  496. dataset_iterators[dataset_idx] = iter(self.datasets[dataset_idx])
  497. yield next(dataset_iterators[dataset_idx])
  498. class TextDataModule(LightningDataModule):
  499. def __init__(
  500. self,
  501. train_dataset: Union[StreamTextDataset, AutoAugTextDataset, InterleaveDataset],
  502. val_dataset: Union[StreamTextDataset, AutoAugTextDataset, InterleaveDataset],
  503. batch_size: int = 32,
  504. tokenizer: AutoTokenizer = None,
  505. max_length: int = 1024,
  506. num_workers: int = 4,
  507. ):
  508. super().__init__()
  509. self.train_dataset = train_dataset
  510. self.val_dataset = val_dataset
  511. self.batch_size = batch_size
  512. self.tokenizer = tokenizer
  513. self.max_length = max_length
  514. self.num_workers = num_workers
  515. def train_dataloader(self):
  516. return DataLoader(
  517. self.train_dataset,
  518. batch_size=self.batch_size,
  519. collate_fn=TextDataCollator(self.tokenizer, self.max_length),
  520. num_workers=self.num_workers,
  521. )
  522. def val_dataloader(self):
  523. return DataLoader(
  524. self.val_dataset,
  525. batch_size=self.batch_size,
  526. collate_fn=TextDataCollator(self.tokenizer, self.max_length),
  527. num_workers=self.num_workers,
  528. )
  529. if __name__ == "__main__":
  530. from tqdm import tqdm
  531. ds = AutoAugTextDataset(
  532. ["data/protos/test"],
  533. tokenizer=AutoTokenizer.from_pretrained("fishaudio/fish-speech-1"),
  534. use_speaker=False,
  535. interactive_prob=1.0,
  536. use_negative_samples=False,
  537. )
  538. for i in ds:
  539. print(ds.tokenizer.decode(i["tokens"][0], skip_special_tokens=False))
  540. # i["labels"][0][i["labels"][0] == -100] = 0
  541. # print(ds.tokenizer.decode(i["labels"][0], skip_special_tokens=False))
  542. break