Explorar el Código

Init webui & mix dataset & optimize spectrogram extractor

Lengyue hace 2 años
padre
commit
c693d63e57

+ 8 - 3
fish_speech/datasets/vqgan.py

@@ -40,9 +40,6 @@ class VQGANDataset(Dataset):
         audio, _ = librosa.load(file, sr=self.sample_rate, mono=True)
         features = np.load(file.with_suffix(".npy"))  # (T, 1024)
 
-        if len(audio) % self.hop_length != 0:
-            audio = np.pad(audio, (0, self.hop_length - (len(audio) % self.hop_length)))
-
         # Slice audio and features
         if self.slice_frames is not None and features.shape[0] > self.slice_frames:
             start = np.random.randint(0, features.shape[0] - self.slice_frames)
@@ -51,6 +48,14 @@ class VQGANDataset(Dataset):
                 start * self.hop_length : (start + self.slice_frames) * self.hop_length
             ]
 
+        if len(audio) < len(features) * self.hop_length:
+            audio = np.pad(
+                audio,
+                (0, len(features) * self.hop_length - len(audio)),
+                mode="constant",
+                constant_values=0,
+            )
+
         return {
             "audio": torch.from_numpy(audio),
             "features": torch.from_numpy(features),

+ 6 - 3
fish_speech/models/vqgan/spectrogram.py

@@ -96,9 +96,12 @@ class LogMelSpectrogram(nn.Module):
     def decompress(self, x: Tensor) -> Tensor:
         return torch.exp(x)
 
-    def forward(self, x: Tensor) -> Tensor:
-        x = self.spectrogram(x)
-        x = self.mel_scale(x)
+    def forward(self, x: Tensor, return_linear: bool = False) -> Tensor:
+        linear = self.spectrogram(x)
+        x = self.mel_scale(linear)
         x = self.compress(x)
 
+        if return_linear:
+            return x, self.compress(linear)
+
         return x

+ 4 - 0
fish_speech/webui/__main__.py

@@ -0,0 +1,4 @@
+from fish_speech.webui.app import app
+
+if __name__ == "__main__":
+    app.launch(show_api=False)

+ 63 - 0
fish_speech/webui/app.py

@@ -0,0 +1,63 @@
+import gradio as gr
+
+HEADER_MD = """
+# Fish Speech
+
+基于 VITS 和 GPT 的多语种语音合成. 项目很大程度上基于 Rcell 的 GPT-VITS.
+"""
+
+with gr.Blocks(theme=gr.themes.Base()) as app:
+    gr.Markdown(HEADER_MD)
+
+    with gr.Row():
+        with gr.Column(scale=5):
+            text = gr.Textbox(lines=5, label="输入文本")
+
+            with gr.Row():
+                with gr.Tab(label="合成参数"):
+                    gr.Markdown("配置常见的合成参数.")
+
+                    input_mode = gr.Dropdown(
+                        choices=["手动输入音素/文本", "自动音素转换"],
+                        value="手动输入音素/文本",
+                        label="输入模式",
+                    )
+
+                with gr.Tab(label="语言优先级"):
+                    gr.Markdown("该参数只在自动音素转换时生效.")
+
+                    with gr.Column(scale=1):
+                        language0 = gr.Dropdown(
+                            choices=["中文", "日文", "英文", "无"],
+                            label="语言 1",
+                            value="中文",
+                        )
+
+                    with gr.Column(scale=1):
+                        language1 = gr.Dropdown(
+                            choices=["中文", "日文", "英文", "无"],
+                            label="语言 2",
+                            value="英文",
+                        )
+
+                    with gr.Column(scale=1):
+                        language2 = gr.Dropdown(
+                            choices=["中文", "日文", "英文", "无"],
+                            label="语言 3",
+                            value="无",
+                        )
+
+            with gr.Row():
+                with gr.Column(scale=2):
+                    generate = gr.Button(value="合成", variant="primary")
+                with gr.Column(scale=1):
+                    clear = gr.Button(value="清空")
+
+        with gr.Column(scale=3):
+            audio = gr.Audio(label="合成音频")
+
+    generate.click(lambda: None, [input_mode], [audio])
+    # dark_mode.link(lambda val: app.set_theme(gr.themes.Dark() if val else gr.themes.Default()))
+
+if __name__ == "__main__":
+    app.launch(show_api=False)

+ 1 - 0
pyproject.toml

@@ -25,6 +25,7 @@ dependencies = [
     "librosa>=0.10.1",
     "vector-quantize-pytorch>=1.9.18",
     "rich>=13.5.3",
+    "gradio>=4.0.0",
     "cn2an",
     "pypinyin",
     "jieba",