Lengyue 2 سال پیش
والد
کامیت
852edb9474
3فایلهای تغییر یافته به همراه16 افزوده شده و 9 حذف شده
  1. 1 1
      fish_speech/configs/vq_diffusion.yaml
  2. 5 1
      fish_speech/datasets/vqgan.py
  3. 10 7
      fish_speech/models/vq_diffusion/lit_module.py

+ 1 - 1
fish_speech/configs/vq_diffusion.yaml

@@ -11,7 +11,7 @@ trainer:
   strategy: ddp_find_unused_parameters_true
   gradient_clip_val: 1.0
   gradient_clip_algorithm: 'norm'
-  precision: 16-mixed
+  precision: bf16-mixed
   max_steps: 300_000
   val_check_interval: 5000
 

+ 5 - 1
fish_speech/datasets/vqgan.py

@@ -26,7 +26,11 @@ class VQGANDataset(Dataset):
         filelist = Path(filelist)
         root = filelist.parent
 
-        self.files = [root / line.strip() for line in filelist.read_text().splitlines()]
+        self.files = [
+            root / line.strip()
+            for line in filelist.read_text().splitlines()
+            if line.strip()
+        ]
         self.sample_rate = sample_rate
         self.hop_length = hop_length
         self.slice_frames = slice_frames

+ 10 - 7
fish_speech/models/vq_diffusion/lit_module.py

@@ -99,8 +99,11 @@ class VQDiffusion(L.LightningModule):
         # Generator and discriminators
         self.mel_transform = mel_transform
         self.feature_mel_transform = feature_mel_transform
-        self.noise_scheduler_train = DDIMScheduler(num_train_timesteps=1000)
-        self.noise_scheduler_infer = UniPCMultistepScheduler(num_train_timesteps=1000)
+        self.noise_scheduler = DDIMScheduler(
+            num_train_timesteps=1000,
+            clip_sample=False,
+            beta_end=0.01,
+        )
 
         # Modules
         self.vq_encoder = vq_encoder
@@ -193,14 +196,14 @@ class VQDiffusion(L.LightningModule):
         # Sample a random timestep for each image
         timesteps = torch.randint(
             0,
-            self.noise_scheduler_train.config.num_train_timesteps,
+            self.noise_scheduler.config.num_train_timesteps,
             (normalized_gt_mels.shape[0],),
             device=normalized_gt_mels.device,
         ).long()
 
         # Add noise to the clean images according to the noise magnitude at each timestep
         # (this is the forward diffusion process)
-        noisy_images = self.noise_scheduler_train.add_noise(
+        noisy_images = self.noise_scheduler.add_noise(
             normalized_gt_mels, noise, timesteps
         )
 
@@ -279,9 +282,9 @@ class VQDiffusion(L.LightningModule):
 
         # Begin sampling
         sampled_mels = torch.randn_like(gt_mels)
-        self.noise_scheduler_infer.set_timesteps(100)
+        self.noise_scheduler.set_timesteps(50)
 
-        for t in tqdm(self.noise_scheduler_infer.timesteps):
+        for t in tqdm(self.noise_scheduler.timesteps):
             timesteps = torch.tensor([t], device=sampled_mels.device, dtype=torch.long)
 
             # 1. predict noise model_output
@@ -290,7 +293,7 @@ class VQDiffusion(L.LightningModule):
             )
 
             # 2. compute previous image: x_t -> x_t-1
-            sampled_mels = self.noise_scheduler_infer.step(
+            sampled_mels = self.noise_scheduler.step(
                 model_output, t, sampled_mels
             ).prev_sample