diff --git a/python/bitblas/utils/target_detector.py b/python/bitblas/utils/target_detector.py index b92409f40..7dc220f6a 100644 --- a/python/bitblas/utils/target_detector.py +++ b/python/bitblas/utils/target_detector.py @@ -1,6 +1,6 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. - +import os import subprocess from typing import List from thefuzz import process @@ -26,16 +26,24 @@ def get_gpu_model_from_nvidia_smi(gpu_id: int = 0): try: # Execute nvidia-smi command to get the GPU name output = subprocess.check_output( - ["nvidia-smi", f"--id={gpu_id}", "--query-gpu=gpu_name", "--format=csv,noheader"], + ["nvidia-smi", "--query-gpu=gpu_name", "--format=csv,noheader"], encoding="utf-8", ).strip() except subprocess.CalledProcessError as e: logger.info("nvidia-smi failed with error: %s", e) return None - # Return the name of the first GPU if multiple are present - return output.split("\n")[0] + gpus = output.split("\n") + + # for multiple gpus, CUDA_DEVICE_ORDER=PCI_BUS_ID must be set to match nvidia-smi or else wrong + # gpu is returned for gpu_id + if len(gpus) > 0 and os.environ.get("CUDA_DEVICE_ORDER") != "PCI_BUS_ID": + raise EnvironmentError("Multi-gpu environment must set `CUDA_DEVICE_ORDER=PCI_BUS_ID`.") + + if gpu_id >= len(gpus) or gpu_id < 0: + raise ValueError(f"Passed gpu_id:{gpu_id} but there are {len(gpus)} detected Nvidia gpus.") + return gpus[gpu_id] def find_best_match(tags, query): """