Ver Fonte

Fix torch compile

Lengyue há 2 anos atrás
pai
commit
a588f458f1
1 ficheiros alterados com 3 adições e 4 exclusões
  1. 3 4
      tools/llama/generate.py

+ 3 - 4
tools/llama/generate.py

@@ -123,10 +123,9 @@ def decode_one_token(
             )
 
     codebooks = torch.stack(codebooks, dim=0)
-    if codebooks[0] == 2:
-        codebooks[1] = 1
-    else:
-        codebooks[1] = codebooks[0] - 32311 + 2
+    codebooks[1] = torch.where(
+        codebooks[0] <= 32311, codebooks[0], codebooks[0] - 32311 + 2
+    )
 
     return codebooks