|
@@ -37,6 +37,7 @@ class ModelArgs:
|
|
|
# Additional decoding heads
|
|
# Additional decoding heads
|
|
|
codebook_size: int = 160
|
|
codebook_size: int = 160
|
|
|
num_codebooks: int = 4
|
|
num_codebooks: int = 4
|
|
|
|
|
+ num_in_codebooks: Optional[int] = None
|
|
|
codebook_padding_idx: int = 0
|
|
codebook_padding_idx: int = 0
|
|
|
|
|
|
|
|
# Use flash attention
|
|
# Use flash attention
|
|
@@ -55,6 +56,8 @@ class ModelArgs:
|
|
|
hidden_dim = 4 * self.dim
|
|
hidden_dim = 4 * self.dim
|
|
|
n_hidden = int(2 * hidden_dim / 3)
|
|
n_hidden = int(2 * hidden_dim / 3)
|
|
|
self.intermediate_size = find_multiple(n_hidden, 256)
|
|
self.intermediate_size = find_multiple(n_hidden, 256)
|
|
|
|
|
+ if self.num_in_codebooks is None:
|
|
|
|
|
+ self.num_in_codebooks = self.num_codebooks
|
|
|
self.head_dim = self.dim // self.n_head
|
|
self.head_dim = self.dim // self.n_head
|
|
|
|
|
|
|
|
|
|
|
|
@@ -91,7 +94,8 @@ class Transformer(nn.Module):
|
|
|
self.config = config
|
|
self.config = config
|
|
|
|
|
|
|
|
self.embeddings = nn.Embedding(
|
|
self.embeddings = nn.Embedding(
|
|
|
- config.vocab_size + config.codebook_size * config.num_codebooks, config.dim
|
|
|
|
|
|
|
+ config.vocab_size + config.codebook_size * config.num_in_codebooks,
|
|
|
|
|
+ config.dim,
|
|
|
)
|
|
)
|
|
|
self.layers = nn.ModuleList(
|
|
self.layers = nn.ModuleList(
|
|
|
TransformerBlock(config) for _ in range(config.n_layer)
|
|
TransformerBlock(config) for _ in range(config.n_layer)
|
|
@@ -148,11 +152,11 @@ class Transformer(nn.Module):
|
|
|
|
|
|
|
|
def embed(self, x: Tensor) -> Tensor:
|
|
def embed(self, x: Tensor) -> Tensor:
|
|
|
# Here we want to merge the embeddings of the codebooks
|
|
# Here we want to merge the embeddings of the codebooks
|
|
|
- if self.config.num_codebooks == 0:
|
|
|
|
|
|
|
+ if self.config.num_in_codebooks == 0:
|
|
|
return self.embeddings(x[:, 0])
|
|
return self.embeddings(x[:, 0])
|
|
|
|
|
|
|
|
vocab_embeds = [self.embeddings(x[:, 0])]
|
|
vocab_embeds = [self.embeddings(x[:, 0])]
|
|
|
- for i in range(self.config.num_codebooks):
|
|
|
|
|
|
|
+ for i in range(self.config.num_in_codebooks):
|
|
|
emb = self.embeddings(
|
|
emb = self.embeddings(
|
|
|
x[:, i + 1] + i * self.config.codebook_size + self.config.vocab_size
|
|
x[:, i + 1] + i * self.config.codebook_size + self.config.vocab_size
|
|
|
)
|
|
)
|