|
@@ -1,21 +1,29 @@
|
|
|
|
|
+import gzip
|
|
|
|
|
+import io
|
|
|
|
|
+import json
|
|
|
import random
|
|
import random
|
|
|
from dataclasses import dataclass
|
|
from dataclasses import dataclass
|
|
|
-from itertools import chain
|
|
|
|
|
from pathlib import Path
|
|
from pathlib import Path
|
|
|
from random import Random
|
|
from random import Random
|
|
|
from typing import Optional, Union
|
|
from typing import Optional, Union
|
|
|
|
|
|
|
|
import numpy as np
|
|
import numpy as np
|
|
|
-import pyarrow.parquet as pq
|
|
|
|
|
import torch
|
|
import torch
|
|
|
import torch.nn.functional as F
|
|
import torch.nn.functional as F
|
|
|
-from datasets.download.streaming_download_manager import xopen
|
|
|
|
|
-from huggingface_hub import HfApi
|
|
|
|
|
|
|
+import zstandard as zstd
|
|
|
from lightning import LightningDataModule
|
|
from lightning import LightningDataModule
|
|
|
from torch.distributed import get_rank, get_world_size, is_initialized
|
|
from torch.distributed import get_rank, get_world_size, is_initialized
|
|
|
from torch.utils.data import DataLoader, IterableDataset, get_worker_info
|
|
from torch.utils.data import DataLoader, IterableDataset, get_worker_info
|
|
|
from transformers import AutoTokenizer
|
|
from transformers import AutoTokenizer
|
|
|
|
|
|
|
|
|
|
+from fish_speech.conversation import (
|
|
|
|
|
+ CODEBOOK_PAD_TOKEN_ID,
|
|
|
|
|
+ SKIP_TEXT_STRING,
|
|
|
|
|
+ Conversation,
|
|
|
|
|
+ Message,
|
|
|
|
|
+ encode_conversation,
|
|
|
|
|
+)
|
|
|
|
|
+from fish_speech.datasets.prompts import asr_instructions, tts_instructions
|
|
|
from fish_speech.datasets.protos.text_data_pb2 import SampledData
|
|
from fish_speech.datasets.protos.text_data_pb2 import SampledData
|
|
|
from fish_speech.datasets.protos.text_data_stream import read_pb_stream
|
|
from fish_speech.datasets.protos.text_data_stream import read_pb_stream
|
|
|
from fish_speech.text.clean import clean_text
|
|
from fish_speech.text.clean import clean_text
|
|
@@ -24,9 +32,7 @@ from fish_speech.utils.braceexpand import braceexpand
|
|
|
|
|
|
|
|
log = RankedLogger(__name__, rank_zero_only=True)
|
|
log = RankedLogger(__name__, rank_zero_only=True)
|
|
|
|
|
|
|
|
-CODEBOOK_PAD_TOKEN_ID = 0
|
|
|
|
|
-CODEBOOK_EOS_TOKEN_ID = 1
|
|
|
|
|
-SKIP_TEXT_STRING = "<|skip_text|>"
|
|
|
|
|
|
|
+DCTX = zstd.ZstdDecompressor(max_window_size=2**31)
|
|
|
|
|
|
|
|
|
|
|
|
|
def split_by_rank_worker(files):
|
|
def split_by_rank_worker(files):
|
|
@@ -56,43 +62,55 @@ def split_by_rank_worker(files):
|
|
|
return files
|
|
return files
|
|
|
|
|
|
|
|
|
|
|
|
|
-class StreamTextDataset(IterableDataset):
|
|
|
|
|
|
|
+def expand_split_proto_files(proto_files, seed: int = 42):
|
|
|
|
|
+ # Expand the proto files
|
|
|
|
|
+ expanded_proto_files = []
|
|
|
|
|
+ for filename in proto_files:
|
|
|
|
|
+ for i in braceexpand(filename):
|
|
|
|
|
+ i = Path(i)
|
|
|
|
|
+ if i.is_file():
|
|
|
|
|
+ expanded_proto_files.append(i)
|
|
|
|
|
+ elif i.is_dir():
|
|
|
|
|
+ expanded_proto_files.extend(i.rglob("*.proto"))
|
|
|
|
|
+ expanded_proto_files.extend(i.rglob("*.protos"))
|
|
|
|
|
+ else:
|
|
|
|
|
+ raise ValueError(f"{i} is not a file or directory")
|
|
|
|
|
+
|
|
|
|
|
+ expanded_proto_files = sorted(expanded_proto_files)
|
|
|
|
|
+ Random(seed).shuffle(expanded_proto_files)
|
|
|
|
|
+ return split_by_rank_worker(expanded_proto_files)
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+class TextPretrainDataset(IterableDataset):
|
|
|
def __init__(
|
|
def __init__(
|
|
|
self,
|
|
self,
|
|
|
- files: Optional[Union[list[str], str]] = None,
|
|
|
|
|
- prefix: Optional[str] = None,
|
|
|
|
|
|
|
+ source: str,
|
|
|
seed: int = 42,
|
|
seed: int = 42,
|
|
|
- parquet_batch_size: int = 10000,
|
|
|
|
|
- repo: str = "uonlp/CulturaX",
|
|
|
|
|
max_length: int = 1024,
|
|
max_length: int = 1024,
|
|
|
tokenizer: AutoTokenizer = None,
|
|
tokenizer: AutoTokenizer = None,
|
|
|
|
|
+ num_codebooks: int = 2,
|
|
|
):
|
|
):
|
|
|
super().__init__()
|
|
super().__init__()
|
|
|
|
|
|
|
|
|
|
+ self.source = Path(source)
|
|
|
self.seed = seed
|
|
self.seed = seed
|
|
|
- self.parquet_batch_size = parquet_batch_size
|
|
|
|
|
- self.repo = repo
|
|
|
|
|
self.max_length = max_length
|
|
self.max_length = max_length
|
|
|
self.tokenizer = tokenizer
|
|
self.tokenizer = tokenizer
|
|
|
|
|
+ self.num_codebooks = num_codebooks
|
|
|
|
|
|
|
|
- if files is None and prefix is None:
|
|
|
|
|
- raise ValueError("Either files or prefix must be specified")
|
|
|
|
|
-
|
|
|
|
|
- if prefix is not None:
|
|
|
|
|
- files = HfApi().list_repo_files(repo, repo_type="dataset")
|
|
|
|
|
|
|
+ if self.source.is_file():
|
|
|
|
|
+ with open(self.source, "r") as f:
|
|
|
|
|
+ files = f.read().strip().split("\n")
|
|
|
|
|
+ self.root = self.source.parent
|
|
|
|
|
+ else:
|
|
|
files = [
|
|
files = [
|
|
|
- f for f in files if f.startswith(prefix) and f.endswith(".parquet")
|
|
|
|
|
|
|
+ str(i.relative_to(self.source)) for i in self.source.rglob("*.jsonl")
|
|
|
]
|
|
]
|
|
|
- log.info(f"Found {len(files)} files in {repo} with prefix {prefix}")
|
|
|
|
|
- else:
|
|
|
|
|
- if isinstance(files, str):
|
|
|
|
|
- files = [files]
|
|
|
|
|
-
|
|
|
|
|
- files = list(chain.from_iterable(map(braceexpand, files)))
|
|
|
|
|
- log.info(f"Expanded {len(files)} files in {repo}")
|
|
|
|
|
|
|
+ self.root = self.source
|
|
|
|
|
|
|
|
# Get sharded files
|
|
# Get sharded files
|
|
|
self.files = sorted(files)
|
|
self.files = sorted(files)
|
|
|
|
|
+
|
|
|
Random(seed).shuffle(self.files)
|
|
Random(seed).shuffle(self.files)
|
|
|
|
|
|
|
|
def __iter__(self):
|
|
def __iter__(self):
|
|
@@ -105,142 +123,147 @@ class StreamTextDataset(IterableDataset):
|
|
|
except Exception as e:
|
|
except Exception as e:
|
|
|
log.exception(f"Failed to parse {filename}: {e}")
|
|
log.exception(f"Failed to parse {filename}: {e}")
|
|
|
|
|
|
|
|
- def parse_data(self, filename: str):
|
|
|
|
|
- for data in self.parse_data_internal(filename):
|
|
|
|
|
- text = data["text"]
|
|
|
|
|
|
|
+ def read_jsonl(self, filename: str):
|
|
|
|
|
+ with open(self.root / filename, "rb") as f:
|
|
|
|
|
+ if filename.endswith(".zst"):
|
|
|
|
|
+ stream_reader = DCTX.stream_reader(f)
|
|
|
|
|
+ elif filename.endswith(".gz"):
|
|
|
|
|
+ stream_reader = gzip.open(f, "rb")
|
|
|
|
|
+ elif filename.endswith(".jsonl"):
|
|
|
|
|
+ stream_reader = f
|
|
|
|
|
+ else:
|
|
|
|
|
+ raise ValueError(f"Unknown file type: {filename}")
|
|
|
|
|
|
|
|
|
|
+ stream = io.TextIOWrapper(stream_reader, encoding="utf-8")
|
|
|
|
|
+
|
|
|
|
|
+ # Parse jsonl
|
|
|
|
|
+ for line in stream:
|
|
|
|
|
+ line = json.loads(line)
|
|
|
|
|
+ yield line
|
|
|
|
|
+
|
|
|
|
|
+ def parse_data(self, filename: str):
|
|
|
|
|
+ for line in self.read_jsonl(filename):
|
|
|
# encode
|
|
# encode
|
|
|
tokens = self.tokenizer.encode(
|
|
tokens = self.tokenizer.encode(
|
|
|
- text,
|
|
|
|
|
|
|
+ line["text"],
|
|
|
add_special_tokens=False,
|
|
add_special_tokens=False,
|
|
|
truncation=False,
|
|
truncation=False,
|
|
|
max_length=10**6,
|
|
max_length=10**6,
|
|
|
)
|
|
)
|
|
|
|
|
|
|
|
- # Random choice self.max_length
|
|
|
|
|
- if len(tokens) > self.max_length:
|
|
|
|
|
- start = random.randint(0, len(tokens) - self.max_length)
|
|
|
|
|
- tokens = tokens[start : start + self.max_length - 1]
|
|
|
|
|
-
|
|
|
|
|
tokens = (
|
|
tokens = (
|
|
|
[self.tokenizer.bos_token_id] + tokens + [self.tokenizer.eos_token_id]
|
|
[self.tokenizer.bos_token_id] + tokens + [self.tokenizer.eos_token_id]
|
|
|
)
|
|
)
|
|
|
- # Pad dims
|
|
|
|
|
- placeholder_multi_codebook = torch.zeros((4, len(tokens)), dtype=torch.long)
|
|
|
|
|
-
|
|
|
|
|
- tokens = torch.concat(
|
|
|
|
|
- [
|
|
|
|
|
- torch.tensor([tokens], dtype=torch.long),
|
|
|
|
|
- placeholder_multi_codebook,
|
|
|
|
|
- ],
|
|
|
|
|
- dim=0,
|
|
|
|
|
- )
|
|
|
|
|
|
|
+
|
|
|
|
|
+ if len(tokens) > self.max_length:
|
|
|
|
|
+ tokens = tokens[: self.max_length]
|
|
|
|
|
+
|
|
|
|
|
+ tokens = self.pad_codebooks(tokens)
|
|
|
labels = tokens.clone()
|
|
labels = tokens.clone()
|
|
|
tokens = tokens[:, :-1]
|
|
tokens = tokens[:, :-1]
|
|
|
labels = labels[:, 1:]
|
|
labels = labels[:, 1:]
|
|
|
- labels[1:] = -100 # remove all placeholders
|
|
|
|
|
|
|
+ labels[1:] = -100 # no loss on codebook
|
|
|
|
|
|
|
|
yield {"tokens": tokens, "labels": labels}
|
|
yield {"tokens": tokens, "labels": labels}
|
|
|
|
|
|
|
|
- def parse_data_internal(self, filename: str):
|
|
|
|
|
- url = f"https://huggingface.co/datasets/{self.repo}/resolve/main/{filename}"
|
|
|
|
|
|
|
+ def pad_codebooks(self, tokens):
|
|
|
|
|
+ placeholder_multi_codebook = (
|
|
|
|
|
+ torch.zeros((self.num_codebooks, len(tokens)), dtype=torch.long)
|
|
|
|
|
+ + CODEBOOK_PAD_TOKEN_ID
|
|
|
|
|
+ )
|
|
|
|
|
+ return torch.concat(
|
|
|
|
|
+ [
|
|
|
|
|
+ torch.tensor([tokens], dtype=torch.long),
|
|
|
|
|
+ placeholder_multi_codebook,
|
|
|
|
|
+ ],
|
|
|
|
|
+ dim=0,
|
|
|
|
|
+ )
|
|
|
|
|
+
|
|
|
|
|
|
|
|
- with xopen(url, mode="rb") as stream:
|
|
|
|
|
- parquet_file = pq.ParquetFile(stream)
|
|
|
|
|
|
|
+class TextInstructionDataset(TextPretrainDataset):
|
|
|
|
|
+ def parse_data(self, filename: str):
|
|
|
|
|
+ for line in self.read_jsonl(filename):
|
|
|
|
|
+ messages = []
|
|
|
|
|
+ for conversation in line["conversations"]:
|
|
|
|
|
+ role = {
|
|
|
|
|
+ "human": "user",
|
|
|
|
|
+ "gpt": "assistant",
|
|
|
|
|
+ "system": "system",
|
|
|
|
|
+ }[conversation["from"]]
|
|
|
|
|
+
|
|
|
|
|
+ message = Message(
|
|
|
|
|
+ role=role,
|
|
|
|
|
+ parts=[conversation["value"]],
|
|
|
|
|
+ )
|
|
|
|
|
+ messages.append(message)
|
|
|
|
|
+
|
|
|
|
|
+ conversation = Conversation(messages=messages)
|
|
|
|
|
+ tokens, labels = encode_conversation(
|
|
|
|
|
+ conversation,
|
|
|
|
|
+ self.tokenizer,
|
|
|
|
|
+ num_codebooks=self.num_codebooks,
|
|
|
|
|
+ )
|
|
|
|
|
|
|
|
- for batch in parquet_file.iter_batches(
|
|
|
|
|
- batch_size=self.parquet_batch_size, columns=["text"]
|
|
|
|
|
- ):
|
|
|
|
|
- # In-batch shuffling
|
|
|
|
|
- texts = [{"text": text.as_py()} for text in batch["text"]]
|
|
|
|
|
- random.shuffle(texts)
|
|
|
|
|
- yield from texts
|
|
|
|
|
|
|
+ yield {"tokens": tokens, "labels": labels}
|
|
|
|
|
|
|
|
|
|
|
|
|
-class AutoAugTextDataset(IterableDataset):
|
|
|
|
|
- """
|
|
|
|
|
- Auto Augment Dataset by Speaker
|
|
|
|
|
|
|
+def semantic_to_tensor(semantics):
|
|
|
|
|
+ num_codebooks = len(semantics)
|
|
|
|
|
+ codes = [[] for _ in range(num_codebooks)]
|
|
|
|
|
|
|
|
- 1. Random concatenate multiple sentences from the same speaker to form a longer sentence
|
|
|
|
|
- 2. Automatically normalize the text
|
|
|
|
|
|
|
+ for book_idx, book in zip(range(num_codebooks), semantics):
|
|
|
|
|
+ for j in book.values:
|
|
|
|
|
+ codes[book_idx].append(int(j))
|
|
|
|
|
|
|
|
- For interactive mode, we use the following format (multiple sequences):
|
|
|
|
|
- <s> [INST] [SPK: speaker] text [/INST] ... [INST] text [/INST] </s>
|
|
|
|
|
|
|
+ return torch.tensor(codes, dtype=torch.int)
|
|
|
|
|
|
|
|
- For non-interactive mode, we use the following format (one long sequence):
|
|
|
|
|
- <s> [INST] text [/INST] ... </s>
|
|
|
|
|
- """
|
|
|
|
|
|
|
|
|
|
|
|
+class AutoTextSemanticInstructionDataset(IterableDataset):
|
|
|
def __init__(
|
|
def __init__(
|
|
|
self,
|
|
self,
|
|
|
proto_files: list[str],
|
|
proto_files: list[str],
|
|
|
seed: int = 42,
|
|
seed: int = 42,
|
|
|
- interactive_prob: float = 0.5,
|
|
|
|
|
max_length: int = 1024,
|
|
max_length: int = 1024,
|
|
|
tokenizer: AutoTokenizer = None,
|
|
tokenizer: AutoTokenizer = None,
|
|
|
- use_speaker: bool | float = True,
|
|
|
|
|
- causual: bool = True,
|
|
|
|
|
- use_negative_samples: bool = False,
|
|
|
|
|
|
|
+ causual: Union[bool, float] = True,
|
|
|
num_codebooks: Optional[int] = None,
|
|
num_codebooks: Optional[int] = None,
|
|
|
skip_text_prob: float = 0.0,
|
|
skip_text_prob: float = 0.0,
|
|
|
|
|
+ asr_prob: float = 0.0,
|
|
|
):
|
|
):
|
|
|
"""
|
|
"""
|
|
|
Args:
|
|
Args:
|
|
|
proto_files: proto buf files if using local data
|
|
proto_files: proto buf files if using local data
|
|
|
seed: random seed
|
|
seed: random seed
|
|
|
- interactive_prob: probability to use interactive mode
|
|
|
|
|
max_length: max length of the text
|
|
max_length: max length of the text
|
|
|
tokenizer: tokenizer
|
|
tokenizer: tokenizer
|
|
|
- use_speaker: include speaker information in the prompt
|
|
|
|
|
causual: use causual sampling when using local data, disable will lead to random sampling
|
|
causual: use causual sampling when using local data, disable will lead to random sampling
|
|
|
- use_negative_samples: generate negative samples
|
|
|
|
|
num_codebooks: number of codebooks, if None, it will be automatically detected
|
|
num_codebooks: number of codebooks, if None, it will be automatically detected
|
|
|
skip_text_prob: probability to skip the text (audio only), this only applies to interactive mode
|
|
skip_text_prob: probability to skip the text (audio only), this only applies to interactive mode
|
|
|
|
|
+ asr_prob: probability to use ASR
|
|
|
"""
|
|
"""
|
|
|
|
|
|
|
|
super().__init__()
|
|
super().__init__()
|
|
|
|
|
|
|
|
- assert 0 <= interactive_prob <= 1, "interactive_prob must be in [0, 1]"
|
|
|
|
|
|
|
+ assert 0 <= skip_text_prob <= 1, "skip_text_prob must be in [0, 1]"
|
|
|
|
|
+ assert 0 <= asr_prob <= 1, "asr_prob must be in [0, 1]"
|
|
|
|
|
|
|
|
self.seed = seed
|
|
self.seed = seed
|
|
|
self.max_length = max_length
|
|
self.max_length = max_length
|
|
|
self.tokenizer = tokenizer
|
|
self.tokenizer = tokenizer
|
|
|
- self.interactive_prob = interactive_prob
|
|
|
|
|
- self.use_speaker = use_speaker
|
|
|
|
|
self.proto_files = proto_files
|
|
self.proto_files = proto_files
|
|
|
self.causual = causual
|
|
self.causual = causual
|
|
|
- self.use_negative_samples = use_negative_samples
|
|
|
|
|
self.num_codebooks = num_codebooks
|
|
self.num_codebooks = num_codebooks
|
|
|
self.skip_text_prob = skip_text_prob
|
|
self.skip_text_prob = skip_text_prob
|
|
|
-
|
|
|
|
|
- self.semantic_token_id = self.tokenizer.convert_tokens_to_ids("<|semantic|>")
|
|
|
|
|
|
|
+ self.asr_prob = asr_prob
|
|
|
self.groups = None
|
|
self.groups = None
|
|
|
|
|
|
|
|
def init_mock_data_server(self):
|
|
def init_mock_data_server(self):
|
|
|
if self.groups is not None:
|
|
if self.groups is not None:
|
|
|
return
|
|
return
|
|
|
|
|
|
|
|
- # Expand the proto files
|
|
|
|
|
- expanded_proto_files = []
|
|
|
|
|
- for filename in self.proto_files:
|
|
|
|
|
- for i in braceexpand(filename):
|
|
|
|
|
- i = Path(i)
|
|
|
|
|
- if i.is_file():
|
|
|
|
|
- expanded_proto_files.append(i)
|
|
|
|
|
- elif i.is_dir():
|
|
|
|
|
- expanded_proto_files.extend(i.rglob("*.proto"))
|
|
|
|
|
- expanded_proto_files.extend(i.rglob("*.protos"))
|
|
|
|
|
- else:
|
|
|
|
|
- raise ValueError(f"{i} is not a file or directory")
|
|
|
|
|
-
|
|
|
|
|
- expanded_proto_files = sorted(expanded_proto_files)
|
|
|
|
|
- Random(self.seed).shuffle(expanded_proto_files)
|
|
|
|
|
-
|
|
|
|
|
self.groups = []
|
|
self.groups = []
|
|
|
- shard_proto_files = split_by_rank_worker(expanded_proto_files)
|
|
|
|
|
- log.info(
|
|
|
|
|
- f"Reading {len(shard_proto_files)} / {len(expanded_proto_files)} files"
|
|
|
|
|
- )
|
|
|
|
|
|
|
+ shard_proto_files = expand_split_proto_files(self.proto_files, seed=self.seed)
|
|
|
|
|
+ log.info(f"Reading {len(shard_proto_files)} files")
|
|
|
|
|
|
|
|
count = 0
|
|
count = 0
|
|
|
for filename in shard_proto_files:
|
|
for filename in shard_proto_files:
|
|
@@ -279,7 +302,11 @@ class AutoAugTextDataset(IterableDataset):
|
|
|
# choice group based on their number of samples
|
|
# choice group based on their number of samples
|
|
|
group = random.choices(self.groups, weights=self.group_weights, k=1)[0]
|
|
group = random.choices(self.groups, weights=self.group_weights, k=1)[0]
|
|
|
|
|
|
|
|
- if self.causual:
|
|
|
|
|
|
|
+ causual = self.causual
|
|
|
|
|
+ if isinstance(self.causual, float):
|
|
|
|
|
+ causual = random.random() < self.causual
|
|
|
|
|
+
|
|
|
|
|
+ if causual:
|
|
|
# Sample in order
|
|
# Sample in order
|
|
|
if num_samples >= len(group.sentences):
|
|
if num_samples >= len(group.sentences):
|
|
|
samples = group.sentences
|
|
samples = group.sentences
|
|
@@ -298,7 +325,6 @@ class AutoAugTextDataset(IterableDataset):
|
|
|
)
|
|
)
|
|
|
|
|
|
|
|
def augment(self):
|
|
def augment(self):
|
|
|
- final_text, final_semantic = [], []
|
|
|
|
|
response = self.sample_data()
|
|
response = self.sample_data()
|
|
|
if len(response.samples) == 0:
|
|
if len(response.samples) == 0:
|
|
|
# Invalid group
|
|
# Invalid group
|
|
@@ -306,29 +332,9 @@ class AutoAugTextDataset(IterableDataset):
|
|
|
|
|
|
|
|
samples = list(response.samples)
|
|
samples = list(response.samples)
|
|
|
idx = 0
|
|
idx = 0
|
|
|
- use_interactive = random.random() < self.interactive_prob
|
|
|
|
|
-
|
|
|
|
|
- if use_interactive is False:
|
|
|
|
|
- # Random sample based on speaker using a truncated normal distribution
|
|
|
|
|
- a = torch.tensor([0], dtype=torch.float32)
|
|
|
|
|
- torch.nn.init.trunc_normal_(
|
|
|
|
|
- a,
|
|
|
|
|
- mean=self.max_length // 2,
|
|
|
|
|
- std=self.max_length // 4,
|
|
|
|
|
- a=10,
|
|
|
|
|
- b=self.max_length,
|
|
|
|
|
- )
|
|
|
|
|
- remaining_tokens = a.long().item() - 4
|
|
|
|
|
- else:
|
|
|
|
|
- remaining_tokens = self.max_length
|
|
|
|
|
-
|
|
|
|
|
- # Use speaker
|
|
|
|
|
- if isinstance(self.use_speaker, float):
|
|
|
|
|
- use_speaker = random.random() < self.use_speaker
|
|
|
|
|
- else:
|
|
|
|
|
- use_speaker = self.use_speaker
|
|
|
|
|
|
|
+ remaining_tokens = self.max_length
|
|
|
|
|
|
|
|
- all_tokens, all_labels = [], []
|
|
|
|
|
|
|
+ all_messages = []
|
|
|
while remaining_tokens > 0 and len(samples) > 0:
|
|
while remaining_tokens > 0 and len(samples) > 0:
|
|
|
sentence = samples.pop(0)
|
|
sentence = samples.pop(0)
|
|
|
|
|
|
|
@@ -336,37 +342,52 @@ class AutoAugTextDataset(IterableDataset):
|
|
|
text, length = self.tokenize_sentence(text)
|
|
text, length = self.tokenize_sentence(text)
|
|
|
remaining_tokens -= length + len(sentence.semantics[0].values)
|
|
remaining_tokens -= length + len(sentence.semantics[0].values)
|
|
|
|
|
|
|
|
- if use_interactive is False:
|
|
|
|
|
- final_text.append(text)
|
|
|
|
|
- final_semantic.append(sentence.semantics)
|
|
|
|
|
|
|
+ # For interactive mode, we only apply speaker for the first sentence
|
|
|
|
|
+ # [INST] [SPK: speaker] text [/INST] ... [INST] text [/INST]
|
|
|
|
|
+
|
|
|
|
|
+ if random.random() < self.asr_prob:
|
|
|
|
|
+ all_messages.append(
|
|
|
|
|
+ Message(
|
|
|
|
|
+ role="user",
|
|
|
|
|
+ parts=[
|
|
|
|
|
+ random.choice(asr_instructions),
|
|
|
|
|
+ semantic_to_tensor(sentence.semantics),
|
|
|
|
|
+ ],
|
|
|
|
|
+ )
|
|
|
|
|
+ )
|
|
|
|
|
+ all_messages.append(
|
|
|
|
|
+ Message(
|
|
|
|
|
+ role="assistant",
|
|
|
|
|
+ parts=[text],
|
|
|
|
|
+ )
|
|
|
|
|
+ )
|
|
|
else:
|
|
else:
|
|
|
- # For interactive mode, we only apply speaker for the first sentence
|
|
|
|
|
- # [INST] [SPK: speaker] text [/INST] ... [INST] text [/INST]
|
|
|
|
|
- tokens, labels = self.pack_sentences(
|
|
|
|
|
- sentences=[text],
|
|
|
|
|
- semantics=[sentence.semantics],
|
|
|
|
|
- speaker=response.name if use_speaker else None,
|
|
|
|
|
- add_bos=idx == 0,
|
|
|
|
|
- skip_text=random.random() < self.skip_text_prob,
|
|
|
|
|
|
|
+ skip_text = random.random() < self.skip_text_prob
|
|
|
|
|
+ if skip_text:
|
|
|
|
|
+ text = SKIP_TEXT_STRING
|
|
|
|
|
+
|
|
|
|
|
+ all_messages.append(
|
|
|
|
|
+ Message(
|
|
|
|
|
+ role="user",
|
|
|
|
|
+ parts=[random.choice(tts_instructions) + text],
|
|
|
|
|
+ mask_labels=skip_text,
|
|
|
|
|
+ )
|
|
|
|
|
+ )
|
|
|
|
|
+ all_messages.append(
|
|
|
|
|
+ Message(
|
|
|
|
|
+ role="assistant",
|
|
|
|
|
+ parts=[semantic_to_tensor(sentence.semantics)],
|
|
|
|
|
+ mask_labels=skip_text,
|
|
|
|
|
+ )
|
|
|
)
|
|
)
|
|
|
-
|
|
|
|
|
- all_tokens.append(tokens)
|
|
|
|
|
- all_labels.append(labels)
|
|
|
|
|
|
|
|
|
|
idx += 1
|
|
idx += 1
|
|
|
|
|
|
|
|
- if use_interactive is False:
|
|
|
|
|
- tokens, labels = self.pack_sentences(
|
|
|
|
|
- final_text,
|
|
|
|
|
- semantics=final_semantic,
|
|
|
|
|
- speaker=response.name if use_speaker else None,
|
|
|
|
|
- add_bos=True,
|
|
|
|
|
- )
|
|
|
|
|
- all_tokens.append(tokens)
|
|
|
|
|
- all_labels.append(labels)
|
|
|
|
|
-
|
|
|
|
|
- tokens = torch.cat(all_tokens, dim=1)
|
|
|
|
|
- labels = torch.cat(all_labels, dim=1)
|
|
|
|
|
|
|
+ tokens, labels = encode_conversation(
|
|
|
|
|
+ Conversation(messages=all_messages),
|
|
|
|
|
+ self.tokenizer,
|
|
|
|
|
+ num_codebooks=self.num_codebooks,
|
|
|
|
|
+ )
|
|
|
|
|
|
|
|
# Verify that the length is correct
|
|
# Verify that the length is correct
|
|
|
assert tokens.size(1) == labels.size(1), f"{tokens.size(1)} != {labels.size(1)}"
|
|
assert tokens.size(1) == labels.size(1), f"{tokens.size(1)} != {labels.size(1)}"
|
|
@@ -374,156 +395,71 @@ class AutoAugTextDataset(IterableDataset):
|
|
|
# Verify bos token
|
|
# Verify bos token
|
|
|
assert tokens[0, 0] == self.tokenizer.bos_token_id
|
|
assert tokens[0, 0] == self.tokenizer.bos_token_id
|
|
|
|
|
|
|
|
- data = {"tokens": tokens, "labels": labels}
|
|
|
|
|
-
|
|
|
|
|
- if self.use_negative_samples:
|
|
|
|
|
- negative_samples = self.generate_negative_samples(all_tokens, all_labels)
|
|
|
|
|
- data.update(negative_samples)
|
|
|
|
|
-
|
|
|
|
|
- return data
|
|
|
|
|
-
|
|
|
|
|
- def generate_negative_samples(self, all_tokens, all_labels):
|
|
|
|
|
- new_tokens, new_labels = [], []
|
|
|
|
|
-
|
|
|
|
|
- for tokens, labels in zip(all_tokens, all_labels):
|
|
|
|
|
- # If all codebooks are not -100, we find where it starts
|
|
|
|
|
- start = torch.where(labels[1:].sum(0) != -100 * (labels.size(0) - 1))[0][0]
|
|
|
|
|
- assert (labels[1:, start:] != -100).all() # This shouldn't happen
|
|
|
|
|
|
|
+ return {"tokens": tokens, "labels": labels}
|
|
|
|
|
|
|
|
- mode = random.choice(["repeat", "lost", "noise"])
|
|
|
|
|
- begin = random.randint(start, labels.size(1) - 1)
|
|
|
|
|
- end = random.randint(begin, labels.size(1) - 1)
|
|
|
|
|
|
|
|
|
|
- if mode == "repeat":
|
|
|
|
|
- tokens = torch.cat(
|
|
|
|
|
- [
|
|
|
|
|
- tokens[:, :begin],
|
|
|
|
|
- tokens[:, begin:end],
|
|
|
|
|
- tokens[:, begin:end],
|
|
|
|
|
- tokens[:, end:],
|
|
|
|
|
- ],
|
|
|
|
|
- dim=1,
|
|
|
|
|
- )
|
|
|
|
|
- labels = torch.cat(
|
|
|
|
|
- [
|
|
|
|
|
- labels[:, :begin],
|
|
|
|
|
- labels[:, begin:end],
|
|
|
|
|
- labels[:, begin:end],
|
|
|
|
|
- labels[:, end:],
|
|
|
|
|
- ],
|
|
|
|
|
- dim=1,
|
|
|
|
|
- )
|
|
|
|
|
- elif mode == "lost":
|
|
|
|
|
- tokens = torch.cat([tokens[:, :begin], tokens[:, end:]], dim=1)
|
|
|
|
|
- labels = torch.cat([labels[:, :begin], labels[:, end:]], dim=1)
|
|
|
|
|
- elif mode == "noise":
|
|
|
|
|
- middle_tokens, middle_labels = (
|
|
|
|
|
- tokens[:, begin:end],
|
|
|
|
|
- labels[:, begin:end],
|
|
|
|
|
- )
|
|
|
|
|
- random_order0 = torch.randperm(middle_tokens.size(1))
|
|
|
|
|
- random_order1 = torch.randperm(middle_tokens.size(1))
|
|
|
|
|
- middle_tokens = middle_tokens[:, random_order0]
|
|
|
|
|
- middle_labels = middle_labels[:, random_order1]
|
|
|
|
|
- tokens = torch.cat(
|
|
|
|
|
- [tokens[:, :begin], middle_tokens, tokens[:, end:]], dim=1
|
|
|
|
|
- )
|
|
|
|
|
- labels = torch.cat(
|
|
|
|
|
- [labels[:, :begin], middle_labels, labels[:, end:]], dim=1
|
|
|
|
|
- )
|
|
|
|
|
-
|
|
|
|
|
- new_tokens.append(tokens)
|
|
|
|
|
- new_labels.append(labels)
|
|
|
|
|
-
|
|
|
|
|
- tokens = torch.cat(new_tokens, dim=1)
|
|
|
|
|
- labels = torch.cat(new_labels, dim=1)
|
|
|
|
|
-
|
|
|
|
|
- # Verify that the length is correct
|
|
|
|
|
- assert tokens.size(1) == labels.size(1), f"{tokens.size(1)} != {labels.size(1)}"
|
|
|
|
|
-
|
|
|
|
|
- return {"negative_tokens": tokens, "negative_labels": labels}
|
|
|
|
|
-
|
|
|
|
|
- def pack_sentences(
|
|
|
|
|
|
|
+class SemanticInstructionDataset(IterableDataset):
|
|
|
|
|
+ def __init__(
|
|
|
self,
|
|
self,
|
|
|
- sentences: list[str],
|
|
|
|
|
- semantics: list,
|
|
|
|
|
- speaker: Optional[str] = None,
|
|
|
|
|
- add_bos: bool = True,
|
|
|
|
|
- skip_text: bool = False,
|
|
|
|
|
|
|
+ proto_files: list[str],
|
|
|
|
|
+ seed: int = 42,
|
|
|
|
|
+ max_length: int = 1024,
|
|
|
|
|
+ tokenizer: AutoTokenizer = None,
|
|
|
|
|
+ num_codebooks: Optional[int] = None,
|
|
|
):
|
|
):
|
|
|
- if speaker is None:
|
|
|
|
|
- speaker = "assistant"
|
|
|
|
|
|
|
+ super().__init__()
|
|
|
|
|
|
|
|
- cated_sentences = " ".join(sentences)
|
|
|
|
|
- if skip_text:
|
|
|
|
|
- cated_sentences = SKIP_TEXT_STRING
|
|
|
|
|
|
|
+ self.seed = seed
|
|
|
|
|
+ self.max_length = max_length
|
|
|
|
|
+ self.tokenizer = tokenizer
|
|
|
|
|
+ self.proto_files = proto_files
|
|
|
|
|
+ self.num_codebooks = num_codebooks
|
|
|
|
|
|
|
|
- final_text = "<|im_start|>user<|im_sep|>" + cated_sentences + "<|im_end|>"
|
|
|
|
|
- final_text = final_text + f"<|im_start|>{speaker}<|im_sep|>"
|
|
|
|
|
|
|
+ def get_data_generator(self):
|
|
|
|
|
+ shard_proto_files = expand_split_proto_files(self.proto_files, seed=self.seed)
|
|
|
|
|
+ random.shuffle(shard_proto_files)
|
|
|
|
|
+ log.info(f"Fetched {len(shard_proto_files)} files")
|
|
|
|
|
|
|
|
- encoded = self.tokenizer.encode(
|
|
|
|
|
- final_text,
|
|
|
|
|
- add_special_tokens=False,
|
|
|
|
|
- truncation=False,
|
|
|
|
|
- max_length=10**6,
|
|
|
|
|
- )
|
|
|
|
|
- semantic_length = sum([len(i[0].values) for i in semantics])
|
|
|
|
|
- prompt_length = len(encoded)
|
|
|
|
|
- num_codebooks = (
|
|
|
|
|
- len(semantics[0]) if self.num_codebooks is None else self.num_codebooks
|
|
|
|
|
- )
|
|
|
|
|
|
|
+ for filename in shard_proto_files:
|
|
|
|
|
+ with open(filename, "rb") as f:
|
|
|
|
|
+ for group in read_pb_stream(f):
|
|
|
|
|
+ yield group
|
|
|
|
|
|
|
|
- bos_bias = 1 if add_bos else 0
|
|
|
|
|
|
|
+ def pack_one_group(self, group):
|
|
|
|
|
+ sentences = group.sentences
|
|
|
|
|
|
|
|
- # Pack the tokens and semantics (add <s> and </s> to semantic tokens)
|
|
|
|
|
- tokens = (
|
|
|
|
|
- encoded
|
|
|
|
|
- + [self.semantic_token_id] * semantic_length
|
|
|
|
|
- + self.tokenizer.convert_tokens_to_ids(
|
|
|
|
|
- ["<|im_end|>", "<|end_of_sequence|>"]
|
|
|
|
|
|
|
+ messages = []
|
|
|
|
|
+ for idx, sentence in enumerate(sentences):
|
|
|
|
|
+ role = "user" if idx % 2 == 0 else "assistant"
|
|
|
|
|
+ semantic = semantic_to_tensor(sentence.semantics)
|
|
|
|
|
+ text = random.choice(sentence.texts)
|
|
|
|
|
+ parts = [semantic]
|
|
|
|
|
+ if role == "assistant":
|
|
|
|
|
+ # Let model to predict the text first
|
|
|
|
|
+ prev_text = random.choice(sentences[idx - 1].texts)
|
|
|
|
|
+ # parts.insert(0, f"Q: {prev_text}\nA: {text}")
|
|
|
|
|
+ messages.append(
|
|
|
|
|
+ Message(
|
|
|
|
|
+ role=role,
|
|
|
|
|
+ parts=parts,
|
|
|
|
|
+ )
|
|
|
)
|
|
)
|
|
|
- )
|
|
|
|
|
-
|
|
|
|
|
- if add_bos:
|
|
|
|
|
- tokens = [self.tokenizer.bos_token_id] + tokens
|
|
|
|
|
-
|
|
|
|
|
- # Codebook bos/padding: 0, eos: 1
|
|
|
|
|
- codes = [
|
|
|
|
|
- [CODEBOOK_PAD_TOKEN_ID] * (prompt_length + bos_bias)
|
|
|
|
|
- for _ in range(num_codebooks)
|
|
|
|
|
- ]
|
|
|
|
|
- for segment in semantics:
|
|
|
|
|
- for book_idx, book in zip(range(num_codebooks), segment):
|
|
|
|
|
- for j in book.values:
|
|
|
|
|
- codes[book_idx].append(int(j) + 2)
|
|
|
|
|
|
|
|
|
|
- for book in codes:
|
|
|
|
|
- book.extend([CODEBOOK_EOS_TOKEN_ID] * 2)
|
|
|
|
|
-
|
|
|
|
|
- tokens = [tokens] + codes
|
|
|
|
|
-
|
|
|
|
|
- tokens = torch.tensor(tokens, dtype=torch.long)
|
|
|
|
|
- labels = tokens.clone()
|
|
|
|
|
-
|
|
|
|
|
- if skip_text:
|
|
|
|
|
- # If text is not provided, the sentence is used for condition only, all labels are -100
|
|
|
|
|
- torch.fill_(labels, -100)
|
|
|
|
|
- return tokens, labels
|
|
|
|
|
-
|
|
|
|
|
- # Mask out the <s> tokens for semantic, predict semantic tokens only
|
|
|
|
|
- # Since we don't mask out the input tokens, the language modeling still works
|
|
|
|
|
- labels[1:, : (prompt_length + bos_bias)] = -100
|
|
|
|
|
-
|
|
|
|
|
- tokens = tokens[:, :-1]
|
|
|
|
|
- labels = labels[:, 1:]
|
|
|
|
|
|
|
+ conversation = Conversation(messages=messages)
|
|
|
|
|
+ tokens, labels = encode_conversation(
|
|
|
|
|
+ conversation,
|
|
|
|
|
+ self.tokenizer,
|
|
|
|
|
+ num_codebooks=self.num_codebooks,
|
|
|
|
|
+ )
|
|
|
|
|
|
|
|
- # Verify the padding is correct, and the last token is eos
|
|
|
|
|
- assert add_bos is False or tokens[0, 0] == self.tokenizer.bos_token_id
|
|
|
|
|
- assert (tokens[1:, : prompt_length + bos_bias] == CODEBOOK_PAD_TOKEN_ID).all()
|
|
|
|
|
- assert labels[0, -1] == self.tokenizer.eos_token_id
|
|
|
|
|
- assert (labels[1:, -2:] == CODEBOOK_EOS_TOKEN_ID).all()
|
|
|
|
|
|
|
+ return {"tokens": tokens, "labels": labels}
|
|
|
|
|
|
|
|
- return tokens, labels
|
|
|
|
|
|
|
+ def __iter__(self):
|
|
|
|
|
+ for group in self.get_data_generator():
|
|
|
|
|
+ try:
|
|
|
|
|
+ yield self.pack_one_group(group)
|
|
|
|
|
+ except Exception as e:
|
|
|
|
|
+ log.exception(f"Failed to parse {group}: {e}")
|
|
|
|
|
|
|
|
|
|
|
|
|
@dataclass
|
|
@dataclass
|
|
@@ -633,8 +569,18 @@ class InterleaveDataset(IterableDataset):
|
|
|
class TextDataModule(LightningDataModule):
|
|
class TextDataModule(LightningDataModule):
|
|
|
def __init__(
|
|
def __init__(
|
|
|
self,
|
|
self,
|
|
|
- train_dataset: Union[StreamTextDataset, AutoAugTextDataset, InterleaveDataset],
|
|
|
|
|
- val_dataset: Union[StreamTextDataset, AutoAugTextDataset, InterleaveDataset],
|
|
|
|
|
|
|
+ train_dataset: Union[
|
|
|
|
|
+ AutoTextSemanticInstructionDataset,
|
|
|
|
|
+ TextPretrainDataset,
|
|
|
|
|
+ TextInstructionDataset,
|
|
|
|
|
+ InterleaveDataset,
|
|
|
|
|
+ ],
|
|
|
|
|
+ val_dataset: Union[
|
|
|
|
|
+ AutoTextSemanticInstructionDataset,
|
|
|
|
|
+ TextPretrainDataset,
|
|
|
|
|
+ TextInstructionDataset,
|
|
|
|
|
+ InterleaveDataset,
|
|
|
|
|
+ ],
|
|
|
batch_size: int = 32,
|
|
batch_size: int = 32,
|
|
|
tokenizer: AutoTokenizer = None,
|
|
tokenizer: AutoTokenizer = None,
|
|
|
max_length: int = 1024,
|
|
max_length: int = 1024,
|
|
@@ -671,17 +617,36 @@ class TextDataModule(LightningDataModule):
|
|
|
if __name__ == "__main__":
|
|
if __name__ == "__main__":
|
|
|
from tqdm import tqdm
|
|
from tqdm import tqdm
|
|
|
|
|
|
|
|
- ds = AutoAugTextDataset(
|
|
|
|
|
- ["data/protos"],
|
|
|
|
|
- tokenizer=AutoTokenizer.from_pretrained("fishaudio/fish-speech-1"),
|
|
|
|
|
- use_speaker=False,
|
|
|
|
|
- interactive_prob=1.0,
|
|
|
|
|
- use_negative_samples=False,
|
|
|
|
|
- skip_text_prob=0.5,
|
|
|
|
|
|
|
+ # ds = AutoTextSemanticInstructionDataset(
|
|
|
|
|
+ # ["data/protos/sft/val/11labs"],
|
|
|
|
|
+ # tokenizer=AutoTokenizer.from_pretrained("checkpoints/fish-speech-agent-1"),
|
|
|
|
|
+ # skip_text_prob=1.0,
|
|
|
|
|
+ # asr_prob=0.0,
|
|
|
|
|
+ # num_codebooks=2,
|
|
|
|
|
+ # )
|
|
|
|
|
+ # ds = TextInstructionDataset(
|
|
|
|
|
+ # source="data/openhermes2_5",
|
|
|
|
|
+ # tokenizer=AutoTokenizer.from_pretrained("checkpoints/fish-speech-agent-1"),
|
|
|
|
|
+ # )
|
|
|
|
|
+
|
|
|
|
|
+ ds = SemanticInstructionDataset(
|
|
|
|
|
+ proto_files=["data/protos/sft/val/ultrachat_200k_spoken_openai"],
|
|
|
|
|
+ tokenizer=AutoTokenizer.from_pretrained("checkpoints/fish-speech-agent-1"),
|
|
|
|
|
+ num_codebooks=2,
|
|
|
)
|
|
)
|
|
|
|
|
|
|
|
for i in ds:
|
|
for i in ds:
|
|
|
- print(ds.tokenizer.decode(i["tokens"][0], skip_special_tokens=False))
|
|
|
|
|
|
|
+ # print(ds.tokenizer.decode(i["tokens"][0], skip_special_tokens=False))
|
|
|
# i["labels"][0][i["labels"][0] == -100] = 0
|
|
# i["labels"][0][i["labels"][0] == -100] = 0
|
|
|
# print(ds.tokenizer.decode(i["labels"][0], skip_special_tokens=False))
|
|
# print(ds.tokenizer.decode(i["labels"][0], skip_special_tokens=False))
|
|
|
|
|
+
|
|
|
|
|
+ length = i["tokens"].size(1)
|
|
|
|
|
+ print(i["tokens"].size(), i["tokens"].dtype)
|
|
|
|
|
+ for j in range(length):
|
|
|
|
|
+ print(
|
|
|
|
|
+ ds.tokenizer.decode(i["tokens"][0, j]),
|
|
|
|
|
+ i["tokens"][:, j],
|
|
|
|
|
+ i["labels"][:, j],
|
|
|
|
|
+ )
|
|
|
|
|
+ input()
|
|
|
break
|
|
break
|