|
|
@@ -7,6 +7,7 @@ from pathlib import Path
|
|
|
from random import Random
|
|
|
from typing import Optional, Union
|
|
|
|
|
|
+import grpc
|
|
|
import numpy as np
|
|
|
import pyarrow.parquet as pq
|
|
|
import torch
|
|
|
@@ -18,8 +19,8 @@ from torch.distributed import get_rank, get_world_size, is_initialized
|
|
|
from torch.utils.data import DataLoader, IterableDataset, get_worker_info
|
|
|
from transformers import AutoTokenizer
|
|
|
|
|
|
-from fish_speech.datasets.protos.text_data_pb2 import Semantics
|
|
|
-from fish_speech.datasets.protos.text_data_stream import read_pb_stream
|
|
|
+from fish_speech.datasets.protos.text_data_pb2 import SampleDataRequest
|
|
|
+from fish_speech.datasets.protos.text_data_pb2_grpc import DataServiceStub
|
|
|
from fish_speech.text.symbols import pad as pad_symbol
|
|
|
from fish_speech.text.symbols import pu_symbols
|
|
|
from fish_speech.utils import RankedLogger
|
|
|
@@ -145,49 +146,26 @@ class AutoAugTextDataset(IterableDataset):
|
|
|
|
|
|
def __init__(
|
|
|
self,
|
|
|
- files: list[str],
|
|
|
+ server: str = "localhost:50051",
|
|
|
seed: int = 42,
|
|
|
phones_prob: float = 0.3,
|
|
|
max_length: int = 1024,
|
|
|
tokenizer: AutoTokenizer = None,
|
|
|
- split: Optional[str] = None,
|
|
|
):
|
|
|
super().__init__()
|
|
|
|
|
|
- self.files = files
|
|
|
self.seed = seed
|
|
|
self.phones_prob = phones_prob
|
|
|
self.max_length = max_length
|
|
|
self.tokenizer = tokenizer
|
|
|
|
|
|
# Read all lines, and group by speaker
|
|
|
- self.groups = []
|
|
|
- count = 0
|
|
|
- for filename in self.files:
|
|
|
- with open(filename, "rb") as f:
|
|
|
- for text_data in read_pb_stream(f):
|
|
|
- self.groups.append(text_data)
|
|
|
- count += 1
|
|
|
-
|
|
|
- if count % 10000 == 0:
|
|
|
- log.info(f"Read {count} groups of text data")
|
|
|
-
|
|
|
- # Shuffle the lines
|
|
|
- Random(seed).shuffle(self.groups)
|
|
|
-
|
|
|
- if split == "train":
|
|
|
- self.groups = self.groups[:-500]
|
|
|
- elif split == "val":
|
|
|
- self.groups = self.groups[-500:]
|
|
|
+ self.channel = grpc.insecure_channel(server)
|
|
|
+ self.stub = DataServiceStub(self.channel)
|
|
|
|
|
|
def __iter__(self):
|
|
|
- groups = split_by_rank_worker(self.groups)
|
|
|
- random.shuffle(groups)
|
|
|
-
|
|
|
- for group in groups:
|
|
|
- x = self.augment(group)
|
|
|
- if x is not None:
|
|
|
- yield x
|
|
|
+ while True:
|
|
|
+ yield self.augment()
|
|
|
|
|
|
def tokenize_sentence(self, sentence: str, phones: list[str], mode: str = "sample"):
|
|
|
if (
|
|
|
@@ -208,12 +186,11 @@ class AutoAugTextDataset(IterableDataset):
|
|
|
)
|
|
|
return sentence, len(tokens)
|
|
|
|
|
|
- def augment(self, group):
|
|
|
+ def augment(self):
|
|
|
# 50% to pure text or pure phones
|
|
|
- # mode = "sample"
|
|
|
- # if random.random() < 0.5:
|
|
|
- # mode = random.choice(["text", "phones"])
|
|
|
- mode = "phones"
|
|
|
+ mode = "sample"
|
|
|
+ if random.random() < 0.5:
|
|
|
+ mode = random.choice(["text", "phones"])
|
|
|
|
|
|
# Random sample based on speaker using a truncated normal distribution
|
|
|
a = torch.tensor([0], dtype=torch.float32)
|
|
|
@@ -229,15 +206,15 @@ class AutoAugTextDataset(IterableDataset):
|
|
|
final_text, final_semantic = [], []
|
|
|
|
|
|
# Shuffle unique lines
|
|
|
- idxs = list(range(len(group.sentences)))
|
|
|
- random.shuffle(idxs)
|
|
|
-
|
|
|
- if len(idxs) == 0:
|
|
|
+ request = SampleDataRequest(num_samples=50)
|
|
|
+ response = self.stub.SampleData(request)
|
|
|
+ if len(response.samples) == 0:
|
|
|
# Invalid group
|
|
|
return None
|
|
|
|
|
|
- while remaining_tokens > 0 and len(idxs) > 0:
|
|
|
- sentence = group.sentences[idxs.pop()]
|
|
|
+ samples = list(response.samples)
|
|
|
+ while remaining_tokens > 0 and len(samples) > 0:
|
|
|
+ sentence = samples.pop()
|
|
|
text, length = self.tokenize_sentence(
|
|
|
sentence.text, sentence.phones, mode=mode
|
|
|
)
|
|
|
@@ -393,18 +370,9 @@ class TextDataModule(LightningDataModule):
|
|
|
if __name__ == "__main__":
|
|
|
import json
|
|
|
|
|
|
- # data/Genshin/English/Aabid/vo_KVCOP001_1907808_aabid_01.lab
|
|
|
- # all_files = [i for i in Path("data/Genshin/English").rglob("*.lab")]
|
|
|
- # with open("test.jsonl", "w") as f:
|
|
|
- # for i in all_files:
|
|
|
- # wav_file = i.with_suffix(".wav")
|
|
|
- # duration = float(Path(wav_file).stat().st_size) / 2 / 44100
|
|
|
- # eta_tokens = duration * 25
|
|
|
- # fake_tokens = [random.randint(0, 2048) for _ in range(int(eta_tokens))]
|
|
|
- # f.write(json.dumps({"text": Path(i).read_text(), "speaker": i.parent.name, "semantic": fake_tokens}) + "\n")
|
|
|
+ from tqdm import tqdm
|
|
|
|
|
|
ds = AutoAugTextDataset(
|
|
|
- files=["data/quantized-dataset-1205.protos"],
|
|
|
tokenizer=AutoTokenizer.from_pretrained("fishaudio/speech-lm-v1"),
|
|
|
)
|
|
|
|
|
|
@@ -417,5 +385,5 @@ if __name__ == "__main__":
|
|
|
num_workers=0,
|
|
|
)
|
|
|
|
|
|
- for batch in dm.train_dataloader():
|
|
|
- print(batch)
|
|
|
+ for batch in tqdm(dm.train_dataloader()):
|
|
|
+ pass
|