unet1d.py 6.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198
  1. # Refer to https://github.com/huawei-noah/Speech-Backbones/blob/main/Grad-TTS/model/diffusion.py
  2. import math
  3. import torch
  4. from einops import rearrange
  5. from torch import nn
  6. class Block(nn.Module):
  7. def __init__(self, dim, dim_out, groups=8):
  8. super().__init__()
  9. self.block = nn.Sequential(
  10. nn.Conv2d(dim, dim_out, 3, padding=1),
  11. nn.GroupNorm(groups, dim_out),
  12. nn.Mish(),
  13. )
  14. def forward(self, x, mask):
  15. output = self.block(x * mask)
  16. return output * mask
  17. class ResnetBlock(nn.Module):
  18. def __init__(self, dim, dim_out, time_emb_dim, groups=8):
  19. super().__init__()
  20. self.mlp = nn.Sequential(nn.Mish(), nn.Linear(time_emb_dim, dim_out))
  21. self.block1 = Block(dim, dim_out, groups=groups)
  22. self.block2 = Block(dim_out, dim_out, groups=groups)
  23. if dim != dim_out:
  24. self.res_conv = nn.Conv2d(dim, dim_out, 1)
  25. else:
  26. self.res_conv = nn.Identity()
  27. def forward(self, x, mask, time_emb):
  28. h = self.block1(x, mask)
  29. h += self.mlp(time_emb).unsqueeze(-1).unsqueeze(-1)
  30. h = self.block2(h, mask)
  31. output = h + self.res_conv(x * mask)
  32. return output
  33. class LinearAttention(nn.Module):
  34. def __init__(self, dim, heads=4, dim_head=32, init_values=1e-5):
  35. super().__init__()
  36. self.heads = heads
  37. hidden_dim = dim_head * heads
  38. self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False)
  39. self.to_out = nn.Conv2d(hidden_dim, dim, 1)
  40. self.gamma = nn.Parameter(torch.ones(dim) * init_values)
  41. def forward(self, x):
  42. b, c, h, w = x.shape
  43. qkv = self.to_qkv(x)
  44. q, k, v = rearrange(
  45. qkv, "b (qkv heads c) h w -> qkv b heads c (h w)", heads=self.heads, qkv=3
  46. )
  47. k = k.softmax(dim=-1)
  48. context = torch.einsum("bhdn,bhen->bhde", k, v)
  49. out = torch.einsum("bhde,bhdn->bhen", context, q)
  50. out = rearrange(
  51. out, "b heads c (h w) -> b (heads c) h w", heads=self.heads, h=h, w=w
  52. )
  53. return self.to_out(out) * self.gamma.view(1, -1, 1, 1) + x
  54. class SinusoidalPosEmb(nn.Module):
  55. def __init__(self, dim):
  56. super().__init__()
  57. self.dim = dim
  58. def forward(self, x, scale=1000):
  59. device = x.device
  60. half_dim = self.dim // 2
  61. emb = math.log(10000) / (half_dim - 1)
  62. emb = torch.exp(torch.arange(half_dim, device=device).float() * -emb)
  63. emb = scale * x.unsqueeze(1) * emb.unsqueeze(0)
  64. emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
  65. return emb
  66. class Unet1DDenoiser(nn.Module):
  67. def __init__(
  68. self,
  69. dim,
  70. dim_mults=(1, 2, 4),
  71. groups=8,
  72. pe_scale=1000,
  73. ):
  74. super().__init__()
  75. self.dim = dim
  76. self.dim_mults = dim_mults
  77. self.groups = groups
  78. self.pe_scale = pe_scale
  79. self.time_pos_emb = SinusoidalPosEmb(dim)
  80. self.mlp = nn.Sequential(
  81. nn.Linear(dim, dim * 4), nn.Mish(), nn.Linear(dim * 4, dim)
  82. )
  83. self.downsample_rate = 2 ** (len(dim_mults) - 1)
  84. dims = [2, *map(lambda m: dim * m, dim_mults)]
  85. in_out = list(zip(dims[:-1], dims[1:]))
  86. self.downs = nn.ModuleList([])
  87. self.ups = nn.ModuleList([])
  88. num_resolutions = len(in_out)
  89. for ind, (dim_in, dim_out) in enumerate(in_out):
  90. is_last = ind >= (num_resolutions - 1)
  91. self.downs.append(
  92. nn.ModuleList(
  93. [
  94. ResnetBlock(dim_in, dim_out, time_emb_dim=dim),
  95. ResnetBlock(dim_out, dim_out, time_emb_dim=dim),
  96. LinearAttention(dim_out),
  97. nn.Conv2d(dim_out, dim_out, 3, 2, 1)
  98. if not is_last
  99. else nn.Identity(),
  100. ]
  101. )
  102. )
  103. mid_dim = dims[-1]
  104. self.mid_block1 = ResnetBlock(mid_dim, mid_dim, time_emb_dim=dim)
  105. self.mid_attn = LinearAttention(mid_dim)
  106. self.mid_block2 = ResnetBlock(mid_dim, mid_dim, time_emb_dim=dim)
  107. for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])):
  108. self.ups.append(
  109. nn.ModuleList(
  110. [
  111. ResnetBlock(dim_out * 2, dim_in, time_emb_dim=dim),
  112. ResnetBlock(dim_in, dim_in, time_emb_dim=dim),
  113. LinearAttention(dim_in),
  114. nn.ConvTranspose2d(dim_in, dim_in, 4, 2, 1),
  115. ]
  116. )
  117. )
  118. self.final_block = Block(dim, dim)
  119. self.final_conv = nn.Conv2d(dim, 1, 1)
  120. def forward(self, x, t, mask, condition):
  121. t = self.time_pos_emb(t, scale=self.pe_scale)
  122. t = self.mlp(t)
  123. x = torch.stack([condition, x], 1)
  124. mask = mask.unsqueeze(1)
  125. original_len = x.shape[3]
  126. if x.shape[3] % self.downsample_rate != 0:
  127. x = nn.functional.pad(
  128. x, (0, self.downsample_rate - x.shape[3] % self.downsample_rate)
  129. )
  130. mask = nn.functional.pad(
  131. mask, (0, self.downsample_rate - mask.shape[3] % self.downsample_rate)
  132. )
  133. hiddens = []
  134. masks = [mask]
  135. for resnet1, resnet2, attn, downsample in self.downs:
  136. mask_down = masks[-1]
  137. x = resnet1(x, mask_down, t)
  138. x = resnet2(x, mask_down, t)
  139. x = attn(x)
  140. hiddens.append(x)
  141. x = downsample(x * mask_down)
  142. masks.append(mask_down[:, :, :, ::2])
  143. masks = masks[:-1]
  144. mask_mid = masks[-1]
  145. x = self.mid_block1(x, mask_mid, t)
  146. x = self.mid_attn(x)
  147. x = self.mid_block2(x, mask_mid, t)
  148. for resnet1, resnet2, attn, upsample in self.ups:
  149. mask_up = masks.pop()
  150. x = torch.cat((x, hiddens.pop()), dim=1)
  151. x = resnet1(x, mask_up, t)
  152. x = resnet2(x, mask_up, t)
  153. x = attn(x)
  154. x = upsample(x * mask_up)
  155. x = self.final_block(x, mask)
  156. output = self.final_conv(x * mask)
  157. output = (output * mask).squeeze(1)
  158. return output[:, :, :original_len]
  159. if __name__ == "__main__":
  160. model = Unet1DDenoiser(128)
  161. mel = torch.randn(1, 128, 99)
  162. mask = torch.ones(1, 1, 99)
  163. print(model(mel, mask, torch.tensor([10], dtype=torch.long), mel).shape)