Lengyue 2 лет назад
Родитель
Сommit
9b916b1c38

+ 2 - 2
fish_speech/configs/hubert_vq_diffusion.yaml

@@ -58,12 +58,12 @@ model:
     hidden_channels: 192
     hidden_channels: 192
     hidden_channels_ffn: 768
     hidden_channels_ffn: 768
     n_heads: 2
     n_heads: 2
-    n_layers: 4
+    n_layers: 6
     kernel_size: 1
     kernel_size: 1
     dropout: 0.1
     dropout: 0.1
     use_vae: false
     use_vae: false
     gin_channels: 512
     gin_channels: 512
-    speaker_cond_layer: 0
+    speaker_cond_layer: 2
 
 
   vq_encoder:
   vq_encoder:
     _target_: fish_speech.models.vqgan.modules.encoders.VQEncoder
     _target_: fish_speech.models.vqgan.modules.encoders.VQEncoder

+ 3 - 3
fish_speech/models/vqgan/modules/transformer.py

@@ -61,7 +61,7 @@ class RelativePositionTransformer(nn.Module):
             self.norm_layers_2.append(LayerNorm(hidden_channels))
             self.norm_layers_2.append(LayerNorm(hidden_channels))
 
 
         if gin_channels != 0:
         if gin_channels != 0:
-            self.cond = nn.Linear(gin_channels, hidden_channels)
+            self.cond = nn.Conv1d(gin_channels, hidden_channels, 1)
 
 
     def forward(
     def forward(
         self,
         self,
@@ -74,9 +74,9 @@ class RelativePositionTransformer(nn.Module):
         for i in range(self.n_layers):
         for i in range(self.n_layers):
             # TODO consider using other conditioning
             # TODO consider using other conditioning
             # TODO https://github.com/svc-develop-team/so-vits-svc/blob/4.1-Stable/modules/attentions.py#L12
             # TODO https://github.com/svc-develop-team/so-vits-svc/blob/4.1-Stable/modules/attentions.py#L12
-            if i == self.speaker_cond_layer - 1 and g is not None:
+            if i == self.speaker_cond_layer and g is not None:
                 # ! g = torch.detach(g)
                 # ! g = torch.detach(g)
-                x = x + self.cond(g.mT).mT
+                x = x + self.cond(g)
                 x = x * x_mask
                 x = x * x_mask
 
 
             y = self.attn_layers[i](x, x, attn_mask)
             y = self.attn_layers[i](x, x, attn_mask)