|
@@ -774,11 +774,13 @@ def launch_thread_safe_queue(
|
|
|
device,
|
|
device,
|
|
|
precision,
|
|
precision,
|
|
|
compile: bool = False,
|
|
compile: bool = False,
|
|
|
|
|
+ num_workers: int = 1,
|
|
|
):
|
|
):
|
|
|
input_queue = queue.Queue()
|
|
input_queue = queue.Queue()
|
|
|
- init_event = threading.Event()
|
|
|
|
|
|
|
+ init_events = [threading.Event() for _ in range(num_workers)]
|
|
|
|
|
|
|
|
- def worker():
|
|
|
|
|
|
|
+ def worker(worker_id, init_event):
|
|
|
|
|
+ logger.info(f"Worker {worker_id} starting, loading model...")
|
|
|
model, decode_one_token = init_model(
|
|
model, decode_one_token = init_model(
|
|
|
checkpoint_path, device, precision, compile=compile
|
|
checkpoint_path, device, precision, compile=compile
|
|
|
)
|
|
)
|
|
@@ -788,6 +790,7 @@ def launch_thread_safe_queue(
|
|
|
max_seq_len=model.config.max_seq_len,
|
|
max_seq_len=model.config.max_seq_len,
|
|
|
dtype=next(model.parameters()).dtype,
|
|
dtype=next(model.parameters()).dtype,
|
|
|
)
|
|
)
|
|
|
|
|
+ logger.info(f"Worker {worker_id} initialized")
|
|
|
init_event.set()
|
|
init_event.set()
|
|
|
|
|
|
|
|
while True:
|
|
while True:
|
|
@@ -817,9 +820,13 @@ def launch_thread_safe_queue(
|
|
|
if torch.cuda.is_available():
|
|
if torch.cuda.is_available():
|
|
|
torch.cuda.empty_cache()
|
|
torch.cuda.empty_cache()
|
|
|
|
|
|
|
|
- threading.Thread(target=worker, daemon=True).start()
|
|
|
|
|
- init_event.wait()
|
|
|
|
|
|
|
+ for i in range(num_workers):
|
|
|
|
|
+ threading.Thread(target=worker, args=(i, init_events[i]), daemon=True).start()
|
|
|
|
|
|
|
|
|
|
+ for event in init_events:
|
|
|
|
|
+ event.wait()
|
|
|
|
|
+
|
|
|
|
|
+ logger.info(f"All {num_workers} workers initialized successfully")
|
|
|
return input_queue
|
|
return input_queue
|
|
|
|
|
|
|
|
|
|
|