|
|
@@ -16,8 +16,9 @@ from torch.distributed import get_rank, get_world_size, is_initialized
|
|
|
from torch.utils.data import DataLoader, IterableDataset, get_worker_info
|
|
|
from transformers import AutoTokenizer
|
|
|
|
|
|
-from fish_speech.datasets.protos.text_data_pb2 import SampleDataRequest
|
|
|
+from fish_speech.datasets.protos.text_data_pb2 import SampleDataRequest, SampledData
|
|
|
from fish_speech.datasets.protos.text_data_pb2_grpc import DataServiceStub
|
|
|
+from fish_speech.datasets.protos.text_data_stream import read_pb_stream
|
|
|
from fish_speech.text.parser import clean_text
|
|
|
from fish_speech.text.symbols import pad as pad_symbol
|
|
|
from fish_speech.text.symbols import pu_symbols
|
|
|
@@ -192,6 +193,9 @@ class AutoAugTextDataset(IterableDataset):
|
|
|
max_length: int = 1024,
|
|
|
tokenizer: AutoTokenizer = None,
|
|
|
use_speaker: bool = True,
|
|
|
+ use_data_server: bool = True,
|
|
|
+ proto_files: str = "data",
|
|
|
+ causual: bool = True,
|
|
|
):
|
|
|
"""
|
|
|
Args:
|
|
|
@@ -202,6 +206,10 @@ class AutoAugTextDataset(IterableDataset):
|
|
|
interactive_prob: probability to use interactive mode
|
|
|
max_length: max length of the text
|
|
|
tokenizer: tokenizer
|
|
|
+ use_speaker: include speaker information in the prompt
|
|
|
+ use_data_server: use data server or local data
|
|
|
+ proto_files: proto buf files if using local data
|
|
|
+ causual: use causual sampling when using local data, disable will lead to random sampling
|
|
|
"""
|
|
|
|
|
|
super().__init__()
|
|
|
@@ -213,10 +221,32 @@ class AutoAugTextDataset(IterableDataset):
|
|
|
self.repetition_prob = repetition_prob
|
|
|
self.interactive_prob = interactive_prob
|
|
|
self.use_speaker = use_speaker
|
|
|
+ self.use_data_server = use_data_server
|
|
|
+ self.proto_files = proto_files
|
|
|
+ self.causual = causual
|
|
|
|
|
|
- # Read all lines, and group by speaker
|
|
|
- self.channel = grpc.insecure_channel(server)
|
|
|
- self.stub = DataServiceStub(self.channel)
|
|
|
+ if use_data_server is True:
|
|
|
+ self.channel = grpc.insecure_channel(server)
|
|
|
+ self.stub = DataServiceStub(self.channel)
|
|
|
+ else:
|
|
|
+ self.init_mock_data_server()
|
|
|
+
|
|
|
+ def init_mock_data_server(self):
|
|
|
+ self.groups = []
|
|
|
+ count = 0
|
|
|
+ for filename in self.proto_files:
|
|
|
+ with open(filename, "rb") as f:
|
|
|
+ for text_data in read_pb_stream(f):
|
|
|
+ self.groups.append(text_data)
|
|
|
+ count += 1
|
|
|
+
|
|
|
+ if count % 1000 == 0:
|
|
|
+ log.info(f"Read {count} groups of data")
|
|
|
+
|
|
|
+ log.info(f"Read total {count} groups of data")
|
|
|
+
|
|
|
+ # Shuffle the lines
|
|
|
+ Random(self.seed).shuffle(self.groups)
|
|
|
|
|
|
def __iter__(self):
|
|
|
while True:
|
|
|
@@ -243,6 +273,37 @@ class AutoAugTextDataset(IterableDataset):
|
|
|
)
|
|
|
return sentence, len(tokens)
|
|
|
|
|
|
+ def sample_data(self):
|
|
|
+ # Shuffle unique lines, estimate that each sample is at least 20 tokens
|
|
|
+ num_samples = self.max_length // 20
|
|
|
+
|
|
|
+ if self.use_data_server:
|
|
|
+ request = SampleDataRequest(num_samples=num_samples)
|
|
|
+ return self.stub.SampleData(request)
|
|
|
+
|
|
|
+ # choice group based on their number of samples
|
|
|
+ group = random.choices(
|
|
|
+ self.groups, weights=[len(i.sentences) for i in self.groups], k=1
|
|
|
+ )[0]
|
|
|
+
|
|
|
+ if self.causual:
|
|
|
+ # Sample in order
|
|
|
+ if num_samples >= len(group.sentences):
|
|
|
+ samples = group.sentences
|
|
|
+ else:
|
|
|
+ begin = random.randint(0, len(group.sentences) - num_samples)
|
|
|
+ samples = group.sentences[begin : begin + num_samples]
|
|
|
+ else:
|
|
|
+ samples = random.choices(
|
|
|
+ group.sentences, k=min(num_samples, len(group.sentences))
|
|
|
+ )
|
|
|
+
|
|
|
+ return SampledData(
|
|
|
+ source=group.source,
|
|
|
+ name=group.name,
|
|
|
+ samples=samples,
|
|
|
+ )
|
|
|
+
|
|
|
def augment(self):
|
|
|
# 50% to pure text or pure phones
|
|
|
mode = "sample"
|
|
|
@@ -261,10 +322,7 @@ class AutoAugTextDataset(IterableDataset):
|
|
|
remaining_tokens = a.long().item() - 4
|
|
|
|
|
|
final_text, final_semantic = [], []
|
|
|
-
|
|
|
- # Shuffle unique lines, estimate that each sample is at least 20 tokens
|
|
|
- request = SampleDataRequest(num_samples=self.max_length // 20)
|
|
|
- response = self.stub.SampleData(request)
|
|
|
+ response = self.sample_data()
|
|
|
if len(response.samples) == 0:
|
|
|
# Invalid group
|
|
|
return None
|
|
|
@@ -500,10 +558,18 @@ class TextDataModule(LightningDataModule):
|
|
|
if __name__ == "__main__":
|
|
|
from tqdm import tqdm
|
|
|
|
|
|
+ # ds = AutoAugTextDataset(
|
|
|
+ # tokenizer=AutoTokenizer.from_pretrained("fishaudio/speech-lm-v1"),
|
|
|
+ # use_speaker=True,
|
|
|
+ # interactive_prob=1.0,
|
|
|
+ # )
|
|
|
+
|
|
|
ds = AutoAugTextDataset(
|
|
|
tokenizer=AutoTokenizer.from_pretrained("fishaudio/speech-lm-v1"),
|
|
|
use_speaker=True,
|
|
|
interactive_prob=1.0,
|
|
|
+ use_data_server=False,
|
|
|
+ proto_files=["data/wenet-speech.protos"],
|
|
|
)
|
|
|
|
|
|
dm = TextDataModule(
|