Print GPU device ID with MPI

这个提交包含在:
craig-warren
2023-12-10 00:10:41 +00:00
父节点 8235cb37b3
当前提交 82e80ab12c

查看文件

@@ -130,7 +130,9 @@ class MPIContext(Context):
model_config = config.ModelConfig() model_config = config.ModelConfig()
# Set GPU deviceID according to worker rank # Set GPU deviceID according to worker rank
if config.sim_config.general["solver"] == "cuda": if config.sim_config.general["solver"] == "cuda":
model_config.device = {"dev": config.sim_config.devices["devs"][self.rank - 1], "snapsgpu2cpu": False} model_config.device = {"dev": config.sim_config.devices["devs"][self.rank - 1],
"deviceID": self.rank - 1,
"snapsgpu2cpu": False}
config.model_configs = model_config config.model_configs = model_config
G = create_G() G = create_G()