|
@@ -99,8 +99,11 @@ class VQDiffusion(L.LightningModule):
|
|
|
# Generator and discriminators
|
|
# Generator and discriminators
|
|
|
self.mel_transform = mel_transform
|
|
self.mel_transform = mel_transform
|
|
|
self.feature_mel_transform = feature_mel_transform
|
|
self.feature_mel_transform = feature_mel_transform
|
|
|
- self.noise_scheduler_train = DDIMScheduler(num_train_timesteps=1000)
|
|
|
|
|
- self.noise_scheduler_infer = UniPCMultistepScheduler(num_train_timesteps=1000)
|
|
|
|
|
|
|
+ self.noise_scheduler = DDIMScheduler(
|
|
|
|
|
+ num_train_timesteps=1000,
|
|
|
|
|
+ clip_sample=False,
|
|
|
|
|
+ beta_end=0.01,
|
|
|
|
|
+ )
|
|
|
|
|
|
|
|
# Modules
|
|
# Modules
|
|
|
self.vq_encoder = vq_encoder
|
|
self.vq_encoder = vq_encoder
|
|
@@ -193,14 +196,14 @@ class VQDiffusion(L.LightningModule):
|
|
|
# Sample a random timestep for each image
|
|
# Sample a random timestep for each image
|
|
|
timesteps = torch.randint(
|
|
timesteps = torch.randint(
|
|
|
0,
|
|
0,
|
|
|
- self.noise_scheduler_train.config.num_train_timesteps,
|
|
|
|
|
|
|
+ self.noise_scheduler.config.num_train_timesteps,
|
|
|
(normalized_gt_mels.shape[0],),
|
|
(normalized_gt_mels.shape[0],),
|
|
|
device=normalized_gt_mels.device,
|
|
device=normalized_gt_mels.device,
|
|
|
).long()
|
|
).long()
|
|
|
|
|
|
|
|
# Add noise to the clean images according to the noise magnitude at each timestep
|
|
# Add noise to the clean images according to the noise magnitude at each timestep
|
|
|
# (this is the forward diffusion process)
|
|
# (this is the forward diffusion process)
|
|
|
- noisy_images = self.noise_scheduler_train.add_noise(
|
|
|
|
|
|
|
+ noisy_images = self.noise_scheduler.add_noise(
|
|
|
normalized_gt_mels, noise, timesteps
|
|
normalized_gt_mels, noise, timesteps
|
|
|
)
|
|
)
|
|
|
|
|
|
|
@@ -279,9 +282,9 @@ class VQDiffusion(L.LightningModule):
|
|
|
|
|
|
|
|
# Begin sampling
|
|
# Begin sampling
|
|
|
sampled_mels = torch.randn_like(gt_mels)
|
|
sampled_mels = torch.randn_like(gt_mels)
|
|
|
- self.noise_scheduler_infer.set_timesteps(100)
|
|
|
|
|
|
|
+ self.noise_scheduler.set_timesteps(50)
|
|
|
|
|
|
|
|
- for t in tqdm(self.noise_scheduler_infer.timesteps):
|
|
|
|
|
|
|
+ for t in tqdm(self.noise_scheduler.timesteps):
|
|
|
timesteps = torch.tensor([t], device=sampled_mels.device, dtype=torch.long)
|
|
timesteps = torch.tensor([t], device=sampled_mels.device, dtype=torch.long)
|
|
|
|
|
|
|
|
# 1. predict noise model_output
|
|
# 1. predict noise model_output
|
|
@@ -290,7 +293,7 @@ class VQDiffusion(L.LightningModule):
|
|
|
)
|
|
)
|
|
|
|
|
|
|
|
# 2. compute previous image: x_t -> x_t-1
|
|
# 2. compute previous image: x_t -> x_t-1
|
|
|
- sampled_mels = self.noise_scheduler_infer.step(
|
|
|
|
|
|
|
+ sampled_mels = self.noise_scheduler.step(
|
|
|
model_output, t, sampled_mels
|
|
model_output, t, sampled_mels
|
|
|
).prev_sample
|
|
).prev_sample
|
|
|
|
|
|