Explorar o código

Update configs

Lengyue %!s(int64=2) %!d(string=hai) anos
pai
achega
2df19066d7

+ 9 - 0
fish_speech/configs/model/dual_ar_2_codebook_large.yaml

@@ -0,0 +1,9 @@
+defaults:
+  - dual_ar_2_codebook_small
+  - _self_
+
+config:
+  n_layer: 30
+  n_fast_layer: 6
+  n_head: 24
+  dim: 1536

+ 1 - 1
fish_speech/configs/model/dual_ar_8_codebook_medium.yaml → fish_speech/configs/model/dual_ar_2_codebook_medium.yaml

@@ -1,5 +1,5 @@
 defaults:
-  - dual_ar_8_codebook_small
+  - dual_ar_2_codebook_small
   - _self_
 
 config:

+ 0 - 0
fish_speech/configs/model/dual_ar_8_codebook_small.yaml → fish_speech/configs/model/dual_ar_2_codebook_small.yaml


+ 3 - 3
fish_speech/configs/model/naive_8_codebook_small.yaml → fish_speech/configs/model/naive_2_codebook_small.yaml

@@ -1,6 +1,6 @@
 _target_: fish_speech.models.text2semantic.llama.NaiveTransformer
 config:
-  _target_: fish_speech.models.text2semantic.llama.DualARModelArgs
+  _target_: fish_speech.models.text2semantic.llama.NaiveModelArgs
   max_seq_len: ${max_length}
   vocab_size: 36408
   n_layer: 12
@@ -8,5 +8,5 @@ config:
   dim: 768
   rope_base: 10000
   norm_eps: 1e-5
-  num_codebooks: 8  # input/output codebook size
-  codebook_size: 264 # codebook size 256 + 2 special tokens
+  num_codebooks: 2  # input/output codebook size
+  codebook_size: 1032 # codebook size 1024 + 2 special tokens

+ 19 - 323
fish_speech/datasets/protos/text_data_pb2.py

@@ -1,337 +1,33 @@
 # -*- coding: utf-8 -*-
 # Generated by the protocol buffer compiler.  DO NOT EDIT!
 # source: text-data.proto
-
+# Protobuf Python Version: 4.25.1
+"""Generated protocol buffer code."""
 from google.protobuf import descriptor as _descriptor
-from google.protobuf import message as _message
-from google.protobuf import reflection as _reflection
+from google.protobuf import descriptor_pool as _descriptor_pool
 from google.protobuf import symbol_database as _symbol_database
+from google.protobuf.internal import builder as _builder
 
 # @@protoc_insertion_point(imports)
 
 _sym_db = _symbol_database.Default()
 
 
-DESCRIPTOR = _descriptor.FileDescriptor(
-    name="text-data.proto",
-    package="text_data",
-    syntax="proto3",
-    serialized_options=None,
-    create_key=_descriptor._internal_create_key,
-    serialized_pb=b'\n\x0ftext-data.proto\x12\ttext_data"\x1b\n\tSemantics\x12\x0e\n\x06values\x18\x01 \x03(\r"B\n\x08Sentence\x12\r\n\x05texts\x18\x01 \x03(\t\x12\'\n\tsemantics\x18\x03 \x03(\x0b\x32\x14.text_data.Semantics"P\n\x08TextData\x12\x0e\n\x06source\x18\x01 \x01(\t\x12\x0c\n\x04name\x18\x02 \x01(\t\x12&\n\tsentences\x18\x04 \x03(\x0b\x32\x13.text_data.Sentence"Q\n\x0bSampledData\x12\x0e\n\x06source\x18\x01 \x01(\t\x12\x0c\n\x04name\x18\x02 \x01(\t\x12$\n\x07samples\x18\x03 \x03(\x0b\x32\x13.text_data.Sentenceb\x06proto3',
-)
-
-
-_SEMANTICS = _descriptor.Descriptor(
-    name="Semantics",
-    full_name="text_data.Semantics",
-    filename=None,
-    file=DESCRIPTOR,
-    containing_type=None,
-    create_key=_descriptor._internal_create_key,
-    fields=[
-        _descriptor.FieldDescriptor(
-            name="values",
-            full_name="text_data.Semantics.values",
-            index=0,
-            number=1,
-            type=13,
-            cpp_type=3,
-            label=3,
-            has_default_value=False,
-            default_value=[],
-            message_type=None,
-            enum_type=None,
-            containing_type=None,
-            is_extension=False,
-            extension_scope=None,
-            serialized_options=None,
-            file=DESCRIPTOR,
-            create_key=_descriptor._internal_create_key,
-        ),
-    ],
-    extensions=[],
-    nested_types=[],
-    enum_types=[],
-    serialized_options=None,
-    is_extendable=False,
-    syntax="proto3",
-    extension_ranges=[],
-    oneofs=[],
-    serialized_start=30,
-    serialized_end=57,
-)
-
-
-_SENTENCE = _descriptor.Descriptor(
-    name="Sentence",
-    full_name="text_data.Sentence",
-    filename=None,
-    file=DESCRIPTOR,
-    containing_type=None,
-    create_key=_descriptor._internal_create_key,
-    fields=[
-        _descriptor.FieldDescriptor(
-            name="texts",
-            full_name="text_data.Sentence.texts",
-            index=0,
-            number=1,
-            type=9,
-            cpp_type=9,
-            label=3,
-            has_default_value=False,
-            default_value=[],
-            message_type=None,
-            enum_type=None,
-            containing_type=None,
-            is_extension=False,
-            extension_scope=None,
-            serialized_options=None,
-            file=DESCRIPTOR,
-            create_key=_descriptor._internal_create_key,
-        ),
-        _descriptor.FieldDescriptor(
-            name="semantics",
-            full_name="text_data.Sentence.semantics",
-            index=1,
-            number=3,
-            type=11,
-            cpp_type=10,
-            label=3,
-            has_default_value=False,
-            default_value=[],
-            message_type=None,
-            enum_type=None,
-            containing_type=None,
-            is_extension=False,
-            extension_scope=None,
-            serialized_options=None,
-            file=DESCRIPTOR,
-            create_key=_descriptor._internal_create_key,
-        ),
-    ],
-    extensions=[],
-    nested_types=[],
-    enum_types=[],
-    serialized_options=None,
-    is_extendable=False,
-    syntax="proto3",
-    extension_ranges=[],
-    oneofs=[],
-    serialized_start=59,
-    serialized_end=125,
-)
-
-
-_TEXTDATA = _descriptor.Descriptor(
-    name="TextData",
-    full_name="text_data.TextData",
-    filename=None,
-    file=DESCRIPTOR,
-    containing_type=None,
-    create_key=_descriptor._internal_create_key,
-    fields=[
-        _descriptor.FieldDescriptor(
-            name="source",
-            full_name="text_data.TextData.source",
-            index=0,
-            number=1,
-            type=9,
-            cpp_type=9,
-            label=1,
-            has_default_value=False,
-            default_value=b"".decode("utf-8"),
-            message_type=None,
-            enum_type=None,
-            containing_type=None,
-            is_extension=False,
-            extension_scope=None,
-            serialized_options=None,
-            file=DESCRIPTOR,
-            create_key=_descriptor._internal_create_key,
-        ),
-        _descriptor.FieldDescriptor(
-            name="name",
-            full_name="text_data.TextData.name",
-            index=1,
-            number=2,
-            type=9,
-            cpp_type=9,
-            label=1,
-            has_default_value=False,
-            default_value=b"".decode("utf-8"),
-            message_type=None,
-            enum_type=None,
-            containing_type=None,
-            is_extension=False,
-            extension_scope=None,
-            serialized_options=None,
-            file=DESCRIPTOR,
-            create_key=_descriptor._internal_create_key,
-        ),
-        _descriptor.FieldDescriptor(
-            name="sentences",
-            full_name="text_data.TextData.sentences",
-            index=2,
-            number=4,
-            type=11,
-            cpp_type=10,
-            label=3,
-            has_default_value=False,
-            default_value=[],
-            message_type=None,
-            enum_type=None,
-            containing_type=None,
-            is_extension=False,
-            extension_scope=None,
-            serialized_options=None,
-            file=DESCRIPTOR,
-            create_key=_descriptor._internal_create_key,
-        ),
-    ],
-    extensions=[],
-    nested_types=[],
-    enum_types=[],
-    serialized_options=None,
-    is_extendable=False,
-    syntax="proto3",
-    extension_ranges=[],
-    oneofs=[],
-    serialized_start=127,
-    serialized_end=207,
+DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(
+    b'\n\x0ftext-data.proto\x12\ttext_data"\x1b\n\tSemantics\x12\x0e\n\x06values\x18\x01 \x03(\r"B\n\x08Sentence\x12\r\n\x05texts\x18\x01 \x03(\t\x12\'\n\tsemantics\x18\x03 \x03(\x0b\x32\x14.text_data.Semantics"P\n\x08TextData\x12\x0e\n\x06source\x18\x01 \x01(\t\x12\x0c\n\x04name\x18\x02 \x01(\t\x12&\n\tsentences\x18\x04 \x03(\x0b\x32\x13.text_data.Sentence"Q\n\x0bSampledData\x12\x0e\n\x06source\x18\x01 \x01(\t\x12\x0c\n\x04name\x18\x02 \x01(\t\x12$\n\x07samples\x18\x03 \x03(\x0b\x32\x13.text_data.Sentenceb\x06proto3'
 )
 
-
-_SAMPLEDDATA = _descriptor.Descriptor(
-    name="SampledData",
-    full_name="text_data.SampledData",
-    filename=None,
-    file=DESCRIPTOR,
-    containing_type=None,
-    create_key=_descriptor._internal_create_key,
-    fields=[
-        _descriptor.FieldDescriptor(
-            name="source",
-            full_name="text_data.SampledData.source",
-            index=0,
-            number=1,
-            type=9,
-            cpp_type=9,
-            label=1,
-            has_default_value=False,
-            default_value=b"".decode("utf-8"),
-            message_type=None,
-            enum_type=None,
-            containing_type=None,
-            is_extension=False,
-            extension_scope=None,
-            serialized_options=None,
-            file=DESCRIPTOR,
-            create_key=_descriptor._internal_create_key,
-        ),
-        _descriptor.FieldDescriptor(
-            name="name",
-            full_name="text_data.SampledData.name",
-            index=1,
-            number=2,
-            type=9,
-            cpp_type=9,
-            label=1,
-            has_default_value=False,
-            default_value=b"".decode("utf-8"),
-            message_type=None,
-            enum_type=None,
-            containing_type=None,
-            is_extension=False,
-            extension_scope=None,
-            serialized_options=None,
-            file=DESCRIPTOR,
-            create_key=_descriptor._internal_create_key,
-        ),
-        _descriptor.FieldDescriptor(
-            name="samples",
-            full_name="text_data.SampledData.samples",
-            index=2,
-            number=3,
-            type=11,
-            cpp_type=10,
-            label=3,
-            has_default_value=False,
-            default_value=[],
-            message_type=None,
-            enum_type=None,
-            containing_type=None,
-            is_extension=False,
-            extension_scope=None,
-            serialized_options=None,
-            file=DESCRIPTOR,
-            create_key=_descriptor._internal_create_key,
-        ),
-    ],
-    extensions=[],
-    nested_types=[],
-    enum_types=[],
-    serialized_options=None,
-    is_extendable=False,
-    syntax="proto3",
-    extension_ranges=[],
-    oneofs=[],
-    serialized_start=209,
-    serialized_end=290,
-)
-
-_SENTENCE.fields_by_name["semantics"].message_type = _SEMANTICS
-_TEXTDATA.fields_by_name["sentences"].message_type = _SENTENCE
-_SAMPLEDDATA.fields_by_name["samples"].message_type = _SENTENCE
-DESCRIPTOR.message_types_by_name["Semantics"] = _SEMANTICS
-DESCRIPTOR.message_types_by_name["Sentence"] = _SENTENCE
-DESCRIPTOR.message_types_by_name["TextData"] = _TEXTDATA
-DESCRIPTOR.message_types_by_name["SampledData"] = _SAMPLEDDATA
-_sym_db.RegisterFileDescriptor(DESCRIPTOR)
-
-Semantics = _reflection.GeneratedProtocolMessageType(
-    "Semantics",
-    (_message.Message,),
-    {
-        "DESCRIPTOR": _SEMANTICS,
-        "__module__": "text_data_pb2"
-        # @@protoc_insertion_point(class_scope:text_data.Semantics)
-    },
-)
-_sym_db.RegisterMessage(Semantics)
-
-Sentence = _reflection.GeneratedProtocolMessageType(
-    "Sentence",
-    (_message.Message,),
-    {
-        "DESCRIPTOR": _SENTENCE,
-        "__module__": "text_data_pb2"
-        # @@protoc_insertion_point(class_scope:text_data.Sentence)
-    },
-)
-_sym_db.RegisterMessage(Sentence)
-
-TextData = _reflection.GeneratedProtocolMessageType(
-    "TextData",
-    (_message.Message,),
-    {
-        "DESCRIPTOR": _TEXTDATA,
-        "__module__": "text_data_pb2"
-        # @@protoc_insertion_point(class_scope:text_data.TextData)
-    },
-)
-_sym_db.RegisterMessage(TextData)
-
-SampledData = _reflection.GeneratedProtocolMessageType(
-    "SampledData",
-    (_message.Message,),
-    {
-        "DESCRIPTOR": _SAMPLEDDATA,
-        "__module__": "text_data_pb2"
-        # @@protoc_insertion_point(class_scope:text_data.SampledData)
-    },
-)
-_sym_db.RegisterMessage(SampledData)
-
-
+_globals = globals()
+_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
+_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, "text_data_pb2", _globals)
+if _descriptor._USE_C_DESCRIPTORS == False:
+    DESCRIPTOR._options = None
+    _globals["_SEMANTICS"]._serialized_start = 30
+    _globals["_SEMANTICS"]._serialized_end = 57
+    _globals["_SENTENCE"]._serialized_start = 59
+    _globals["_SENTENCE"]._serialized_end = 125
+    _globals["_TEXTDATA"]._serialized_start = 127
+    _globals["_TEXTDATA"]._serialized_end = 207
+    _globals["_SAMPLEDDATA"]._serialized_start = 209
+    _globals["_SAMPLEDDATA"]._serialized_end = 290
 # @@protoc_insertion_point(module_scope)