| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121 |
- from pathlib import Path
- import click
- import hydra
- import numpy as np
- import soundfile as sf
- import torch
- import torchaudio
- from hydra import compose, initialize
- from hydra.utils import instantiate
- from loguru import logger
- from omegaconf import OmegaConf
- from tools.file import AUDIO_EXTENSIONS
- # register eval resolver
- OmegaConf.register_new_resolver("eval", eval)
- def load_model(config_name, checkpoint_path, device="cuda"):
- hydra.core.global_hydra.GlobalHydra.instance().clear()
- with initialize(version_base="1.3", config_path="../../fish_speech/configs"):
- cfg = compose(config_name=config_name)
- model = instantiate(cfg)
- state_dict = torch.load(
- checkpoint_path, map_location=device, mmap=True, weights_only=True
- )
- if "state_dict" in state_dict:
- state_dict = state_dict["state_dict"]
- if any("generator" in k for k in state_dict):
- state_dict = {
- k.replace("generator.", ""): v
- for k, v in state_dict.items()
- if "generator." in k
- }
- result = model.load_state_dict(state_dict, strict=False, assign=True)
- model.eval()
- model.to(device)
- logger.info(f"Loaded model: {result}")
- return model
- @torch.no_grad()
- @click.command()
- @click.option(
- "--input-path",
- "-i",
- default="test.wav",
- type=click.Path(exists=True, path_type=Path),
- )
- @click.option(
- "--output-path", "-o", default="fake.wav", type=click.Path(path_type=Path)
- )
- @click.option("--config-name", default="firefly_gan_vq")
- @click.option(
- "--checkpoint-path",
- default="checkpoints/fish-speech-1.4/firefly-gan-vq-fsq-8x1024-21hz-generator.pth",
- )
- @click.option(
- "--device",
- "-d",
- default="cuda",
- )
- def main(input_path, output_path, config_name, checkpoint_path, device):
- model = load_model(config_name, checkpoint_path, device=device)
- if input_path.suffix in AUDIO_EXTENSIONS:
- logger.info(f"Processing in-place reconstruction of {input_path}")
- # Load audio
- audio, sr = torchaudio.load(str(input_path))
- if audio.shape[0] > 1:
- audio = audio.mean(0, keepdim=True)
- audio = torchaudio.functional.resample(
- audio, sr, model.spec_transform.sample_rate
- )
- audios = audio[None].to(device)
- logger.info(
- f"Loaded audio with {audios.shape[2] / model.spec_transform.sample_rate:.2f} seconds"
- )
- # VQ Encoder
- audio_lengths = torch.tensor([audios.shape[2]], device=device, dtype=torch.long)
- indices = model.encode(audios, audio_lengths)[0][0]
- logger.info(f"Generated indices of shape {indices.shape}")
- # Save indices
- np.save(output_path.with_suffix(".npy"), indices.cpu().numpy())
- elif input_path.suffix == ".npy":
- logger.info(f"Processing precomputed indices from {input_path}")
- indices = np.load(input_path)
- indices = torch.from_numpy(indices).to(device).long()
- assert indices.ndim == 2, f"Expected 2D indices, got {indices.ndim}"
- else:
- raise ValueError(f"Unknown input type: {input_path}")
- # Restore
- feature_lengths = torch.tensor([indices.shape[1]], device=device)
- fake_audios, _ = model.decode(
- indices=indices[None], feature_lengths=feature_lengths
- )
- audio_time = fake_audios.shape[-1] / model.spec_transform.sample_rate
- logger.info(
- f"Generated audio of shape {fake_audios.shape}, equivalent to {audio_time:.2f} seconds from {indices.shape[1]} features, features/second: {indices.shape[1] / audio_time:.2f}"
- )
- # Save audio
- fake_audio = fake_audios[0, 0].float().cpu().numpy()
- sf.write(output_path, fake_audio, model.spec_transform.sample_rate)
- logger.info(f"Saved audio to {output_path}")
- if __name__ == "__main__":
- main()
|