Просмотр исходного кода

fix a dtype mismatch when use mps (#790)

* fix a dtype mismatch when use mps

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
heiway 1 год назад
Родитель
Сommit
d76b9174d1
1 измененных файлов с 1 добавлено и 1 удалено
  1. 1 1
      fish_speech/models/text2semantic/llama.py

+ 1 - 1
fish_speech/models/text2semantic/llama.py

@@ -249,7 +249,7 @@ class BaseTransformer(nn.Module):
     def embed(self, inp: Tensor, share_codebook_embeddings=True) -> Tensor:
         embeds = []
         semantic_token_ids_tensor = torch.tensor(
-            self.semantic_token_ids, device=inp.device
+            self.semantic_token_ids, device=inp.device, dtype=inp.dtype
         )
 
         for i in range(self.config.num_codebooks):