|
|
@@ -20,11 +20,11 @@ OmegaConf.register_new_resolver("eval", eval)
|
|
|
@torch.autocast(device_type="cuda", enabled=True)
|
|
|
def main():
|
|
|
with initialize(version_base="1.3", config_path="../fish_speech/configs"):
|
|
|
- cfg = compose(config_name="vqgan")
|
|
|
+ cfg = compose(config_name="vqgan_single_2x")
|
|
|
|
|
|
model: LightningModule = instantiate(cfg.model)
|
|
|
state_dict = torch.load(
|
|
|
- "checkpoints/vqgan/step_000380000.ckpt",
|
|
|
+ "results/vqgan_single_2x/checkpoints/step_000160000.ckpt",
|
|
|
map_location=model.device,
|
|
|
)["state_dict"]
|
|
|
model.load_state_dict(state_dict, strict=True)
|
|
|
@@ -33,7 +33,11 @@ def main():
|
|
|
logger.info("Restored model from checkpoint")
|
|
|
|
|
|
# Load audio
|
|
|
- audio = librosa.load("test.wav", sr=model.sampling_rate, mono=True)[0]
|
|
|
+ audio = librosa.load(
|
|
|
+ "data/StarRail/Chinese/停云/chapter2_1_tingyun_142.wav",
|
|
|
+ sr=model.sampling_rate,
|
|
|
+ mono=True,
|
|
|
+ )[0]
|
|
|
audios = torch.from_numpy(audio).to(model.device)[None, None, :]
|
|
|
logger.info(
|
|
|
f"Loaded audio with {audios.shape[2] / model.sampling_rate:.2f} seconds"
|
|
|
@@ -69,29 +73,49 @@ def main():
|
|
|
print(indices.shape)
|
|
|
|
|
|
# Restore
|
|
|
- indices = np.load("codes_0.npy")
|
|
|
- indices = torch.from_numpy(indices).to(model.device).long()
|
|
|
- indices = indices.unsqueeze(1).unsqueeze(-1)
|
|
|
- mel_lengths = indices.shape[2] * (
|
|
|
- model.downsample.total_strides if model.downsample is not None else 1
|
|
|
+ # indices = np.load("codes_0.npy")
|
|
|
+ # indices = torch.from_numpy(indices).to(model.device).long()
|
|
|
+ # indices = indices.unsqueeze(1).unsqueeze(-1)
|
|
|
+ # mel_lengths = indices.shape[2] * (
|
|
|
+ # model.downsample.total_strides if model.downsample is not None else 1
|
|
|
+ # )
|
|
|
+ # mel_lengths = torch.tensor([mel_lengths], device=model.device, dtype=torch.long)
|
|
|
+ # mel_masks = torch.ones(
|
|
|
+ # (1, 1, mel_lengths), device=model.device, dtype=torch.float32
|
|
|
+ # )
|
|
|
+
|
|
|
+ # print(mel_lengths)
|
|
|
+
|
|
|
+ # Reference speaker
|
|
|
+ ref_audio = librosa.load(
|
|
|
+ "data/StarRail/Chinese/符玄/chapter2_8_fuxuan_104.wav",
|
|
|
+ sr=model.sampling_rate,
|
|
|
+ mono=True,
|
|
|
+ )[0]
|
|
|
+ ref_audios = torch.from_numpy(ref_audio).to(model.device)[None, None, :]
|
|
|
+ ref_audio_lengths = torch.tensor(
|
|
|
+ [ref_audios.shape[2]], device=model.device, dtype=torch.long
|
|
|
)
|
|
|
- mel_lengths = torch.tensor([mel_lengths], device=model.device, dtype=torch.long)
|
|
|
- mel_masks = torch.ones(
|
|
|
- (1, 1, mel_lengths), device=model.device, dtype=torch.float32
|
|
|
- )
|
|
|
-
|
|
|
- print(mel_lengths)
|
|
|
+ ref_mels = model.mel_transform(ref_audios, sample_rate=model.sampling_rate)
|
|
|
+ ref_mel_lengths = ref_audio_lengths // model.hop_length
|
|
|
+ ref_mel_masks = torch.unsqueeze(
|
|
|
+ sequence_mask(ref_mel_lengths, ref_mels.shape[2]), 1
|
|
|
+ ).to(gt_mels.dtype)
|
|
|
+ speaker_features = model.speaker_encoder(ref_mels, ref_mel_masks)
|
|
|
+ # speaker_features = model.speaker_encoder(gt_mels, mel_masks)
|
|
|
|
|
|
+ print("indices", indices.shape)
|
|
|
text_features = model.vq_encoder.decode(indices)
|
|
|
+
|
|
|
logger.info(
|
|
|
f"VQ Encoded, indices: {indices.shape} equivalent to "
|
|
|
- + f"{1/(mel_lengths[0] * model.hop_length / model.sampling_rate / indices.shape[2]):.2f} Hz"
|
|
|
+ + f"{1/(mel_lengths[0] * model.hop_length / model.sampling_rate / indices.shape[1]):.2f} Hz"
|
|
|
)
|
|
|
|
|
|
text_features = F.interpolate(text_features, size=mel_lengths[0], mode="nearest")
|
|
|
|
|
|
# Sample mels
|
|
|
- decoded_mels = model.decoder(text_features, mel_masks)
|
|
|
+ decoded_mels = model.decoder(text_features, mel_masks, g=speaker_features)
|
|
|
fake_audios = model.generator(decoded_mels)
|
|
|
|
|
|
# Save audio
|