Browse Source

Add dataset protobuf & parallel building

Lengyue 2 years ago
parent
commit
5ae08bd43b

+ 18 - 0
fish_speech/datasets/protos/text-data.proto

@@ -0,0 +1,18 @@
+syntax = "proto3";
+
+message Semantics {
+    repeated uint32 values = 1;
+}
+
+message Sentence {
+    string text = 1;
+    repeated string phones = 2;
+    repeated Semantics semantics = 3;
+}
+
+message TextData {
+    string source = 1;
+    string name = 2;
+    repeated string languages = 3;
+    repeated Sentence sentences = 4;
+}

+ 31 - 0
fish_speech/datasets/protos/text_data_pb2.py

@@ -0,0 +1,31 @@
+# -*- coding: utf-8 -*-
+# Generated by the protocol buffer compiler.  DO NOT EDIT!
+# source: text-data.proto
+# Protobuf Python Version: 4.25.1
+"""Generated protocol buffer code."""
+from google.protobuf import descriptor as _descriptor
+from google.protobuf import descriptor_pool as _descriptor_pool
+from google.protobuf import symbol_database as _symbol_database
+from google.protobuf.internal import builder as _builder
+
+# @@protoc_insertion_point(imports)
+
+_sym_db = _symbol_database.Default()
+
+
+DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(
+    b'\n\x0ftext-data.proto"\x1b\n\tSemantics\x12\x0e\n\x06values\x18\x01 \x03(\r"G\n\x08Sentence\x12\x0c\n\x04text\x18\x01 \x01(\t\x12\x0e\n\x06phones\x18\x02 \x03(\t\x12\x1d\n\tsemantics\x18\x03 \x03(\x0b\x32\n.Semantics"Y\n\x08TextData\x12\x0e\n\x06source\x18\x01 \x01(\t\x12\x0c\n\x04name\x18\x02 \x01(\t\x12\x11\n\tlanguages\x18\x03 \x03(\t\x12\x1c\n\tsentences\x18\x04 \x03(\x0b\x32\t.Sentenceb\x06proto3'
+)
+
+_globals = globals()
+_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
+_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, "text_data_pb2", _globals)
+if _descriptor._USE_C_DESCRIPTORS == False:
+    DESCRIPTOR._options = None
+    _globals["_SEMANTICS"]._serialized_start = 19
+    _globals["_SEMANTICS"]._serialized_end = 46
+    _globals["_SENTENCE"]._serialized_start = 48
+    _globals["_SENTENCE"]._serialized_end = 119
+    _globals["_TEXTDATA"]._serialized_start = 121
+    _globals["_TEXTDATA"]._serialized_end = 210
+# @@protoc_insertion_point(module_scope)

+ 26 - 0
fish_speech/datasets/protos/text_data_stream.py

@@ -0,0 +1,26 @@
+import struct
+
+from .text_data_pb2 import TextData
+
+
+def read_pb_stream(f):
+    while True:
+        buf = f.read(4)
+        if len(buf) == 0:
+            break
+        size = struct.unpack("I", buf)[0]
+        buf = f.read(size)
+        text_data = TextData()
+        text_data.ParseFromString(buf)
+        yield text_data
+
+
+def write_pb_stream(f, text_data):
+    buf = text_data.SerializeToString()
+    f.write(struct.pack("I", len(buf)))
+    f.write(buf)
+
+
+def pack_pb_stream(text_data):
+    buf = text_data.SerializeToString()
+    return struct.pack("I", len(buf)) + buf

+ 15 - 11
fish_speech/datasets/text.py

@@ -8,6 +8,7 @@ from random import Random
 from typing import Optional, Union
 
 import numpy as np
+import orjson
 import pyarrow.parquet as pq
 import torch
 import torch.nn.functional as F
@@ -166,21 +167,24 @@ class AutoAugTextDataset(IterableDataset):
         self.tokenizer = tokenizer
 
         # Read all lines, and group by speaker
-        self.speakers = {}
-        self.lines = []
+        self.groups = []
+        from tqdm import tqdm
 
         for filename in self.jsonl_files:
-            lines = Path(filename).read_text().splitlines()
-            for json_line in lines:
-                line = json.loads(json_line)
-                speaker = line.get("speaker", None)
+            with open(filename, "r") as f:
+                for json_line in tqdm(f):
+                    if json_line.strip() == "":
+                        continue
 
-                if speaker not in self.speakers:
-                    self.speakers[speaker] = []
+                    line = orjson.loads(json_line)
+                    # for i in line["sentences"]:
+                    #     # Save memory
+                    #     i["semantics"] = np.array(i["semantics"], dtype=np.uint16)
+                    self.groups.append(line)
 
-                self.lines.append(line)
-                self.speakers[speaker].append(line)
+        import sys
 
+        print(sys.getsizeof(self.groups) / 1024 / 1024)
         # Shuffle the lines
         Random(seed).shuffle(self.lines)
 
@@ -394,7 +398,7 @@ if __name__ == "__main__":
     #         f.write(json.dumps({"text": Path(i).read_text(), "speaker": i.parent.name, "semantic": fake_tokens}) + "\n")
 
     ds = AutoAugTextDataset(
-        jsonl_files=["test.jsonl"],
+        jsonl_files=["data/quantized-dataset-1205.json"],
         order=["en"],
         tokenizer=AutoTokenizer.from_pretrained(
             "fishaudio/speech-lm-300m", revision="text-pretrain-10k-phones"

+ 73 - 78
tools/llama/build_dataset.py

@@ -1,14 +1,13 @@
-import json
 import re
 from collections import defaultdict
-from dataclasses import asdict, dataclass
-from pathlib import Path
-from typing import Union
+from multiprocessing import Pool
 
 import numpy as np
 from loguru import logger
 from tqdm import tqdm
 
+from fish_speech.datasets.protos.text_data_pb2 import Semantics, Sentence, TextData
+from fish_speech.datasets.protos.text_data_stream import pack_pb_stream
 from fish_speech.text import g2p
 from fish_speech.utils.file import AUDIO_EXTENSIONS, list_files
 
@@ -25,87 +24,83 @@ DATASETS = [
 ]
 
 
-@dataclass
-class Sentence:
-    text: str
-    phones: list[str]
-    # Support multiple codebooks
-    semantics: Union[list[int], list[list[int]]]
-
-
-@dataclass
-class PackedSentences:
-    source: str
-    name: str
-    languages: list[str]
-    sentences: list[Sentence]
-
-
-dataset_fp = open("data/quantized-dataset-1205.json", "w")
-
-for root, source, languages, extension, parent_level in DATASETS:
-    # Load the files
-    exts = extension.split(".")
-    files = list_files(root, AUDIO_EXTENSIONS, recursive=True)
-    logger.info(f"Found {len(files)} files in {root}")
-
-    grouped_files = defaultdict(list)
-    for file in files:
-        if parent_level == 1:
-            p = file.parent.name
-        elif parent_level == 2:
-            p = file.parent.parent.name
-        else:
-            raise ValueError(f"Invalid parent level {parent_level}")
-
-        grouped_files[p].append(file)
-
-    for name, subset in tqdm(grouped_files.items()):
-        # Parse the files
-        sentences = []
-        for file in subset:
-            np_file = file.with_suffix(".npy")
-            txt_file = file.with_suffix(extension)
-            if np_file.exists() is False or txt_file.exists() is False:
-                continue
-
-            with open(txt_file, "r") as f:
-                text = f.read().strip()
-
-            # Simple cleaning: replace { xxx } and < xxx > with space
-            text = re.sub(r"\{.*?\}", " ", text)
-            text = re.sub(r"<.*?>", " ", text)
-            text = re.sub(r"\s+", " ", text)
-
-            try:
-                phones = [v for _, v in g2p(text, order=languages)]
-                semantics = np.load(np_file)
-            except Exception as e:
-                logger.error(f"Failed to parse {file}: {e}")
-                continue
-
-            if isinstance(semantics, np.ndarray):
-                semantics = semantics.tolist()
-
-            sentences.append(
-                Sentence(
-                    text=text,
-                    phones=phones,
-                    semantics=semantics,
-                )
+def task_generator():
+    for root, source, languages, extension, parent_level in DATASETS:
+        # Load the files
+        files = list_files(root, AUDIO_EXTENSIONS, recursive=True)
+
+        grouped_files = defaultdict(list)
+        for file in files:
+            if parent_level == 1:
+                p = file.parent.name
+            elif parent_level == 2:
+                p = file.parent.parent.name
+            else:
+                raise ValueError(f"Invalid parent level {parent_level}")
+
+            grouped_files[p].append(file)
+
+        logger.info(f"Found {len(grouped_files)} groups in {root}")
+        for name, subset in grouped_files.items():
+            yield name, subset, source, languages, extension
+
+
+def run_task(task):
+    name, subset, source, languages, extension = task
+
+    # Parse the files
+    sentences = []
+    for file in subset:
+        np_file = file.with_suffix(".npy")
+        txt_file = file.with_suffix(extension)
+        if np_file.exists() is False or txt_file.exists() is False:
+            continue
+
+        with open(txt_file, "r") as f:
+            text = f.read().strip()
+
+        # Simple cleaning: replace { xxx } and < xxx > with space
+        text = re.sub(r"\{.*?\}", " ", text)
+        text = re.sub(r"<.*?>", " ", text)
+        text = re.sub(r"\s+", " ", text)
+
+        try:
+            phones = [v for _, v in g2p(text, order=languages)]
+            semantics = np.load(np_file)
+        except Exception as e:
+            logger.error(f"Failed to parse {file}: {e}")
+            continue
+
+        if isinstance(semantics, np.ndarray):
+            semantics = semantics.tolist()
+
+        sentences.append(
+            Sentence(
+                text=text,
+                phones=phones,
+                semantics=[Semantics(values=s) for s in semantics],
             )
+        )
 
-        # Pack the sentences
-        packed_sentences = PackedSentences(
+    # Pack the sentences
+    return pack_pb_stream(
+        TextData(
             source=source,
             name=name,
             languages=languages,
             sentences=sentences,
         )
+    )
 
-        dataset_fp.write(
-            json.dumps(asdict(packed_sentences), ensure_ascii=False) + "\n"
-        )
+
+def main():
+    dataset_fp = open("data/quantized-dataset-1205.protos", "wb")
+    with Pool(16) as p:
+        for result in tqdm(p.imap_unordered(run_task, task_generator())):
+            dataset_fp.write(result)
+
+    dataset_fp.close()
 
 
-dataset_fp.close()
+if __name__ == "__main__":
+    main()