Explorar el Código

Return error for empty input

Lengyue hace 1 año
padre
commit
9a18199a3e

+ 52 - 0
fish_speech/datasets/concat_repeat.py

@@ -0,0 +1,52 @@
+import bisect
+from typing import Iterable
+
+from torch.utils.data import Dataset, IterableDataset
+
+
+class ConcatRepeatDataset(Dataset):
+    datasets: list[Dataset]
+    cumulative_sizes: list[int]
+    repeats: list[int]
+
+    @staticmethod
+    def cumsum(sequence, repeats):
+        r, s = [], 0
+        for dataset, repeat in zip(sequence, repeats):
+            l = len(dataset) * repeat
+            r.append(l + s)
+            s += l
+        return r
+
+    def __init__(self, datasets: Iterable[Dataset], repeats: list[int]):
+        super().__init__()
+
+        self.datasets = list(datasets)
+        self.repeats = repeats
+
+        assert len(self.datasets) > 0, "datasets should not be an empty iterable"
+        assert len(self.datasets) == len(
+            repeats
+        ), "datasets and repeats should have the same length"
+
+        for d in self.datasets:
+            assert not isinstance(
+                d, IterableDataset
+            ), "ConcatDataset does not support IterableDataset"
+
+        self.cumulative_sizes = self.cumsum(self.datasets, self.repeats)
+
+    def __len__(self):
+        return self.cumulative_sizes[-1]
+
+    def __getitem__(self, idx):
+        dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx)
+
+        if dataset_idx == 0:
+            sample_idx = idx
+        else:
+            sample_idx = idx - self.cumulative_sizes[dataset_idx - 1]
+
+        dataset = self.datasets[dataset_idx]
+
+        return dataset[sample_idx % len(dataset)]

+ 1 - 0
fish_speech/i18n/locale/en_US.json

@@ -49,6 +49,7 @@
     "Model Size": "Model Size",
     "Move": "Move",
     "Move files successfully": "Move files successfully",
+    "No audio generated, please check the input text.": "No audio generated, please check the input text.",
     "No selected options": "No selected options",
     "Number of Workers": "Number of Workers",
     "Open Inference Server": "Open Inference Server",

+ 1 - 0
fish_speech/i18n/locale/es_ES.json

@@ -49,6 +49,7 @@
     "Model Size": "Tamaño del Modelo",
     "Move": "Mover",
     "Move files successfully": "Archivos movidos exitosamente",
+    "No audio generated, please check the input text.": "No se generó audio, por favor verifique el texto de entrada.",
     "No selected options": "No hay opciones seleccionadas",
     "Number of Workers": "Número de Trabajadores",
     "Open Inference Server": "Abrir Servidor de Inferencia",

+ 1 - 0
fish_speech/i18n/locale/ja_JP.json

@@ -49,6 +49,7 @@
     "Model Size": "モデルサイズ",
     "Move": "移動",
     "Move files successfully": "ファイルの移動に成功しました",
+    "No audio generated, please check the input text.": "オーディオが生成されていません。入力テキストを確認してください。",
     "No selected options": "選択されたオプションはありません",
     "Number of Workers": "ワーカー数",
     "Open Inference Server": "推論サーバーを開く",

+ 1 - 0
fish_speech/i18n/locale/zh_CN.json

@@ -49,6 +49,7 @@
     "Model Size": "模型规模",
     "Move": "移动",
     "Move files successfully": "移动文件成功",
+    "No audio generated, please check the input text.": "没有生成音频,请检查输入文本.",
     "No selected options": "没有选择的选项",
     "Number of Workers": "数据加载进程数",
     "Open Inference Server": "打开推理服务器",

+ 6 - 1
tools/api.py

@@ -162,7 +162,12 @@ def inference(req: InvokeRequest):
         else:
             segments.append(fake_audios)
 
-    if req.streaming is False:
+    if len(segments) == 0:
+        raise HTTPException(
+            HTTPStatus.INTERNAL_SERVER_ERROR,
+            content="No audio generated, please check the input text.",
+        )
+    elif req.streaming is False:
         fake_audios = np.concatenate(segments, axis=0)
         yield fake_audios
 

+ 5 - 1
tools/webui.py

@@ -155,7 +155,11 @@ def inference(
         else:
             segments.append(fake_audios)
 
-    if streaming is False:
+    if len(segments) == 0:
+        yield None, build_html_error_message(
+            i18n("No audio generated, please check the input text.")
+        )
+    elif streaming is False:
         audio = np.concatenate(segments, axis=0)
         yield (vqgan_model.sampling_rate, audio), None