text.py 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652
  1. import gzip
  2. import io
  3. import json
  4. import random
  5. from dataclasses import dataclass
  6. from pathlib import Path
  7. from random import Random
  8. from typing import Optional, Union
  9. import numpy as np
  10. import torch
  11. import torch.nn.functional as F
  12. import zstandard as zstd
  13. from lightning import LightningDataModule
  14. from torch.distributed import get_rank, get_world_size, is_initialized
  15. from torch.utils.data import DataLoader, IterableDataset, get_worker_info
  16. from transformers import AutoTokenizer
  17. from fish_speech.conversation import (
  18. CODEBOOK_PAD_TOKEN_ID,
  19. SKIP_TEXT_STRING,
  20. Conversation,
  21. Message,
  22. encode_conversation,
  23. )
  24. from fish_speech.datasets.prompts import asr_instructions, tts_instructions
  25. from fish_speech.datasets.protos.text_data_pb2 import SampledData
  26. from fish_speech.datasets.protos.text_data_stream import read_pb_stream
  27. from fish_speech.text.clean import clean_text
  28. from fish_speech.utils import RankedLogger
  29. from fish_speech.utils.braceexpand import braceexpand
  30. log = RankedLogger(__name__, rank_zero_only=True)
  31. DCTX = zstd.ZstdDecompressor(max_window_size=2**31)
  32. def split_by_rank_worker(files):
  33. # We need to know the total number of devices
  34. # to split the data properly
  35. total_devices = 1
  36. if is_initialized():
  37. total_devices = get_world_size()
  38. worker_info = get_worker_info()
  39. if worker_info is not None:
  40. total_devices *= worker_info.num_workers
  41. if len(files) < total_devices:
  42. # Repeat the files N times to match the number of devices
  43. files = files * (total_devices // len(files) + 1)
  44. # DDP
  45. if is_initialized():
  46. files = files[get_rank() :: get_world_size()]
  47. # Split by worker
  48. if worker_info is not None:
  49. files = files[worker_info.id :: worker_info.num_workers]
  50. return files
  51. def expand_split_proto_files(proto_files, seed: int = 42):
  52. # Expand the proto files
  53. expanded_proto_files = []
  54. for filename in proto_files:
  55. for i in braceexpand(filename):
  56. i = Path(i)
  57. if i.is_file():
  58. expanded_proto_files.append(i)
  59. elif i.is_dir():
  60. expanded_proto_files.extend(i.rglob("*.proto"))
  61. expanded_proto_files.extend(i.rglob("*.protos"))
  62. else:
  63. raise ValueError(f"{i} is not a file or directory")
  64. expanded_proto_files = sorted(expanded_proto_files)
  65. Random(seed).shuffle(expanded_proto_files)
  66. return split_by_rank_worker(expanded_proto_files)
  67. class TextPretrainDataset(IterableDataset):
  68. def __init__(
  69. self,
  70. source: str,
  71. seed: int = 42,
  72. max_length: int = 1024,
  73. tokenizer: AutoTokenizer = None,
  74. num_codebooks: int = 2,
  75. ):
  76. super().__init__()
  77. self.source = Path(source)
  78. self.seed = seed
  79. self.max_length = max_length
  80. self.tokenizer = tokenizer
  81. self.num_codebooks = num_codebooks
  82. if self.source.is_file():
  83. with open(self.source, "r") as f:
  84. files = f.read().strip().split("\n")
  85. self.root = self.source.parent
  86. else:
  87. files = [
  88. str(i.relative_to(self.source)) for i in self.source.rglob("*.jsonl")
  89. ]
  90. self.root = self.source
  91. # Get sharded files
  92. self.files = sorted(files)
  93. Random(seed).shuffle(self.files)
  94. def __iter__(self):
  95. files = split_by_rank_worker(self.files)
  96. random.shuffle(files)
  97. for filename in files:
  98. try:
  99. yield from self.parse_data(filename)
  100. except Exception as e:
  101. log.exception(f"Failed to parse {filename}: {e}")
  102. def read_jsonl(self, filename: str):
  103. with open(self.root / filename, "rb") as f:
  104. if filename.endswith(".zst"):
  105. stream_reader = DCTX.stream_reader(f)
  106. elif filename.endswith(".gz"):
  107. stream_reader = gzip.open(f, "rb")
  108. elif filename.endswith(".jsonl"):
  109. stream_reader = f
  110. else:
  111. raise ValueError(f"Unknown file type: {filename}")
  112. stream = io.TextIOWrapper(stream_reader, encoding="utf-8")
  113. # Parse jsonl
  114. for line in stream:
  115. line = json.loads(line)
  116. yield line
  117. def parse_data(self, filename: str):
  118. for line in self.read_jsonl(filename):
  119. # encode
  120. tokens = self.tokenizer.encode(
  121. line["text"],
  122. add_special_tokens=False,
  123. truncation=False,
  124. max_length=10**6,
  125. )
  126. tokens = (
  127. [self.tokenizer.bos_token_id] + tokens + [self.tokenizer.eos_token_id]
  128. )
  129. if len(tokens) > self.max_length:
  130. tokens = tokens[: self.max_length]
  131. tokens = self.pad_codebooks(tokens)
  132. labels = tokens.clone()
  133. tokens = tokens[:, :-1]
  134. labels = labels[:, 1:]
  135. labels[1:] = -100 # no loss on codebook
  136. yield {"tokens": tokens, "labels": labels}
  137. def pad_codebooks(self, tokens):
  138. placeholder_multi_codebook = (
  139. torch.zeros((self.num_codebooks, len(tokens)), dtype=torch.long)
  140. + CODEBOOK_PAD_TOKEN_ID
  141. )
  142. return torch.concat(
  143. [
  144. torch.tensor([tokens], dtype=torch.long),
  145. placeholder_multi_codebook,
  146. ],
  147. dim=0,
  148. )
  149. class TextInstructionDataset(TextPretrainDataset):
  150. def parse_data(self, filename: str):
  151. for line in self.read_jsonl(filename):
  152. messages = []
  153. for conversation in line["conversations"]:
  154. role = {
  155. "human": "user",
  156. "gpt": "assistant",
  157. "system": "system",
  158. }[conversation["from"]]
  159. message = Message(
  160. role=role,
  161. parts=[conversation["value"]],
  162. )
  163. messages.append(message)
  164. conversation = Conversation(messages=messages)
  165. tokens, labels = encode_conversation(
  166. conversation,
  167. self.tokenizer,
  168. num_codebooks=self.num_codebooks,
  169. )
  170. yield {"tokens": tokens, "labels": labels}
  171. def semantic_to_tensor(semantics):
  172. num_codebooks = len(semantics)
  173. codes = [[] for _ in range(num_codebooks)]
  174. for book_idx, book in zip(range(num_codebooks), semantics):
  175. for j in book.values:
  176. codes[book_idx].append(int(j))
  177. return torch.tensor(codes, dtype=torch.int)
  178. class AutoTextSemanticInstructionDataset(IterableDataset):
  179. def __init__(
  180. self,
  181. proto_files: list[str],
  182. seed: int = 42,
  183. max_length: int = 1024,
  184. tokenizer: AutoTokenizer = None,
  185. causual: Union[bool, float] = True,
  186. num_codebooks: Optional[int] = None,
  187. skip_text_prob: float = 0.0,
  188. asr_prob: float = 0.0,
  189. ):
  190. """
  191. Args:
  192. proto_files: proto buf files if using local data
  193. seed: random seed
  194. max_length: max length of the text
  195. tokenizer: tokenizer
  196. causual: use causual sampling when using local data, disable will lead to random sampling
  197. num_codebooks: number of codebooks, if None, it will be automatically detected
  198. skip_text_prob: probability to skip the text (audio only), this only applies to interactive mode
  199. asr_prob: probability to use ASR
  200. """
  201. super().__init__()
  202. assert 0 <= skip_text_prob <= 1, "skip_text_prob must be in [0, 1]"
  203. assert 0 <= asr_prob <= 1, "asr_prob must be in [0, 1]"
  204. self.seed = seed
  205. self.max_length = max_length
  206. self.tokenizer = tokenizer
  207. self.proto_files = proto_files
  208. self.causual = causual
  209. self.num_codebooks = num_codebooks
  210. self.skip_text_prob = skip_text_prob
  211. self.asr_prob = asr_prob
  212. self.groups = None
  213. def init_mock_data_server(self):
  214. if self.groups is not None:
  215. return
  216. self.groups = []
  217. shard_proto_files = expand_split_proto_files(self.proto_files, seed=self.seed)
  218. log.info(f"Reading {len(shard_proto_files)} files")
  219. count = 0
  220. for filename in shard_proto_files:
  221. with open(filename, "rb") as f:
  222. for text_data in read_pb_stream(f):
  223. self.groups.append(text_data)
  224. count += 1
  225. log.info(f"Read total {count} groups of data")
  226. # Shuffle the lines
  227. Random(self.seed).shuffle(self.groups)
  228. self.group_weights = [len(i.sentences) for i in self.groups]
  229. def __iter__(self):
  230. while True:
  231. yield self.augment()
  232. def tokenize_sentence(self, sentence: str):
  233. sentence = clean_text(sentence)
  234. tokens = self.tokenizer.encode(
  235. f"{sentence}",
  236. max_length=10**6,
  237. add_special_tokens=False,
  238. truncation=False,
  239. )
  240. return sentence, len(tokens)
  241. def sample_data(self):
  242. if self.groups is None:
  243. self.init_mock_data_server()
  244. # Shuffle unique lines, estimate that each sample is at least 20 tokens
  245. num_samples = self.max_length // 20
  246. # choice group based on their number of samples
  247. group = random.choices(self.groups, weights=self.group_weights, k=1)[0]
  248. causual = self.causual
  249. if isinstance(self.causual, float):
  250. causual = random.random() < self.causual
  251. if causual:
  252. # Sample in order
  253. if num_samples >= len(group.sentences):
  254. samples = group.sentences
  255. else:
  256. begin = random.randint(0, len(group.sentences) - num_samples)
  257. samples = group.sentences[begin : begin + num_samples]
  258. else:
  259. samples = random.choices(
  260. group.sentences, k=min(num_samples, len(group.sentences))
  261. )
  262. return SampledData(
  263. source=group.source,
  264. name=group.name,
  265. samples=samples,
  266. )
  267. def augment(self):
  268. response = self.sample_data()
  269. if len(response.samples) == 0:
  270. # Invalid group
  271. return None
  272. samples = list(response.samples)
  273. idx = 0
  274. remaining_tokens = self.max_length
  275. all_messages = []
  276. while remaining_tokens > 0 and len(samples) > 0:
  277. sentence = samples.pop(0)
  278. text = random.choice(sentence.texts)
  279. text, length = self.tokenize_sentence(text)
  280. remaining_tokens -= length + len(sentence.semantics[0].values)
  281. # For interactive mode, we only apply speaker for the first sentence
  282. # [INST] [SPK: speaker] text [/INST] ... [INST] text [/INST]
  283. if random.random() < self.asr_prob:
  284. all_messages.append(
  285. Message(
  286. role="user",
  287. parts=[
  288. random.choice(asr_instructions),
  289. semantic_to_tensor(sentence.semantics),
  290. ],
  291. )
  292. )
  293. all_messages.append(
  294. Message(
  295. role="assistant",
  296. parts=[text],
  297. )
  298. )
  299. else:
  300. skip_text = random.random() < self.skip_text_prob
  301. if skip_text:
  302. text = SKIP_TEXT_STRING
  303. all_messages.append(
  304. Message(
  305. role="user",
  306. parts=[random.choice(tts_instructions) + text],
  307. mask_labels=skip_text,
  308. )
  309. )
  310. all_messages.append(
  311. Message(
  312. role="assistant",
  313. parts=[semantic_to_tensor(sentence.semantics)],
  314. mask_labels=skip_text,
  315. )
  316. )
  317. idx += 1
  318. tokens, labels = encode_conversation(
  319. Conversation(messages=all_messages),
  320. self.tokenizer,
  321. num_codebooks=self.num_codebooks,
  322. )
  323. # Verify that the length is correct
  324. assert tokens.size(1) == labels.size(1), f"{tokens.size(1)} != {labels.size(1)}"
  325. # Verify bos token
  326. assert tokens[0, 0] == self.tokenizer.bos_token_id
  327. return {"tokens": tokens, "labels": labels}
  328. class SemanticInstructionDataset(IterableDataset):
  329. def __init__(
  330. self,
  331. proto_files: list[str],
  332. seed: int = 42,
  333. max_length: int = 1024,
  334. tokenizer: AutoTokenizer = None,
  335. num_codebooks: Optional[int] = None,
  336. ):
  337. super().__init__()
  338. self.seed = seed
  339. self.max_length = max_length
  340. self.tokenizer = tokenizer
  341. self.proto_files = proto_files
  342. self.num_codebooks = num_codebooks
  343. def get_data_generator(self):
  344. shard_proto_files = expand_split_proto_files(self.proto_files, seed=self.seed)
  345. random.shuffle(shard_proto_files)
  346. log.info(f"Fetched {len(shard_proto_files)} files")
  347. for filename in shard_proto_files:
  348. with open(filename, "rb") as f:
  349. for group in read_pb_stream(f):
  350. yield group
  351. def pack_one_group(self, group):
  352. sentences = group.sentences
  353. messages = []
  354. for idx, sentence in enumerate(sentences):
  355. role = "user" if idx % 2 == 0 else "assistant"
  356. semantic = semantic_to_tensor(sentence.semantics)
  357. text = random.choice(sentence.texts)
  358. parts = [semantic]
  359. if role == "assistant":
  360. # Let model to predict the text first
  361. prev_text = random.choice(sentences[idx - 1].texts)
  362. # parts.insert(0, f"Q: {prev_text}\nA: {text}")
  363. messages.append(
  364. Message(
  365. role=role,
  366. parts=parts,
  367. )
  368. )
  369. conversation = Conversation(messages=messages)
  370. tokens, labels = encode_conversation(
  371. conversation,
  372. self.tokenizer,
  373. num_codebooks=self.num_codebooks,
  374. )
  375. return {"tokens": tokens, "labels": labels}
  376. def __iter__(self):
  377. for group in self.get_data_generator():
  378. try:
  379. yield self.pack_one_group(group)
  380. except Exception as e:
  381. log.exception(f"Failed to parse {group}: {e}")
  382. @dataclass
  383. class TextDataCollator:
  384. tokenizer: AutoTokenizer
  385. max_length: int = 1024
  386. def __call__(self, examples):
  387. if "negative_tokens" in examples:
  388. positive_examples = []
  389. negative_examples = []
  390. for i in examples:
  391. positive_examples.append(
  392. {
  393. "tokens": i["tokens"],
  394. "labels": i["labels"],
  395. }
  396. )
  397. negative_examples.append(
  398. {
  399. "tokens": i["negative_tokens"],
  400. "labels": i["negative_labels"],
  401. }
  402. )
  403. examples = positive_examples + negative_examples
  404. return self.batchify(examples)
  405. def batchify(self, examples, tokens_key="tokens", labels_key="labels"):
  406. tokens, attention_masks, labels = [], [], []
  407. # Calculate the max length
  408. max_tokens_length = 0
  409. for example in examples:
  410. max_tokens_length = max(max_tokens_length, example[tokens_key].size(1))
  411. max_tokens_length = min(max_tokens_length, self.max_length)
  412. for example in examples:
  413. _tokens = example[tokens_key][:, :max_tokens_length]
  414. _labels = example[labels_key][:, :max_tokens_length]
  415. _attention_mask = torch.ones((max_tokens_length,), dtype=torch.bool)
  416. tokens_length = _tokens.size(1)
  417. _attention_mask[:tokens_length] = False
  418. assert tokens_length == _labels.size(
  419. 1
  420. ), f"{tokens_length} != {_labels.size(1)}"
  421. if tokens_length < max_tokens_length:
  422. _tokens = F.pad(
  423. _tokens,
  424. (0, max_tokens_length - tokens_length),
  425. value=self.tokenizer.eos_token_id,
  426. )
  427. _tokens[1:, tokens_length:] = CODEBOOK_PAD_TOKEN_ID
  428. _labels = F.pad(
  429. _labels, (0, max_tokens_length - _labels.size(1)), value=-100
  430. )
  431. tokens.append(_tokens)
  432. attention_masks.append(_attention_mask)
  433. labels.append(_labels)
  434. tokens = torch.stack(tokens, dim=0)
  435. attention_masks = torch.stack(attention_masks, dim=0)
  436. labels = torch.stack(labels, dim=0)
  437. return {
  438. "inputs": tokens,
  439. "attention_masks": attention_masks,
  440. "labels": labels,
  441. }
  442. class InterleaveDataset(IterableDataset):
  443. def __init__(
  444. self,
  445. datasets: list[IterableDataset],
  446. probabilities: list[float],
  447. seed: int = 42,
  448. ):
  449. super().__init__()
  450. self.datasets = datasets
  451. self.probabilities = probabilities
  452. self.seed = seed
  453. def __iter__(self):
  454. rng = np.random.default_rng(self.seed)
  455. dataset_iterators = [iter(dataset) for dataset in self.datasets]
  456. while True:
  457. # Random choice one
  458. dataset_idx = rng.choice(len(self.datasets), p=self.probabilities)
  459. dataset_iterator = dataset_iterators[dataset_idx]
  460. try:
  461. yield next(dataset_iterator)
  462. except StopIteration:
  463. # Exhausted, create a new iterator
  464. dataset_iterators[dataset_idx] = iter(self.datasets[dataset_idx])
  465. yield next(dataset_iterators[dataset_idx])
  466. class TextDataModule(LightningDataModule):
  467. def __init__(
  468. self,
  469. train_dataset: Union[
  470. AutoTextSemanticInstructionDataset,
  471. TextPretrainDataset,
  472. TextInstructionDataset,
  473. InterleaveDataset,
  474. ],
  475. val_dataset: Union[
  476. AutoTextSemanticInstructionDataset,
  477. TextPretrainDataset,
  478. TextInstructionDataset,
  479. InterleaveDataset,
  480. ],
  481. batch_size: int = 32,
  482. tokenizer: AutoTokenizer = None,
  483. max_length: int = 1024,
  484. num_workers: int = 4,
  485. ):
  486. super().__init__()
  487. self.train_dataset = train_dataset
  488. self.val_dataset = val_dataset
  489. self.batch_size = batch_size
  490. self.tokenizer = tokenizer
  491. self.max_length = max_length
  492. self.num_workers = num_workers
  493. def train_dataloader(self):
  494. return DataLoader(
  495. self.train_dataset,
  496. batch_size=self.batch_size,
  497. collate_fn=TextDataCollator(self.tokenizer, self.max_length),
  498. num_workers=self.num_workers,
  499. persistent_workers=True,
  500. )
  501. def val_dataloader(self):
  502. return DataLoader(
  503. self.val_dataset,
  504. batch_size=self.batch_size,
  505. collate_fn=TextDataCollator(self.tokenizer, self.max_length),
  506. num_workers=self.num_workers,
  507. persistent_workers=True,
  508. )
  509. if __name__ == "__main__":
  510. from tqdm import tqdm
  511. # ds = AutoTextSemanticInstructionDataset(
  512. # ["data/protos/sft/val/11labs"],
  513. # tokenizer=AutoTokenizer.from_pretrained("checkpoints/fish-speech-agent-1"),
  514. # skip_text_prob=1.0,
  515. # asr_prob=0.0,
  516. # num_codebooks=2,
  517. # )
  518. # ds = TextInstructionDataset(
  519. # source="data/openhermes2_5",
  520. # tokenizer=AutoTokenizer.from_pretrained("checkpoints/fish-speech-agent-1"),
  521. # )
  522. ds = SemanticInstructionDataset(
  523. proto_files=["data/protos/sft/val/ultrachat_200k_spoken_openai"],
  524. tokenizer=AutoTokenizer.from_pretrained("checkpoints/fish-speech-agent-1"),
  525. num_codebooks=2,
  526. )
  527. for i in ds:
  528. # print(ds.tokenizer.decode(i["tokens"][0], skip_special_tokens=False))
  529. # i["labels"][0][i["labels"][0] == -100] = 0
  530. # print(ds.tokenizer.decode(i["labels"][0], skip_special_tokens=False))
  531. length = i["tokens"].size(1)
  532. print(i["tokens"].size(), i["tokens"].dtype)
  533. for j in range(length):
  534. print(
  535. ds.tokenizer.decode(i["tokens"][0, j]),
  536. i["tokens"][:, j],
  537. i["labels"][:, j],
  538. )
  539. input()
  540. break