semantic.py 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627
  1. import random
  2. from dataclasses import dataclass
  3. from itertools import chain
  4. from pathlib import Path
  5. from random import Random
  6. from typing import Optional, Union
  7. import numpy as np
  8. import pyarrow.parquet as pq
  9. import torch
  10. import torch.nn.functional as F
  11. from datasets.download.streaming_download_manager import xopen
  12. from huggingface_hub import HfApi
  13. from lightning import LightningDataModule
  14. from torch.distributed import get_rank, get_world_size, is_initialized
  15. from torch.utils.data import DataLoader, Dataset, IterableDataset, get_worker_info
  16. from fish_speech.content_sequence import ContentSequence, TextPart, VQPart
  17. CODEBOOK_PAD_TOKEN_ID = 0
  18. from fish_speech.datasets.protos.text_data_pb2 import SampledData
  19. from fish_speech.datasets.protos.text_data_stream import read_pb_stream
  20. from fish_speech.text.clean import clean_text
  21. from fish_speech.tokenizer import FishTokenizer
  22. from fish_speech.utils import RankedLogger
  23. from fish_speech.utils.braceexpand import braceexpand
  24. log = RankedLogger(__name__, rank_zero_only=True)
  25. def split_by_rank_worker(files):
  26. # We need to know the total number of devices
  27. # to split the data properly
  28. total_devices = 1
  29. if is_initialized():
  30. total_devices = get_world_size()
  31. worker_info = get_worker_info()
  32. if worker_info is not None:
  33. total_devices *= worker_info.num_workers
  34. if len(files) < total_devices:
  35. # Repeat the files N times to match the number of devices
  36. files = files * (total_devices // len(files) + 1)
  37. # DDP
  38. if is_initialized():
  39. files = files[get_rank() :: get_world_size()]
  40. # Split by worker
  41. if worker_info is not None:
  42. files = files[worker_info.id :: worker_info.num_workers]
  43. return files
  44. class AutoTextSemanticInstructionIterableDataset(IterableDataset):
  45. """
  46. Auto Augment Dataset by Speaker
  47. 1. Random concatenate multiple sentences from the same speaker to form a longer sentence
  48. 2. Automatically normalize the text
  49. For interactive mode, we use the following format (multiple sequences):
  50. <s> [INST] [SPK: speaker] text [/INST] ... [INST] text [/INST] </s>
  51. For non-interactive mode, we use the following format (one long sequence):
  52. <s> [INST] text [/INST] ... </s>
  53. """
  54. def __init__(
  55. self,
  56. proto_files: list[str],
  57. seed: int = 42,
  58. interactive_prob: float = 0.5,
  59. max_length: int = 1024,
  60. tokenizer: FishTokenizer = None,
  61. use_speaker: bool | float = True,
  62. causal: bool = True,
  63. num_codebooks: Optional[int] = None,
  64. skip_text_prob: float = 0.0,
  65. ):
  66. """
  67. Args:
  68. proto_files: proto buf files if using local data
  69. seed: random seed
  70. interactive_prob: probability to use interactive mode
  71. max_length: max length of the text
  72. tokenizer: tokenizer
  73. use_speaker: include speaker information in the prompt
  74. causal: use causal sampling when using local data, disable will lead to random sampling
  75. num_codebooks: number of codebooks, if None, it will be automatically detected
  76. skip_text_prob: probability to skip the text (audio only), this only applies to interactive mode
  77. """
  78. super().__init__()
  79. assert 0 <= interactive_prob <= 1, "interactive_prob must be in [0, 1]"
  80. self.seed = seed
  81. self.max_length = max_length
  82. self.tokenizer = tokenizer
  83. self.interactive_prob = interactive_prob
  84. self.use_speaker = use_speaker
  85. self.proto_files = proto_files
  86. self.causal = causal
  87. self.num_codebooks = num_codebooks
  88. self.skip_text_prob = skip_text_prob
  89. self.groups = None
  90. def __iter__(self):
  91. while True:
  92. yield self.augment()
  93. def init_mock_data_server(self):
  94. if self.groups is not None:
  95. return
  96. # Expand the proto files
  97. expanded_proto_files = []
  98. for filename in self.proto_files:
  99. for i in braceexpand(filename):
  100. i = Path(i)
  101. if i.is_file():
  102. expanded_proto_files.append(i)
  103. elif i.is_dir():
  104. expanded_proto_files.extend(i.rglob("*.proto"))
  105. expanded_proto_files.extend(i.rglob("*.protos"))
  106. else:
  107. raise ValueError(f"{i} is not a file or directory")
  108. expanded_proto_files = sorted(expanded_proto_files)
  109. Random(self.seed).shuffle(expanded_proto_files)
  110. self.groups = []
  111. shard_proto_files = split_by_rank_worker(expanded_proto_files)
  112. log.info(
  113. f"Reading {len(shard_proto_files)} / {len(expanded_proto_files)} files"
  114. )
  115. count = 0
  116. for filename in shard_proto_files:
  117. with open(filename, "rb") as f:
  118. for text_data in read_pb_stream(f):
  119. self.groups.append(text_data)
  120. count += 1
  121. log.info(f"Read total {count} groups of data")
  122. # Shuffle the lines
  123. Random(self.seed).shuffle(self.groups)
  124. self.group_weights = [len(i.sentences) for i in self.groups]
  125. def sample_data(self):
  126. if self.groups is None:
  127. self.init_mock_data_server()
  128. # Shuffle unique lines, estimate that each sample is at least 20 tokens
  129. num_samples = self.max_length // 20
  130. # choice group based on their number of samples
  131. group = random.choices(self.groups, weights=self.group_weights, k=1)[0]
  132. if self.causal:
  133. # Sample in order
  134. if num_samples >= len(group.sentences):
  135. samples = group.sentences
  136. else:
  137. begin = random.randint(0, len(group.sentences) - num_samples)
  138. samples = group.sentences[begin : begin + num_samples]
  139. else:
  140. samples = random.choices(
  141. group.sentences, k=min(num_samples, len(group.sentences))
  142. )
  143. return SampledData(
  144. source=group.source,
  145. name=group.name,
  146. samples=samples,
  147. )
  148. def pack_sentences(
  149. self,
  150. sentences: list[str],
  151. semantics: list,
  152. # speaker: Optional[str] = None, # speaker is now handled by tokens
  153. skip_text: bool = False,
  154. ):
  155. seq = ContentSequence()
  156. seq.append(TextPart(text="Speak out the provided text."))
  157. # User's turn
  158. cated_sentences = " ".join(sentences)
  159. if skip_text:
  160. cated_sentences = "<|skip_text|>"
  161. seq.append(
  162. TextPart(text=f"<|speaker:user|> {cated_sentences}"),
  163. add_end=True,
  164. )
  165. # Assistant's turn
  166. vq_codes = [x.values for x in semantics[0]]
  167. vq_codes_tensor = torch.tensor(vq_codes).to(torch.int32)
  168. # 将 cal_loss=True 直接关联到 VQPart 上,这比之前更精确
  169. vq_part = VQPart(codes=vq_codes_tensor, cal_loss=True)
  170. # 将多个 parts 一起添加,最后也加上 <|im_end|>
  171. seq.append(
  172. [TextPart(text="<|speaker:assistant|> <|voice|>"), vq_part],
  173. add_end=True,
  174. )
  175. encoded = seq.encode(
  176. tokenizer=self.tokenizer,
  177. )
  178. num_codebooks = (
  179. len(semantics[0]) if self.num_codebooks is None else self.num_codebooks
  180. )
  181. tokens_raw = encoded.tokens
  182. tokens = torch.zeros((num_codebooks + 1, len(tokens_raw)), dtype=torch.int)
  183. tokens[0] = tokens_raw
  184. vq_parts = encoded.vq_parts
  185. vq_parts = [part.to(tokens.device) for part in vq_parts]
  186. vq_parts = torch.cat(vq_parts, dim=1)
  187. tokens[1:, encoded.vq_mask_tokens] = vq_parts
  188. labels_raw = encoded.labels
  189. labels = torch.full((num_codebooks + 1, len(labels_raw)), -100, dtype=torch.int)
  190. labels[0, :] = labels_raw
  191. labels[1:, encoded.vq_mask_labels] = vq_parts
  192. labels[1:, -1:] = CODEBOOK_PAD_TOKEN_ID
  193. tokens = tokens.long()
  194. labels = labels.long()
  195. # Verify the padding is correct, and the last token is eos
  196. assert (tokens[1:, ~(encoded.vq_mask_tokens)] == CODEBOOK_PAD_TOKEN_ID).all()
  197. assert (labels[1:, -1:] == CODEBOOK_PAD_TOKEN_ID).all()
  198. return tokens, labels
  199. def augment(self):
  200. response = self.sample_data()
  201. if len(response.samples) == 0:
  202. # Invalid group
  203. return None
  204. samples = list(response.samples)
  205. all_tokens, all_labels = [], []
  206. while len(samples) > 0:
  207. sentence = samples.pop(0)
  208. text = clean_text(random.choice(sentence.texts))
  209. tokens, labels = self.pack_sentences(
  210. sentences=[text],
  211. semantics=[sentence.semantics],
  212. # speaker=response.name if use_speaker else None,
  213. skip_text=random.random() < self.skip_text_prob,
  214. )
  215. all_tokens.append(tokens)
  216. all_labels.append(labels)
  217. tokens = torch.cat(all_tokens, dim=1)
  218. labels = torch.cat(all_labels, dim=1)
  219. # Verify that the length is correct
  220. assert tokens.size(1) == labels.size(1), f"{tokens.size(1)} != {labels.size(1)}"
  221. data = {"tokens": tokens, "labels": labels}
  222. return data
  223. class AutoTextSemanticInstructionDataset(Dataset):
  224. """
  225. Auto Augment Dataset by Speaker
  226. 1. Random concatenate multiple sentences from the same speaker to form a longer sentence
  227. 2. Automatically normalize the text
  228. For interactive mode, we use the following format (multiple sequences):
  229. <s> [INST] [SPK: speaker] text [/INST] ... [INST] text [/INST] </s>
  230. For non-interactive mode, we use the following format (one long sequence):
  231. <s> [INST] text [/INST] ... </s>
  232. """
  233. def __init__(
  234. self,
  235. proto_files: list[str],
  236. seed: int = 42,
  237. interactive_prob: float = 0.5,
  238. max_length: int = 1024,
  239. tokenizer: FishTokenizer = None,
  240. use_speaker: bool | float = True,
  241. causal: bool = True,
  242. num_codebooks: Optional[int] = None,
  243. skip_text_prob: float = 0.0,
  244. ):
  245. """
  246. Args:
  247. proto_files: proto buf files if using local data
  248. seed: random seed
  249. interactive_prob: probability to use interactive mode
  250. max_length: max length of the text
  251. tokenizer: tokenizer
  252. use_speaker: include speaker information in the prompt
  253. causal: use causal sampling when using local data, disable will lead to random sampling
  254. num_codebooks: number of codebooks, if None, it will be automatically detected
  255. skip_text_prob: probability to skip the text (audio only), this only applies to interactive mode
  256. """
  257. super().__init__()
  258. assert 0 <= interactive_prob <= 1, "interactive_prob must be in [0, 1]"
  259. self.seed = seed
  260. self.max_length = max_length
  261. self.tokenizer = tokenizer
  262. self.interactive_prob = interactive_prob
  263. self.use_speaker = use_speaker
  264. self.proto_files = proto_files
  265. self.causal = causal
  266. self.num_codebooks = num_codebooks
  267. self.skip_text_prob = skip_text_prob
  268. self.data = []
  269. self._init_data()
  270. def _init_data(self):
  271. expanded_proto_files = []
  272. for filename in self.proto_files:
  273. for i in braceexpand(filename):
  274. i = Path(i)
  275. if i.is_file():
  276. expanded_proto_files.append(i)
  277. elif i.is_dir():
  278. expanded_proto_files.extend(i.rglob("*.proto"))
  279. expanded_proto_files.extend(i.rglob("*.protos"))
  280. else:
  281. raise ValueError(f"{i} is not a file or directory")
  282. expanded_proto_files = sorted(expanded_proto_files)
  283. Random(self.seed).shuffle(expanded_proto_files)
  284. groups = []
  285. shard_proto_files = split_by_rank_worker(expanded_proto_files)
  286. log.info(
  287. f"Reading {len(shard_proto_files)} / {len(expanded_proto_files)} files"
  288. )
  289. count = 0
  290. for filename in shard_proto_files:
  291. with open(filename, "rb") as f:
  292. for text_data in read_pb_stream(f):
  293. groups.append(text_data)
  294. count += 1
  295. log.info(f"Read total {count} groups of data")
  296. for group in groups:
  297. if len(group.sentences) == 0:
  298. continue
  299. samples = list(group.sentences)
  300. for sentence in samples:
  301. text = clean_text(random.choice(sentence.texts))
  302. tokens, labels = self.pack_sentences(
  303. sentences=[text],
  304. semantics=[sentence.semantics],
  305. skip_text=random.random() < self.skip_text_prob,
  306. )
  307. self.data.append({"tokens": tokens, "labels": labels})
  308. random.Random(self.seed).shuffle(self.data)
  309. def __len__(self):
  310. return len(self.data)
  311. def __getitem__(self, idx):
  312. return self.data[idx]
  313. def pack_sentences(
  314. self,
  315. sentences: list[str],
  316. semantics: list,
  317. skip_text: bool = False,
  318. ):
  319. messages = [
  320. Message(
  321. role="system",
  322. parts=[TextPart(text="Speak out the provided text.")],
  323. )
  324. ]
  325. cated_sentences = " ".join(sentences)
  326. if skip_text:
  327. cated_sentences = "<|skip_text|>"
  328. messages.append(
  329. Message(
  330. role="user",
  331. parts=[TextPart(text=cated_sentences)],
  332. )
  333. )
  334. vq_codes = [x.values for x in semantics[0]]
  335. vq_codes_tensor = torch.tensor(vq_codes).to(torch.int32)
  336. vqpart = VQPart(codes=vq_codes_tensor)
  337. messages.append(
  338. Message(
  339. role="assistant",
  340. parts=[TextPart(text="<|voice|>"), vqpart],
  341. cal_loss=True,
  342. )
  343. )
  344. num_codebooks = (
  345. len(semantics[0]) if self.num_codebooks is None else self.num_codebooks
  346. )
  347. conversation = Conversation(messages=messages)
  348. encoded = conversation.encode(
  349. tokenizer=self.tokenizer,
  350. )
  351. tokens_raw = encoded.tokens
  352. tokens = torch.zeros((num_codebooks + 1, len(tokens_raw)), dtype=torch.int)
  353. tokens[0] = tokens_raw
  354. vq_parts = encoded.vq_parts
  355. vq_parts = [part.to(tokens.device) for part in vq_parts]
  356. vq_parts = torch.cat(vq_parts, dim=1)
  357. tokens[1:, encoded.vq_mask_tokens] = vq_parts
  358. labels_raw = encoded.labels
  359. labels = torch.full((num_codebooks + 1, len(labels_raw)), -100, dtype=torch.int)
  360. labels[0, :] = labels_raw
  361. labels[1:, encoded.vq_mask_labels] = vq_parts
  362. labels[1:, -1:] = CODEBOOK_PAD_TOKEN_ID
  363. tokens = tokens.long()
  364. labels = labels.long()
  365. assert (tokens[1:, ~(encoded.vq_mask_tokens)] == CODEBOOK_PAD_TOKEN_ID).all()
  366. assert (labels[1:, -1:] == CODEBOOK_PAD_TOKEN_ID).all()
  367. return tokens, labels
  368. class InterleaveDataset(IterableDataset):
  369. def __init__(
  370. self,
  371. datasets: list[IterableDataset],
  372. probabilities: list[float],
  373. seed: int = 42,
  374. ):
  375. super().__init__()
  376. self.datasets = datasets
  377. self.probabilities = probabilities
  378. self.seed = seed
  379. def __iter__(self):
  380. rng = np.random.default_rng(self.seed)
  381. dataset_iterators = [iter(dataset) for dataset in self.datasets]
  382. while True:
  383. # Random choice one
  384. dataset_idx = rng.choice(len(self.datasets), p=self.probabilities)
  385. dataset_iterator = dataset_iterators[dataset_idx]
  386. try:
  387. yield next(dataset_iterator)
  388. except StopIteration:
  389. # Exhausted, create a new iterator
  390. dataset_iterators[dataset_idx] = iter(self.datasets[dataset_idx])
  391. yield next(dataset_iterators[dataset_idx])
  392. @dataclass
  393. class TextDataCollator:
  394. tokenizer: FishTokenizer
  395. max_length: int = 1024
  396. def __call__(self, examples):
  397. if "negative_tokens" in examples:
  398. positive_examples = []
  399. negative_examples = []
  400. for i in examples:
  401. positive_examples.append(
  402. {
  403. "tokens": i["tokens"],
  404. "labels": i["labels"],
  405. }
  406. )
  407. negative_examples.append(
  408. {
  409. "tokens": i["negative_tokens"],
  410. "labels": i["negative_labels"],
  411. }
  412. )
  413. examples = positive_examples + negative_examples
  414. return self.batchify(examples)
  415. def batchify(self, examples, tokens_key="tokens", labels_key="labels"):
  416. tokens, attention_masks, labels = [], [], []
  417. # Calculate the max length
  418. max_tokens_length = 0
  419. for example in examples:
  420. max_tokens_length = max(max_tokens_length, example[tokens_key].size(1))
  421. max_tokens_length = min(max_tokens_length, self.max_length)
  422. for example in examples:
  423. _tokens = example[tokens_key][:, :max_tokens_length]
  424. _labels = example[labels_key][:, :max_tokens_length]
  425. _attention_mask = torch.ones((max_tokens_length,), dtype=torch.bool)
  426. tokens_length = _tokens.size(1)
  427. _attention_mask[:tokens_length] = False
  428. assert tokens_length == _labels.size(
  429. 1
  430. ), f"{tokens_length} != {_labels.size(1)}"
  431. if tokens_length < max_tokens_length:
  432. _tokens = F.pad(
  433. _tokens,
  434. (0, max_tokens_length - tokens_length),
  435. value=self.tokenizer.get_token_id("<|end_of_text|>"),
  436. )
  437. _tokens[1:, tokens_length:] = CODEBOOK_PAD_TOKEN_ID
  438. _labels = F.pad(
  439. _labels, (0, max_tokens_length - _labels.size(1)), value=-100
  440. )
  441. tokens.append(_tokens)
  442. attention_masks.append(_attention_mask)
  443. labels.append(_labels)
  444. tokens = torch.stack(tokens, dim=0)
  445. attention_masks = torch.stack(attention_masks, dim=0)
  446. labels = torch.stack(labels, dim=0)
  447. return {
  448. "inputs": tokens,
  449. "attention_masks": attention_masks,
  450. "labels": labels,
  451. }
  452. class SemanticDataModule(LightningDataModule):
  453. def __init__(
  454. self,
  455. train_dataset: Union[
  456. AutoTextSemanticInstructionDataset,
  457. AutoTextSemanticInstructionIterableDataset,
  458. InterleaveDataset,
  459. ],
  460. val_dataset: Union[
  461. AutoTextSemanticInstructionDataset,
  462. AutoTextSemanticInstructionIterableDataset,
  463. InterleaveDataset,
  464. ],
  465. batch_size: int = 32,
  466. tokenizer: FishTokenizer = None,
  467. max_length: int = 1024,
  468. num_workers: int = 4,
  469. ):
  470. super().__init__()
  471. self.train_dataset = train_dataset
  472. self.val_dataset = val_dataset
  473. self.batch_size = batch_size
  474. self.tokenizer = tokenizer
  475. self.max_length = max_length
  476. self.num_workers = num_workers
  477. def train_dataloader(self):
  478. return DataLoader(
  479. self.train_dataset,
  480. batch_size=self.batch_size,
  481. collate_fn=TextDataCollator(self.tokenizer, self.max_length),
  482. num_workers=self.num_workers,
  483. persistent_workers=True,
  484. )
  485. def val_dataloader(self):
  486. return DataLoader(
  487. self.val_dataset,
  488. batch_size=self.batch_size,
  489. collate_fn=TextDataCollator(self.tokenizer, self.max_length),
  490. num_workers=self.num_workers,
  491. persistent_workers=True,
  492. )
  493. if __name__ == "__main__":
  494. from tqdm import tqdm
  495. ds = AutoTextSemanticInstructionDataset(
  496. ["data/protos"],
  497. tokenizer=FishTokenizer("checkpoints/fish-speech-1.5/tokenizer.tiktoken"),
  498. use_speaker=False,
  499. interactive_prob=1.0,
  500. skip_text_prob=0.5,
  501. )
  502. for i in range(100):
  503. # Please uncomment line 235 to visualize the tokenized message
  504. print(ds[i])