|
|
@@ -609,11 +609,13 @@ def launch_thread_safe_queue(
|
|
|
compile=False,
|
|
|
):
|
|
|
input_queue = queue.Queue()
|
|
|
+ init_event = threading.Event()
|
|
|
|
|
|
def worker():
|
|
|
model, decode_one_token = load_model(
|
|
|
config_name, checkpoint_path, device, precision, max_length, compile=compile
|
|
|
)
|
|
|
+ init_event.set()
|
|
|
|
|
|
while True:
|
|
|
item = input_queue.get()
|
|
|
@@ -637,6 +639,7 @@ def launch_thread_safe_queue(
|
|
|
event.set()
|
|
|
|
|
|
threading.Thread(target=worker, daemon=True).start()
|
|
|
+ init_event.wait()
|
|
|
|
|
|
return input_queue
|
|
|
|