Lengyue 2 лет назад
Родитель
Сommit
0384982870

+ 1 - 11
fish_speech/datasets/protos/text-data.proto

@@ -7,15 +7,13 @@ message Semantics {
 }
 
 message Sentence {
-    string text = 1;
-    repeated string phones = 2;
+    repeated string texts = 1;
     repeated Semantics semantics = 3;
 }
 
 message TextData {
     string source = 1;
     string name = 2;
-    repeated string languages = 3;
     repeated Sentence sentences = 4;
 }
 
@@ -24,11 +22,3 @@ message SampledData {
     string name = 2;
     repeated Sentence samples = 3;
 }
-
-message SampleDataRequest {
-    uint32 num_samples = 1;
-}
-
-service DataService {
-    rpc SampleData (SampleDataRequest) returns (SampledData) {}
-}

+ 325 - 25
fish_speech/datasets/protos/text_data_pb2.py

@@ -1,37 +1,337 @@
 # -*- coding: utf-8 -*-
 # Generated by the protocol buffer compiler.  DO NOT EDIT!
 # source: text-data.proto
-# Protobuf Python Version: 4.25.0
-"""Generated protocol buffer code."""
+
 from google.protobuf import descriptor as _descriptor
-from google.protobuf import descriptor_pool as _descriptor_pool
+from google.protobuf import message as _message
+from google.protobuf import reflection as _reflection
 from google.protobuf import symbol_database as _symbol_database
-from google.protobuf.internal import builder as _builder
 
 # @@protoc_insertion_point(imports)
 
 _sym_db = _symbol_database.Default()
 
 
-DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(
-    b'\n\x0ftext-data.proto\x12\ttext_data"\x1b\n\tSemantics\x12\x0e\n\x06values\x18\x01 \x03(\r"Q\n\x08Sentence\x12\x0c\n\x04text\x18\x01 \x01(\t\x12\x0e\n\x06phones\x18\x02 \x03(\t\x12\'\n\tsemantics\x18\x03 \x03(\x0b\x32\x14.text_data.Semantics"c\n\x08TextData\x12\x0e\n\x06source\x18\x01 \x01(\t\x12\x0c\n\x04name\x18\x02 \x01(\t\x12\x11\n\tlanguages\x18\x03 \x03(\t\x12&\n\tsentences\x18\x04 \x03(\x0b\x32\x13.text_data.Sentence"Q\n\x0bSampledData\x12\x0e\n\x06source\x18\x01 \x01(\t\x12\x0c\n\x04name\x18\x02 \x01(\t\x12$\n\x07samples\x18\x03 \x03(\x0b\x32\x13.text_data.Sentence"(\n\x11SampleDataRequest\x12\x13\n\x0bnum_samples\x18\x01 \x01(\r2S\n\x0b\x44\x61taService\x12\x44\n\nSampleData\x12\x1c.text_data.SampleDataRequest\x1a\x16.text_data.SampledData"\x00\x62\x06proto3'
-)
-
-_globals = globals()
-_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
-_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, "text_data_pb2", _globals)
-if _descriptor._USE_C_DESCRIPTORS == False:
-    DESCRIPTOR._options = None
-    _globals["_SEMANTICS"]._serialized_start = 30
-    _globals["_SEMANTICS"]._serialized_end = 57
-    _globals["_SENTENCE"]._serialized_start = 59
-    _globals["_SENTENCE"]._serialized_end = 140
-    _globals["_TEXTDATA"]._serialized_start = 142
-    _globals["_TEXTDATA"]._serialized_end = 241
-    _globals["_SAMPLEDDATA"]._serialized_start = 243
-    _globals["_SAMPLEDDATA"]._serialized_end = 324
-    _globals["_SAMPLEDATAREQUEST"]._serialized_start = 326
-    _globals["_SAMPLEDATAREQUEST"]._serialized_end = 366
-    _globals["_DATASERVICE"]._serialized_start = 368
-    _globals["_DATASERVICE"]._serialized_end = 451
+DESCRIPTOR = _descriptor.FileDescriptor(
+    name="text-data.proto",
+    package="text_data",
+    syntax="proto3",
+    serialized_options=None,
+    create_key=_descriptor._internal_create_key,
+    serialized_pb=b'\n\x0ftext-data.proto\x12\ttext_data"\x1b\n\tSemantics\x12\x0e\n\x06values\x18\x01 \x03(\r"B\n\x08Sentence\x12\r\n\x05texts\x18\x01 \x03(\t\x12\'\n\tsemantics\x18\x03 \x03(\x0b\x32\x14.text_data.Semantics"P\n\x08TextData\x12\x0e\n\x06source\x18\x01 \x01(\t\x12\x0c\n\x04name\x18\x02 \x01(\t\x12&\n\tsentences\x18\x04 \x03(\x0b\x32\x13.text_data.Sentence"Q\n\x0bSampledData\x12\x0e\n\x06source\x18\x01 \x01(\t\x12\x0c\n\x04name\x18\x02 \x01(\t\x12$\n\x07samples\x18\x03 \x03(\x0b\x32\x13.text_data.Sentenceb\x06proto3',
+)
+
+
+_SEMANTICS = _descriptor.Descriptor(
+    name="Semantics",
+    full_name="text_data.Semantics",
+    filename=None,
+    file=DESCRIPTOR,
+    containing_type=None,
+    create_key=_descriptor._internal_create_key,
+    fields=[
+        _descriptor.FieldDescriptor(
+            name="values",
+            full_name="text_data.Semantics.values",
+            index=0,
+            number=1,
+            type=13,
+            cpp_type=3,
+            label=3,
+            has_default_value=False,
+            default_value=[],
+            message_type=None,
+            enum_type=None,
+            containing_type=None,
+            is_extension=False,
+            extension_scope=None,
+            serialized_options=None,
+            file=DESCRIPTOR,
+            create_key=_descriptor._internal_create_key,
+        ),
+    ],
+    extensions=[],
+    nested_types=[],
+    enum_types=[],
+    serialized_options=None,
+    is_extendable=False,
+    syntax="proto3",
+    extension_ranges=[],
+    oneofs=[],
+    serialized_start=30,
+    serialized_end=57,
+)
+
+
+_SENTENCE = _descriptor.Descriptor(
+    name="Sentence",
+    full_name="text_data.Sentence",
+    filename=None,
+    file=DESCRIPTOR,
+    containing_type=None,
+    create_key=_descriptor._internal_create_key,
+    fields=[
+        _descriptor.FieldDescriptor(
+            name="texts",
+            full_name="text_data.Sentence.texts",
+            index=0,
+            number=1,
+            type=9,
+            cpp_type=9,
+            label=3,
+            has_default_value=False,
+            default_value=[],
+            message_type=None,
+            enum_type=None,
+            containing_type=None,
+            is_extension=False,
+            extension_scope=None,
+            serialized_options=None,
+            file=DESCRIPTOR,
+            create_key=_descriptor._internal_create_key,
+        ),
+        _descriptor.FieldDescriptor(
+            name="semantics",
+            full_name="text_data.Sentence.semantics",
+            index=1,
+            number=3,
+            type=11,
+            cpp_type=10,
+            label=3,
+            has_default_value=False,
+            default_value=[],
+            message_type=None,
+            enum_type=None,
+            containing_type=None,
+            is_extension=False,
+            extension_scope=None,
+            serialized_options=None,
+            file=DESCRIPTOR,
+            create_key=_descriptor._internal_create_key,
+        ),
+    ],
+    extensions=[],
+    nested_types=[],
+    enum_types=[],
+    serialized_options=None,
+    is_extendable=False,
+    syntax="proto3",
+    extension_ranges=[],
+    oneofs=[],
+    serialized_start=59,
+    serialized_end=125,
+)
+
+
+_TEXTDATA = _descriptor.Descriptor(
+    name="TextData",
+    full_name="text_data.TextData",
+    filename=None,
+    file=DESCRIPTOR,
+    containing_type=None,
+    create_key=_descriptor._internal_create_key,
+    fields=[
+        _descriptor.FieldDescriptor(
+            name="source",
+            full_name="text_data.TextData.source",
+            index=0,
+            number=1,
+            type=9,
+            cpp_type=9,
+            label=1,
+            has_default_value=False,
+            default_value=b"".decode("utf-8"),
+            message_type=None,
+            enum_type=None,
+            containing_type=None,
+            is_extension=False,
+            extension_scope=None,
+            serialized_options=None,
+            file=DESCRIPTOR,
+            create_key=_descriptor._internal_create_key,
+        ),
+        _descriptor.FieldDescriptor(
+            name="name",
+            full_name="text_data.TextData.name",
+            index=1,
+            number=2,
+            type=9,
+            cpp_type=9,
+            label=1,
+            has_default_value=False,
+            default_value=b"".decode("utf-8"),
+            message_type=None,
+            enum_type=None,
+            containing_type=None,
+            is_extension=False,
+            extension_scope=None,
+            serialized_options=None,
+            file=DESCRIPTOR,
+            create_key=_descriptor._internal_create_key,
+        ),
+        _descriptor.FieldDescriptor(
+            name="sentences",
+            full_name="text_data.TextData.sentences",
+            index=2,
+            number=4,
+            type=11,
+            cpp_type=10,
+            label=3,
+            has_default_value=False,
+            default_value=[],
+            message_type=None,
+            enum_type=None,
+            containing_type=None,
+            is_extension=False,
+            extension_scope=None,
+            serialized_options=None,
+            file=DESCRIPTOR,
+            create_key=_descriptor._internal_create_key,
+        ),
+    ],
+    extensions=[],
+    nested_types=[],
+    enum_types=[],
+    serialized_options=None,
+    is_extendable=False,
+    syntax="proto3",
+    extension_ranges=[],
+    oneofs=[],
+    serialized_start=127,
+    serialized_end=207,
+)
+
+
+_SAMPLEDDATA = _descriptor.Descriptor(
+    name="SampledData",
+    full_name="text_data.SampledData",
+    filename=None,
+    file=DESCRIPTOR,
+    containing_type=None,
+    create_key=_descriptor._internal_create_key,
+    fields=[
+        _descriptor.FieldDescriptor(
+            name="source",
+            full_name="text_data.SampledData.source",
+            index=0,
+            number=1,
+            type=9,
+            cpp_type=9,
+            label=1,
+            has_default_value=False,
+            default_value=b"".decode("utf-8"),
+            message_type=None,
+            enum_type=None,
+            containing_type=None,
+            is_extension=False,
+            extension_scope=None,
+            serialized_options=None,
+            file=DESCRIPTOR,
+            create_key=_descriptor._internal_create_key,
+        ),
+        _descriptor.FieldDescriptor(
+            name="name",
+            full_name="text_data.SampledData.name",
+            index=1,
+            number=2,
+            type=9,
+            cpp_type=9,
+            label=1,
+            has_default_value=False,
+            default_value=b"".decode("utf-8"),
+            message_type=None,
+            enum_type=None,
+            containing_type=None,
+            is_extension=False,
+            extension_scope=None,
+            serialized_options=None,
+            file=DESCRIPTOR,
+            create_key=_descriptor._internal_create_key,
+        ),
+        _descriptor.FieldDescriptor(
+            name="samples",
+            full_name="text_data.SampledData.samples",
+            index=2,
+            number=3,
+            type=11,
+            cpp_type=10,
+            label=3,
+            has_default_value=False,
+            default_value=[],
+            message_type=None,
+            enum_type=None,
+            containing_type=None,
+            is_extension=False,
+            extension_scope=None,
+            serialized_options=None,
+            file=DESCRIPTOR,
+            create_key=_descriptor._internal_create_key,
+        ),
+    ],
+    extensions=[],
+    nested_types=[],
+    enum_types=[],
+    serialized_options=None,
+    is_extendable=False,
+    syntax="proto3",
+    extension_ranges=[],
+    oneofs=[],
+    serialized_start=209,
+    serialized_end=290,
+)
+
+_SENTENCE.fields_by_name["semantics"].message_type = _SEMANTICS
+_TEXTDATA.fields_by_name["sentences"].message_type = _SENTENCE
+_SAMPLEDDATA.fields_by_name["samples"].message_type = _SENTENCE
+DESCRIPTOR.message_types_by_name["Semantics"] = _SEMANTICS
+DESCRIPTOR.message_types_by_name["Sentence"] = _SENTENCE
+DESCRIPTOR.message_types_by_name["TextData"] = _TEXTDATA
+DESCRIPTOR.message_types_by_name["SampledData"] = _SAMPLEDDATA
+_sym_db.RegisterFileDescriptor(DESCRIPTOR)
+
+Semantics = _reflection.GeneratedProtocolMessageType(
+    "Semantics",
+    (_message.Message,),
+    {
+        "DESCRIPTOR": _SEMANTICS,
+        "__module__": "text_data_pb2"
+        # @@protoc_insertion_point(class_scope:text_data.Semantics)
+    },
+)
+_sym_db.RegisterMessage(Semantics)
+
+Sentence = _reflection.GeneratedProtocolMessageType(
+    "Sentence",
+    (_message.Message,),
+    {
+        "DESCRIPTOR": _SENTENCE,
+        "__module__": "text_data_pb2"
+        # @@protoc_insertion_point(class_scope:text_data.Sentence)
+    },
+)
+_sym_db.RegisterMessage(Sentence)
+
+TextData = _reflection.GeneratedProtocolMessageType(
+    "TextData",
+    (_message.Message,),
+    {
+        "DESCRIPTOR": _TEXTDATA,
+        "__module__": "text_data_pb2"
+        # @@protoc_insertion_point(class_scope:text_data.TextData)
+    },
+)
+_sym_db.RegisterMessage(TextData)
+
+SampledData = _reflection.GeneratedProtocolMessageType(
+    "SampledData",
+    (_message.Message,),
+    {
+        "DESCRIPTOR": _SAMPLEDDATA,
+        "__module__": "text_data_pb2"
+        # @@protoc_insertion_point(class_scope:text_data.SampledData)
+    },
+)
+_sym_db.RegisterMessage(SampledData)
+
+
 # @@protoc_insertion_point(module_scope)

+ 0 - 79
fish_speech/datasets/protos/text_data_pb2_grpc.py

@@ -1,79 +0,0 @@
-# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT!
-"""Client and server classes corresponding to protobuf-defined services."""
-import grpc
-
-import fish_speech.datasets.protos.text_data_pb2 as text__data__pb2
-
-
-class DataServiceStub(object):
-    """Missing associated documentation comment in .proto file."""
-
-    def __init__(self, channel):
-        """Constructor.
-
-        Args:
-            channel: A grpc.Channel.
-        """
-        self.SampleData = channel.unary_unary(
-            "/text_data.DataService/SampleData",
-            request_serializer=text__data__pb2.SampleDataRequest.SerializeToString,
-            response_deserializer=text__data__pb2.SampledData.FromString,
-        )
-
-
-class DataServiceServicer(object):
-    """Missing associated documentation comment in .proto file."""
-
-    def SampleData(self, request, context):
-        """Missing associated documentation comment in .proto file."""
-        context.set_code(grpc.StatusCode.UNIMPLEMENTED)
-        context.set_details("Method not implemented!")
-        raise NotImplementedError("Method not implemented!")
-
-
-def add_DataServiceServicer_to_server(servicer, server):
-    rpc_method_handlers = {
-        "SampleData": grpc.unary_unary_rpc_method_handler(
-            servicer.SampleData,
-            request_deserializer=text__data__pb2.SampleDataRequest.FromString,
-            response_serializer=text__data__pb2.SampledData.SerializeToString,
-        ),
-    }
-    generic_handler = grpc.method_handlers_generic_handler(
-        "text_data.DataService", rpc_method_handlers
-    )
-    server.add_generic_rpc_handlers((generic_handler,))
-
-
-# This class is part of an EXPERIMENTAL API.
-class DataService(object):
-    """Missing associated documentation comment in .proto file."""
-
-    @staticmethod
-    def SampleData(
-        request,
-        target,
-        options=(),
-        channel_credentials=None,
-        call_credentials=None,
-        insecure=False,
-        compression=None,
-        wait_for_ready=None,
-        timeout=None,
-        metadata=None,
-    ):
-        return grpc.experimental.unary_unary(
-            request,
-            target,
-            "/text_data.DataService/SampleData",
-            text__data__pb2.SampleDataRequest.SerializeToString,
-            text__data__pb2.SampledData.FromString,
-            options,
-            channel_credentials,
-            insecure,
-            call_credentials,
-            compression,
-            wait_for_ready,
-            timeout,
-            metadata,
-        )

+ 2 - 4
fish_speech/datasets/text.py

@@ -163,7 +163,6 @@ class AutoAugTextDataset(IterableDataset):
 
     1. Random concatenate multiple sentences from the same speaker to form a longer sentence
     2. Automatically normalize the text
-    3. Mix text and phones
 
     For interactive mode, we use the following format (multiple sequences):
     <s> [INST] [SPK: speaker] text [/INST] ... [INST] text [/INST] </s>
@@ -176,7 +175,6 @@ class AutoAugTextDataset(IterableDataset):
         self,
         proto_files: list[str],
         seed: int = 42,
-        phones_prob: float = 0.3,
         interactive_prob: float = 0.5,
         max_length: int = 1024,
         tokenizer: AutoTokenizer = None,
@@ -203,7 +201,6 @@ class AutoAugTextDataset(IterableDataset):
         assert 0 <= interactive_prob <= 1, "interactive_prob must be in [0, 1]"
 
         self.seed = seed
-        self.phones_prob = phones_prob
         self.max_length = max_length
         self.tokenizer = tokenizer
         self.interactive_prob = interactive_prob
@@ -324,7 +321,8 @@ class AutoAugTextDataset(IterableDataset):
         while remaining_tokens > 0 and len(samples) > 0:
             sentence = samples.pop()
 
-            text, length = self.tokenize_sentence(sentence.text)
+            text = random.choice(sentence.texts)
+            text, length = self.tokenize_sentence(text)
             remaining_tokens -= length + len(sentence.semantics[0].values)
 
             if use_interactive is False:

+ 25 - 15
tools/llama/build_dataset.py

@@ -27,7 +27,17 @@ def task_generator_folder(root: Path, text_extension: str):
     grouped_files = defaultdict(list)
     for file in tqdm(files, desc=f"Grouping {root}"):
         p = str(file.parent)
-        grouped_files[p].append((file, file.with_suffix(text_extension).read_text()))
+
+        try:
+            if isinstance(text_extension, str):
+                texts = [file.with_suffix(text_extension).read_text()]
+            else:
+                texts = [file.with_suffix(ext).read_text() for ext in text_extension]
+        except Exception as e:
+            logger.error(f"Failed to read text {file}: {e}")
+            continue
+
+        grouped_files[p].append((file, texts))
 
     logger.info(
         f"Found {len(grouped_files)} groups in {root}, {list(grouped_files.keys())[:5]}..."
@@ -39,31 +49,34 @@ def task_generator_folder(root: Path, text_extension: str):
 def task_generator_filelist(filelist):
     grouped_files = defaultdict(list)
     for filename, speaker, _, text in load_filelist(filelist):
-        grouped_files[speaker].append((Path(filename), text))
+        grouped_files[speaker].append((Path(filename), [text]))
 
     logger.info(f"Found {len(grouped_files)} groups in {filelist}")
     for speaker, values in grouped_files.items():
         yield speaker, values, "filelist"
 
 
-def run_task(task, output):
-    group_idx, task = task
+def run_task(task):
     name, subset, source = task
 
     # Parse the files
     sentences = []
     for file in subset:
-        file, text = file
+        file, texts = file
 
         np_file = file.with_suffix(".npy")
         if np_file.exists() is False:
             logger.warning(f"Can't find {np_file}")
             continue
 
-        # Simple cleaning: replace { xxx } and < xxx > with space
-        text = re.sub(r"\{.*?\}", " ", text)
-        text = re.sub(r"<.*?>", " ", text)
-        text = re.sub(r"\s+", " ", text)
+        new_texts = []
+
+        for text in texts:
+            # Simple cleaning: replace { xxx } and < xxx > with space
+            text = re.sub(r"\{.*?\}", " ", text)
+            text = re.sub(r"<.*?>", " ", text)
+            text = re.sub(r"\s+", " ", text)
+            new_texts.append(text)
 
         try:
             semantics = np.load(np_file)
@@ -76,8 +89,7 @@ def run_task(task, output):
 
         sentences.append(
             Sentence(
-                text=text,
-                phones=[],
+                texts=new_texts,
                 semantics=[Semantics(values=s) for s in semantics],
             )
         )
@@ -87,7 +99,6 @@ def run_task(task, output):
         TextData(
             source=source,
             name=name,
-            languages=[],
             sentences=sentences,
         )
     )
@@ -105,7 +116,7 @@ def run_task(task, output):
     "--output", type=click.Path(path_type=Path), default="data/quantized-dataset-ft"
 )
 @click.option("--num-workers", type=int, default=16)
-@click.option("--text-extension", type=str, default=".txt")
+@click.option("--text-extension", type=str, default=[".txt"], multiple=True)
 @click.option(
     "--shard-size", type=int, default=10, help="The maximum size of each shard in mb"
 )
@@ -123,7 +134,6 @@ def main(input, output, num_workers, text_extension, shard_size):
         generator_fns.append(generator_fn)
 
     generator_fn = itertools.chain(*generator_fns)
-    run_task_p = partial(run_task, output=output)
     output.mkdir(parents=True, exist_ok=True)
 
     dataset_fp = None
@@ -131,7 +141,7 @@ def main(input, output, num_workers, text_extension, shard_size):
     written_size = 0
 
     with Pool(num_workers) as p:
-        for result in tqdm(p.imap_unordered(run_task_p, enumerate(generator_fn))):
+        for result in tqdm(p.imap_unordered(run_task, generator_fn)):
             if dataset_fp is None:
                 dataset_fp = open(Path(output) / f"{tar_idx:08d}.protos", "wb")
 

+ 1 - 1
tools/vqgan/inference.py

@@ -33,7 +33,7 @@ OmegaConf.register_new_resolver("eval", eval)
 @click.option(
     "--checkpoint-path",
     "-ckpt",
-    default="checkpoints/vq-gan-group-fsq-8x1024-wn-20x768-30kh.pth",
+    default="checkpoints/vq-gan-group-fsq-2x1024.pth",
 )
 def main(input_path, output_path, config_name, checkpoint_path):
     with initialize(version_base="1.3", config_path="../../fish_speech/configs"):