Lengyue 1 год назад
Родитель
Сommit
8a6d0d7ef7
1 измененных файлов с 3 добавлено и 0 удалено
  1. 3 0
      tools/llama/generate.py

+ 3 - 0
tools/llama/generate.py

@@ -609,11 +609,13 @@ def launch_thread_safe_queue(
     compile=False,
 ):
     input_queue = queue.Queue()
+    init_event = threading.Event()
 
     def worker():
         model, decode_one_token = load_model(
             config_name, checkpoint_path, device, precision, max_length, compile=compile
         )
+        init_event.set()
 
         while True:
             item = input_queue.get()
@@ -637,6 +639,7 @@ def launch_thread_safe_queue(
             event.set()
 
     threading.Thread(target=worker, daemon=True).start()
+    init_event.wait()
 
     return input_queue