firefly.py 1.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263
  1. import torch
  2. from torch import nn
  3. from .convnext import ConvNeXtEncoder
  4. from .hifigan import HiFiGANGenerator
  5. class FireflyBase(nn.Module):
  6. def __init__(self, ckpt_path: str = None):
  7. super().__init__()
  8. self.backbone = ConvNeXtEncoder(
  9. input_channels=160,
  10. depths=[3, 3, 9, 3],
  11. dims=[128, 256, 384, 512],
  12. drop_path_rate=0.2,
  13. kernel_sizes=[7],
  14. )
  15. self.head = HiFiGANGenerator(
  16. hop_length=512,
  17. upsample_rates=[8, 8, 2, 2, 2],
  18. upsample_kernel_sizes=[16, 16, 4, 4, 4],
  19. resblock_kernel_sizes=[3, 7, 11],
  20. resblock_dilation_sizes=[[1, 3, 5], [1, 3, 5], [1, 3, 5]],
  21. num_mels=512,
  22. upsample_initial_channel=512,
  23. use_template=True,
  24. pre_conv_kernel_size=13,
  25. post_conv_kernel_size=13,
  26. )
  27. if ckpt_path is None:
  28. return
  29. state_dict = torch.load(ckpt_path, map_location="cpu")
  30. if "state_dict" in state_dict:
  31. state_dict = state_dict["state_dict"]
  32. if any("generator." in k for k in state_dict):
  33. state_dict = {
  34. k.replace("generator.", ""): v
  35. for k, v in state_dict.items()
  36. if "generator." in k
  37. }
  38. self.load_state_dict(state_dict, strict=True)
  39. def encode(self, x: torch.Tensor) -> torch.Tensor:
  40. x = self.backbone(x)
  41. return x
  42. def decode(self, x: torch.Tensor) -> torch.Tensor:
  43. x = self.head(x)
  44. if x.ndim == 2:
  45. x = x[:, None, :]
  46. return x
  47. def forward(self, x: torch.Tensor) -> torch.Tensor:
  48. x = self.encode(x)
  49. x = self.decode(x)
  50. return x