Updated method for detecting available GPUs.

这个提交包含在:
Craig Warren
2019-03-11 14:30:05 +00:00
父节点 774ee78ce3
当前提交 4c9999505e

查看文件

@@ -370,7 +370,7 @@ def detect_check_gpus(deviceIDs):
"""Get information about Nvidia GPU(s). """Get information about Nvidia GPU(s).
Args: Args:
deviceIDs (list): List of device IDs. deviceIDs (list): List of integers of device IDs.
Returns: Returns:
gpus (list): Detected GPU(s) object(s). gpus (list): Detected GPU(s) object(s).
@@ -387,8 +387,11 @@ def detect_check_gpus(deviceIDs):
raise GeneralError('No NVIDIA CUDA-Enabled GPUs detected (https://developer.nvidia.com/cuda-gpus)') raise GeneralError('No NVIDIA CUDA-Enabled GPUs detected (https://developer.nvidia.com/cuda-gpus)')
# Get list of available GPU device IDs # Get list of available GPU device IDs
deviceIDsavail = os.environ.get('CUDA_VISIBLE_DEVICES') if 'CUDA_VISIBLE_DEVICES' in os.environ:
deviceIDsavail = [int(s) for s in deviceIDsavail.split(',')] deviceIDsavail = os.environ.get('CUDA_VISIBLE_DEVICES')
deviceIDsavail = [int(s) for s in deviceIDsavail.split(',')]
else:
deviceIDsavail = range(drv.Device.count())
# Print information about all detected GPUs # Print information about all detected GPUs
gpus = [] gpus = []