Ver Fonte

From whisper to sensevoice (#482)

* Fix win deps

* Fix win docs

* Fix indent = 4

* Fix zh docs indent

* Update whisper asr: no split

* Add sensevoice labeling pipeline

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* allow api accept any ref audio

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* update mel-band-roformer

* revert to bs-roformer

* Add option to save emo

* fsmnvad -> silerovad

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Update req

* max single seg time

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
spicysama há 1 ano atrás
pai
commit
8702c61100

+ 18 - 18
docs/zh/index.md

@@ -13,8 +13,8 @@
 </div>
 
 !!! warning
-   我们不对代码库的任何非法使用承担任何责任. 请参阅您当地关于 DMCA (数字千年法案) 和其他相关法律法规. <br/>
-   此代码库与所有模型根据 CC-BY-NC-SA-4.0 许可证发布.
+    我们不对代码库的任何非法使用承担任何责任. 请参阅您当地关于 DMCA (数字千年法案) 和其他相关法律法规. <br/>
+    此代码库与所有模型根据 CC-BY-NC-SA-4.0 许可证发布.
 
 <p align="center">
    <img src="../assets/figs/diagram.png" width="75%">
@@ -33,23 +33,23 @@ Windows 非专业用户可考虑以下为免 Linux 环境的基础运行方法
 
 1. 解压项目压缩包。
 2. 点击 `install_env.bat` 安装环境。
-   - 可以通过编辑 `install_env.bat` 的 `USE_MIRROR` 项来决定是否使用镜像站下载。
-   - `USE_MIRROR=false` 使用原始站下载最新稳定版 `torch` 环境。`USE_MIRROR=true` 为从镜像站下载最新 `torch` 环境。默认为 `true`。
-   - 可以通过编辑 `install_env.bat` 的 `INSTALL_TYPE` 项来决定是否启用可编译环境下载。
-   - `INSTALL_TYPE=preview` 下载开发版编译环境。`INSTALL_TYPE=stable` 下载稳定版不带编译环境。
+    - 可以通过编辑 `install_env.bat` 的 `USE_MIRROR` 项来决定是否使用镜像站下载。
+    - `USE_MIRROR=false` 使用原始站下载最新稳定版 `torch` 环境。`USE_MIRROR=true` 为从镜像站下载最新 `torch` 环境。默认为 `true`。
+    - 可以通过编辑 `install_env.bat` 的 `INSTALL_TYPE` 项来决定是否启用可编译环境下载。
+    - `INSTALL_TYPE=preview` 下载开发版编译环境。`INSTALL_TYPE=stable` 下载稳定版不带编译环境。
 3. 若第 2 步 `INSTALL_TYPE=preview` 则执行这一步(可跳过,此步为激活编译模型环境)
-   1. 使用如下链接下载 LLVM 编译器。
-      - [LLVM-17.0.6(原站站点下载)](https://huggingface.co/fishaudio/fish-speech-1/resolve/main/LLVM-17.0.6-win64.exe?download=true)
-      - [LLVM-17.0.6(镜像站点下载)](https://hf-mirror.com/fishaudio/fish-speech-1/resolve/main/LLVM-17.0.6-win64.exe?download=true)
-      - 下载完 `LLVM-17.0.6-win64.exe` 后,双击进行安装,选择合适的安装位置,最重要的是勾选 `Add Path to Current User` 添加环境变量。
-      - 确认安装完成。
-   2. 下载安装 Microsoft Visual C++ 可再发行程序包,解决潜在 .dll 丢失问题。
-      - [MSVC++ 14.40.33810.0 下载](https://aka.ms/vs/17/release/vc_redist.x64.exe)
-   3. 下载安装 Visual Studio 社区版以获取 MSVC++ 编译工具, 解决 LLVM 的头文件依赖问题。
-      - [Visual Studio 下载](https://visualstudio.microsoft.com/zh-hans/downloads/)
-      - 安装好 Visual Studio Installer 之后,下载 Visual Studio Community 2022
-      - 如下图点击`修改`按钮,找到`使用C++的桌面开发`项,勾选下载
-   4. 下载安装 [CUDA Toolkit 12](https://developer.nvidia.com/cuda-12-1-0-download-archive?target_os=Windows&target_arch=x86_64)
+    1. 使用如下链接下载 LLVM 编译器。
+        - [LLVM-17.0.6(原站站点下载)](https://huggingface.co/fishaudio/fish-speech-1/resolve/main/LLVM-17.0.6-win64.exe?download=true)
+        - [LLVM-17.0.6(镜像站点下载)](https://hf-mirror.com/fishaudio/fish-speech-1/resolve/main/LLVM-17.0.6-win64.exe?download=true)
+        - 下载完 `LLVM-17.0.6-win64.exe` 后,双击进行安装,选择合适的安装位置,最重要的是勾选 `Add Path to Current User` 添加环境变量。
+        - 确认安装完成。
+    2. 下载安装 Microsoft Visual C++ 可再发行程序包,解决潜在 .dll 丢失问题。
+        - [MSVC++ 14.40.33810.0 下载](https://aka.ms/vs/17/release/vc_redist.x64.exe)
+    3. 下载安装 Visual Studio 社区版以获取 MSVC++ 编译工具, 解决 LLVM 的头文件依赖问题。
+        - [Visual Studio 下载](https://visualstudio.microsoft.com/zh-hans/downloads/)
+        - 安装好 Visual Studio Installer 之后,下载 Visual Studio Community 2022
+        - 如下图点击`修改`按钮,找到`使用C++的桌面开发`项,勾选下载
+    4. 下载安装 [CUDA Toolkit 12](https://developer.nvidia.com/cuda-12-1-0-download-archive?target_os=Windows&target_arch=x86_64)
 4. 双击 `start.bat` 打开训练推理 WebUI 管理界面. 如有需要,可照下列提示修改`API_FLAGS`.
 
 !!! info "可选"

+ 0 - 103
fish_speech/utils/file.py

@@ -1,55 +1,5 @@
 import os
-from glob import glob
 from pathlib import Path
-from typing import Union
-
-from loguru import logger
-from natsort import natsorted
-
-AUDIO_EXTENSIONS = {
-    ".mp3",
-    ".wav",
-    ".flac",
-    ".ogg",
-    ".m4a",
-    ".wma",
-    ".aac",
-    ".aiff",
-    ".aif",
-    ".aifc",
-}
-
-
-def list_files(
-    path: Union[Path, str],
-    extensions: set[str] = None,
-    recursive: bool = False,
-    sort: bool = True,
-) -> list[Path]:
-    """List files in a directory.
-
-    Args:
-        path (Path): Path to the directory.
-        extensions (set, optional): Extensions to filter. Defaults to None.
-        recursive (bool, optional): Whether to search recursively. Defaults to False.
-        sort (bool, optional): Whether to sort the files. Defaults to True.
-
-    Returns:
-        list: List of files.
-    """
-
-    if isinstance(path, str):
-        path = Path(path)
-
-    if not path.exists():
-        raise FileNotFoundError(f"Directory {path} does not exist.")
-
-    files = [file for ext in extensions for file in path.rglob(f"*{ext}")]
-
-    if sort:
-        files = natsorted(files)
-
-    return files
 
 
 def get_latest_checkpoint(path: Path | str) -> Path | None:
@@ -64,56 +14,3 @@ def get_latest_checkpoint(path: Path | str) -> Path | None:
         return None
 
     return ckpts[-1]
-
-
-def load_filelist(path: Path | str) -> list[tuple[Path, str, str, str]]:
-    """
-    Load a Bert-VITS2 style filelist.
-    """
-
-    files = set()
-    results = []
-    count_duplicated, count_not_found = 0, 0
-
-    LANGUAGE_TO_LANGUAGES = {
-        "zh": ["zh", "en"],
-        "jp": ["jp", "en"],
-        "en": ["en"],
-    }
-
-    with open(path, "r", encoding="utf-8") as f:
-        for line in f.readlines():
-            splits = line.strip().split("|", maxsplit=3)
-            if len(splits) != 4:
-                logger.warning(f"Invalid line: {line}")
-                continue
-
-            filename, speaker, language, text = splits
-            file = Path(filename)
-            language = language.strip().lower()
-
-            if language == "ja":
-                language = "jp"
-
-            assert language in ["zh", "jp", "en"], f"Invalid language {language}"
-            languages = LANGUAGE_TO_LANGUAGES[language]
-
-            if file in files:
-                logger.warning(f"Duplicated file: {file}")
-                count_duplicated += 1
-                continue
-
-            if not file.exists():
-                logger.warning(f"File not found: {file}")
-                count_not_found += 1
-                continue
-
-            results.append((file, speaker, languages, text))
-
-    if count_duplicated > 0:
-        logger.warning(f"Total duplicated files: {count_duplicated}")
-
-    if count_not_found > 0:
-        logger.warning(f"Total files not found: {count_not_found}")
-
-    return results

+ 4 - 2
pyproject.toml

@@ -38,9 +38,11 @@ dependencies = [
     "zstandard>=0.22.0",
     "pydub",
     "faster_whisper",
-    "modelscope==1.16.1",
-    "funasr==1.1.2",
+    "modelscope==1.17.1",
+    "funasr==1.1.5",
     "opencc-python-reimplemented==0.1.7",
+    "audio-seperator[gpu]==0.18.3",
+    "silero-vad",
 ]
 
 [project.optional-dependencies]

+ 1 - 1
run_cmd.bat

@@ -29,7 +29,7 @@ set INSTALL_ENV_DIR=%cd%\fishenv\env
 
 
 set PYTHONNOUSERSITE=1
-set PYTHONPATH=
+set PYTHONPATH=%~dp0
 set PYTHONHOME=
 
 

+ 14 - 2
tools/api.py

@@ -3,6 +3,7 @@ import io
 import json
 import queue
 import random
+import sys
 import traceback
 import wave
 from argparse import ArgumentParser
@@ -10,11 +11,11 @@ from http import HTTPStatus
 from pathlib import Path
 from typing import Annotated, Literal, Optional
 
-import librosa
 import numpy as np
 import pyrootutils
 import soundfile as sf
 import torch
+import torchaudio
 from kui.asgi import (
     Body,
     HTTPException,
@@ -87,7 +88,18 @@ def load_audio(reference_audio, sr):
         except base64.binascii.Error:
             raise ValueError("Invalid path or base64 string")
 
-    audio, _ = librosa.load(reference_audio, sr=sr, mono=True)
+    waveform, original_sr = torchaudio.load(
+        reference_audio, backend="sox" if sys.platform == "linux" else "soundfile"
+    )
+
+    if waveform.shape[0] > 1:
+        waveform = torch.mean(waveform, dim=0, keepdim=True)
+
+    if original_sr != sr:
+        resampler = torchaudio.transforms.Resample(orig_freq=original_sr, new_freq=sr)
+        waveform = resampler(waveform)
+
+    audio = waveform.squeeze().numpy()
     return audio
 
 

+ 108 - 0
tools/file.py

@@ -0,0 +1,108 @@
+from pathlib import Path
+from typing import Union
+
+from loguru import logger
+from natsort import natsorted
+
+AUDIO_EXTENSIONS = {
+    ".mp3",
+    ".wav",
+    ".flac",
+    ".ogg",
+    ".m4a",
+    ".wma",
+    ".aac",
+    ".aiff",
+    ".aif",
+    ".aifc",
+}
+
+VIDEO_EXTENSIONS = {
+    ".mp4",
+    ".avi",
+}
+
+
+def list_files(
+    path: Union[Path, str],
+    extensions: set[str] = None,
+    recursive: bool = False,
+    sort: bool = True,
+) -> list[Path]:
+    """List files in a directory.
+
+    Args:
+        path (Path): Path to the directory.
+        extensions (set, optional): Extensions to filter. Defaults to None.
+        recursive (bool, optional): Whether to search recursively. Defaults to False.
+        sort (bool, optional): Whether to sort the files. Defaults to True.
+
+    Returns:
+        list: List of files.
+    """
+
+    if isinstance(path, str):
+        path = Path(path)
+
+    if not path.exists():
+        raise FileNotFoundError(f"Directory {path} does not exist.")
+
+    files = [file for ext in extensions for file in path.rglob(f"*{ext}")]
+
+    if sort:
+        files = natsorted(files)
+
+    return files
+
+
+def load_filelist(path: Path | str) -> list[tuple[Path, str, str, str]]:
+    """
+    Load a Bert-VITS2 style filelist.
+    """
+
+    files = set()
+    results = []
+    count_duplicated, count_not_found = 0, 0
+
+    LANGUAGE_TO_LANGUAGES = {
+        "zh": ["zh", "en"],
+        "jp": ["jp", "en"],
+        "en": ["en"],
+    }
+
+    with open(path, "r", encoding="utf-8") as f:
+        for line in f.readlines():
+            splits = line.strip().split("|", maxsplit=3)
+            if len(splits) != 4:
+                logger.warning(f"Invalid line: {line}")
+                continue
+
+            filename, speaker, language, text = splits
+            file = Path(filename)
+            language = language.strip().lower()
+
+            if language == "ja":
+                language = "jp"
+
+            assert language in ["zh", "jp", "en"], f"Invalid language {language}"
+            languages = LANGUAGE_TO_LANGUAGES[language]
+
+            if file in files:
+                logger.warning(f"Duplicated file: {file}")
+                count_duplicated += 1
+                continue
+
+            if not file.exists():
+                logger.warning(f"File not found: {file}")
+                count_not_found += 1
+                continue
+
+            results.append((file, speaker, languages, text))
+
+    if count_duplicated > 0:
+        logger.warning(f"Total duplicated files: {count_duplicated}")
+
+    if count_not_found > 0:
+        logger.warning(f"Total files not found: {count_not_found}")
+
+    return results

+ 1 - 1
tools/merge_asr_files.py

@@ -4,7 +4,7 @@ from pathlib import Path
 from pydub import AudioSegment
 from tqdm import tqdm
 
-from fish_speech.utils.file import AUDIO_EXTENSIONS, list_files
+from tools.file import AUDIO_EXTENSIONS, list_files
 
 
 def merge_and_delete_files(save_dir, original_files):

+ 59 - 0
tools/sensevoice/README.md

@@ -0,0 +1,59 @@
+# FunASR Command Line Interface
+
+This tool provides a command-line interface for separating vocals from instrumental tracks, converting videos to audio, and performing speech-to-text transcription on the resulting audio files.
+
+## Requirements
+
+- Python >= 3.10
+- PyTorch <= 2.3.1
+- ffmpeg, pydub, audio-separator[gpu].
+
+## Installation
+
+Install the required packages:
+
+```bash
+pip install -e .[stable]
+```
+
+Make sure you have `ffmpeg` installed and available in your `PATH`.
+
+## Usage
+
+### Basic Usage
+
+To run the tool with default settings:
+
+```bash
+python tools/sensevoice/fun_asr.py --audio-dir <audio_directory> --save-dir <output_directory>
+```
+
+## Options
+
+|          Option           |                                  Description                                  |
+| :-----------------------: | :---------------------------------------------------------------------------: |
+|        --audio-dir        |                  Directory containing audio or video files.                   |
+|        --save-dir         |                   Directory to save processed audio files.                    |
+|         --device          |         Device to use for processing. Options: cuda (default) or cpu.         |
+|        --language         |                Language of the transcription. Default is auto.                |
+| --max_single_segment_time | Maximum duration of a single audio segment in milliseconds. Default is 20000. |
+|          --punc           |                        Enable punctuation prediction.                         |
+|         --denoise         |                  Enable noise reduction (vocal separation).                   |
+
+## Example
+
+To process audio files in the directory `path/to/audio` and save the output to `path/to/output`, with punctuation and noise reduction enabled:
+
+```bash
+python tools/sensevoice/fun_asr.py --audio-dir path/to/audio --save-dir path/to/output --punc --denoise
+```
+
+## Additional Notes
+
+- The tool supports `both audio and video files`. Videos will be converted to audio automatically.
+- If the `--denoise` option is used, the tool will perform vocal separation to isolate the vocals from the instrumental tracks.
+- The script will automatically create necessary directories in the `--save-dir`.
+
+## Troubleshooting
+
+If you encounter any issues, make sure all dependencies are correctly installed and configured. For more detailed troubleshooting, refer to the documentation of each dependency.

+ 0 - 0
tools/sensevoice/__init__.py


+ 573 - 0
tools/sensevoice/auto_model.py

@@ -0,0 +1,573 @@
+#!/usr/bin/env python3
+# -*- encoding: utf-8 -*-
+# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
+#  MIT License  (https://opensource.org/licenses/MIT)
+
+import copy
+import json
+import logging
+import os.path
+import random
+import re
+import string
+import time
+
+import numpy as np
+import torch
+from funasr.download.download_model_from_hub import download_model
+from funasr.download.file import download_from_url
+from funasr.register import tables
+from funasr.train_utils.load_pretrained_model import load_pretrained_model
+from funasr.train_utils.set_all_random_seed import set_all_random_seed
+from funasr.utils import export_utils, misc
+from funasr.utils.load_utils import load_audio_text_image_video, load_bytes
+from funasr.utils.misc import deep_update
+from funasr.utils.timestamp_tools import timestamp_sentence, timestamp_sentence_en
+from tqdm import tqdm
+
+from .vad_utils import merge_vad, slice_padding_audio_samples
+
+try:
+    from funasr.models.campplus.cluster_backend import ClusterBackend
+    from funasr.models.campplus.utils import distribute_spk, postprocess, sv_chunk
+except:
+    pass
+
+
+def prepare_data_iterator(data_in, input_len=None, data_type=None, key=None):
+    """ """
+    data_list = []
+    key_list = []
+    filelist = [".scp", ".txt", ".json", ".jsonl", ".text"]
+
+    chars = string.ascii_letters + string.digits
+    if isinstance(data_in, str):
+        if data_in.startswith("http://") or data_in.startswith("https://"):  # url
+            data_in = download_from_url(data_in)
+
+    if isinstance(data_in, str) and os.path.exists(
+        data_in
+    ):  # wav_path; filelist: wav.scp, file.jsonl;text.txt;
+        _, file_extension = os.path.splitext(data_in)
+        file_extension = file_extension.lower()
+        if file_extension in filelist:  # filelist: wav.scp, file.jsonl;text.txt;
+            with open(data_in, encoding="utf-8") as fin:
+                for line in fin:
+                    key = "rand_key_" + "".join(random.choice(chars) for _ in range(13))
+                    if data_in.endswith(
+                        ".jsonl"
+                    ):  # file.jsonl: json.dumps({"source": data})
+                        lines = json.loads(line.strip())
+                        data = lines["source"]
+                        key = data["key"] if "key" in data else key
+                    else:  # filelist, wav.scp, text.txt: id \t data or data
+                        lines = line.strip().split(maxsplit=1)
+                        data = lines[1] if len(lines) > 1 else lines[0]
+                        key = lines[0] if len(lines) > 1 else key
+
+                    data_list.append(data)
+                    key_list.append(key)
+        else:
+            if key is None:
+                # key = "rand_key_" + "".join(random.choice(chars) for _ in range(13))
+                key = misc.extract_filename_without_extension(data_in)
+            data_list = [data_in]
+            key_list = [key]
+    elif isinstance(data_in, (list, tuple)):
+        if data_type is not None and isinstance(
+            data_type, (list, tuple)
+        ):  # mutiple inputs
+            data_list_tmp = []
+            for data_in_i, data_type_i in zip(data_in, data_type):
+                key_list, data_list_i = prepare_data_iterator(
+                    data_in=data_in_i, data_type=data_type_i
+                )
+                data_list_tmp.append(data_list_i)
+            data_list = []
+            for item in zip(*data_list_tmp):
+                data_list.append(item)
+        else:
+            # [audio sample point, fbank, text]
+            data_list = data_in
+            key_list = []
+            for data_i in data_in:
+                if isinstance(data_i, str) and os.path.exists(data_i):
+                    key = misc.extract_filename_without_extension(data_i)
+                else:
+                    if key is None:
+                        key = "rand_key_" + "".join(
+                            random.choice(chars) for _ in range(13)
+                        )
+                key_list.append(key)
+
+    else:  # raw text; audio sample point, fbank; bytes
+        if isinstance(data_in, bytes):  # audio bytes
+            data_in = load_bytes(data_in)
+        if key is None:
+            key = "rand_key_" + "".join(random.choice(chars) for _ in range(13))
+        data_list = [data_in]
+        key_list = [key]
+
+    return key_list, data_list
+
+
+class AutoModel:
+
+    def __init__(self, **kwargs):
+
+        try:
+            from funasr.utils.version_checker import check_for_update
+
+            print(
+                "Check update of funasr, and it would cost few times. You may disable it by set `disable_update=True` in AutoModel"
+            )
+            check_for_update(disable=kwargs.get("disable_update", False))
+        except:
+            pass
+
+        log_level = getattr(logging, kwargs.get("log_level", "INFO").upper())
+        logging.basicConfig(level=log_level)
+
+        model, kwargs = self.build_model(**kwargs)
+
+        # if vad_model is not None, build vad model else None
+        vad_model = kwargs.get("vad_model", None)
+        vad_kwargs = (
+            {} if kwargs.get("vad_kwargs", {}) is None else kwargs.get("vad_kwargs", {})
+        )
+        if vad_model is not None:
+            logging.info("Building VAD model.")
+            vad_kwargs["model"] = vad_model
+            vad_kwargs["model_revision"] = kwargs.get("vad_model_revision", "master")
+            vad_kwargs["device"] = kwargs["device"]
+            vad_model, vad_kwargs = self.build_model(**vad_kwargs)
+
+        # if punc_model is not None, build punc model else None
+        punc_model = kwargs.get("punc_model", None)
+        punc_kwargs = (
+            {}
+            if kwargs.get("punc_kwargs", {}) is None
+            else kwargs.get("punc_kwargs", {})
+        )
+        if punc_model is not None:
+            logging.info("Building punc model.")
+            punc_kwargs["model"] = punc_model
+            punc_kwargs["model_revision"] = kwargs.get("punc_model_revision", "master")
+            punc_kwargs["device"] = kwargs["device"]
+            punc_model, punc_kwargs = self.build_model(**punc_kwargs)
+
+        # if spk_model is not None, build spk model else None
+        spk_model = kwargs.get("spk_model", None)
+        spk_kwargs = (
+            {} if kwargs.get("spk_kwargs", {}) is None else kwargs.get("spk_kwargs", {})
+        )
+        if spk_model is not None:
+            logging.info("Building SPK model.")
+            spk_kwargs["model"] = spk_model
+            spk_kwargs["model_revision"] = kwargs.get("spk_model_revision", "master")
+            spk_kwargs["device"] = kwargs["device"]
+            spk_model, spk_kwargs = self.build_model(**spk_kwargs)
+            self.cb_model = ClusterBackend().to(kwargs["device"])
+            spk_mode = kwargs.get("spk_mode", "punc_segment")
+            if spk_mode not in ["default", "vad_segment", "punc_segment"]:
+                logging.error(
+                    "spk_mode should be one of default, vad_segment and punc_segment."
+                )
+            self.spk_mode = spk_mode
+
+        self.kwargs = kwargs
+        self.model = model
+        self.vad_model = vad_model
+        self.vad_kwargs = vad_kwargs
+        self.punc_model = punc_model
+        self.punc_kwargs = punc_kwargs
+        self.spk_model = spk_model
+        self.spk_kwargs = spk_kwargs
+        self.model_path = kwargs.get("model_path")
+
+    @staticmethod
+    def build_model(**kwargs):
+        assert "model" in kwargs
+        if "model_conf" not in kwargs:
+            logging.info(
+                "download models from model hub: {}".format(kwargs.get("hub", "ms"))
+            )
+            kwargs = download_model(**kwargs)
+
+        set_all_random_seed(kwargs.get("seed", 0))
+
+        device = kwargs.get("device", "cuda")
+        if not torch.cuda.is_available() or kwargs.get("ngpu", 1) == 0:
+            device = "cpu"
+            kwargs["batch_size"] = 1
+        kwargs["device"] = device
+
+        torch.set_num_threads(kwargs.get("ncpu", 4))
+
+        # build tokenizer
+        tokenizer = kwargs.get("tokenizer", None)
+        if tokenizer is not None:
+            tokenizer_class = tables.tokenizer_classes.get(tokenizer)
+            tokenizer = tokenizer_class(**kwargs.get("tokenizer_conf", {}))
+            kwargs["token_list"] = (
+                tokenizer.token_list if hasattr(tokenizer, "token_list") else None
+            )
+            kwargs["token_list"] = (
+                tokenizer.get_vocab()
+                if hasattr(tokenizer, "get_vocab")
+                else kwargs["token_list"]
+            )
+            vocab_size = (
+                len(kwargs["token_list"]) if kwargs["token_list"] is not None else -1
+            )
+            if vocab_size == -1 and hasattr(tokenizer, "get_vocab_size"):
+                vocab_size = tokenizer.get_vocab_size()
+        else:
+            vocab_size = -1
+        kwargs["tokenizer"] = tokenizer
+
+        # build frontend
+        frontend = kwargs.get("frontend", None)
+        kwargs["input_size"] = None
+        if frontend is not None:
+            frontend_class = tables.frontend_classes.get(frontend)
+            frontend = frontend_class(**kwargs.get("frontend_conf", {}))
+            kwargs["input_size"] = (
+                frontend.output_size() if hasattr(frontend, "output_size") else None
+            )
+        kwargs["frontend"] = frontend
+        # build model
+        model_class = tables.model_classes.get(kwargs["model"])
+        assert model_class is not None, f'{kwargs["model"]} is not registered'
+        model_conf = {}
+        deep_update(model_conf, kwargs.get("model_conf", {}))
+        deep_update(model_conf, kwargs)
+        model = model_class(**model_conf, vocab_size=vocab_size)
+
+        # init_param
+        init_param = kwargs.get("init_param", None)
+        if init_param is not None:
+            if os.path.exists(init_param):
+                logging.info(f"Loading pretrained params from {init_param}")
+                load_pretrained_model(
+                    model=model,
+                    path=init_param,
+                    ignore_init_mismatch=kwargs.get("ignore_init_mismatch", True),
+                    oss_bucket=kwargs.get("oss_bucket", None),
+                    scope_map=kwargs.get("scope_map", []),
+                    excludes=kwargs.get("excludes", None),
+                )
+            else:
+                print(f"error, init_param does not exist!: {init_param}")
+
+        # fp16
+        if kwargs.get("fp16", False):
+            model.to(torch.float16)
+        elif kwargs.get("bf16", False):
+            model.to(torch.bfloat16)
+        model.to(device)
+
+        if not kwargs.get("disable_log", True):
+            tables.print()
+
+        return model, kwargs
+
+    def __call__(self, *args, **cfg):
+        kwargs = self.kwargs
+        deep_update(kwargs, cfg)
+        res = self.model(*args, kwargs)
+        return res
+
+    def generate(self, input, input_len=None, **cfg):
+        if self.vad_model is None:
+            return self.inference(input, input_len=input_len, **cfg)
+
+        else:
+            return self.inference_with_vad(input, input_len=input_len, **cfg)
+
+    def inference(
+        self, input, input_len=None, model=None, kwargs=None, key=None, **cfg
+    ):
+        kwargs = self.kwargs if kwargs is None else kwargs
+        if "cache" in kwargs:
+            kwargs.pop("cache")
+        deep_update(kwargs, cfg)
+        model = self.model if model is None else model
+        model.eval()
+
+        batch_size = kwargs.get("batch_size", 1)
+        # if kwargs.get("device", "cpu") == "cpu":
+        #     batch_size = 1
+
+        key_list, data_list = prepare_data_iterator(
+            input, input_len=input_len, data_type=kwargs.get("data_type", None), key=key
+        )
+
+        speed_stats = {}
+        asr_result_list = []
+        num_samples = len(data_list)
+        disable_pbar = self.kwargs.get("disable_pbar", False)
+        pbar = (
+            tqdm(colour="blue", total=num_samples, dynamic_ncols=True)
+            if not disable_pbar
+            else None
+        )
+        time_speech_total = 0.0
+        time_escape_total = 0.0
+        for beg_idx in range(0, num_samples, batch_size):
+            end_idx = min(num_samples, beg_idx + batch_size)
+            data_batch = data_list[beg_idx:end_idx]
+            key_batch = key_list[beg_idx:end_idx]
+            batch = {"data_in": data_batch, "key": key_batch}
+
+            if (end_idx - beg_idx) == 1 and kwargs.get(
+                "data_type", None
+            ) == "fbank":  # fbank
+                batch["data_in"] = data_batch[0]
+                batch["data_lengths"] = input_len
+
+            time1 = time.perf_counter()
+            with torch.no_grad():
+                res = model.inference(**batch, **kwargs)
+                if isinstance(res, (list, tuple)):
+                    results = res[0] if len(res) > 0 else [{"text": ""}]
+                    meta_data = res[1] if len(res) > 1 else {}
+            time2 = time.perf_counter()
+
+            asr_result_list.extend(results)
+
+            # batch_data_time = time_per_frame_s * data_batch_i["speech_lengths"].sum().item()
+            batch_data_time = meta_data.get("batch_data_time", -1)
+            time_escape = time2 - time1
+            speed_stats["load_data"] = meta_data.get("load_data", 0.0)
+            speed_stats["extract_feat"] = meta_data.get("extract_feat", 0.0)
+            speed_stats["forward"] = f"{time_escape:0.3f}"
+            speed_stats["batch_size"] = f"{len(results)}"
+            speed_stats["rtf"] = f"{(time_escape) / batch_data_time:0.3f}"
+            description = f"{speed_stats}, "
+            if pbar:
+                pbar.update(end_idx - beg_idx)
+                pbar.set_description(description)
+            time_speech_total += batch_data_time
+            time_escape_total += time_escape
+
+        if pbar:
+            # pbar.update(1)
+            pbar.set_description(f"rtf_avg: {time_escape_total/time_speech_total:0.3f}")
+        torch.cuda.empty_cache()
+        return asr_result_list
+
+    def vad(self, input, input_len=None, **cfg):
+        kwargs = self.kwargs
+        # step.1: compute the vad model
+        deep_update(self.vad_kwargs, cfg)
+        beg_vad = time.time()
+        res = self.inference(
+            input,
+            input_len=input_len,
+            model=self.vad_model,
+            kwargs=self.vad_kwargs,
+            **cfg,
+        )
+        end_vad = time.time()
+        #  FIX(gcf): concat the vad clips for sense vocie model for better aed
+        if cfg.get("merge_vad", False):
+            for i in range(len(res)):
+                res[i]["value"] = merge_vad(
+                    res[i]["value"], kwargs.get("merge_length_s", 15) * 1000
+                )
+        elapsed = end_vad - beg_vad
+        return elapsed, res
+
+    def inference_with_vadres(self, input, vad_res, input_len=None, **cfg):
+
+        kwargs = self.kwargs
+
+        # step.2 compute asr model
+        model = self.model
+        deep_update(kwargs, cfg)
+        batch_size = max(int(kwargs.get("batch_size_s", 300)) * 1000, 1)
+        batch_size_threshold_ms = int(kwargs.get("batch_size_threshold_s", 60)) * 1000
+        kwargs["batch_size"] = batch_size
+
+        key_list, data_list = prepare_data_iterator(
+            input, input_len=input_len, data_type=kwargs.get("data_type", None)
+        )
+        results_ret_list = []
+        time_speech_total_all_samples = 1e-6
+
+        beg_total = time.time()
+        pbar_total = (
+            tqdm(colour="red", total=len(vad_res), dynamic_ncols=True)
+            if not kwargs.get("disable_pbar", False)
+            else None
+        )
+
+        for i in range(len(vad_res)):
+            key = vad_res[i]["key"]
+            vadsegments = vad_res[i]["value"]
+            input_i = data_list[i]
+            fs = kwargs["frontend"].fs if hasattr(kwargs["frontend"], "fs") else 16000
+            speech = load_audio_text_image_video(
+                input_i, fs=fs, audio_fs=kwargs.get("fs", 16000)
+            )
+            speech_lengths = len(speech)
+            n = len(vadsegments)
+            data_with_index = [(vadsegments[i], i) for i in range(n)]
+            sorted_data = sorted(data_with_index, key=lambda x: x[0][1] - x[0][0])
+            results_sorted = []
+
+            if not len(sorted_data):
+                results_ret_list.append({"key": key, "text": "", "timestamp": []})
+                logging.info("decoding, utt: {}, empty speech".format(key))
+                continue
+
+            if len(sorted_data) > 0 and len(sorted_data[0]) > 0:
+                batch_size = max(
+                    batch_size, sorted_data[0][0][1] - sorted_data[0][0][0]
+                )
+
+            if kwargs["device"] == "cpu":
+                batch_size = 0
+
+            beg_idx = 0
+            beg_asr_total = time.time()
+            time_speech_total_per_sample = speech_lengths / 16000
+            time_speech_total_all_samples += time_speech_total_per_sample
+
+            # pbar_sample = tqdm(colour="blue", total=n, dynamic_ncols=True)
+
+            all_segments = []
+            max_len_in_batch = 0
+            end_idx = 1
+
+            for j, _ in enumerate(range(0, n)):
+                # pbar_sample.update(1)
+                sample_length = sorted_data[j][0][1] - sorted_data[j][0][0]
+                potential_batch_length = max(max_len_in_batch, sample_length) * (
+                    j + 1 - beg_idx
+                )
+                # batch_size_ms_cum += sorted_data[j][0][1] - sorted_data[j][0][0]
+                if (
+                    j < n - 1
+                    and sample_length < batch_size_threshold_ms
+                    and potential_batch_length < batch_size
+                ):
+                    max_len_in_batch = max(max_len_in_batch, sample_length)
+                    end_idx += 1
+                    continue
+
+                speech_j, speech_lengths_j, intervals = slice_padding_audio_samples(
+                    speech, speech_lengths, sorted_data[beg_idx:end_idx]
+                )
+                results = self.inference(
+                    speech_j, input_len=None, model=model, kwargs=kwargs, **cfg
+                )
+
+                for _b in range(len(speech_j)):
+                    results[_b]["interval"] = intervals[_b]
+
+                if self.spk_model is not None:
+                    # compose vad segments: [[start_time_sec, end_time_sec, speech], [...]]
+                    for _b in range(len(speech_j)):
+                        vad_segments = [
+                            [
+                                sorted_data[beg_idx:end_idx][_b][0][0] / 1000.0,
+                                sorted_data[beg_idx:end_idx][_b][0][1] / 1000.0,
+                                np.array(speech_j[_b]),
+                            ]
+                        ]
+                        segments = sv_chunk(vad_segments)
+                        all_segments.extend(segments)
+                        speech_b = [i[2] for i in segments]
+                        spk_res = self.inference(
+                            speech_b,
+                            input_len=None,
+                            model=self.spk_model,
+                            kwargs=kwargs,
+                            **cfg,
+                        )
+                        results[_b]["spk_embedding"] = spk_res[0]["spk_embedding"]
+
+                beg_idx = end_idx
+                end_idx += 1
+                max_len_in_batch = sample_length
+                if len(results) < 1:
+                    continue
+                results_sorted.extend(results)
+
+            # end_asr_total = time.time()
+            # time_escape_total_per_sample = end_asr_total - beg_asr_total
+            # pbar_sample.update(1)
+            # pbar_sample.set_description(f"rtf_avg_per_sample: {time_escape_total_per_sample / time_speech_total_per_sample:0.3f}, "
+            #                      f"time_speech_total_per_sample: {time_speech_total_per_sample: 0.3f}, "
+            #                      f"time_escape_total_per_sample: {time_escape_total_per_sample:0.3f}")
+
+            restored_data = [0] * n
+            for j in range(n):
+                index = sorted_data[j][1]
+                cur = results_sorted[j]
+                pattern = r"<\|([^|]+)\|>"
+                emotion_string = re.findall(pattern, cur["text"])
+                cur["text"] = re.sub(pattern, "", cur["text"])
+                cur["emo"] = "".join([f"<|{t}|>" for t in emotion_string])
+                if self.punc_model is not None and len(cur["text"].strip()) > 0:
+                    deep_update(self.punc_kwargs, cfg)
+                    punc_res = self.inference(
+                        cur["text"],
+                        model=self.punc_model,
+                        kwargs=self.punc_kwargs,
+                        **cfg,
+                    )
+                    cur["text"] = punc_res[0]["text"]
+
+                restored_data[index] = cur
+
+            end_asr_total = time.time()
+            time_escape_total_per_sample = end_asr_total - beg_asr_total
+            if pbar_total:
+                pbar_total.update(1)
+                pbar_total.set_description(
+                    f"rtf_avg: {time_escape_total_per_sample / time_speech_total_per_sample:0.3f}, "
+                    f"time_speech: {time_speech_total_per_sample: 0.3f}, "
+                    f"time_escape: {time_escape_total_per_sample:0.3f}"
+                )
+
+        # end_total = time.time()
+        # time_escape_total_all_samples = end_total - beg_total
+        # print(f"rtf_avg_all: {time_escape_total_all_samples / time_speech_total_all_samples:0.3f}, "
+        #                      f"time_speech_all: {time_speech_total_all_samples: 0.3f}, "
+        #                      f"time_escape_all: {time_escape_total_all_samples:0.3f}")
+        return restored_data
+
+    def export(self, input=None, **cfg):
+        """
+
+        :param input:
+        :param type:
+        :param quantize:
+        :param fallback_num:
+        :param calib_num:
+        :param opset_version:
+        :param cfg:
+        :return:
+        """
+
+        device = cfg.get("device", "cpu")
+        model = self.model.to(device=device)
+        kwargs = self.kwargs
+        deep_update(kwargs, cfg)
+        kwargs["device"] = device
+        del kwargs["model"]
+        model.eval()
+
+        type = kwargs.get("type", "onnx")
+
+        key_list, data_list = prepare_data_iterator(
+            input, input_len=None, data_type=kwargs.get("data_type", None), key=None
+        )
+
+        with torch.no_grad():
+            export_dir = export_utils.export(model=model, data_in=data_list, **kwargs)
+
+        return export_dir

+ 332 - 0
tools/sensevoice/fun_asr.py

@@ -0,0 +1,332 @@
+import gc
+import os
+import re
+
+from audio_separator.separator import Separator
+
+os.environ["MODELSCOPE_CACHE"] = "./.cache/funasr"
+os.environ["UVR5_CACHE"] = "./.cache/uvr5-models"
+import json
+import subprocess
+from pathlib import Path
+
+import click
+import torch
+from loguru import logger
+from pydub import AudioSegment
+from silero_vad import get_speech_timestamps, load_silero_vad, read_audio
+from tqdm import tqdm
+
+from tools.file import AUDIO_EXTENSIONS, VIDEO_EXTENSIONS, list_files
+from tools.sensevoice.auto_model import AutoModel
+
+
+def uvr5_cli(
+    audio_dir: Path,
+    output_folder: Path,
+    audio_files: list[Path] | None = None,
+    output_format: str = "flac",
+    model: str = "BS-Roformer-Viperx-1296.ckpt",
+):
+    # ["BS-Roformer-Viperx-1297.ckpt", "BS-Roformer-Viperx-1296.ckpt", "BS-Roformer-Viperx-1053.ckpt", "Mel-Roformer-Viperx-1143.ckpt"]
+    sepr = Separator(
+        model_file_dir=os.environ["UVR5_CACHE"],
+        output_dir=output_folder,
+        output_format=output_format,
+    )
+    dictmodel = {
+        "BS-Roformer-Viperx-1297.ckpt": "model_bs_roformer_ep_317_sdr_12.9755.ckpt",
+        "BS-Roformer-Viperx-1296.ckpt": "model_bs_roformer_ep_368_sdr_12.9628.ckpt",
+        "BS-Roformer-Viperx-1053.ckpt": "model_bs_roformer_ep_937_sdr_10.5309.ckpt",
+        "Mel-Roformer-Viperx-1143.ckpt": "model_mel_band_roformer_ep_3005_sdr_11.4360.ckpt",
+    }
+    roformer_model = dictmodel[model]
+    sepr.load_model(roformer_model)
+    if audio_files is None:
+        audio_files = list_files(
+            path=audio_dir, extensions=AUDIO_EXTENSIONS, recursive=True
+        )
+    total_files = len(audio_files)
+
+    print(f"{total_files} audio files found")
+
+    res = []
+    for audio in tqdm(audio_files, desc="Denoising: "):
+        file_path = str(audio_dir / audio)
+        sep_out = sepr.separate(file_path)
+        if isinstance(sep_out, str):
+            res.append(sep_out)
+        elif isinstance(sep_out, list):
+            res.extend(sep_out)
+    del sepr
+    gc.collect()
+    if torch.cuda.is_available():
+        torch.cuda.empty_cache()
+
+    return res, roformer_model
+
+
+def get_sample_rate(media_path: Path):
+    result = subprocess.run(
+        [
+            "ffprobe",
+            "-v",
+            "quiet",
+            "-print_format",
+            "json",
+            "-show_streams",
+            str(media_path),
+        ],
+        capture_output=True,
+        text=True,
+        check=True,
+    )
+    media_info = json.loads(result.stdout)
+    for stream in media_info.get("streams", []):
+        if stream.get("codec_type") == "audio":
+            return stream.get("sample_rate")
+    return "44100"  # Default sample rate if not found
+
+
+def convert_to_mono(src_path: Path, out_path: Path, out_fmt: str = "wav"):
+    sr = get_sample_rate(src_path)
+    out_path.parent.mkdir(parents=True, exist_ok=True)
+    if src_path.resolve() == out_path.resolve():
+        output = str(out_path.with_stem(out_path.stem + f"_{sr}"))
+    else:
+        output = str(out_path)
+    subprocess.run(
+        [
+            "ffmpeg",
+            "-loglevel",
+            "error",
+            "-i",
+            str(src_path),
+            "-acodec",
+            "pcm_s16le" if out_fmt == "wav" else "flac",
+            "-ar",
+            sr,
+            "-ac",
+            "1",
+            "-y",
+            output,
+        ],
+        check=True,
+    )
+    return out_path
+
+
+def convert_video_to_audio(video_path: Path, audio_dir: Path):
+    cur_dir = audio_dir / video_path.relative_to(audio_dir).parent
+    vocals = [
+        p
+        for p in cur_dir.glob(f"{video_path.stem}_(Vocals)*.*")
+        if p.suffix in AUDIO_EXTENSIONS
+    ]
+    if len(vocals) > 0:
+        return vocals[0]
+    audio_path = cur_dir / f"{video_path.stem}.wav"
+    convert_to_mono(video_path, audio_path)
+    return audio_path
+
+
+@click.command()
+@click.option("--audio-dir", required=True, help="Directory containing audio files")
+@click.option(
+    "--save-dir", required=True, help="Directory to save processed audio files"
+)
+@click.option("--device", default="cuda", help="Device to use [cuda / cpu]")
+@click.option("--language", default="auto", help="Language of the transcription")
+@click.option(
+    "--max_single_segment_time",
+    default=20000,
+    type=int,
+    help="Maximum of Output single audio duration(ms)",
+)
+@click.option("--fsmn-vad/--silero-vad", default=False)
+@click.option("--punc/--no-punc", default=False)
+@click.option("--denoise/--no-denoise", default=False)
+@click.option("--save_emo/--no_save_emo", default=False)
+def main(
+    audio_dir: str,
+    save_dir: str,
+    device: str,
+    language: str,
+    max_single_segment_time: int,
+    fsmn_vad: bool,
+    punc: bool,
+    denoise: bool,
+    save_emo: bool,
+):
+
+    audios_path = Path(audio_dir)
+    save_path = Path(save_dir)
+    save_path.mkdir(parents=True, exist_ok=True)
+
+    video_files = list_files(
+        path=audio_dir, extensions=VIDEO_EXTENSIONS, recursive=True
+    )
+    v2a_files = [convert_video_to_audio(p, audio_dir) for p in video_files]
+
+    if denoise:
+        VOCAL = "_(Vocals)"
+        original_files = [
+            p
+            for p in audios_path.glob("**/*")
+            if p.suffix in AUDIO_EXTENSIONS and VOCAL not in p.stem
+        ]
+
+        _, cur_model = uvr5_cli(
+            audio_dir=audio_dir, output_folder=audio_dir, audio_files=original_files
+        )
+        need_remove = [p for p in audios_path.glob("**/*(Instrumental)*")]
+        need_remove.extend(original_files)
+        for _ in need_remove:
+            _.unlink()
+        vocal_files = [
+            p
+            for p in audios_path.glob("**/*")
+            if p.suffix in AUDIO_EXTENSIONS and VOCAL in p.stem
+        ]
+        for f in vocal_files:
+            fn, ext = f.stem, f.suffix
+
+            v_pos = fn.find(VOCAL + "_" + cur_model.split(".")[0])
+            if v_pos != -1:
+                new_fn = fn[: v_pos + len(VOCAL)]
+                new_f = f.with_name(new_fn + ext)
+                f = f.rename(new_f)
+                convert_to_mono(f, f, "flac")
+                f.unlink()
+
+    audio_files = list_files(
+        path=audio_dir, extensions=AUDIO_EXTENSIONS, recursive=True
+    )
+
+    logger.info("Loading / Downloading Funasr model...")
+
+    model_dir = "iic/SenseVoiceSmall"
+
+    vad_model = "fsmn-vad" if fsmn_vad else None
+    vad_kwargs = {"max_single_segment_time": max_single_segment_time}
+    punc_model = "ct-punc" if punc else None
+
+    manager = AutoModel(
+        model=model_dir,
+        trust_remote_code=False,
+        vad_model=vad_model,
+        vad_kwargs=vad_kwargs,
+        punc_model=punc_model,
+        device=device,
+    )
+
+    if not fsmn_vad and vad_model is None:
+        vad_model = load_silero_vad()
+
+    logger.info("Model loaded.")
+
+    pattern = re.compile(r"_\d{3}\.")
+
+    for file_path in tqdm(audio_files, desc="Processing audio file"):
+
+        if pattern.search(file_path.name):
+            # logger.info(f"Skipping {file_path} as it has already been processed.")
+            continue
+
+        file_stem = file_path.stem
+        file_suffix = file_path.suffix
+
+        rel_path = Path(file_path).relative_to(audio_dir)
+        (save_path / rel_path.parent).mkdir(parents=True, exist_ok=True)
+
+        audio = AudioSegment.from_file(file_path)
+
+        cfg = dict(
+            cache={},
+            language=language,  # "zh", "en", "yue", "ja", "ko", "nospeech"
+            use_itn=False,
+            batch_size_s=60,
+        )
+
+        if fsmn_vad:
+            elapsed, vad_res = manager.vad(input=str(file_path), **cfg)
+        else:
+            wav = read_audio(
+                str(file_path)
+            )  # backend (sox, soundfile, or ffmpeg) required!
+            audio_key = file_path.stem
+            audio_val = []
+            speech_timestamps = get_speech_timestamps(
+                wav,
+                vad_model,
+                max_speech_duration_s=max_single_segment_time // 1000,
+                return_seconds=True,
+            )
+
+            audio_val = [
+                [int(timestamp["start"] * 1000), int(timestamp["end"] * 1000)]
+                for timestamp in speech_timestamps
+            ]
+            vad_res = []
+            vad_res.append(dict(key=audio_key, value=audio_val))
+
+        res = manager.inference_with_vadres(
+            input=str(file_path), vad_res=vad_res, **cfg
+        )
+
+        for i, info in enumerate(res):
+            [start_ms, end_ms] = info["interval"]
+            text = info["text"]
+            emo = info["emo"]
+            sliced_audio = audio[start_ms:end_ms]
+            audio_save_path = (
+                save_path / rel_path.parent / f"{file_stem}_{i:03d}{file_suffix}"
+            )
+            sliced_audio.export(audio_save_path, format=file_suffix[1:])
+            print(f"Exported {audio_save_path}: {text}")
+
+            transcript_save_path = (
+                save_path / rel_path.parent / f"{file_stem}_{i:03d}.lab"
+            )
+            with open(
+                transcript_save_path,
+                "w",
+                encoding="utf-8",
+            ) as f:
+                f.write(text)
+
+            if save_emo:
+                emo_save_path = save_path / rel_path.parent / f"{file_stem}_{i:03d}.emo"
+                with open(
+                    emo_save_path,
+                    "w",
+                    encoding="utf-8",
+                ) as f:
+                    f.write(emo)
+
+        if audios_path.resolve() == save_path.resolve():
+            file_path.unlink()
+
+
+if __name__ == "__main__":
+    main()
+    exit(0)
+    from funasr.utils.postprocess_utils import rich_transcription_postprocess
+
+    # Load the audio file
+    audio_path = Path(r"D:\PythonProject\ok\1_output_(Vocals).wav")
+    model_dir = "iic/SenseVoiceSmall"
+    m, kwargs = SenseVoiceSmall.from_pretrained(model=model_dir, device="cuda:0")
+    m.eval()
+
+    res = m.inference(
+        data_in=f"{kwargs['model_path']}/example/zh.mp3",
+        language="auto",  # "zh", "en", "yue", "ja", "ko", "nospeech"
+        use_itn=False,
+        ban_emo_unk=False,
+        **kwargs,
+    )
+
+    print(res)
+    text = rich_transcription_postprocess(res[0][0]["text"])
+    print(text)

+ 61 - 0
tools/sensevoice/vad_utils.py

@@ -0,0 +1,61 @@
+import torch
+from torch.nn.utils.rnn import pad_sequence
+
+
+def slice_padding_fbank(speech, speech_lengths, vad_segments):
+    speech_list = []
+    speech_lengths_list = []
+    for i, segment in enumerate(vad_segments):
+
+        bed_idx = int(segment[0][0] * 16)
+        end_idx = min(int(segment[0][1] * 16), speech_lengths[0])
+        speech_i = speech[0, bed_idx:end_idx]
+        speech_lengths_i = end_idx - bed_idx
+        speech_list.append(speech_i)
+        speech_lengths_list.append(speech_lengths_i)
+    feats_pad = pad_sequence(speech_list, batch_first=True, padding_value=0.0)
+    speech_lengths_pad = torch.Tensor(speech_lengths_list).int()
+    return feats_pad, speech_lengths_pad
+
+
+def slice_padding_audio_samples(speech, speech_lengths, vad_segments):
+    speech_list = []
+    speech_lengths_list = []
+    intervals = []
+    for i, segment in enumerate(vad_segments):
+        bed_idx = int(segment[0][0] * 16)
+        end_idx = min(int(segment[0][1] * 16), speech_lengths)
+        speech_i = speech[bed_idx:end_idx]
+        speech_lengths_i = end_idx - bed_idx
+        speech_list.append(speech_i)
+        speech_lengths_list.append(speech_lengths_i)
+        intervals.append([bed_idx // 16, end_idx // 16])
+
+    return speech_list, speech_lengths_list, intervals
+
+
+def merge_vad(vad_result, max_length=15000, min_length=0):
+    new_result = []
+    if len(vad_result) <= 1:
+        return vad_result
+    time_step = [t[0] for t in vad_result] + [t[1] for t in vad_result]
+    time_step = sorted(list(set(time_step)))
+    if len(time_step) == 0:
+        return []
+    bg = 0
+    for i in range(len(time_step) - 1):
+        time = time_step[i]
+        if time_step[i + 1] - bg < max_length:
+            continue
+        if time - bg > min_length:
+            new_result.append([bg, time])
+        # if time - bg < max_length * 1.5:
+        #     new_result.append([bg, time])
+        # else:
+        #     split_num = int(time - bg) // max_length + 1
+        #     spl_l = int(time - bg) // split_num
+        #     for j in range(split_num):
+        #         new_result.append([bg + j * spl_l, bg + (j + 1) * spl_l])
+        bg = time
+    new_result.append([bg, time_step[-1]])
+    return new_result

+ 1 - 1
tools/smart_pad.py

@@ -8,7 +8,7 @@ import torch.nn.functional as F
 import torchaudio
 from tqdm import tqdm
 
-from fish_speech.utils.file import AUDIO_EXTENSIONS, list_files
+from tools.file import AUDIO_EXTENSIONS, list_files
 
 threshold = 10 ** (-50 / 20.0)
 

+ 1 - 1
tools/vqgan/create_train_split.py

@@ -7,7 +7,7 @@ from loguru import logger
 from pydub import AudioSegment
 from tqdm import tqdm
 
-from fish_speech.utils.file import AUDIO_EXTENSIONS, list_files, load_filelist
+from tools.file import AUDIO_EXTENSIONS, list_files, load_filelist
 
 
 @click.command()

+ 1 - 1
tools/vqgan/extract_vq.py

@@ -17,7 +17,7 @@ from lightning import LightningModule
 from loguru import logger
 from omegaconf import OmegaConf
 
-from fish_speech.utils.file import AUDIO_EXTENSIONS, list_files, load_filelist
+from tools.file import AUDIO_EXTENSIONS, list_files, load_filelist
 
 # register eval resolver
 OmegaConf.register_new_resolver("eval", eval)

+ 20 - 35
tools/whisper_asr.py

@@ -32,7 +32,7 @@ from loguru import logger
 from pydub import AudioSegment
 from tqdm import tqdm
 
-from fish_speech.utils.file import AUDIO_EXTENSIONS, list_files
+from tools.file import AUDIO_EXTENSIONS, list_files
 
 
 @click.command()
@@ -83,8 +83,6 @@ def main(
         path=audio_dir, extensions=AUDIO_EXTENSIONS, recursive=True
     )
 
-    numbered_suffix_pattern = re.compile(r"-\d{3}$")
-
     for file_path in tqdm(audio_files, desc="Processing audio file"):
         file_stem = file_path.stem
         file_suffix = file_path.suffix
@@ -92,18 +90,6 @@ def main(
         rel_path = Path(file_path).relative_to(audio_dir)
         (save_path / rel_path.parent).mkdir(parents=True, exist_ok=True)
 
-        # Skip files that already have a .lab file or a -{3-digit number} suffix
-        numbered_suffix = numbered_suffix_pattern.search(file_stem)
-        lab_file = file_path.with_suffix(".lab")
-
-        if numbered_suffix and lab_file.exists():
-            continue
-
-        if not numbered_suffix and lab_file.with_stem(lab_file.stem + "-001").exists():
-            if file_path.exists():
-                file_path.unlink()
-            continue
-
         audio = AudioSegment.from_file(file_path)
 
         segments, info = model.transcribe(
@@ -119,6 +105,7 @@ def main(
         )
         print("Total len(ms): ", len(audio))
 
+        whole_text = None
         for segment in segments:
             id, start, end, text = (
                 segment.id,
@@ -127,26 +114,24 @@ def main(
                 segment.text,
             )
             print("Segment %03d [%.2fs -> %.2fs] %s" % (id, start, end, text))
-            start_ms = int(start * 1000)
-            end_ms = int(end * 1000) + 200  # add 0.2s avoid truncating
-            segment_audio = audio[start_ms:end_ms]
-            audio_save_path = (
-                save_path / rel_path.parent / f"{file_stem}-{id:03d}{file_suffix}"
-            )
-            segment_audio.export(audio_save_path, format=file_suffix[1:])
-            print(f"Exported {audio_save_path}")
-
-            transcript_save_path = (
-                save_path / rel_path.parent / f"{file_stem}-{id:03d}.lab"
-            )
-            with open(
-                transcript_save_path,
-                "w",
-                encoding="utf-8",
-            ) as f:
-                f.write(segment.text)
-
-        file_path.unlink()
+            if not whole_text:
+                whole_text = text
+            else:
+                whole_text += ", " + text
+
+        whole_text += "."
+
+        audio_save_path = save_path / rel_path.parent / f"{file_stem}{file_suffix}"
+        audio.export(audio_save_path, format=file_suffix[1:])
+        print(f"Exported {audio_save_path}")
+
+        transcript_save_path = save_path / rel_path.parent / f"{file_stem}.lab"
+        with open(
+            transcript_save_path,
+            "w",
+            encoding="utf-8",
+        ) as f:
+            f.write(whole_text)
 
 
 if __name__ == "__main__":