| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116 |
- import os
- from glob import glob
- from pathlib import Path
- from typing import Union
- from loguru import logger
- AUDIO_EXTENSIONS = {
- ".mp3",
- ".wav",
- ".flac",
- ".ogg",
- ".m4a",
- ".wma",
- ".aac",
- ".aiff",
- ".aif",
- ".aifc",
- }
- def list_files(
- path: Union[Path, str],
- extensions: set[str] = None,
- recursive: bool = False,
- sort: bool = True,
- ) -> list[Path]:
- """List files in a directory.
- Args:
- path (Path): Path to the directory.
- extensions (set, optional): Extensions to filter. Defaults to None.
- recursive (bool, optional): Whether to search recursively. Defaults to False.
- sort (bool, optional): Whether to sort the files. Defaults to True.
- Returns:
- list: List of files.
- """
- if isinstance(path, str):
- path = Path(path)
- if not path.exists():
- raise FileNotFoundError(f"Directory {path} does not exist.")
- files = [Path(f) for f in glob(str(path / "**/*"), recursive=recursive)]
- if extensions is not None:
- files = [f for f in files if f.suffix in extensions]
- if sort:
- files = sorted(files)
- return files
- def get_latest_checkpoint(path: Path | str) -> Path | None:
- # Find the latest checkpoint
- ckpt_dir = Path(path)
- if ckpt_dir.exists() is False:
- return None
- ckpts = sorted(ckpt_dir.glob("*.ckpt"), key=os.path.getmtime)
- if len(ckpts) == 0:
- return None
- return ckpts[-1]
- def load_filelist(path: Path | str) -> list[tuple[Path, str, str, str]]:
- """
- Load a Bert-VITS2 style filelist.
- """
- files = set()
- results = []
- count_duplicated, count_not_found = 0, 0
- LANGUAGE_TO_LANGUAGES = {
- "zh": ["zh", "en"],
- "jp": ["jp", "en"],
- "en": ["en"],
- }
- with open(path, "r", encoding="utf-8") as f:
- for line in f.readlines():
- filename, speaker, language, text = line.strip().split("|")
- file = Path(filename)
- language = language.strip().lower()
- if language == "ja":
- language = "jp"
- assert language in ["zh", "jp", "en"], f"Invalid language {language}"
- languages = LANGUAGE_TO_LANGUAGES[language]
- if file in files:
- logger.warning(f"Duplicated file: {file}")
- count_duplicated += 1
- continue
- if not file.exists():
- logger.warning(f"File not found: {file}")
- count_not_found += 1
- continue
- results.append((file, speaker, languages, text))
- if count_duplicated > 0:
- logger.warning(f"Total duplicated files: {count_duplicated}")
- if count_not_found > 0:
- logger.warning(f"Total files not found: {count_not_found}")
- return results
|