Jelajahi Sumber

Fix broken remove_parameterization (#620)

med1844 1 tahun lalu
induk
melakukan
2cef13f716
1 mengubah file dengan 9 tambahan dan 9 penghapusan
  1. 9 9
      fish_speech/models/vqgan/modules/firefly.py

+ 9 - 9
fish_speech/models/vqgan/modules/firefly.py

@@ -102,8 +102,8 @@ class FishConvNet(nn.Module):
         self.conv = weight_norm(self.conv, name=name, dim=dim)
         return self
 
-    def remove_weight_norm(self):
-        self.conv = remove_parametrizations(self.conv)
+    def remove_parametrizations(self, name="weight"):
+        self.conv = remove_parametrizations(self.conv, name)
         return self
 
 
@@ -128,8 +128,8 @@ class FishTransConvNet(nn.Module):
         self.conv = weight_norm(self.conv, name=name, dim=dim)
         return self
 
-    def remove_weight_norm(self):
-        self.conv = remove_parametrizations(self.conv)
+    def remove_parametrizations(self, name="weight"):
+        self.conv = remove_parametrizations(self.conv, name)
         return self
 
 
@@ -178,9 +178,9 @@ class ResBlock1(torch.nn.Module):
 
     def remove_parametrizations(self):
         for conv in self.convs1:
-            remove_parametrizations(conv, tensor_name="weight")
+            conv.remove_parametrizations()
         for conv in self.convs2:
-            remove_parametrizations(conv, tensor_name="weight")
+            conv.remove_parametrizations()
 
 
 class ParallelBlock(nn.Module):
@@ -288,11 +288,11 @@ class HiFiGANGenerator(nn.Module):
 
     def remove_parametrizations(self):
         for up in self.ups:
-            remove_parametrizations(up, tensor_name="weight")
+            up.remove_parametrizations()
         for block in self.resblocks:
             block.remove_parametrizations()
-        remove_parametrizations(self.conv_pre, tensor_name="weight")
-        remove_parametrizations(self.conv_post, tensor_name="weight")
+        self.conv_pre.remove_parametrizations()
+        self.conv_post.remove_parametrizations()
 
 
 # DropPath copied from timm library