semantic.py 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639
  1. import random
  2. from dataclasses import dataclass
  3. from itertools import chain
  4. from pathlib import Path
  5. from random import Random
  6. from typing import Optional, Union
  7. import numpy as np
  8. import pyarrow.parquet as pq
  9. import torch
  10. import torch.nn.functional as F
  11. from datasets.download.streaming_download_manager import xopen
  12. from huggingface_hub import HfApi
  13. from lightning import LightningDataModule
  14. from torch.distributed import get_rank, get_world_size, is_initialized
  15. from torch.utils.data import DataLoader, Dataset, IterableDataset, get_worker_info
  16. from fish_speech.conversation import (
  17. CODEBOOK_PAD_TOKEN_ID,
  18. Conversation,
  19. Message,
  20. TextPart,
  21. VQPart,
  22. )
  23. from fish_speech.datasets.protos.text_data_pb2 import SampledData
  24. from fish_speech.datasets.protos.text_data_stream import read_pb_stream
  25. from fish_speech.text.clean import clean_text
  26. from fish_speech.tokenizer import FishTokenizer
  27. from fish_speech.utils import RankedLogger
  28. from fish_speech.utils.braceexpand import braceexpand
  29. log = RankedLogger(__name__, rank_zero_only=True)
  30. def split_by_rank_worker(files):
  31. # We need to know the total number of devices
  32. # to split the data properly
  33. total_devices = 1
  34. if is_initialized():
  35. total_devices = get_world_size()
  36. worker_info = get_worker_info()
  37. if worker_info is not None:
  38. total_devices *= worker_info.num_workers
  39. if len(files) < total_devices:
  40. # Repeat the files N times to match the number of devices
  41. files = files * (total_devices // len(files) + 1)
  42. # DDP
  43. if is_initialized():
  44. files = files[get_rank() :: get_world_size()]
  45. # Split by worker
  46. if worker_info is not None:
  47. files = files[worker_info.id :: worker_info.num_workers]
  48. return files
  49. class AutoTextSemanticInstructionIterableDataset(IterableDataset):
  50. """
  51. Auto Augment Dataset by Speaker
  52. 1. Random concatenate multiple sentences from the same speaker to form a longer sentence
  53. 2. Automatically normalize the text
  54. For interactive mode, we use the following format (multiple sequences):
  55. <s> [INST] [SPK: speaker] text [/INST] ... [INST] text [/INST] </s>
  56. For non-interactive mode, we use the following format (one long sequence):
  57. <s> [INST] text [/INST] ... </s>
  58. """
  59. def __init__(
  60. self,
  61. proto_files: list[str],
  62. seed: int = 42,
  63. interactive_prob: float = 0.5,
  64. max_length: int = 1024,
  65. tokenizer: FishTokenizer = None,
  66. use_speaker: bool | float = True,
  67. causal: bool = True,
  68. num_codebooks: Optional[int] = None,
  69. skip_text_prob: float = 0.0,
  70. ):
  71. """
  72. Args:
  73. proto_files: proto buf files if using local data
  74. seed: random seed
  75. interactive_prob: probability to use interactive mode
  76. max_length: max length of the text
  77. tokenizer: tokenizer
  78. use_speaker: include speaker information in the prompt
  79. causal: use causal sampling when using local data, disable will lead to random sampling
  80. num_codebooks: number of codebooks, if None, it will be automatically detected
  81. skip_text_prob: probability to skip the text (audio only), this only applies to interactive mode
  82. """
  83. super().__init__()
  84. assert 0 <= interactive_prob <= 1, "interactive_prob must be in [0, 1]"
  85. self.seed = seed
  86. self.max_length = max_length
  87. self.tokenizer = tokenizer
  88. self.interactive_prob = interactive_prob
  89. self.use_speaker = use_speaker
  90. self.proto_files = proto_files
  91. self.causal = causal
  92. self.num_codebooks = num_codebooks
  93. self.skip_text_prob = skip_text_prob
  94. self.groups = None
  95. def __iter__(self):
  96. while True:
  97. yield self.augment()
  98. def init_mock_data_server(self):
  99. if self.groups is not None:
  100. return
  101. # Expand the proto files
  102. expanded_proto_files = []
  103. for filename in self.proto_files:
  104. for i in braceexpand(filename):
  105. i = Path(i)
  106. if i.is_file():
  107. expanded_proto_files.append(i)
  108. elif i.is_dir():
  109. expanded_proto_files.extend(i.rglob("*.proto"))
  110. expanded_proto_files.extend(i.rglob("*.protos"))
  111. else:
  112. raise ValueError(f"{i} is not a file or directory")
  113. expanded_proto_files = sorted(expanded_proto_files)
  114. Random(self.seed).shuffle(expanded_proto_files)
  115. self.groups = []
  116. shard_proto_files = split_by_rank_worker(expanded_proto_files)
  117. log.info(
  118. f"Reading {len(shard_proto_files)} / {len(expanded_proto_files)} files"
  119. )
  120. count = 0
  121. for filename in shard_proto_files:
  122. with open(filename, "rb") as f:
  123. for text_data in read_pb_stream(f):
  124. self.groups.append(text_data)
  125. count += 1
  126. log.info(f"Read total {count} groups of data")
  127. # Shuffle the lines
  128. Random(self.seed).shuffle(self.groups)
  129. self.group_weights = [len(i.sentences) for i in self.groups]
  130. def sample_data(self):
  131. if self.groups is None:
  132. self.init_mock_data_server()
  133. # Shuffle unique lines, estimate that each sample is at least 20 tokens
  134. num_samples = self.max_length // 20
  135. # choice group based on their number of samples
  136. group = random.choices(self.groups, weights=self.group_weights, k=1)[0]
  137. if self.causal:
  138. # Sample in order
  139. if num_samples >= len(group.sentences):
  140. samples = group.sentences
  141. else:
  142. begin = random.randint(0, len(group.sentences) - num_samples)
  143. samples = group.sentences[begin : begin + num_samples]
  144. else:
  145. samples = random.choices(
  146. group.sentences, k=min(num_samples, len(group.sentences))
  147. )
  148. return SampledData(
  149. source=group.source,
  150. name=group.name,
  151. samples=samples,
  152. )
  153. def pack_sentences(
  154. self,
  155. sentences: list[str],
  156. semantics: list,
  157. # speaker: Optional[str] = None,
  158. skip_text: bool = False,
  159. ):
  160. # if speaker is None:
  161. # speaker = "assistant"
  162. messages = [
  163. Message(
  164. role="system",
  165. parts=[TextPart(text="Speak out the provided text.")],
  166. # add_im_end=False,
  167. # cal_loss=True,
  168. )
  169. ]
  170. cated_sentences = " ".join(sentences)
  171. if skip_text:
  172. cated_sentences = "<|skip_text|>"
  173. messages.append(
  174. Message(
  175. role="user",
  176. parts=[TextPart(text=cated_sentences)],
  177. # cal_loss=True,
  178. )
  179. )
  180. vq_codes = [x.values for x in semantics[0]]
  181. vq_codes_tensor = torch.tensor(vq_codes).to(torch.int32)
  182. vqpart = VQPart(codes=vq_codes_tensor)
  183. messages.append(
  184. Message(
  185. role="assistant",
  186. parts=[TextPart(text="<|voice|>"), vqpart],
  187. cal_loss=True,
  188. )
  189. )
  190. num_codebooks = (
  191. len(semantics[0]) if self.num_codebooks is None else self.num_codebooks
  192. )
  193. conversation = Conversation(messages=messages)
  194. # conversation.visualize(tokenizer=self.tokenizer)
  195. encoded = conversation.encode(
  196. tokenizer=self.tokenizer,
  197. )
  198. tokens_raw = encoded.tokens
  199. tokens = torch.zeros((num_codebooks + 1, len(tokens_raw)), dtype=torch.int)
  200. tokens[0] = tokens_raw
  201. vq_parts = encoded.vq_parts
  202. vq_parts = [part.to(tokens.device) for part in vq_parts]
  203. vq_parts = torch.cat(vq_parts, dim=1)
  204. tokens[1:, encoded.vq_mask_tokens] = vq_parts
  205. labels_raw = encoded.labels
  206. labels = torch.full((num_codebooks + 1, len(labels_raw)), -100, dtype=torch.int)
  207. labels[0, :] = labels_raw
  208. labels[1:, encoded.vq_mask_labels] = vq_parts
  209. labels[1:, -1:] = CODEBOOK_PAD_TOKEN_ID
  210. tokens = tokens.long()
  211. labels = labels.long()
  212. # Verify the padding is correct, and the last token is eos
  213. assert (tokens[1:, ~(encoded.vq_mask_tokens)] == CODEBOOK_PAD_TOKEN_ID).all()
  214. assert (labels[1:, -1:] == CODEBOOK_PAD_TOKEN_ID).all()
  215. return tokens, labels
  216. def augment(self):
  217. response = self.sample_data()
  218. if len(response.samples) == 0:
  219. # Invalid group
  220. return None
  221. samples = list(response.samples)
  222. all_tokens, all_labels = [], []
  223. while len(samples) > 0:
  224. sentence = samples.pop(0)
  225. text = clean_text(random.choice(sentence.texts))
  226. tokens, labels = self.pack_sentences(
  227. sentences=[text],
  228. semantics=[sentence.semantics],
  229. # speaker=response.name if use_speaker else None,
  230. skip_text=random.random() < self.skip_text_prob,
  231. )
  232. all_tokens.append(tokens)
  233. all_labels.append(labels)
  234. tokens = torch.cat(all_tokens, dim=1)
  235. labels = torch.cat(all_labels, dim=1)
  236. # Verify that the length is correct
  237. assert tokens.size(1) == labels.size(1), f"{tokens.size(1)} != {labels.size(1)}"
  238. data = {"tokens": tokens, "labels": labels}
  239. return data
  240. class AutoTextSemanticInstructionDataset(Dataset):
  241. """
  242. Auto Augment Dataset by Speaker
  243. 1. Random concatenate multiple sentences from the same speaker to form a longer sentence
  244. 2. Automatically normalize the text
  245. For interactive mode, we use the following format (multiple sequences):
  246. <s> [INST] [SPK: speaker] text [/INST] ... [INST] text [/INST] </s>
  247. For non-interactive mode, we use the following format (one long sequence):
  248. <s> [INST] text [/INST] ... </s>
  249. """
  250. def __init__(
  251. self,
  252. proto_files: list[str],
  253. seed: int = 42,
  254. interactive_prob: float = 0.5,
  255. max_length: int = 1024,
  256. tokenizer: FishTokenizer = None,
  257. use_speaker: bool | float = True,
  258. causal: bool = True,
  259. num_codebooks: Optional[int] = None,
  260. skip_text_prob: float = 0.0,
  261. ):
  262. """
  263. Args:
  264. proto_files: proto buf files if using local data
  265. seed: random seed
  266. interactive_prob: probability to use interactive mode
  267. max_length: max length of the text
  268. tokenizer: tokenizer
  269. use_speaker: include speaker information in the prompt
  270. causal: use causal sampling when using local data, disable will lead to random sampling
  271. num_codebooks: number of codebooks, if None, it will be automatically detected
  272. skip_text_prob: probability to skip the text (audio only), this only applies to interactive mode
  273. """
  274. super().__init__()
  275. assert 0 <= interactive_prob <= 1, "interactive_prob must be in [0, 1]"
  276. self.seed = seed
  277. self.max_length = max_length
  278. self.tokenizer = tokenizer
  279. self.interactive_prob = interactive_prob
  280. self.use_speaker = use_speaker
  281. self.proto_files = proto_files
  282. self.causal = causal
  283. self.num_codebooks = num_codebooks
  284. self.skip_text_prob = skip_text_prob
  285. self.data = []
  286. self._init_data()
  287. def _init_data(self):
  288. expanded_proto_files = []
  289. for filename in self.proto_files:
  290. for i in braceexpand(filename):
  291. i = Path(i)
  292. if i.is_file():
  293. expanded_proto_files.append(i)
  294. elif i.is_dir():
  295. expanded_proto_files.extend(i.rglob("*.proto"))
  296. expanded_proto_files.extend(i.rglob("*.protos"))
  297. else:
  298. raise ValueError(f"{i} is not a file or directory")
  299. expanded_proto_files = sorted(expanded_proto_files)
  300. Random(self.seed).shuffle(expanded_proto_files)
  301. groups = []
  302. shard_proto_files = split_by_rank_worker(expanded_proto_files)
  303. log.info(
  304. f"Reading {len(shard_proto_files)} / {len(expanded_proto_files)} files"
  305. )
  306. count = 0
  307. for filename in shard_proto_files:
  308. with open(filename, "rb") as f:
  309. for text_data in read_pb_stream(f):
  310. groups.append(text_data)
  311. count += 1
  312. log.info(f"Read total {count} groups of data")
  313. for group in groups:
  314. if len(group.sentences) == 0:
  315. continue
  316. samples = list(group.sentences)
  317. for sentence in samples:
  318. text = clean_text(random.choice(sentence.texts))
  319. tokens, labels = self.pack_sentences(
  320. sentences=[text],
  321. semantics=[sentence.semantics],
  322. skip_text=random.random() < self.skip_text_prob,
  323. )
  324. self.data.append({"tokens": tokens, "labels": labels})
  325. random.Random(self.seed).shuffle(self.data)
  326. def __len__(self):
  327. return len(self.data)
  328. def __getitem__(self, idx):
  329. return self.data[idx]
  330. def pack_sentences(
  331. self,
  332. sentences: list[str],
  333. semantics: list,
  334. skip_text: bool = False,
  335. ):
  336. messages = [
  337. Message(
  338. role="system",
  339. parts=[TextPart(text="Speak out the provided text.")],
  340. )
  341. ]
  342. cated_sentences = " ".join(sentences)
  343. if skip_text:
  344. cated_sentences = "<|skip_text|>"
  345. messages.append(
  346. Message(
  347. role="user",
  348. parts=[TextPart(text=cated_sentences)],
  349. )
  350. )
  351. vq_codes = [x.values for x in semantics[0]]
  352. vq_codes_tensor = torch.tensor(vq_codes).to(torch.int32)
  353. vqpart = VQPart(codes=vq_codes_tensor)
  354. messages.append(
  355. Message(
  356. role="assistant",
  357. parts=[TextPart(text="<|voice|>"), vqpart],
  358. cal_loss=True,
  359. )
  360. )
  361. num_codebooks = (
  362. len(semantics[0]) if self.num_codebooks is None else self.num_codebooks
  363. )
  364. conversation = Conversation(messages=messages)
  365. encoded = conversation.encode(
  366. tokenizer=self.tokenizer,
  367. )
  368. tokens_raw = encoded.tokens
  369. tokens = torch.zeros((num_codebooks + 1, len(tokens_raw)), dtype=torch.int)
  370. tokens[0] = tokens_raw
  371. vq_parts = encoded.vq_parts
  372. vq_parts = [part.to(tokens.device) for part in vq_parts]
  373. vq_parts = torch.cat(vq_parts, dim=1)
  374. tokens[1:, encoded.vq_mask_tokens] = vq_parts
  375. labels_raw = encoded.labels
  376. labels = torch.full((num_codebooks + 1, len(labels_raw)), -100, dtype=torch.int)
  377. labels[0, :] = labels_raw
  378. labels[1:, encoded.vq_mask_labels] = vq_parts
  379. labels[1:, -1:] = CODEBOOK_PAD_TOKEN_ID
  380. tokens = tokens.long()
  381. labels = labels.long()
  382. assert (tokens[1:, ~(encoded.vq_mask_tokens)] == CODEBOOK_PAD_TOKEN_ID).all()
  383. assert (labels[1:, -1:] == CODEBOOK_PAD_TOKEN_ID).all()
  384. return tokens, labels
  385. class InterleaveDataset(IterableDataset):
  386. def __init__(
  387. self,
  388. datasets: list[IterableDataset],
  389. probabilities: list[float],
  390. seed: int = 42,
  391. ):
  392. super().__init__()
  393. self.datasets = datasets
  394. self.probabilities = probabilities
  395. self.seed = seed
  396. def __iter__(self):
  397. rng = np.random.default_rng(self.seed)
  398. dataset_iterators = [iter(dataset) for dataset in self.datasets]
  399. while True:
  400. # Random choice one
  401. dataset_idx = rng.choice(len(self.datasets), p=self.probabilities)
  402. dataset_iterator = dataset_iterators[dataset_idx]
  403. try:
  404. yield next(dataset_iterator)
  405. except StopIteration:
  406. # Exhausted, create a new iterator
  407. dataset_iterators[dataset_idx] = iter(self.datasets[dataset_idx])
  408. yield next(dataset_iterators[dataset_idx])
  409. @dataclass
  410. class TextDataCollator:
  411. tokenizer: FishTokenizer
  412. max_length: int = 1024
  413. def __call__(self, examples):
  414. if "negative_tokens" in examples:
  415. positive_examples = []
  416. negative_examples = []
  417. for i in examples:
  418. positive_examples.append(
  419. {
  420. "tokens": i["tokens"],
  421. "labels": i["labels"],
  422. }
  423. )
  424. negative_examples.append(
  425. {
  426. "tokens": i["negative_tokens"],
  427. "labels": i["negative_labels"],
  428. }
  429. )
  430. examples = positive_examples + negative_examples
  431. return self.batchify(examples)
  432. def batchify(self, examples, tokens_key="tokens", labels_key="labels"):
  433. tokens, attention_masks, labels = [], [], []
  434. # Calculate the max length
  435. max_tokens_length = 0
  436. for example in examples:
  437. max_tokens_length = max(max_tokens_length, example[tokens_key].size(1))
  438. max_tokens_length = min(max_tokens_length, self.max_length)
  439. for example in examples:
  440. _tokens = example[tokens_key][:, :max_tokens_length]
  441. _labels = example[labels_key][:, :max_tokens_length]
  442. _attention_mask = torch.ones((max_tokens_length,), dtype=torch.bool)
  443. tokens_length = _tokens.size(1)
  444. _attention_mask[:tokens_length] = False
  445. assert tokens_length == _labels.size(
  446. 1
  447. ), f"{tokens_length} != {_labels.size(1)}"
  448. if tokens_length < max_tokens_length:
  449. _tokens = F.pad(
  450. _tokens,
  451. (0, max_tokens_length - tokens_length),
  452. value=self.tokenizer.get_token_id("<|end_of_text|>"),
  453. )
  454. _tokens[1:, tokens_length:] = CODEBOOK_PAD_TOKEN_ID
  455. _labels = F.pad(
  456. _labels, (0, max_tokens_length - _labels.size(1)), value=-100
  457. )
  458. tokens.append(_tokens)
  459. attention_masks.append(_attention_mask)
  460. labels.append(_labels)
  461. tokens = torch.stack(tokens, dim=0)
  462. attention_masks = torch.stack(attention_masks, dim=0)
  463. labels = torch.stack(labels, dim=0)
  464. return {
  465. "inputs": tokens,
  466. "attention_masks": attention_masks,
  467. "labels": labels,
  468. }
  469. class SemanticDataModule(LightningDataModule):
  470. def __init__(
  471. self,
  472. train_dataset: Union[
  473. AutoTextSemanticInstructionDataset,
  474. AutoTextSemanticInstructionIterableDataset,
  475. InterleaveDataset,
  476. ],
  477. val_dataset: Union[
  478. AutoTextSemanticInstructionDataset,
  479. AutoTextSemanticInstructionIterableDataset,
  480. InterleaveDataset,
  481. ],
  482. batch_size: int = 32,
  483. tokenizer: FishTokenizer = None,
  484. max_length: int = 1024,
  485. num_workers: int = 4,
  486. ):
  487. super().__init__()
  488. self.train_dataset = train_dataset
  489. self.val_dataset = val_dataset
  490. self.batch_size = batch_size
  491. self.tokenizer = tokenizer
  492. self.max_length = max_length
  493. self.num_workers = num_workers
  494. def train_dataloader(self):
  495. return DataLoader(
  496. self.train_dataset,
  497. batch_size=self.batch_size,
  498. collate_fn=TextDataCollator(self.tokenizer, self.max_length),
  499. num_workers=self.num_workers,
  500. persistent_workers=True,
  501. )
  502. def val_dataloader(self):
  503. return DataLoader(
  504. self.val_dataset,
  505. batch_size=self.batch_size,
  506. collate_fn=TextDataCollator(self.tokenizer, self.max_length),
  507. num_workers=self.num_workers,
  508. persistent_workers=True,
  509. )
  510. if __name__ == "__main__":
  511. from tqdm import tqdm
  512. ds = AutoTextSemanticInstructionDataset(
  513. ["data/protos"],
  514. tokenizer=FishTokenizer("checkpoints/fish-speech-1.5/tokenizer.tiktoken"),
  515. use_speaker=False,
  516. interactive_prob=1.0,
  517. skip_text_prob=0.5,
  518. )
  519. for i in range(100):
  520. # Please uncomment line 235 to visualize the tokenized message
  521. print(ds[i])