| 123456789101112131415161718192021222324 |
- # cond import
- try:
- from pynvml import nvmlDeviceGetHandleByIndex, nvmlDeviceGetMemoryInfo, nvmlInit
- def print_gpu_utilization():
- nvmlInit()
- handle = nvmlDeviceGetHandleByIndex(0)
- info = nvmlDeviceGetMemoryInfo(handle)
- print(f"GPU memory occupied: {info.used // 1024**2} MB.")
- def print_summary(result):
- print(f"Time: {result.metrics['train_runtime']:.2f}")
- print(f"Samples/second: {result.metrics['train_samples_per_second']:.2f}")
- print_gpu_utilization()
- except ImportError:
- print("pynvml not found. GPU stats will not be printed.")
- def print_summary(result):
- print(f"Time: {result.metrics['train_runtime']:.2f}")
- print(f"Samples/second: {result.metrics['train_samples_per_second']:.2f}")
- def print_gpu_utilization():
- pass
|