|
@@ -502,9 +502,10 @@ class FireflyBase(nn.Module):
|
|
|
if ckpt_path is not None:
|
|
if ckpt_path is not None:
|
|
|
self.load_state_dict(torch.load(ckpt_path, map_location="cpu"))
|
|
self.load_state_dict(torch.load(ckpt_path, map_location="cpu"))
|
|
|
elif pretrained:
|
|
elif pretrained:
|
|
|
- state_dict = torch.load(
|
|
|
|
|
- "checkpoints/firefly-gan-base-generator.ckpt",
|
|
|
|
|
|
|
+ state_dict = torch.hub.load_state_dict_from_url(
|
|
|
|
|
+ "https://github.com/fishaudio/vocoder/releases/download/1.0.0/firefly-gan-base-generator.ckpt",
|
|
|
map_location="cpu",
|
|
map_location="cpu",
|
|
|
|
|
+ model_dir="checkpoints",
|
|
|
)
|
|
)
|
|
|
|
|
|
|
|
if "state_dict" in state_dict:
|
|
if "state_dict" in state_dict:
|