text.py 24 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714
  1. import random
  2. from dataclasses import dataclass
  3. from itertools import chain
  4. from random import Random
  5. from typing import Optional, Union
  6. import grpc
  7. import numpy as np
  8. import pyarrow.parquet as pq
  9. import torch
  10. import torch.nn.functional as F
  11. from datasets.download.streaming_download_manager import xopen
  12. from huggingface_hub import HfApi
  13. from lightning import LightningDataModule
  14. from torch.distributed import get_rank, get_world_size, is_initialized
  15. from torch.utils.data import DataLoader, IterableDataset, get_worker_info
  16. from transformers import AutoTokenizer
  17. from fish_speech.datasets.protos.text_data_pb2 import SampleDataRequest, SampledData
  18. from fish_speech.datasets.protos.text_data_pb2_grpc import DataServiceStub
  19. from fish_speech.datasets.protos.text_data_stream import read_pb_stream
  20. from fish_speech.text.parser import clean_text
  21. from fish_speech.text.symbols import pad as pad_symbol
  22. from fish_speech.text.symbols import pu_symbols
  23. from fish_speech.utils import RankedLogger
  24. from fish_speech.utils.braceexpand import braceexpand
  25. log = RankedLogger(__name__, rank_zero_only=True)
  26. CODEBOOK_PAD_TOKEN_ID = 0
  27. CODEBOOK_EOS_TOKEN_ID = 1
  28. def split_by_rank_worker(files):
  29. # We need to know the total number of devices
  30. # to split the data properly
  31. total_devices = 1
  32. if is_initialized():
  33. total_devices = get_world_size()
  34. worker_info = get_worker_info()
  35. if worker_info is not None:
  36. total_devices *= worker_info.num_workers
  37. if len(files) < total_devices:
  38. # Repeat the files N times to match the number of devices
  39. files = files * (total_devices // len(files) + 1)
  40. # DDP
  41. if is_initialized():
  42. files = files[get_rank() :: get_world_size()]
  43. # Split by worker
  44. if worker_info is not None:
  45. files = files[worker_info.id :: worker_info.num_workers]
  46. return files
  47. class StreamTextDataset(IterableDataset):
  48. def __init__(
  49. self,
  50. files: Optional[Union[list[str], str]] = None,
  51. prefix: Optional[str] = None,
  52. seed: int = 42,
  53. parquet_batch_size: int = 10000,
  54. repo: str = "uonlp/CulturaX",
  55. max_length: int = 1024,
  56. tokenizer: AutoTokenizer = None,
  57. ):
  58. super().__init__()
  59. self.seed = seed
  60. self.parquet_batch_size = parquet_batch_size
  61. self.repo = repo
  62. self.max_length = max_length
  63. self.tokenizer = tokenizer
  64. if files is None and prefix is None:
  65. raise ValueError("Either files or prefix must be specified")
  66. if prefix is not None:
  67. files = HfApi().list_repo_files(repo, repo_type="dataset")
  68. files = [
  69. f for f in files if f.startswith(prefix) and f.endswith(".parquet")
  70. ]
  71. log.info(f"Found {len(files)} files in {repo} with prefix {prefix}")
  72. else:
  73. if isinstance(files, str):
  74. files = [files]
  75. files = list(chain.from_iterable(map(braceexpand, files)))
  76. log.info(f"Expanded {len(files)} files in {repo}")
  77. # Get sharded files
  78. self.files = sorted(files)
  79. Random(seed).shuffle(self.files)
  80. def __iter__(self):
  81. files = split_by_rank_worker(self.files)
  82. random.shuffle(files)
  83. for filename in files:
  84. try:
  85. yield from self.parse_data(filename)
  86. except Exception as e:
  87. log.exception(f"Failed to parse {filename}: {e}")
  88. def parse_data(self, filename: str):
  89. for data in self.parse_data_internal(filename):
  90. text = data["text"]
  91. # 30% modeling phones
  92. if random.random() < 0.3:
  93. text = " ".join(
  94. [
  95. (f"<p:{i}>" if i not in pu_symbols and i != pad_symbol else i)
  96. for i in text
  97. ]
  98. )
  99. # encode
  100. tokens = self.tokenizer.encode(
  101. text,
  102. add_special_tokens=False,
  103. truncation=False,
  104. max_length=10**6,
  105. )
  106. # Random choice self.max_length
  107. if len(tokens) > self.max_length:
  108. start = random.randint(0, len(tokens) - self.max_length)
  109. tokens = tokens[start : start + self.max_length - 1]
  110. tokens = (
  111. [self.tokenizer.bos_token_id] + tokens + [self.tokenizer.eos_token_id]
  112. )
  113. # Pad dims
  114. placeholder_multi_codebook = torch.zeros((4, len(tokens)), dtype=torch.long)
  115. tokens = torch.concat(
  116. [
  117. torch.tensor([tokens], dtype=torch.long),
  118. placeholder_multi_codebook,
  119. ],
  120. dim=0,
  121. )
  122. labels = tokens.clone()
  123. tokens = tokens[:, :-1]
  124. labels = labels[:, 1:]
  125. labels[1:] = -100 # remove all placeholders
  126. yield {"tokens": tokens, "labels": labels}
  127. def parse_data_internal(self, filename: str):
  128. url = f"https://huggingface.co/datasets/{self.repo}/resolve/main/{filename}"
  129. with xopen(url, mode="rb") as stream:
  130. parquet_file = pq.ParquetFile(stream)
  131. for batch in parquet_file.iter_batches(
  132. batch_size=self.parquet_batch_size, columns=["text"]
  133. ):
  134. # In-batch shuffling
  135. texts = [{"text": text.as_py()} for text in batch["text"]]
  136. random.shuffle(texts)
  137. yield from texts
  138. class AutoAugTextDataset(IterableDataset):
  139. """
  140. Auto Augment Dataset by Speaker
  141. 1. Random concatenate multiple sentences from the same speaker to form a longer sentence
  142. 2. Automatically normalize the text
  143. 3. Mix text and phones
  144. For interactive mode, we use the following format (multiple sequences):
  145. <s> [INST] [SPK: speaker] text [/INST] ... [INST] text [/INST] </s>
  146. For non-interactive mode, we use the following format (one long sequence):
  147. <s> [INST] text [/INST] ... </s>
  148. """
  149. def __init__(
  150. self,
  151. server: str = "localhost:50051",
  152. seed: int = 42,
  153. phones_prob: float = 0.3,
  154. repetition_prob: float = 0.0,
  155. interactive_prob: float = 0.5,
  156. max_length: int = 1024,
  157. tokenizer: AutoTokenizer = None,
  158. use_speaker: bool = True,
  159. use_data_server: bool = True,
  160. proto_files: str = "data",
  161. causual: bool = True,
  162. mix_text_phone_prob: float = 0.5,
  163. use_negative_samples: bool = False,
  164. num_codebooks: Optional[int] = None,
  165. use_delay_pattern: bool = True,
  166. ):
  167. """
  168. Args:
  169. server: gRPC server address
  170. seed: random seed
  171. phones_prob: probability to use phones
  172. repetition_prob: probability to repeat the same sentence
  173. interactive_prob: probability to use interactive mode
  174. max_length: max length of the text
  175. tokenizer: tokenizer
  176. use_speaker: include speaker information in the prompt
  177. use_data_server: use data server or local data
  178. proto_files: proto buf files if using local data
  179. causual: use causual sampling when using local data, disable will lead to random sampling
  180. mix_text_phone_prob: probability to mix text and phones, if this is 0, then it will be pure text or pure phones
  181. use_negative_samples: generate negative samples
  182. num_codebooks: number of codebooks, if None, it will be automatically detected
  183. use_delay_pattern: use delay pattern for codebooks
  184. """
  185. super().__init__()
  186. assert 0 <= phones_prob <= 1, "phones_prob must be in [0, 1]"
  187. assert 0 <= repetition_prob <= 1, "repetition_prob must be in [0, 1]"
  188. assert 0 <= interactive_prob <= 1, "interactive_prob must be in [0, 1]"
  189. assert 0 <= mix_text_phone_prob <= 1, "mix_text_phone_prob must be in [0, 1]"
  190. self.seed = seed
  191. self.phones_prob = phones_prob
  192. self.max_length = max_length
  193. self.tokenizer = tokenizer
  194. self.repetition_prob = repetition_prob
  195. self.interactive_prob = interactive_prob
  196. self.use_speaker = use_speaker
  197. self.use_data_server = use_data_server
  198. self.proto_files = proto_files
  199. self.causual = causual
  200. self.mix_text_phone_prob = mix_text_phone_prob
  201. self.use_negative_samples = use_negative_samples
  202. self.num_codebooks = num_codebooks
  203. self.use_delay_pattern = use_delay_pattern
  204. if use_data_server is True:
  205. self.channel = grpc.insecure_channel(server)
  206. self.stub = DataServiceStub(self.channel)
  207. else:
  208. self.init_mock_data_server()
  209. def init_mock_data_server(self):
  210. self.groups = []
  211. count = 0
  212. for filename in self.proto_files:
  213. with open(filename, "rb") as f:
  214. for text_data in read_pb_stream(f):
  215. self.groups.append(text_data)
  216. count += 1
  217. if count % 1000 == 0:
  218. log.info(f"Read {count} groups of data")
  219. log.info(f"Read total {count} groups of data")
  220. # Shuffle the lines
  221. Random(self.seed).shuffle(self.groups)
  222. def __iter__(self):
  223. while True:
  224. yield self.augment()
  225. def tokenize_sentence(self, sentence: str, phones: list[str], mode: str = "sample"):
  226. if (
  227. mode == "sample" and (random.random() < self.phones_prob)
  228. ) or mode == "phones":
  229. sentence = " ".join(
  230. [
  231. (f"<p:{i}>" if i not in pu_symbols and i != pad_symbol else i)
  232. for i in phones
  233. ]
  234. )
  235. else:
  236. sentence = clean_text(sentence)
  237. tokens = self.tokenizer.encode(
  238. f"{sentence}",
  239. max_length=10**6,
  240. add_special_tokens=False,
  241. truncation=False,
  242. )
  243. return sentence, len(tokens)
  244. def sample_data(self):
  245. # Shuffle unique lines, estimate that each sample is at least 20 tokens
  246. num_samples = self.max_length // 20
  247. if self.use_data_server:
  248. request = SampleDataRequest(num_samples=num_samples)
  249. return self.stub.SampleData(request)
  250. # choice group based on their number of samples
  251. group = random.choices(
  252. self.groups, weights=[len(i.sentences) for i in self.groups], k=1
  253. )[0]
  254. if self.causual:
  255. # Sample in order
  256. if num_samples >= len(group.sentences):
  257. samples = group.sentences
  258. else:
  259. begin = random.randint(0, len(group.sentences) - num_samples)
  260. samples = group.sentences[begin : begin + num_samples]
  261. else:
  262. samples = random.choices(
  263. group.sentences, k=min(num_samples, len(group.sentences))
  264. )
  265. return SampledData(
  266. source=group.source,
  267. name=group.name,
  268. samples=samples,
  269. )
  270. def augment(self):
  271. # 50% to pure text or pure phones
  272. mode = "sample"
  273. if random.random() > self.mix_text_phone_prob:
  274. mode = random.choices(
  275. ["text", "phones"],
  276. weights=[1 - self.phones_prob, self.phones_prob],
  277. k=1,
  278. )[0]
  279. # Random sample based on speaker using a truncated normal distribution
  280. a = torch.tensor([0], dtype=torch.float32)
  281. torch.nn.init.trunc_normal_(
  282. a,
  283. mean=self.max_length // 2,
  284. std=self.max_length // 4,
  285. a=10,
  286. b=self.max_length,
  287. )
  288. remaining_tokens = a.long().item() - 4
  289. final_text, final_semantic = [], []
  290. response = self.sample_data()
  291. if len(response.samples) == 0:
  292. # Invalid group
  293. return None
  294. samples = list(response.samples)
  295. idx = 0
  296. use_interactive = random.random() < self.interactive_prob
  297. all_tokens, all_labels = [], []
  298. while remaining_tokens > 0 and len(samples) > 0:
  299. if random.random() < self.repetition_prob:
  300. # Repeat the same sentence
  301. sentence = samples[-1]
  302. else:
  303. sentence = samples.pop()
  304. text, length = self.tokenize_sentence(
  305. sentence.text, sentence.phones, mode=mode
  306. )
  307. remaining_tokens -= length + len(sentence.semantics[0].values)
  308. if use_interactive is False:
  309. final_text.append(text)
  310. final_semantic.append(sentence.semantics)
  311. else:
  312. # For interactive mode, we only apply speaker for the first sentence
  313. # [INST] [SPK: speaker] text [/INST] ... [INST] text [/INST]
  314. tokens, labels = self.pack_sentences(
  315. sentences=[text],
  316. semantics=[sentence.semantics],
  317. speaker=response.name if (self.use_speaker and idx == 0) else None,
  318. add_bos=idx == 0,
  319. )
  320. all_tokens.append(tokens)
  321. all_labels.append(labels)
  322. idx += 1
  323. if use_interactive is False:
  324. tokens, labels = self.pack_sentences(
  325. final_text,
  326. semantics=final_semantic,
  327. speaker=None if self.use_speaker else response.name,
  328. add_bos=True,
  329. )
  330. all_tokens.append(tokens)
  331. all_labels.append(labels)
  332. tokens = torch.cat(all_tokens, dim=1)
  333. labels = torch.cat(all_labels, dim=1)
  334. # Verify that the length is correct
  335. assert tokens.size(1) == labels.size(1), f"{tokens.size(1)} != {labels.size(1)}"
  336. # Verify only one <s> token
  337. assert (tokens[:, 0] == self.tokenizer.bos_token_id).sum() == 1
  338. data = {"tokens": tokens, "labels": labels}
  339. if self.use_negative_samples:
  340. negative_samples = self.generate_negative_samples(all_tokens, all_labels)
  341. data.update(negative_samples)
  342. return data
  343. def generate_negative_samples(self, all_tokens, all_labels):
  344. new_tokens, new_labels = [], []
  345. for tokens, labels in zip(all_tokens, all_labels):
  346. # If all codebooks are not -100, we find where it starts
  347. start = torch.where(labels[1:].sum(0) != -100 * (labels.size(0) - 1))[0][0]
  348. assert (labels[1:, start:] != -100).all() # This shouldn't happen
  349. mode = random.choice(["repeat", "lost", "noise"])
  350. begin = random.randint(start, labels.size(1) - 1)
  351. end = random.randint(begin, labels.size(1) - 1)
  352. if mode == "repeat":
  353. tokens = torch.cat(
  354. [
  355. tokens[:, :begin],
  356. tokens[:, begin:end],
  357. tokens[:, begin:end],
  358. tokens[:, end:],
  359. ],
  360. dim=1,
  361. )
  362. labels = torch.cat(
  363. [
  364. labels[:, :begin],
  365. labels[:, begin:end],
  366. labels[:, begin:end],
  367. labels[:, end:],
  368. ],
  369. dim=1,
  370. )
  371. elif mode == "lost":
  372. tokens = torch.cat([tokens[:, :begin], tokens[:, end:]], dim=1)
  373. labels = torch.cat([labels[:, :begin], labels[:, end:]], dim=1)
  374. elif mode == "noise":
  375. middle_tokens, middle_labels = (
  376. tokens[:, begin:end],
  377. labels[:, begin:end],
  378. )
  379. random_order0 = torch.randperm(middle_tokens.size(1))
  380. random_order1 = torch.randperm(middle_tokens.size(1))
  381. middle_tokens = middle_tokens[:, random_order0]
  382. middle_labels = middle_labels[:, random_order1]
  383. tokens = torch.cat(
  384. [tokens[:, :begin], middle_tokens, tokens[:, end:]], dim=1
  385. )
  386. labels = torch.cat(
  387. [labels[:, :begin], middle_labels, labels[:, end:]], dim=1
  388. )
  389. new_tokens.append(tokens)
  390. new_labels.append(labels)
  391. tokens = torch.cat(new_tokens, dim=1)
  392. labels = torch.cat(new_labels, dim=1)
  393. # Verify that the length is correct
  394. assert tokens.size(1) == labels.size(1), f"{tokens.size(1)} != {labels.size(1)}"
  395. return {"negative_tokens": tokens, "negative_labels": labels}
  396. def pack_sentences(
  397. self,
  398. sentences: list[str],
  399. semantics=list,
  400. speaker: Optional[str] = None,
  401. add_bos: bool = True,
  402. ):
  403. if speaker is not None:
  404. sentences = [f"[SPK: {speaker}]"] + sentences
  405. final_text = "[INST] " + " ".join(sentences) + " [/INST]"
  406. encoded = self.tokenizer.encode(
  407. final_text,
  408. add_special_tokens=False,
  409. truncation=False,
  410. max_length=10**6,
  411. )
  412. semantic_length = sum([len(i[0].values) for i in semantics])
  413. prompt_length = len(encoded)
  414. num_codebooks = (
  415. len(semantics[0]) if self.num_codebooks is None else self.num_codebooks
  416. )
  417. bos_bias = 1 if add_bos else 0
  418. # Pack the tokens and semantics (add <s> and </s> to semantic tokens)
  419. pad_token_length = semantic_length + (
  420. num_codebooks - 1 if self.use_delay_pattern else 0
  421. )
  422. tokens = (
  423. encoded
  424. + [self.tokenizer.pad_token_id] * pad_token_length
  425. + [self.tokenizer.eos_token_id]
  426. )
  427. if add_bos:
  428. tokens = [self.tokenizer.bos_token_id] + tokens
  429. # Codebook bos/padding: 0, eos: 1
  430. # Implement delay pattern
  431. codes = [
  432. [CODEBOOK_PAD_TOKEN_ID]
  433. * (prompt_length + bos_bias + (i if self.use_delay_pattern else 0))
  434. for i in range(num_codebooks)
  435. ]
  436. for segment in semantics:
  437. for book_idx, book in zip(range(num_codebooks), segment):
  438. for j in book.values:
  439. codes[book_idx].append(int(j) + 2)
  440. for idx, book in enumerate(codes):
  441. if self.use_delay_pattern:
  442. book.extend([CODEBOOK_PAD_TOKEN_ID] * (len(codes) - idx - 1))
  443. book.append(CODEBOOK_EOS_TOKEN_ID)
  444. tokens = [tokens] + codes
  445. tokens = torch.tensor(tokens, dtype=torch.long)
  446. labels = tokens.clone()
  447. # Mask out the <s> tokens for semantic, predict semantic tokens only
  448. # Since we don't mask out the input tokens, the language modeling still works
  449. labels[1:, : (prompt_length + bos_bias)] = -100
  450. tokens = tokens[:, :-1]
  451. labels = labels[:, 1:]
  452. # Verify the padding is correct, and the last token is eos
  453. assert add_bos is False or tokens[0, 0] == self.tokenizer.bos_token_id
  454. assert (tokens[1:, : prompt_length + bos_bias] == CODEBOOK_PAD_TOKEN_ID).all()
  455. assert labels[0, -1] == self.tokenizer.eos_token_id
  456. assert (labels[1:, -1] == CODEBOOK_EOS_TOKEN_ID).all()
  457. return tokens, labels
  458. @dataclass
  459. class TextDataCollator:
  460. tokenizer: AutoTokenizer
  461. max_length: int = 1024
  462. def __call__(self, examples):
  463. if "negative_tokens" in examples:
  464. positive_examples = []
  465. negative_examples = []
  466. for i in examples:
  467. positive_examples.append(
  468. {
  469. "tokens": i["tokens"],
  470. "labels": i["labels"],
  471. }
  472. )
  473. negative_examples.append(
  474. {
  475. "tokens": i["negative_tokens"],
  476. "labels": i["negative_labels"],
  477. }
  478. )
  479. examples = positive_examples + negative_examples
  480. return self.batchify(examples)
  481. def batchify(self, examples, tokens_key="tokens", labels_key="labels"):
  482. tokens, attention_masks, labels = [], [], []
  483. for example in examples:
  484. _tokens = example[tokens_key][:, : self.max_length]
  485. _labels = example[labels_key][:, : self.max_length]
  486. _attention_mask = torch.ones((self.max_length,), dtype=torch.bool)
  487. tokens_length = _tokens.size(1)
  488. _attention_mask[:tokens_length] = False
  489. assert tokens_length == _labels.size(
  490. 1
  491. ), f"{tokens_length} != {_labels.size(1)}"
  492. if tokens_length < self.max_length:
  493. _tokens = F.pad(
  494. _tokens,
  495. (0, self.max_length - tokens_length),
  496. value=self.tokenizer.eos_token_id,
  497. )
  498. _tokens[1:, tokens_length:] = CODEBOOK_EOS_TOKEN_ID
  499. _labels = F.pad(
  500. _labels, (0, self.max_length - _labels.size(1)), value=-100
  501. )
  502. tokens.append(_tokens)
  503. attention_masks.append(_attention_mask)
  504. labels.append(_labels)
  505. tokens = torch.stack(tokens, dim=0)
  506. attention_masks = torch.stack(attention_masks, dim=0)
  507. labels = torch.stack(labels, dim=0)
  508. return {
  509. "inputs": tokens,
  510. "attention_masks": attention_masks,
  511. "labels": labels,
  512. }
  513. class InterleaveDataset(IterableDataset):
  514. def __init__(
  515. self,
  516. datasets: list[IterableDataset],
  517. probabilities: list[float],
  518. seed: int = 42,
  519. ):
  520. super().__init__()
  521. self.datasets = datasets
  522. self.probabilities = probabilities
  523. self.seed = seed
  524. def __iter__(self):
  525. rng = np.random.default_rng(self.seed)
  526. dataset_iterators = [iter(dataset) for dataset in self.datasets]
  527. while True:
  528. # Random choice one
  529. dataset_idx = rng.choice(len(self.datasets), p=self.probabilities)
  530. dataset_iterator = dataset_iterators[dataset_idx]
  531. try:
  532. yield next(dataset_iterator)
  533. except StopIteration:
  534. # Exhausted, create a new iterator
  535. dataset_iterators[dataset_idx] = iter(self.datasets[dataset_idx])
  536. yield next(dataset_iterators[dataset_idx])
  537. class TextDataModule(LightningDataModule):
  538. def __init__(
  539. self,
  540. train_dataset: Union[StreamTextDataset, AutoAugTextDataset, InterleaveDataset],
  541. val_dataset: Union[StreamTextDataset, AutoAugTextDataset, InterleaveDataset],
  542. batch_size: int = 32,
  543. tokenizer: AutoTokenizer = None,
  544. max_length: int = 1024,
  545. num_workers: int = 4,
  546. ):
  547. super().__init__()
  548. self.train_dataset = train_dataset
  549. self.val_dataset = val_dataset
  550. self.batch_size = batch_size
  551. self.tokenizer = tokenizer
  552. self.max_length = max_length
  553. self.num_workers = num_workers
  554. def train_dataloader(self):
  555. return DataLoader(
  556. self.train_dataset,
  557. batch_size=self.batch_size,
  558. collate_fn=TextDataCollator(self.tokenizer, self.max_length),
  559. num_workers=self.num_workers,
  560. )
  561. def val_dataloader(self):
  562. return DataLoader(
  563. self.val_dataset,
  564. batch_size=self.batch_size,
  565. collate_fn=TextDataCollator(self.tokenizer, self.max_length),
  566. num_workers=self.num_workers,
  567. )
  568. if __name__ == "__main__":
  569. from tqdm import tqdm
  570. ds = AutoAugTextDataset(
  571. tokenizer=AutoTokenizer.from_pretrained("fishaudio/speech-lm-v1"),
  572. use_speaker=True,
  573. interactive_prob=1.0,
  574. phones_prob=1.0,
  575. use_negative_samples=False,
  576. num_codebooks=4,
  577. )
  578. # ds = AutoAugTextDataset(
  579. # tokenizer=AutoTokenizer.from_pretrained("fishaudio/speech-lm-v1"),
  580. # use_speaker=True,
  581. # interactive_prob=1.0,
  582. # use_data_server=False,
  583. # proto_files=["data/wenet-speech.protos"],
  584. # )
  585. dm = TextDataModule(
  586. train_dataset=ds,
  587. val_dataset=ds,
  588. tokenizer=ds.tokenizer,
  589. batch_size=2,
  590. max_length=1024,
  591. num_workers=0,
  592. )
  593. for batch in tqdm(dm.train_dataloader()):
  594. print(batch)
  595. break