text.py 24 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717
  1. import random
  2. from dataclasses import dataclass
  3. from itertools import chain
  4. from pathlib import Path
  5. from random import Random
  6. from typing import Optional, Union
  7. import grpc
  8. import numpy as np
  9. import pyarrow.parquet as pq
  10. import torch
  11. import torch.nn.functional as F
  12. from datasets.download.streaming_download_manager import xopen
  13. from huggingface_hub import HfApi
  14. from lightning import LightningDataModule
  15. from torch.distributed import get_rank, get_world_size, is_initialized
  16. from torch.utils.data import DataLoader, IterableDataset, get_worker_info
  17. from transformers import AutoTokenizer
  18. from fish_speech.datasets.protos.text_data_pb2 import SampleDataRequest, SampledData
  19. from fish_speech.datasets.protos.text_data_pb2_grpc import DataServiceStub
  20. from fish_speech.datasets.protos.text_data_stream import read_pb_stream
  21. from fish_speech.text.parser import clean_text
  22. from fish_speech.text.symbols import pad as pad_symbol
  23. from fish_speech.text.symbols import pu_symbols
  24. from fish_speech.utils import RankedLogger
  25. from fish_speech.utils.braceexpand import braceexpand
  26. log = RankedLogger(__name__, rank_zero_only=True)
  27. CODEBOOK_PAD_TOKEN_ID = 0
  28. CODEBOOK_EOS_TOKEN_ID = 1
  29. def split_by_rank_worker(files):
  30. # We need to know the total number of devices
  31. # to split the data properly
  32. total_devices = 1
  33. if is_initialized():
  34. total_devices = get_world_size()
  35. worker_info = get_worker_info()
  36. if worker_info is not None:
  37. total_devices *= worker_info.num_workers
  38. if len(files) < total_devices:
  39. # Repeat the files N times to match the number of devices
  40. files = files * (total_devices // len(files) + 1)
  41. # DDP
  42. if is_initialized():
  43. files = files[get_rank() :: get_world_size()]
  44. # Split by worker
  45. if worker_info is not None:
  46. files = files[worker_info.id :: worker_info.num_workers]
  47. return files
  48. class StreamTextDataset(IterableDataset):
  49. def __init__(
  50. self,
  51. files: Optional[Union[list[str], str]] = None,
  52. prefix: Optional[str] = None,
  53. seed: int = 42,
  54. parquet_batch_size: int = 10000,
  55. repo: str = "uonlp/CulturaX",
  56. max_length: int = 1024,
  57. tokenizer: AutoTokenizer = None,
  58. ):
  59. super().__init__()
  60. self.seed = seed
  61. self.parquet_batch_size = parquet_batch_size
  62. self.repo = repo
  63. self.max_length = max_length
  64. self.tokenizer = tokenizer
  65. if files is None and prefix is None:
  66. raise ValueError("Either files or prefix must be specified")
  67. if prefix is not None:
  68. files = HfApi().list_repo_files(repo, repo_type="dataset")
  69. files = [
  70. f for f in files if f.startswith(prefix) and f.endswith(".parquet")
  71. ]
  72. log.info(f"Found {len(files)} files in {repo} with prefix {prefix}")
  73. else:
  74. if isinstance(files, str):
  75. files = [files]
  76. files = list(chain.from_iterable(map(braceexpand, files)))
  77. log.info(f"Expanded {len(files)} files in {repo}")
  78. # Get sharded files
  79. self.files = sorted(files)
  80. Random(seed).shuffle(self.files)
  81. def __iter__(self):
  82. files = split_by_rank_worker(self.files)
  83. random.shuffle(files)
  84. for filename in files:
  85. try:
  86. yield from self.parse_data(filename)
  87. except Exception as e:
  88. log.exception(f"Failed to parse {filename}: {e}")
  89. def parse_data(self, filename: str):
  90. for data in self.parse_data_internal(filename):
  91. text = data["text"]
  92. # 30% modeling phones
  93. if random.random() < 0.3:
  94. text = " ".join(
  95. [
  96. (f"<p:{i}>" if i not in pu_symbols and i != pad_symbol else i)
  97. for i in text
  98. ]
  99. )
  100. # encode
  101. tokens = self.tokenizer.encode(
  102. text,
  103. add_special_tokens=False,
  104. truncation=False,
  105. max_length=10**6,
  106. )
  107. # Random choice self.max_length
  108. if len(tokens) > self.max_length:
  109. start = random.randint(0, len(tokens) - self.max_length)
  110. tokens = tokens[start : start + self.max_length - 1]
  111. tokens = (
  112. [self.tokenizer.bos_token_id] + tokens + [self.tokenizer.eos_token_id]
  113. )
  114. # Pad dims
  115. placeholder_multi_codebook = torch.zeros((4, len(tokens)), dtype=torch.long)
  116. tokens = torch.concat(
  117. [
  118. torch.tensor([tokens], dtype=torch.long),
  119. placeholder_multi_codebook,
  120. ],
  121. dim=0,
  122. )
  123. labels = tokens.clone()
  124. tokens = tokens[:, :-1]
  125. labels = labels[:, 1:]
  126. labels[1:] = -100 # remove all placeholders
  127. yield {"tokens": tokens, "labels": labels}
  128. def parse_data_internal(self, filename: str):
  129. url = f"https://huggingface.co/datasets/{self.repo}/resolve/main/{filename}"
  130. with xopen(url, mode="rb") as stream:
  131. parquet_file = pq.ParquetFile(stream)
  132. for batch in parquet_file.iter_batches(
  133. batch_size=self.parquet_batch_size, columns=["text"]
  134. ):
  135. # In-batch shuffling
  136. texts = [{"text": text.as_py()} for text in batch["text"]]
  137. random.shuffle(texts)
  138. yield from texts
  139. class AutoAugTextDataset(IterableDataset):
  140. """
  141. Auto Augment Dataset by Speaker
  142. 1. Random concatenate multiple sentences from the same speaker to form a longer sentence
  143. 2. Automatically normalize the text
  144. 3. Mix text and phones
  145. For interactive mode, we use the following format (multiple sequences):
  146. <s> [INST] [SPK: speaker] text [/INST] ... [INST] text [/INST] </s>
  147. For non-interactive mode, we use the following format (one long sequence):
  148. <s> [INST] text [/INST] ... </s>
  149. """
  150. def __init__(
  151. self,
  152. proto_files: list[str],
  153. seed: int = 42,
  154. phones_prob: float = 0.3,
  155. repetition_prob: float = 0.0,
  156. interactive_prob: float = 0.5,
  157. max_length: int = 1024,
  158. tokenizer: AutoTokenizer = None,
  159. use_speaker: bool = True,
  160. causual: bool = True,
  161. mix_text_phone_prob: float = 0.5,
  162. use_negative_samples: bool = False,
  163. num_codebooks: Optional[int] = None,
  164. ):
  165. """
  166. Args:
  167. proto_files: proto buf files if using local data
  168. seed: random seed
  169. phones_prob: probability to use phones
  170. repetition_prob: probability to repeat the same sentence
  171. interactive_prob: probability to use interactive mode
  172. max_length: max length of the text
  173. tokenizer: tokenizer
  174. use_speaker: include speaker information in the prompt
  175. causual: use causual sampling when using local data, disable will lead to random sampling
  176. mix_text_phone_prob: probability to mix text and phones, if this is 0, then it will be pure text or pure phones
  177. use_negative_samples: generate negative samples
  178. num_codebooks: number of codebooks, if None, it will be automatically detected
  179. """
  180. super().__init__()
  181. assert 0 <= phones_prob <= 1, "phones_prob must be in [0, 1]"
  182. assert 0 <= repetition_prob <= 1, "repetition_prob must be in [0, 1]"
  183. assert 0 <= interactive_prob <= 1, "interactive_prob must be in [0, 1]"
  184. assert 0 <= mix_text_phone_prob <= 1, "mix_text_phone_prob must be in [0, 1]"
  185. self.seed = seed
  186. self.phones_prob = phones_prob
  187. self.max_length = max_length
  188. self.tokenizer = tokenizer
  189. self.repetition_prob = repetition_prob
  190. self.interactive_prob = interactive_prob
  191. self.use_speaker = use_speaker
  192. self.proto_files = proto_files
  193. self.causual = causual
  194. self.mix_text_phone_prob = mix_text_phone_prob
  195. self.use_negative_samples = use_negative_samples
  196. self.num_codebooks = num_codebooks
  197. self.semantic_token_id = self.tokenizer.convert_tokens_to_ids("<s:0>")
  198. self.groups = None
  199. def init_mock_data_server(self):
  200. if self.groups is not None:
  201. return
  202. # Expand the proto files
  203. expanded_proto_files = []
  204. for filename in self.proto_files:
  205. for i in braceexpand(filename):
  206. i = Path(i)
  207. if i.is_file():
  208. expanded_proto_files.append(i)
  209. elif i.is_dir():
  210. expanded_proto_files.extend(i.rglob("*.proto"))
  211. expanded_proto_files.extend(i.rglob("*.protos"))
  212. else:
  213. raise ValueError(f"{i} is not a file or directory")
  214. expanded_proto_files = sorted(expanded_proto_files)
  215. Random(self.seed).shuffle(expanded_proto_files)
  216. self.groups = []
  217. shard_proto_files = split_by_rank_worker(expanded_proto_files)
  218. log.info(
  219. f"Reading {len(shard_proto_files)} / {len(expanded_proto_files)} files"
  220. )
  221. count = 0
  222. for filename in shard_proto_files:
  223. with open(filename, "rb") as f:
  224. for text_data in read_pb_stream(f):
  225. self.groups.append(text_data)
  226. count += 1
  227. log.info(f"Read total {count} groups of data")
  228. # Shuffle the lines
  229. Random(self.seed).shuffle(self.groups)
  230. def __iter__(self):
  231. while True:
  232. yield self.augment()
  233. def tokenize_sentence(self, sentence: str, phones: list[str], mode: str = "sample"):
  234. if (
  235. mode == "sample" and (random.random() < self.phones_prob)
  236. ) or mode == "phones":
  237. sentence = " ".join(
  238. [
  239. (f"<p:{i}>" if i not in pu_symbols and i != pad_symbol else i)
  240. for i in phones
  241. ]
  242. )
  243. else:
  244. sentence = clean_text(sentence)
  245. sentence = " ".join([f"<s:{i:d}>" for i in sentence.encode("utf-8")])
  246. tokens = self.tokenizer.encode(
  247. f"{sentence}",
  248. max_length=10**6,
  249. add_special_tokens=False,
  250. truncation=False,
  251. )
  252. return sentence, len(tokens)
  253. def sample_data(self):
  254. if self.groups is None:
  255. self.init_mock_data_server()
  256. # Shuffle unique lines, estimate that each sample is at least 20 tokens
  257. num_samples = self.max_length // 20
  258. # choice group based on their number of samples
  259. group = random.choices(
  260. self.groups, weights=[len(i.sentences) for i in self.groups], k=1
  261. )[0]
  262. if self.causual:
  263. # Sample in order
  264. if num_samples >= len(group.sentences):
  265. samples = group.sentences
  266. else:
  267. begin = random.randint(0, len(group.sentences) - num_samples)
  268. samples = group.sentences[begin : begin + num_samples]
  269. else:
  270. samples = random.choices(
  271. group.sentences, k=min(num_samples, len(group.sentences))
  272. )
  273. return SampledData(
  274. source=group.source,
  275. name=group.name,
  276. samples=samples,
  277. )
  278. def augment(self):
  279. # 50% to pure text or pure phones
  280. mode = "sample"
  281. if random.random() > self.mix_text_phone_prob:
  282. mode = random.choices(
  283. ["text", "phones"],
  284. weights=[1 - self.phones_prob, self.phones_prob],
  285. k=1,
  286. )[0]
  287. # Random sample based on speaker using a truncated normal distribution
  288. a = torch.tensor([0], dtype=torch.float32)
  289. torch.nn.init.trunc_normal_(
  290. a,
  291. mean=self.max_length // 2,
  292. std=self.max_length // 4,
  293. a=10,
  294. b=self.max_length,
  295. )
  296. remaining_tokens = a.long().item() - 4
  297. final_text, final_semantic = [], []
  298. response = self.sample_data()
  299. if len(response.samples) == 0:
  300. # Invalid group
  301. return None
  302. samples = list(response.samples)
  303. idx = 0
  304. use_interactive = random.random() < self.interactive_prob
  305. all_tokens, all_labels = [], []
  306. while remaining_tokens > 0 and len(samples) > 0:
  307. if random.random() < self.repetition_prob:
  308. # Repeat the same sentence
  309. sentence = samples[-1]
  310. else:
  311. sentence = samples.pop()
  312. text, length = self.tokenize_sentence(
  313. sentence.text, sentence.phones, mode=mode
  314. )
  315. remaining_tokens -= length + len(sentence.semantics[0].values)
  316. if use_interactive is False:
  317. final_text.append(text)
  318. final_semantic.append(sentence.semantics)
  319. else:
  320. # For interactive mode, we only apply speaker for the first sentence
  321. # [INST] [SPK: speaker] text [/INST] ... [INST] text [/INST]
  322. tokens, labels = self.pack_sentences(
  323. sentences=[text],
  324. semantics=[sentence.semantics],
  325. speaker=response.name if (self.use_speaker and idx == 0) else None,
  326. add_bos=idx == 0,
  327. )
  328. all_tokens.append(tokens)
  329. all_labels.append(labels)
  330. idx += 1
  331. if use_interactive is False:
  332. tokens, labels = self.pack_sentences(
  333. final_text,
  334. semantics=final_semantic,
  335. speaker=response.name if self.use_speaker else None,
  336. add_bos=True,
  337. )
  338. all_tokens.append(tokens)
  339. all_labels.append(labels)
  340. tokens = torch.cat(all_tokens, dim=1)
  341. labels = torch.cat(all_labels, dim=1)
  342. # Verify that the length is correct
  343. assert tokens.size(1) == labels.size(1), f"{tokens.size(1)} != {labels.size(1)}"
  344. # Verify only one <s> token
  345. assert (tokens[:, 0] == self.tokenizer.bos_token_id).sum() == 1
  346. data = {"tokens": tokens, "labels": labels}
  347. if self.use_negative_samples:
  348. negative_samples = self.generate_negative_samples(all_tokens, all_labels)
  349. data.update(negative_samples)
  350. return data
  351. def generate_negative_samples(self, all_tokens, all_labels):
  352. new_tokens, new_labels = [], []
  353. for tokens, labels in zip(all_tokens, all_labels):
  354. # If all codebooks are not -100, we find where it starts
  355. start = torch.where(labels[1:].sum(0) != -100 * (labels.size(0) - 1))[0][0]
  356. assert (labels[1:, start:] != -100).all() # This shouldn't happen
  357. mode = random.choice(["repeat", "lost", "noise"])
  358. begin = random.randint(start, labels.size(1) - 1)
  359. end = random.randint(begin, labels.size(1) - 1)
  360. if mode == "repeat":
  361. tokens = torch.cat(
  362. [
  363. tokens[:, :begin],
  364. tokens[:, begin:end],
  365. tokens[:, begin:end],
  366. tokens[:, end:],
  367. ],
  368. dim=1,
  369. )
  370. labels = torch.cat(
  371. [
  372. labels[:, :begin],
  373. labels[:, begin:end],
  374. labels[:, begin:end],
  375. labels[:, end:],
  376. ],
  377. dim=1,
  378. )
  379. elif mode == "lost":
  380. tokens = torch.cat([tokens[:, :begin], tokens[:, end:]], dim=1)
  381. labels = torch.cat([labels[:, :begin], labels[:, end:]], dim=1)
  382. elif mode == "noise":
  383. middle_tokens, middle_labels = (
  384. tokens[:, begin:end],
  385. labels[:, begin:end],
  386. )
  387. random_order0 = torch.randperm(middle_tokens.size(1))
  388. random_order1 = torch.randperm(middle_tokens.size(1))
  389. middle_tokens = middle_tokens[:, random_order0]
  390. middle_labels = middle_labels[:, random_order1]
  391. tokens = torch.cat(
  392. [tokens[:, :begin], middle_tokens, tokens[:, end:]], dim=1
  393. )
  394. labels = torch.cat(
  395. [labels[:, :begin], middle_labels, labels[:, end:]], dim=1
  396. )
  397. new_tokens.append(tokens)
  398. new_labels.append(labels)
  399. tokens = torch.cat(new_tokens, dim=1)
  400. labels = torch.cat(new_labels, dim=1)
  401. # Verify that the length is correct
  402. assert tokens.size(1) == labels.size(1), f"{tokens.size(1)} != {labels.size(1)}"
  403. return {"negative_tokens": tokens, "negative_labels": labels}
  404. def pack_sentences(
  405. self,
  406. sentences: list[str],
  407. semantics=list,
  408. speaker: Optional[str] = None,
  409. add_bos: bool = True,
  410. ):
  411. if speaker is not None:
  412. sentences = [f"[SPK: {speaker}]"] + sentences
  413. final_text = "[INST] " + " ".join(sentences) + " [/INST]"
  414. encoded = self.tokenizer.encode(
  415. final_text,
  416. add_special_tokens=False,
  417. truncation=False,
  418. max_length=10**6,
  419. )
  420. semantic_length = sum([len(i[0].values) for i in semantics])
  421. prompt_length = len(encoded)
  422. num_codebooks = (
  423. len(semantics[0]) if self.num_codebooks is None else self.num_codebooks
  424. )
  425. bos_bias = 1 if add_bos else 0
  426. # Pack the tokens and semantics (add <s> and </s> to semantic tokens)
  427. tokens = (
  428. encoded
  429. + [self.semantic_token_id] * semantic_length
  430. + [self.tokenizer.eos_token_id]
  431. )
  432. if add_bos:
  433. tokens = [self.tokenizer.bos_token_id] + tokens
  434. # Codebook bos/padding: 0, eos: 1
  435. codes = [
  436. [CODEBOOK_PAD_TOKEN_ID] * (prompt_length + bos_bias)
  437. for i in range(num_codebooks)
  438. ]
  439. for segment in semantics:
  440. for book_idx, book in zip(range(num_codebooks), segment):
  441. for j in book.values:
  442. codes[book_idx].append(int(j) + 2)
  443. for idx, book in enumerate(codes):
  444. book.append(CODEBOOK_EOS_TOKEN_ID)
  445. tokens = [tokens] + codes
  446. tokens = torch.tensor(tokens, dtype=torch.long)
  447. labels = tokens.clone()
  448. # Mask out the <s> tokens for semantic, predict semantic tokens only
  449. # Since we don't mask out the input tokens, the language modeling still works
  450. labels[1:, : (prompt_length + bos_bias)] = -100
  451. tokens = tokens[:, :-1]
  452. labels = labels[:, 1:]
  453. # Verify the padding is correct, and the last token is eos
  454. assert add_bos is False or tokens[0, 0] == self.tokenizer.bos_token_id
  455. assert (tokens[1:, : prompt_length + bos_bias] == CODEBOOK_PAD_TOKEN_ID).all()
  456. assert labels[0, -1] == self.tokenizer.eos_token_id
  457. assert (labels[1:, -1] == CODEBOOK_EOS_TOKEN_ID).all()
  458. return tokens, labels
  459. @dataclass
  460. class TextDataCollator:
  461. tokenizer: AutoTokenizer
  462. max_length: int = 1024
  463. def __call__(self, examples):
  464. if "negative_tokens" in examples:
  465. positive_examples = []
  466. negative_examples = []
  467. for i in examples:
  468. positive_examples.append(
  469. {
  470. "tokens": i["tokens"],
  471. "labels": i["labels"],
  472. }
  473. )
  474. negative_examples.append(
  475. {
  476. "tokens": i["negative_tokens"],
  477. "labels": i["negative_labels"],
  478. }
  479. )
  480. examples = positive_examples + negative_examples
  481. return self.batchify(examples)
  482. def batchify(self, examples, tokens_key="tokens", labels_key="labels"):
  483. tokens, attention_masks, labels = [], [], []
  484. # Calculate the max length
  485. max_tokens_length = 0
  486. for example in examples:
  487. max_tokens_length = max(max_tokens_length, example[tokens_key].size(1))
  488. max_tokens_length = min(max_tokens_length, self.max_length)
  489. for example in examples:
  490. _tokens = example[tokens_key][:, :max_tokens_length]
  491. _labels = example[labels_key][:, :max_tokens_length]
  492. _attention_mask = torch.ones((max_tokens_length,), dtype=torch.bool)
  493. tokens_length = _tokens.size(1)
  494. _attention_mask[:tokens_length] = False
  495. assert tokens_length == _labels.size(
  496. 1
  497. ), f"{tokens_length} != {_labels.size(1)}"
  498. if tokens_length < max_tokens_length:
  499. _tokens = F.pad(
  500. _tokens,
  501. (0, max_tokens_length - tokens_length),
  502. value=self.tokenizer.eos_token_id,
  503. )
  504. _tokens[1:, tokens_length:] = CODEBOOK_PAD_TOKEN_ID
  505. _labels = F.pad(
  506. _labels, (0, max_tokens_length - _labels.size(1)), value=-100
  507. )
  508. tokens.append(_tokens)
  509. attention_masks.append(_attention_mask)
  510. labels.append(_labels)
  511. tokens = torch.stack(tokens, dim=0)
  512. attention_masks = torch.stack(attention_masks, dim=0)
  513. labels = torch.stack(labels, dim=0)
  514. return {
  515. "inputs": tokens,
  516. "attention_masks": attention_masks,
  517. "labels": labels,
  518. }
  519. class InterleaveDataset(IterableDataset):
  520. def __init__(
  521. self,
  522. datasets: list[IterableDataset],
  523. probabilities: list[float],
  524. seed: int = 42,
  525. ):
  526. super().__init__()
  527. self.datasets = datasets
  528. self.probabilities = probabilities
  529. self.seed = seed
  530. def __iter__(self):
  531. rng = np.random.default_rng(self.seed)
  532. dataset_iterators = [iter(dataset) for dataset in self.datasets]
  533. while True:
  534. # Random choice one
  535. dataset_idx = rng.choice(len(self.datasets), p=self.probabilities)
  536. dataset_iterator = dataset_iterators[dataset_idx]
  537. try:
  538. yield next(dataset_iterator)
  539. except StopIteration:
  540. # Exhausted, create a new iterator
  541. dataset_iterators[dataset_idx] = iter(self.datasets[dataset_idx])
  542. yield next(dataset_iterators[dataset_idx])
  543. class TextDataModule(LightningDataModule):
  544. def __init__(
  545. self,
  546. train_dataset: Union[StreamTextDataset, AutoAugTextDataset, InterleaveDataset],
  547. val_dataset: Union[StreamTextDataset, AutoAugTextDataset, InterleaveDataset],
  548. batch_size: int = 32,
  549. tokenizer: AutoTokenizer = None,
  550. max_length: int = 1024,
  551. num_workers: int = 4,
  552. ):
  553. super().__init__()
  554. self.train_dataset = train_dataset
  555. self.val_dataset = val_dataset
  556. self.batch_size = batch_size
  557. self.tokenizer = tokenizer
  558. self.max_length = max_length
  559. self.num_workers = num_workers
  560. def train_dataloader(self):
  561. return DataLoader(
  562. self.train_dataset,
  563. batch_size=self.batch_size,
  564. collate_fn=TextDataCollator(self.tokenizer, self.max_length),
  565. num_workers=self.num_workers,
  566. )
  567. def val_dataloader(self):
  568. return DataLoader(
  569. self.val_dataset,
  570. batch_size=self.batch_size,
  571. collate_fn=TextDataCollator(self.tokenizer, self.max_length),
  572. num_workers=self.num_workers,
  573. )
  574. if __name__ == "__main__":
  575. from tqdm import tqdm
  576. ds = AutoAugTextDataset(
  577. tokenizer=AutoTokenizer.from_pretrained("fishaudio/speech-lm-v1"),
  578. use_speaker=True,
  579. interactive_prob=1.0,
  580. phones_prob=1.0,
  581. use_negative_samples=False,
  582. num_codebooks=4,
  583. )
  584. dm = TextDataModule(
  585. train_dataset=ds,
  586. val_dataset=ds,
  587. tokenizer=ds.tokenizer,
  588. batch_size=2,
  589. max_length=1024,
  590. num_workers=0,
  591. )
  592. for batch in tqdm(dm.train_dataloader()):
  593. print(batch)
  594. break