Przeglądaj źródła

Optimize drifiting

Lengyue 2 lat temu
rodzic
commit
4a0ec974d1
1 zmienionych plików z 2 dodań i 2 usunięć
  1. 2 2
      tools/llama/generate.py

+ 2 - 2
tools/llama/generate.py

@@ -540,9 +540,9 @@ def main(
             if i != 0 and i % 2 == 0:
             if i != 0 and i % 2 == 0:
                 i -= 1
                 i -= 1
 
 
-            # Rotate the list
+            # Rotate the list, always make sure first segment is included to avoid drift
             if i < len(global_encoded) - 2:
             if i < len(global_encoded) - 2:
-                partial_encoded = global_encoded[-i:]
+                partial_encoded = global_encoded[:2] + global_encoded[-i:]
             else:
             else:
                 partial_encoded = global_encoded
                 partial_encoded = global_encoded