Kaynağa Gözat

Automatically download models (#219)

* Automatically download models

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Fix

* Ensure mirror enabled

* no_proxy before mirror

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
spicysama 1 yıl önce
ebeveyn
işleme
ff60db4293
2 değiştirilmiş dosya ile 64 ekleme ve 3 silme
  1. 3 3
      start.bat
  2. 61 0
      tools/download_models.py

+ 3 - 3
start.bat

@@ -6,12 +6,12 @@ set PYTHONPATH=%~dp0
 set PYTHON_CMD=%cd%\fishenv\env\python
 set API_FLAG_PATH=%~dp0API_FLAGS.txt
 
-setlocal enabledelayedexpansion
-
-set no_proxy="localhost, 127.0.0.1, 0.0.0.0"
 :: 设置Hugging Face镜像源
+set no_proxy="localhost, 127.0.0.1, 0.0.0.0"
 set HF_ENDPOINT=https://hf-mirror.com
+%PYTHON_CMD% .\tools\download_models.py
 
+setlocal enabledelayedexpansion
 
 set "API_FLAGS="
 set "flags="

+ 61 - 0
tools/download_models.py

@@ -0,0 +1,61 @@
+import os
+
+from huggingface_hub import hf_hub_download
+
+# 要检查和下载的文件列表
+files = [
+    "firefly-gan-base-generator.ckpt",
+    "README.md",
+    "special_tokens_map.json",
+    "text2semantic-sft-large-v1.1-4k.pth",
+    "text2semantic-sft-medium-v1.1-4k.pth",
+    "tokenizer_config.json",
+    "tokenizer.json",
+    "vits_decoder_v1.1.ckpt",
+    "vq-gan-group-fsq-2x1024.pth",
+]
+
+# Hugging Face 仓库信息
+repo_id = "fishaudio/fish-speech-1"
+cache_dir = "./checkpoints"
+
+
+os.makedirs(cache_dir, exist_ok=True)
+
+# 检查每个文件是否存在,如果不存在则从 Hugging Face 仓库下载
+for file in files:
+    file_path = os.path.join(cache_dir, file)
+    if not os.path.exists(file_path):
+        print(f"{file} 不存在,从 Hugging Face 仓库下载...")
+        hf_hub_download(
+            repo_id=repo_id,
+            filename=file,
+            cache_dir=cache_dir,
+            local_dir_use_symlinks=False,
+        )
+    else:
+        print(f"{file} 已存在,跳过下载。")
+
+
+files = [
+    "medium.pt",
+    "small.pt",
+]
+
+# Hugging Face 仓库信息
+repo_id = "SpicyqSama007/fish-speech-packed"
+cache_dir = ".cache/whisper"
+os.makedirs(cache_dir, exist_ok=True)
+
+for file in files:
+    file_path = os.path.join(cache_dir, file)
+    if not os.path.exists(file_path):
+        print(f"{file} 不存在,从 Hugging Face 仓库下载...")
+        hf_hub_download(
+            repo_id=repo_id,
+            filename=file,
+            cache_dir=cache_dir,
+            local_dir_use_symlinks=False,
+        )
+    else:
+        print(f"{file} 已存在,跳过下载。")