diff --git a/gprMax/grid/mpi_grid.py b/gprMax/grid/mpi_grid.py index 52de80a7..146aba93 100644 --- a/gprMax/grid/mpi_grid.py +++ b/gprMax/grid/mpi_grid.py @@ -24,16 +24,13 @@ from gprMax.grid.fdtd_grid import FDTDGrid class MPIGrid(FDTDGrid): - xmin: int - ymin: int - zmin: int - xmax: int - ymax: int - zmax: int - - comm: MPI.Intracomm - - def __init__(self, mpi_tasks_x: int, mpi_tasks_y: int, mpi_tasks_z: int, comm: Optional[MPI.Intracomm] = None): + def __init__( + self, + mpi_tasks_x: int, + mpi_tasks_y: int, + mpi_tasks_z: int, + comm: Optional[MPI.Intracomm] = None, + ): super().__init__() if comm is None: @@ -52,16 +49,25 @@ class MPIGrid(FDTDGrid): self.rank = self.comm.rank self.size = self.comm.size + self.xmin = 0 + self.ymin = 0 + self.zmin = 0 + self.xmax = 0 + self.ymax = 0 + self.zmax = 0 + def initialise_field_arrays(self): super().initialise_field_arrays() - self.local_grid_size_x = self.nx // self.mpi_tasks_x - self.local_grid_size_y = self.ny // self.mpi_tasks_y - self.local_grid_size_z = self.nz // self.mpi_tasks_z + local_grid_size_x = self.nx // self.mpi_tasks_x + local_grid_size_y = self.ny // self.mpi_tasks_y + local_grid_size_z = self.nz // self.mpi_tasks_z - self.xmin = (self.rank % self.nx) * self.local_grid_size_x - self.ymin = ((self.mpi_tasks_x * self.rank) % self.ny) * self.local_grid_size_y - self.zmin = ((self.mpi_tasks_y * self.mpi_tasks_x * self.rank) % self.nz) * self.local_grid_size_z - self.xmax = self.xmin + self.local_grid_size_x - self.ymax = self.ymin + self.local_grid_size_y - self.zmax = self.zmin + self.local_grid_size_z + self.xmin = (self.rank % self.nx) * local_grid_size_x + self.ymin = ((self.mpi_tasks_x * self.rank) % self.ny) * local_grid_size_y + self.zmin = ( + (self.mpi_tasks_y * self.mpi_tasks_x * self.rank) % self.nz + ) * local_grid_size_z + self.xmax = self.xmin + local_grid_size_x + self.ymax = self.ymin + local_grid_size_y + self.zmax = self.zmin + local_grid_size_z