|
@@ -102,8 +102,8 @@ class FishConvNet(nn.Module):
|
|
|
self.conv = weight_norm(self.conv, name=name, dim=dim)
|
|
self.conv = weight_norm(self.conv, name=name, dim=dim)
|
|
|
return self
|
|
return self
|
|
|
|
|
|
|
|
- def remove_weight_norm(self):
|
|
|
|
|
- self.conv = remove_parametrizations(self.conv)
|
|
|
|
|
|
|
+ def remove_parametrizations(self, name="weight"):
|
|
|
|
|
+ self.conv = remove_parametrizations(self.conv, name)
|
|
|
return self
|
|
return self
|
|
|
|
|
|
|
|
|
|
|
|
@@ -128,8 +128,8 @@ class FishTransConvNet(nn.Module):
|
|
|
self.conv = weight_norm(self.conv, name=name, dim=dim)
|
|
self.conv = weight_norm(self.conv, name=name, dim=dim)
|
|
|
return self
|
|
return self
|
|
|
|
|
|
|
|
- def remove_weight_norm(self):
|
|
|
|
|
- self.conv = remove_parametrizations(self.conv)
|
|
|
|
|
|
|
+ def remove_parametrizations(self, name="weight"):
|
|
|
|
|
+ self.conv = remove_parametrizations(self.conv, name)
|
|
|
return self
|
|
return self
|
|
|
|
|
|
|
|
|
|
|
|
@@ -178,9 +178,9 @@ class ResBlock1(torch.nn.Module):
|
|
|
|
|
|
|
|
def remove_parametrizations(self):
|
|
def remove_parametrizations(self):
|
|
|
for conv in self.convs1:
|
|
for conv in self.convs1:
|
|
|
- remove_parametrizations(conv, tensor_name="weight")
|
|
|
|
|
|
|
+ conv.remove_parametrizations()
|
|
|
for conv in self.convs2:
|
|
for conv in self.convs2:
|
|
|
- remove_parametrizations(conv, tensor_name="weight")
|
|
|
|
|
|
|
+ conv.remove_parametrizations()
|
|
|
|
|
|
|
|
|
|
|
|
|
class ParallelBlock(nn.Module):
|
|
class ParallelBlock(nn.Module):
|
|
@@ -288,11 +288,11 @@ class HiFiGANGenerator(nn.Module):
|
|
|
|
|
|
|
|
def remove_parametrizations(self):
|
|
def remove_parametrizations(self):
|
|
|
for up in self.ups:
|
|
for up in self.ups:
|
|
|
- remove_parametrizations(up, tensor_name="weight")
|
|
|
|
|
|
|
+ up.remove_parametrizations()
|
|
|
for block in self.resblocks:
|
|
for block in self.resblocks:
|
|
|
block.remove_parametrizations()
|
|
block.remove_parametrizations()
|
|
|
- remove_parametrizations(self.conv_pre, tensor_name="weight")
|
|
|
|
|
- remove_parametrizations(self.conv_post, tensor_name="weight")
|
|
|
|
|
|
|
+ self.conv_pre.remove_parametrizations()
|
|
|
|
|
+ self.conv_post.remove_parametrizations()
|
|
|
|
|
|
|
|
|
|
|
|
|
# DropPath copied from timm library
|
|
# DropPath copied from timm library
|