flow.py 8.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297
  1. import torch
  2. from torch import nn
  3. from fish_speech.models.vqgan.modules.modules import WN, Flip
  4. from fish_speech.models.vqgan.modules.normalization import LayerNorm
  5. from fish_speech.models.vqgan.modules.transformer import FFN, MultiHeadAttention
  6. class ResidualCouplingBlock(nn.Module):
  7. def __init__(
  8. self,
  9. channels,
  10. hidden_channels,
  11. kernel_size,
  12. dilation_rate,
  13. n_layers,
  14. n_flows=4,
  15. gin_channels=0,
  16. ):
  17. super().__init__()
  18. self.channels = channels
  19. self.hidden_channels = hidden_channels
  20. self.kernel_size = kernel_size
  21. self.dilation_rate = dilation_rate
  22. self.n_layers = n_layers
  23. self.n_flows = n_flows
  24. self.gin_channels = gin_channels
  25. self.flows = nn.ModuleList()
  26. for i in range(n_flows):
  27. self.flows.append(
  28. ResidualCouplingLayer(
  29. channels,
  30. hidden_channels,
  31. kernel_size,
  32. dilation_rate,
  33. n_layers,
  34. gin_channels=gin_channels,
  35. mean_only=True,
  36. )
  37. )
  38. self.flows.append(Flip())
  39. def forward(self, x, x_mask, g=None, reverse=False):
  40. if not reverse:
  41. for flow in self.flows:
  42. x, _ = flow(x, x_mask, g=g, reverse=reverse)
  43. else:
  44. for flow in reversed(self.flows):
  45. x = flow(x, x_mask, g=g, reverse=reverse)
  46. return x
  47. class ResidualCouplingLayer(nn.Module):
  48. def __init__(
  49. self,
  50. channels,
  51. hidden_channels,
  52. kernel_size,
  53. dilation_rate,
  54. n_layers,
  55. p_dropout=0,
  56. gin_channels=0,
  57. mean_only=False,
  58. ):
  59. assert channels % 2 == 0, "channels should be divisible by 2"
  60. super().__init__()
  61. self.channels = channels
  62. self.hidden_channels = hidden_channels
  63. self.kernel_size = kernel_size
  64. self.dilation_rate = dilation_rate
  65. self.n_layers = n_layers
  66. self.half_channels = channels // 2
  67. self.mean_only = mean_only
  68. self.pre = nn.Conv1d(self.half_channels, hidden_channels, 1)
  69. self.enc = WN(
  70. hidden_channels,
  71. kernel_size,
  72. dilation_rate,
  73. n_layers,
  74. p_dropout=p_dropout,
  75. gin_channels=gin_channels,
  76. )
  77. self.post = nn.Conv1d(hidden_channels, self.half_channels * (2 - mean_only), 1)
  78. self.post.weight.data.zero_()
  79. self.post.bias.data.zero_()
  80. def forward(self, x, x_mask, g=None, reverse=False):
  81. x0, x1 = torch.split(x, [self.half_channels] * 2, 1)
  82. h = self.pre(x0) * x_mask
  83. h = self.enc(h, x_mask, g=g)
  84. stats = self.post(h) * x_mask
  85. if not self.mean_only:
  86. m, logs = torch.split(stats, [self.half_channels] * 2, 1)
  87. else:
  88. m = stats
  89. logs = torch.zeros_like(m)
  90. if not reverse:
  91. x1 = m + x1 * torch.exp(logs) * x_mask
  92. x = torch.cat([x0, x1], 1)
  93. logdet = torch.sum(logs, [1, 2])
  94. return x, logdet
  95. else:
  96. x1 = (x1 - m) * torch.exp(-logs) * x_mask
  97. x = torch.cat([x0, x1], 1)
  98. return x
  99. class TransformerCouplingBlock(nn.Module):
  100. def __init__(
  101. self,
  102. channels,
  103. hidden_channels,
  104. filter_channels,
  105. n_heads,
  106. n_layers,
  107. kernel_size,
  108. p_dropout,
  109. n_flows=4,
  110. gin_channels=0,
  111. ):
  112. super().__init__()
  113. self.channels = channels
  114. self.hidden_channels = hidden_channels
  115. self.kernel_size = kernel_size
  116. self.n_layers = n_layers
  117. self.n_flows = n_flows
  118. self.gin_channels = gin_channels
  119. self.flows = nn.ModuleList()
  120. for i in range(n_flows):
  121. self.flows.append(
  122. TransformerCouplingLayer(
  123. channels,
  124. hidden_channels,
  125. kernel_size,
  126. n_layers,
  127. n_heads,
  128. p_dropout,
  129. filter_channels,
  130. mean_only=True,
  131. gin_channels=self.gin_channels,
  132. )
  133. )
  134. self.flows.append(Flip())
  135. def forward(self, x, x_mask, g=None, reverse=False):
  136. if not reverse:
  137. for flow in self.flows:
  138. x, _ = flow(x, x_mask, g=g, reverse=reverse)
  139. else:
  140. for flow in reversed(self.flows):
  141. x = flow(x, x_mask, g=g, reverse=reverse)
  142. return x
  143. class TransformerCouplingLayer(nn.Module):
  144. def __init__(
  145. self,
  146. channels,
  147. hidden_channels,
  148. kernel_size,
  149. n_layers,
  150. n_heads,
  151. p_dropout=0,
  152. filter_channels=0,
  153. mean_only=False,
  154. gin_channels=0,
  155. ):
  156. super().__init__()
  157. assert channels % 2 == 0, "channels should be divisible by 2"
  158. self.channels = channels
  159. self.hidden_channels = hidden_channels
  160. self.kernel_size = kernel_size
  161. self.n_layers = n_layers
  162. self.half_channels = channels // 2
  163. self.mean_only = mean_only
  164. self.pre = nn.Conv1d(self.half_channels, hidden_channels, 1)
  165. self.enc = Encoder(
  166. hidden_channels,
  167. filter_channels,
  168. n_heads,
  169. n_layers,
  170. kernel_size,
  171. p_dropout,
  172. gin_channels=gin_channels,
  173. )
  174. self.post = nn.Conv1d(hidden_channels, self.half_channels * (2 - mean_only), 1)
  175. self.post.weight.data.zero_()
  176. self.post.bias.data.zero_()
  177. def forward(self, x, x_mask, g=None, reverse=False):
  178. x0, x1 = torch.split(x, [self.half_channels] * 2, 1)
  179. h = self.pre(x0) * x_mask
  180. h = self.enc(h, x_mask, g=g)
  181. stats = self.post(h) * x_mask
  182. if not self.mean_only:
  183. m, logs = torch.split(stats, [self.half_channels] * 2, 1)
  184. else:
  185. m = stats
  186. logs = torch.zeros_like(m)
  187. if not reverse:
  188. x1 = m + x1 * torch.exp(logs) * x_mask
  189. x = torch.cat([x0, x1], 1)
  190. logdet = torch.sum(logs, [1, 2])
  191. return x, logdet
  192. else:
  193. x1 = (x1 - m) * torch.exp(-logs) * x_mask
  194. x = torch.cat([x0, x1], 1)
  195. return x
  196. class Encoder(nn.Module):
  197. def __init__(
  198. self,
  199. hidden_channels,
  200. filter_channels,
  201. n_heads,
  202. n_layers,
  203. kernel_size=1,
  204. p_dropout=0.0,
  205. window_size=4,
  206. gin_channels=512,
  207. cond_layer_idx=2,
  208. ):
  209. super().__init__()
  210. self.hidden_channels = hidden_channels
  211. self.filter_channels = filter_channels
  212. self.n_heads = n_heads
  213. self.n_layers = n_layers
  214. self.kernel_size = kernel_size
  215. self.p_dropout = p_dropout
  216. self.window_size = window_size
  217. self.spk_emb_linear = nn.Linear(gin_channels, self.hidden_channels)
  218. self.cond_layer_idx = cond_layer_idx
  219. assert (
  220. self.cond_layer_idx < self.n_layers
  221. ), "cond_layer_idx should be less than n_layers"
  222. self.drop = nn.Dropout(p_dropout)
  223. self.attn_layers = nn.ModuleList()
  224. self.norm_layers_1 = nn.ModuleList()
  225. self.ffn_layers = nn.ModuleList()
  226. self.norm_layers_2 = nn.ModuleList()
  227. for i in range(self.n_layers):
  228. self.attn_layers.append(
  229. MultiHeadAttention(
  230. hidden_channels,
  231. hidden_channels,
  232. n_heads,
  233. p_dropout=p_dropout,
  234. window_size=window_size,
  235. )
  236. )
  237. self.norm_layers_1.append(LayerNorm(hidden_channels))
  238. self.ffn_layers.append(
  239. FFN(
  240. hidden_channels,
  241. hidden_channels,
  242. filter_channels,
  243. kernel_size,
  244. p_dropout=p_dropout,
  245. )
  246. )
  247. self.norm_layers_2.append(LayerNorm(hidden_channels))
  248. def forward(self, x, x_mask, g=None):
  249. attn_mask = x_mask.unsqueeze(2) * x_mask.unsqueeze(-1)
  250. x = x * x_mask
  251. for i in range(self.n_layers):
  252. if i == self.cond_layer_idx and g is not None:
  253. g = self.spk_emb_linear(g.transpose(1, 2))
  254. g = g.transpose(1, 2)
  255. x = x + g
  256. x = x * x_mask
  257. y = self.attn_layers[i](x, x, attn_mask)
  258. y = self.drop(y)
  259. x = self.norm_layers_1[i](x + y)
  260. y = self.ffn_layers[i](x, x_mask)
  261. y = self.drop(y)
  262. x = self.norm_layers_2[i](x + y)
  263. x = x * x_mask
  264. return x