file.py 3.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123
  1. import os
  2. from glob import glob
  3. from pathlib import Path
  4. from typing import Union
  5. from loguru import logger
  6. from natsort import natsorted
  7. AUDIO_EXTENSIONS = {
  8. ".mp3",
  9. ".wav",
  10. ".flac",
  11. ".ogg",
  12. ".m4a",
  13. ".wma",
  14. ".aac",
  15. ".aiff",
  16. ".aif",
  17. ".aifc",
  18. }
  19. def list_files(
  20. path: Union[Path, str],
  21. extensions: set[str] = None,
  22. recursive: bool = False,
  23. sort: bool = True,
  24. ) -> list[Path]:
  25. """List files in a directory.
  26. Args:
  27. path (Path): Path to the directory.
  28. extensions (set, optional): Extensions to filter. Defaults to None.
  29. recursive (bool, optional): Whether to search recursively. Defaults to False.
  30. sort (bool, optional): Whether to sort the files. Defaults to True.
  31. Returns:
  32. list: List of files.
  33. """
  34. if isinstance(path, str):
  35. path = Path(path)
  36. if not path.exists():
  37. raise FileNotFoundError(f"Directory {path} does not exist.")
  38. # files = [Path(f) for f in glob(str(path / "**/*"), recursive=recursive)]
  39. files = [file for ext in extensions for file in directory.glob(f"**/*{ext}")]
  40. # if extensions is not None:
  41. # files = [f for f in files if f.suffix in extensions]
  42. if sort:
  43. files = natsorted(files)
  44. return files
  45. def get_latest_checkpoint(path: Path | str) -> Path | None:
  46. # Find the latest checkpoint
  47. ckpt_dir = Path(path)
  48. if ckpt_dir.exists() is False:
  49. return None
  50. ckpts = sorted(ckpt_dir.glob("*.ckpt"), key=os.path.getmtime)
  51. if len(ckpts) == 0:
  52. return None
  53. return ckpts[-1]
  54. def load_filelist(path: Path | str) -> list[tuple[Path, str, str, str]]:
  55. """
  56. Load a Bert-VITS2 style filelist.
  57. """
  58. files = set()
  59. results = []
  60. count_duplicated, count_not_found = 0, 0
  61. LANGUAGE_TO_LANGUAGES = {
  62. "zh": ["zh", "en"],
  63. "jp": ["jp", "en"],
  64. "en": ["en"],
  65. }
  66. with open(path, "r", encoding="utf-8") as f:
  67. for line in f.readlines():
  68. splits = line.strip().split("|", maxsplit=3)
  69. if len(splits) != 4:
  70. logger.warning(f"Invalid line: {line}")
  71. continue
  72. filename, speaker, language, text = splits
  73. file = Path(filename)
  74. language = language.strip().lower()
  75. if language == "ja":
  76. language = "jp"
  77. assert language in ["zh", "jp", "en"], f"Invalid language {language}"
  78. languages = LANGUAGE_TO_LANGUAGES[language]
  79. if file in files:
  80. logger.warning(f"Duplicated file: {file}")
  81. count_duplicated += 1
  82. continue
  83. if not file.exists():
  84. logger.warning(f"File not found: {file}")
  85. count_not_found += 1
  86. continue
  87. results.append((file, speaker, languages, text))
  88. if count_duplicated > 0:
  89. logger.warning(f"Total duplicated files: {count_duplicated}")
  90. if count_not_found > 0:
  91. logger.warning(f"Total files not found: {count_not_found}")
  92. return results