Kaynağa Gözat

Fix torch compile

Lengyue 2 yıl önce
ebeveyn
işleme
a588f458f1
1 değiştirilmiş dosya ile 3 ekleme ve 4 silme
  1. 3 4
      tools/llama/generate.py

+ 3 - 4
tools/llama/generate.py

@@ -123,10 +123,9 @@ def decode_one_token(
             )
 
     codebooks = torch.stack(codebooks, dim=0)
-    if codebooks[0] == 2:
-        codebooks[1] = 1
-    else:
-        codebooks[1] = codebooks[0] - 32311 + 2
+    codebooks[1] = torch.where(
+        codebooks[0] <= 32311, codebooks[0], codebooks[0] - 32311 + 2
+    )
 
     return codebooks