transformer.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328
  1. import math
  2. import torch
  3. from torch import nn
  4. from torch.nn import functional as F
  5. from fish_speech.models.vqgan.modules.normalization import LayerNorm
  6. from fish_speech.models.vqgan.utils import convert_pad_shape
  7. # TODO add conditioning on language
  8. # TODO check whether we need to stop gradient for speaker embedding
  9. class RelativePositionTransformer(nn.Module):
  10. def __init__(
  11. self,
  12. in_channels: int,
  13. hidden_channels: int,
  14. out_channels: int,
  15. hidden_channels_ffn: int,
  16. n_heads: int,
  17. n_layers: int,
  18. kernel_size=1,
  19. dropout=0.0,
  20. window_size=4,
  21. gin_channels=0,
  22. speaker_cond_layer=0,
  23. ):
  24. super().__init__()
  25. assert (
  26. out_channels == hidden_channels
  27. ), "out_channels must be equal to hidden_channels"
  28. self.n_layers = n_layers
  29. self.speaker_cond_layer = speaker_cond_layer
  30. self.drop = nn.Dropout(dropout)
  31. self.attn_layers = nn.ModuleList()
  32. self.norm_layers_1 = nn.ModuleList()
  33. self.ffn_layers = nn.ModuleList()
  34. self.norm_layers_2 = nn.ModuleList()
  35. for i in range(self.n_layers):
  36. self.attn_layers.append(
  37. MultiHeadAttention(
  38. hidden_channels if i != 0 else in_channels,
  39. hidden_channels,
  40. n_heads,
  41. p_dropout=dropout,
  42. window_size=window_size,
  43. )
  44. )
  45. self.norm_layers_1.append(LayerNorm(hidden_channels))
  46. self.ffn_layers.append(
  47. FFN(
  48. hidden_channels,
  49. hidden_channels,
  50. hidden_channels_ffn,
  51. kernel_size,
  52. p_dropout=dropout,
  53. )
  54. )
  55. self.norm_layers_2.append(LayerNorm(hidden_channels))
  56. if gin_channels != 0:
  57. self.cond = nn.Conv1d(gin_channels, hidden_channels, 1)
  58. def forward(
  59. self,
  60. x: torch.Tensor,
  61. x_mask: torch.Tensor,
  62. g: torch.Tensor = None,
  63. ):
  64. attn_mask = x_mask.unsqueeze(2) * x_mask.unsqueeze(-1)
  65. x = x * x_mask
  66. for i in range(self.n_layers):
  67. # TODO consider using other conditioning
  68. # TODO https://github.com/svc-develop-team/so-vits-svc/blob/4.1-Stable/modules/attentions.py#L12
  69. if i == self.speaker_cond_layer and g is not None:
  70. # ! g = torch.detach(g)
  71. x = x + self.cond(g)
  72. x = x * x_mask
  73. y = self.attn_layers[i](x, x, attn_mask)
  74. y = self.drop(y)
  75. x = self.norm_layers_1[i](x + y)
  76. y = self.ffn_layers[i](x, x_mask)
  77. y = self.drop(y)
  78. x = self.norm_layers_2[i](x + y)
  79. x = x * x_mask
  80. return x
  81. class MultiHeadAttention(nn.Module):
  82. def __init__(
  83. self,
  84. channels,
  85. out_channels,
  86. n_heads,
  87. p_dropout=0.0,
  88. window_size=None,
  89. heads_share=True,
  90. block_length=None,
  91. proximal_bias=False,
  92. proximal_init=False,
  93. ):
  94. super().__init__()
  95. assert channels % n_heads == 0
  96. self.channels = channels
  97. self.out_channels = out_channels
  98. self.n_heads = n_heads
  99. self.p_dropout = p_dropout
  100. self.window_size = window_size
  101. self.heads_share = heads_share
  102. self.block_length = block_length
  103. self.proximal_bias = proximal_bias
  104. self.proximal_init = proximal_init
  105. self.attn = None
  106. self.k_channels = channels // n_heads
  107. self.conv_q = nn.Linear(channels, channels)
  108. self.conv_k = nn.Linear(channels, channels)
  109. self.conv_v = nn.Linear(channels, channels)
  110. self.conv_o = nn.Linear(channels, out_channels)
  111. self.drop = nn.Dropout(p_dropout)
  112. if window_size is not None:
  113. n_heads_rel = 1 if heads_share else n_heads
  114. rel_stddev = self.k_channels**-0.5
  115. self.emb_rel_k = nn.Parameter(
  116. torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels)
  117. * rel_stddev
  118. )
  119. self.emb_rel_v = nn.Parameter(
  120. torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels)
  121. * rel_stddev
  122. )
  123. nn.init.xavier_uniform_(self.conv_q.weight)
  124. nn.init.xavier_uniform_(self.conv_k.weight)
  125. nn.init.xavier_uniform_(self.conv_v.weight)
  126. if proximal_init:
  127. with torch.no_grad():
  128. self.conv_k.weight.copy_(self.conv_q.weight)
  129. self.conv_k.bias.copy_(self.conv_q.bias)
  130. def forward(self, x, c, attn_mask=None):
  131. q = self.conv_q(x.mT).mT
  132. k = self.conv_k(c.mT).mT
  133. v = self.conv_v(c.mT).mT
  134. x, self.attn = self.attention(q, k, v, mask=attn_mask)
  135. x = self.conv_o(x.mT).mT
  136. return x
  137. def attention(self, query, key, value, mask=None):
  138. # reshape [b, d, t] -> [b, n_h, t, d_k]
  139. b, d, t_s, t_t = (*key.size(), query.size(2))
  140. query = query.view(b, self.n_heads, self.k_channels, t_t).mT
  141. key = key.view(b, self.n_heads, self.k_channels, t_s).mT
  142. value = value.view(b, self.n_heads, self.k_channels, t_s).mT
  143. scores = torch.matmul(query / math.sqrt(self.k_channels), key.mT)
  144. if self.window_size is not None:
  145. assert (
  146. t_s == t_t
  147. ), "Relative attention is only available for self-attention."
  148. key_relative_embeddings = self._get_relative_embeddings(self.emb_rel_k, t_s)
  149. rel_logits = self._matmul_with_relative_keys(
  150. query / math.sqrt(self.k_channels), key_relative_embeddings
  151. )
  152. scores_local = self._relative_position_to_absolute_position(rel_logits)
  153. scores = scores + scores_local
  154. if self.proximal_bias:
  155. assert t_s == t_t, "Proximal bias is only available for self-attention."
  156. scores = scores + self._attention_bias_proximal(t_s).to(
  157. device=scores.device, dtype=scores.dtype
  158. )
  159. if mask is not None:
  160. scores = scores.masked_fill(mask == 0, -1e4)
  161. if self.block_length is not None:
  162. assert (
  163. t_s == t_t
  164. ), "Local attention is only available for self-attention."
  165. block_mask = (
  166. torch.ones_like(scores)
  167. .triu(-self.block_length)
  168. .tril(self.block_length)
  169. )
  170. scores = scores.masked_fill(block_mask == 0, -1e4)
  171. p_attn = F.softmax(scores, dim=-1) # [b, n_h, t_t, t_s]
  172. p_attn = self.drop(p_attn)
  173. output = torch.matmul(p_attn, value)
  174. if self.window_size is not None:
  175. relative_weights = self._absolute_position_to_relative_position(p_attn)
  176. value_relative_embeddings = self._get_relative_embeddings(
  177. self.emb_rel_v, t_s
  178. )
  179. output = output + self._matmul_with_relative_values(
  180. relative_weights, value_relative_embeddings
  181. )
  182. output = output.mT.contiguous().view(
  183. b, d, t_t
  184. ) # [b, n_h, t_t, d_k] -> [b, d, t_t]
  185. return output, p_attn
  186. def _matmul_with_relative_values(self, x: torch.Tensor, y: torch.Tensor):
  187. """
  188. x: [b, h, l, m]
  189. y: [h or 1, m, d]
  190. ret: [b, h, l, d]
  191. """
  192. return torch.matmul(x, y.unsqueeze(0))
  193. def _matmul_with_relative_keys(self, x: torch.Tensor, y: torch.Tensor):
  194. """
  195. x: [b, h, l, d]
  196. y: [h or 1, m, d]
  197. ret: [b, h, l, m]
  198. """
  199. return torch.matmul(x, y.unsqueeze(0).mT)
  200. def _get_relative_embeddings(self, relative_embeddings, length):
  201. max_relative_position = 2 * self.window_size + 1
  202. # Pad first before slice to avoid using cond ops.
  203. pad_length = max(length - (self.window_size + 1), 0)
  204. slice_start_position = max((self.window_size + 1) - length, 0)
  205. slice_end_position = slice_start_position + 2 * length - 1
  206. if pad_length > 0:
  207. padded_relative_embeddings = F.pad(
  208. relative_embeddings,
  209. convert_pad_shape([[0, 0], [pad_length, pad_length], [0, 0]]),
  210. )
  211. else:
  212. padded_relative_embeddings = relative_embeddings
  213. used_relative_embeddings = padded_relative_embeddings[
  214. :, slice_start_position:slice_end_position
  215. ]
  216. return used_relative_embeddings
  217. def _relative_position_to_absolute_position(self, x):
  218. """
  219. x: [b, h, l, 2*l-1]
  220. ret: [b, h, l, l]
  221. """
  222. batch, heads, length, _ = x.size()
  223. # Concat columns of pad to shift from relative to absolute indexing.
  224. x = F.pad(x, convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, 1]]))
  225. # Concat extra elements so to add up to shape (len+1, 2*len-1).
  226. x_flat = x.view([batch, heads, length * 2 * length])
  227. x_flat = F.pad(x_flat, convert_pad_shape([[0, 0], [0, 0], [0, length - 1]]))
  228. # Reshape and slice out the padded elements.
  229. x_final = x_flat.view([batch, heads, length + 1, 2 * length - 1])[
  230. :, :, :length, length - 1 :
  231. ]
  232. return x_final
  233. def _absolute_position_to_relative_position(self, x):
  234. """
  235. x: [b, h, l, l]
  236. ret: [b, h, l, 2*l-1]
  237. """
  238. batch, heads, length, _ = x.size()
  239. # pad along column
  240. x = F.pad(x, convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, length - 1]]))
  241. x_flat = x.view([batch, heads, length**2 + length * (length - 1)])
  242. # add 0's in the beginning that will skew the elements after reshape
  243. x_flat = F.pad(x_flat, convert_pad_shape([[0, 0], [0, 0], [length, 0]]))
  244. x_final = x_flat.view([batch, heads, length, 2 * length])[:, :, :, 1:]
  245. return x_final
  246. def _attention_bias_proximal(self, length):
  247. """Bias for self-attention to encourage attention to close positions.
  248. Args:
  249. length: an integer scalar.
  250. Returns:
  251. a Tensor with shape [1, 1, length, length]
  252. """
  253. r = torch.arange(length, dtype=torch.float32)
  254. diff = torch.unsqueeze(r, 0) - torch.unsqueeze(r, 1)
  255. return torch.unsqueeze(torch.unsqueeze(-torch.log1p(torch.abs(diff)), 0), 0)
  256. class FFN(nn.Module):
  257. def __init__(
  258. self,
  259. in_channels,
  260. out_channels,
  261. filter_channels,
  262. kernel_size,
  263. p_dropout=0.0,
  264. causal=False,
  265. ):
  266. super().__init__()
  267. self.kernel_size = kernel_size
  268. self.padding = self._causal_padding if causal else self._same_padding
  269. self.conv_1 = nn.Conv1d(in_channels, filter_channels, kernel_size)
  270. self.conv_2 = nn.Conv1d(filter_channels, out_channels, kernel_size)
  271. self.drop = nn.Dropout(p_dropout)
  272. def forward(self, x, x_mask):
  273. x = self.conv_1(self.padding(x * x_mask))
  274. x = torch.relu(x)
  275. x = self.drop(x)
  276. x = self.conv_2(self.padding(x * x_mask))
  277. return x * x_mask
  278. def _causal_padding(self, x):
  279. if self.kernel_size == 1:
  280. return x
  281. pad_l = self.kernel_size - 1
  282. pad_r = 0
  283. padding = [[0, 0], [0, 0], [pad_l, pad_r]]
  284. x = F.pad(x, convert_pad_shape(padding))
  285. return x
  286. def _same_padding(self, x):
  287. if self.kernel_size == 1:
  288. return x
  289. pad_l = (self.kernel_size - 1) // 2
  290. pad_r = self.kernel_size // 2
  291. padding = [[0, 0], [0, 0], [pad_l, pad_r]]
  292. x = F.pad(x, convert_pad_shape(padding))
  293. return x