Skip to content

Commit 1f6efce

Browse files
authored
tune env command output (#3570)
Signed-off-by: Matrix Yao <[email protected]>
1 parent 9fa97f9 commit 1f6efce

File tree

1 file changed

+22
-10
lines changed

1 file changed

+22
-10
lines changed

src/accelerate/commands/env.py

Lines changed: 22 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,20 @@ def env_command(args):
5353
pt_musa_available = is_musa_available()
5454
pt_npu_available = is_npu_available()
5555

56+
accelerator = "N/A"
57+
if pt_cuda_available:
58+
accelerator = "CUDA"
59+
elif pt_xpu_available:
60+
accelerator = "XPU"
61+
elif pt_mlu_available:
62+
accelerator = "MLU"
63+
elif pt_sdaa_available:
64+
accelerator = "SDAA"
65+
elif pt_musa_available:
66+
accelerator = "MUSA"
67+
elif pt_npu_available:
68+
accelerator = "NPU"
69+
5670
accelerate_config = "Not found"
5771
# Get the default from the config file.
5872
if args.config_file is not None or os.path.isfile(default_config_file):
@@ -73,23 +87,21 @@ def env_command(args):
7387
"`accelerate` bash location": bash_location,
7488
"Python version": platform.python_version(),
7589
"Numpy version": np.__version__,
76-
"PyTorch version (GPU?)": f"{pt_version} ({pt_cuda_available})",
77-
"PyTorch XPU available": str(pt_xpu_available),
78-
"PyTorch NPU available": str(pt_npu_available),
79-
"PyTorch MLU available": str(pt_mlu_available),
80-
"PyTorch SDAA available": str(pt_sdaa_available),
81-
"PyTorch MUSA available": str(pt_musa_available),
90+
"PyTorch version": f"{pt_version}",
91+
"PyTorch accelerator": accelerator,
8292
"System RAM": f"{psutil.virtual_memory().total / 1024**3:.2f} GB",
8393
}
8494
if pt_cuda_available:
8595
info["GPU type"] = torch.cuda.get_device_name()
86-
if pt_mlu_available:
96+
elif pt_xpu_available:
97+
info["XPU type"] = torch.xpu.get_device_name()
98+
elif pt_mlu_available:
8799
info["MLU type"] = torch.mlu.get_device_name()
88-
if pt_sdaa_available:
100+
elif pt_sdaa_available:
89101
info["SDAA type"] = torch.sdaa.get_device_name()
90-
if pt_musa_available:
102+
elif pt_musa_available:
91103
info["MUSA type"] = torch.musa.get_device_name()
92-
if pt_npu_available:
104+
elif pt_npu_available:
93105
info["CANN version"] = torch.version.cann
94106

95107
print("\nCopy-and-paste the text below in your GitHub issue\n")

0 commit comments

Comments
 (0)