Explorar el Código

Fix quantizer norm

Lengyue hace 2 años
padre
commit
d28ba743d3
Se han modificado 1 ficheros con 2 adiciones y 0 borrados
  1. 2 0
      speech_lm/models/hubert_vq.py

+ 2 - 0
speech_lm/models/hubert_vq.py

@@ -46,6 +46,7 @@ class HubertVQ(nn.Module):
             param.requires_grad = True
 
         # Quantization
+        self.quantizer_ln = nn.LayerNorm(self.hubert.config.hidden_size)
         self.quantizer = VectorQuantization(
             codebook_size=codebook_size,
             dim=self.hubert.config.hidden_size,
@@ -164,6 +165,7 @@ class HubertVQ(nn.Module):
         )
 
         # Quantize
+        hidden_states = self.quantizer_ln(hidden_states)
         quantize, _, vq_loss = self.quantizer(hidden_states.transpose(1, 2))
         quantize = quantize.transpose(1, 2)