|
|
@@ -510,6 +510,11 @@ if __name__ == "__main__":
|
|
|
args = parse_args()
|
|
|
args.precision = torch.half if args.half else torch.bfloat16
|
|
|
|
|
|
+ # Check if CUDA is available
|
|
|
+ if not torch.cuda.is_available():
|
|
|
+ logger.info("CUDA is not available, running on CPU.")
|
|
|
+ args.device = "cpu"
|
|
|
+
|
|
|
logger.info("Loading Llama model...")
|
|
|
llama_queue = launch_thread_safe_queue(
|
|
|
checkpoint_path=args.llama_checkpoint_path,
|