Begin creating MPI context

这个提交包含在:
nmannall
2024-03-04 17:44:12 +00:00
父节点 b706580e99
当前提交 b4d5e11dc9
共有 2 个文件被更改,包括 33 次插入10 次删除

查看文件

@@ -38,15 +38,17 @@ logger = logging.getLogger(__name__)
class Context:
"""Standard context - models are run one after another and each model
can exploit parallelisation using either OpenMP (CPU), CUDA (GPU), or
OpenCL (CPU/GPU).
"""Standard context for building and running models.
Models are run one after another and each model can exploit
parallelisation using either OpenMP (CPU), CUDA (GPU), or OpenCL
(CPU/GPU).
"""
def __init__(self):
self.model_range = range(config.sim_config.model_start, config.sim_config.model_end)
self.tsimend = 0
self.tsimstart = 0
self.sim_start_time = 0
self.sim_end_time = 0
def run(self):
"""Run the simulation in the correct context.
@@ -55,7 +57,7 @@ class Context:
results: dict that can contain useful results/data from simulation.
"""
self.tsimstart = timer()
self.sim_start_time = timer()
self.print_logo_copyright()
print_host_info(config.sim_config.hostinfo)
if config.sim_config.general["solver"] == "cuda":
@@ -94,7 +96,7 @@ class Context:
gc.collect()
self.tsimend = timer()
self.sim_end_time = timer()
self.print_sim_time_taken()
return {}
@@ -108,11 +110,27 @@ class Context:
"""Prints the total simulation time based on context."""
s = (
f"\n=== Simulation completed in "
f"{humanize.precisedelta(datetime.timedelta(seconds=self.tsimend - self.tsimstart), format='%0.4f')}"
f"{humanize.precisedelta(datetime.timedelta(seconds=self.sim_end_time - self.sim_start_time), format='%0.4f')}"
)
logger.basic(f"{s} {'=' * (get_terminal_width() - 1 - len(s))}\n")
class MPIContext(Context):
def __init__(self):
super().__init__()
from mpi4py import MPI
self.comm = MPI.COMM_WORLD
self.rank = self.comm.rank
def run(self):
if self.rank == 0:
super().run()
else:
grid = create_G()
solver = create_solver(grid)
class TaskfarmContext(Context):
"""Mixed mode MPI/OpenMP/CUDA context - MPI task farm is used to distribute
models, and each model parallelised using either OpenMP (CPU),
@@ -180,7 +198,9 @@ class TaskfarmContext(Context):
print_opencl_info(config.sim_config.devices["devs"])
s = f"\n--- Input file: {config.sim_config.input_file_path}"
logger.basic(Fore.GREEN + f"{s} {'-' * (get_terminal_width() - 1 - len(s))}\n" + Style.RESET_ALL)
logger.basic(
Fore.GREEN + f"{s} {'-' * (get_terminal_width() - 1 - len(s))}\n" + Style.RESET_ALL
)
sys.stdout.flush()

查看文件

@@ -20,7 +20,7 @@ import argparse
import gprMax.config as config
from .contexts import Context, TaskfarmContext
from .contexts import Context, MPIContext, TaskfarmContext
from .utilities.logging import logging_config
# Arguments (used for API) and their default values (used for API and CLI)
@@ -260,6 +260,9 @@ def run_main(args):
# MPI taskfarm running with (OpenMP/CUDA/OpenCL)
if config.sim_config.args.taskfarm:
context = TaskfarmContext()
# MPI running to divide model between ranks
elif config.sim_config.args.mpi:
context = MPIContext()
# Standard running (OpenMP/CUDA/OpenCL)
else:
context = Context()