gpu_stats.py 833 B

123456789101112131415161718192021222324
  1. # cond import
  2. try:
  3. from pynvml import nvmlDeviceGetHandleByIndex, nvmlDeviceGetMemoryInfo, nvmlInit
  4. def print_gpu_utilization():
  5. nvmlInit()
  6. handle = nvmlDeviceGetHandleByIndex(0)
  7. info = nvmlDeviceGetMemoryInfo(handle)
  8. print(f"GPU memory occupied: {info.used // 1024**2} MB.")
  9. def print_summary(result):
  10. print(f"Time: {result.metrics['train_runtime']:.2f}")
  11. print(f"Samples/second: {result.metrics['train_samples_per_second']:.2f}")
  12. print_gpu_utilization()
  13. except ImportError:
  14. print("pynvml not found. GPU stats will not be printed.")
  15. def print_summary(result):
  16. print(f"Time: {result.metrics['train_runtime']:.2f}")
  17. print(f"Samples/second: {result.metrics['train_samples_per_second']:.2f}")
  18. def print_gpu_utilization():
  19. pass