|
|
@@ -249,7 +249,7 @@ class BaseTransformer(nn.Module):
|
|
|
def embed(self, inp: Tensor, share_codebook_embeddings=True) -> Tensor:
|
|
|
embeds = []
|
|
|
semantic_token_ids_tensor = torch.tensor(
|
|
|
- self.semantic_token_ids, device=inp.device
|
|
|
+ self.semantic_token_ids, device=inp.device, dtype=inp.dtype
|
|
|
)
|
|
|
|
|
|
for i in range(self.config.num_codebooks):
|