diff --git a/gprMax/grid/fdtd_grid.py b/gprMax/grid/fdtd_grid.py index 090edc05..bb712eb1 100644 --- a/gprMax/grid/fdtd_grid.py +++ b/gprMax/grid/fdtd_grid.py @@ -138,6 +138,8 @@ class FDTDGrid: self.dl[2] = value def build(self) -> None: + """Build the grid.""" + # Set default CFS parameter for PMLs if not user provided if not self.pmls["cfs"]: self.pmls["cfs"] = [CFS()] @@ -158,6 +160,8 @@ class FDTDGrid: self._build_materials() def _build_pmls(self) -> None: + """Construct and calculate material properties of the PMLs.""" + pbar = tqdm( total=sum(1 for value in self.pmls["thickness"].values() if value > 0), desc=f"Building PML boundaries [{self.name}]", @@ -170,7 +174,8 @@ class FDTDGrid: pml = self._construct_pml(pml_id, thickness) averageer, averagemr = self._calculate_average_pml_material_properties(pml) logger.debug( - f"PML {pml.ID}: Average permittivity = {averageer}, Average permeability = {averagemr}" + f"PML {pml.ID}: Average permittivity = {averageer}, Average permeability =" + f" {averagemr}" ) pml.calculate_update_coeffs(averageer, averagemr) self.pmls["slabs"].append(pml) @@ -180,14 +185,15 @@ class FDTDGrid: PmlType = TypeVar("PmlType", bound=PML) def _construct_pml(self, pml_ID: str, thickness: int, pml_type: type[PmlType] = PML) -> PmlType: - """Builds instances of the PML and calculates the initial parameters and - coefficients including setting profile (based on underlying material - er and mr from solid array). + """Build PML instance of the specified ID, thickness and type. + + Constructs a PML of the specified type and thickness. Properties + of the PML are set based on the provided identifier. Args: - G: FDTDGrid class describing a grid in a model. - pml_ID: string identifier of PML slab. - thickness: int with thickness of PML slab in cells. + pml_ID: Identifier of PML slab. + thickness: Thickness of PML slab in cells. + pml_type: PML class to construct. """ if pml_ID == "x0": pml = pml_type( @@ -267,6 +273,15 @@ class FDTDGrid: return pml def _calculate_average_pml_material_properties(self, pml: PML) -> Tuple[float, float]: + """Calculate average material properties for the provided PML. + + Args: + pml: PML to calculate the properties of. + + Returns: + averageer, averagemr: Average permittivity and permeability + in the PML slab. + """ # Arrays to hold values of permittivity and permeability (avoids accessing # Material class in Cython.) ers = np.zeros(len(self.materials)) @@ -294,8 +309,11 @@ class FDTDGrid: return pml_average_er_mr(n1, n2, config.get_model_config().ompthreads, solid, ers, mrs) def _build_components(self) -> None: - # Build the model, i.e. set the material properties (ID) for every edge - # of every Yee cell + """Build electric and magnetic components of the grid. + + Set the material properties (stored in the ID array) for every + edge of every Yee cell. + """ pbar = tqdm( total=2, desc=f"Building Yee cells [{self.name}]", @@ -310,6 +328,7 @@ class FDTDGrid: pbar.close() def _tm_grid_update(self) -> None: + """Add PEC boundaries to invariant if in 2D mode.""" if config.get_model_config().mode == "2D TMx": self.tmx() elif config.get_model_config().mode == "2D TMy": @@ -318,14 +337,21 @@ class FDTDGrid: self.tmz() def _create_voltage_source_materials(self): + """Create materials for voltage sources. + + Process any voltage sources (that have resistance) to create a + new material at the source location. + """ # Process any voltage sources (that have resistance) to create a new # material at the source location for voltagesource in self.voltagesources: voltagesource.create_material(self) def _build_materials(self) -> None: - # Process complete list of materials - calculate update coefficients, - # store in arrays, and build text list of materials/properties + """Calculate properties of materials in the grid. + + Log a summary of the material properties. + """ materialsdata = process_materials(self) # materialstable = SingleTable(materialsdata) materialstable = AsciiTable(materialsdata) @@ -338,6 +364,17 @@ class FDTDGrid: def _update_positions( self, items: Iterable[Union[Source, Rx]], step_size: List[int], step_number: int ) -> None: + """Update the grid positions of the provided items. + + Args: + items: Sources and receivers to update. + step_size: Number of grid cells to move the items each step. + step_number: Number of steps to move the items by. + + Raises: + ValueError: Raised if any of the items would be stepped + outside of the grid. + """ if step_size[0] != 0 or step_size[1] != 0 or step_size[2] != 0: for item in items: if step_number == 0: @@ -355,6 +392,20 @@ class FDTDGrid: item.zcoord = item.zcoordorigin + step_number * step_size[2] def update_simple_source_positions(self, step_size: List[int], step: int = 0) -> None: + """Update the positions of sources in the grid. + + Move hertzian dipole and magnetic dipole sources. Transmission + line sources and voltage sources will not be moved. + + Args: + step_size: Number of grid cells to move the sources each + step. + step: Number of steps to move the sources by. + + Raises: + ValueError: Raised if any of the sources would be stepped + outside of the grid. + """ try: self._update_positions( itertools.chain(self.hertziandipoles, self.magneticdipoles), step_size, step @@ -364,13 +415,35 @@ class FDTDGrid: raise ValueError from e def update_receiver_positions(self, step_size: List[int], step: int = 0) -> None: + """Update the positions of receivers in the grid. + + Args: + step_size: Number of grid cells to move the receivers each + step. + step: Number of steps to move the receivers by. + + Raises: + ValueError: Raised if any of the receivers would be stepped + outside of the grid. + """ try: self._update_positions(self.rxs, step_size, step) except ValueError as e: logger.exception("Receiver(s) will be stepped to a position outside the domain.") raise ValueError from e - def within_bounds(self, p): + IntPoint = Tuple[int, int, int] + FloatPoint = Tuple[float, float, float] + + def within_bounds(self, p: IntPoint): + """Check a point is within the grid. + + Args: + p: Point to check. + + Raises: + ValueError: Raised if the point is outside the grid. + """ if p[0] < 0 or p[0] > self.nx: raise ValueError("x") if p[1] < 0 or p[1] > self.ny: @@ -378,39 +451,71 @@ class FDTDGrid: if p[2] < 0 or p[2] > self.nz: raise ValueError("z") - def discretise_point(self, p): + def discretise_point(self, p: FloatPoint) -> IntPoint: + """Calculate the nearest grid cell to the given point. + + Args: + p: Point to discretise. + + Returns: + x, y, z: Discretised point. + """ x = round_value(float(p[0]) / self.dx) y = round_value(float(p[1]) / self.dy) z = round_value(float(p[2]) / self.dz) return (x, y, z) - def round_to_grid(self, p): + def round_to_grid(self, p: FloatPoint) -> FloatPoint: + """Round the provided point to the nearest grid cell. + + Args: + p: Point to round. + + Returns: + p_r: Rounded point. + """ p = self.discretise_point(p) p_r = (p[0] * self.dx, p[1] * self.dy, p[2] * self.dz) return p_r - def within_pml(self, p): - if ( + def within_pml(self, p: IntPoint) -> bool: + """Check if the provided point is within a PML. + + Args: + p: Point to check. + + Returns: + within_pml: True if the point is within a PML. + """ + return ( p[0] < self.pmls["thickness"]["x0"] or p[0] > self.nx - self.pmls["thickness"]["xmax"] or p[1] < self.pmls["thickness"]["y0"] or p[1] > self.ny - self.pmls["thickness"]["ymax"] or p[2] < self.pmls["thickness"]["z0"] or p[2] > self.nz - self.pmls["thickness"]["zmax"] - ): - return True - else: - return False + ) def get_waveform_by_id(self, waveform_id: str) -> Waveform: + """Get waveform with the specified ID. + + Args: + waveform_id: ID of the waveform. + + Returns: + waveform: Requested waveform + """ return next(waveform for waveform in self.waveforms if waveform.ID == waveform_id) def initialise_geometry_arrays(self): - """Initialise an array for volumetric material IDs (solid); - boolean arrays for specifying whether materials can have dielectric - smoothing (rigid); and an array for cell edge IDs (ID). - Solid and ID arrays are initialised to free_space (one); - rigid arrays to allow dielectric smoothing (zero). + """Initialise arrays to store geometry properties. + + Initialise an array for volumetric material IDs (solid); boolean + arrays for specifying whether materials can have dielectric + smoothing (rigid); and an array for cell edge IDs (ID). + + Solid and ID arrays are initialised to free_space (one); rigid + arrays to allow dielectric smoothing (zero). """ self.solid = np.ones((self.nx, self.ny, self.nz), dtype=np.uint32) self.rigidE = np.zeros((12, self.nx, self.ny, self.nz), dtype=np.int8) @@ -566,8 +671,10 @@ class FDTDGrid: return mem_use def mem_est_fractals(self): - """Estimates the amount of memory (RAM) required to build any objects - which use the FractalVolume/FractalSurface classes. + """Calculate the memory required to build fractal objects. + + Estimates the amount of memory (RAM) required to build any + objects which use the FractalVolume/FractalSurface classes. Returns: mem_use: int of memory (bytes). @@ -693,20 +800,23 @@ class FDTDGrid: return Iz def dispersion_analysis(self, iterations: int): - # Check to see if numerical dispersion might be a problem + """Check to see if numerical dispersion might be a problem. + + Raises: + ValueError: Raised if a problem is encountered. + """ results = self._dispersion_analysis(iterations) if results["error"]: logger.warning( - f"Numerical dispersion analysis [{self.name}] " - f"not carried out as {results['error']}" + f"Numerical dispersion analysis [{self.name}] not carried out as {results['error']}" ) elif results["N"] < config.get_model_config().numdispersion["mingridsampling"]: logger.exception( f"\nNon-physical wave propagation in [{self.name}] " f"detected. Material '{results['material'].ID}' " f"has wavelength sampled by {results['N']} cells, " - f"less than required minimum for physical wave " - f"propagation. Maximum significant frequency " + "less than required minimum for physical wave " + "propagation. Maximum significant frequency " f"estimated as {results['maxfreq']:g}Hz" ) raise ValueError @@ -717,29 +827,31 @@ class FDTDGrid: ): logger.warning( f"[{self.name}] has potentially significant " - f"numerical dispersion. Estimated largest physical " + "numerical dispersion. Estimated largest physical " f"phase-velocity error is {results['deltavp']:.2f}% " f"in material '{results['material'].ID}' whose " f"wavelength sampled by {results['N']} cells. " - f"Maximum significant frequency estimated as " + "Maximum significant frequency estimated as " f"{results['maxfreq']:g}Hz\n" ) elif results["deltavp"]: logger.info( f"Numerical dispersion analysis [{self.name}]: " - f"estimated largest physical phase-velocity error is " + "estimated largest physical phase-velocity error is " f"{results['deltavp']:.2f}% in material '{results['material'].ID}' " f"whose wavelength sampled by {results['N']} cells. " - f"Maximum significant frequency estimated as " + "Maximum significant frequency estimated as " f"{results['maxfreq']:g}Hz\n" ) - def _dispersion_analysis(self, iterations: int): - """Analysis of numerical dispersion (Taflove et al, 2005, p112) - - worse case of maximum frequency and minimum wavelength + def _dispersion_analysis(self, iterations: int) -> dict[str, Any]: + """Run dispersion analysis. + + Analysis of numerical dispersion (Taflove et al, 2005, p112) - + worse case of maximum frequency and minimum wavelength. Args: - G: FDTDGrid class describing a grid in a model. + iterations: Number of iterations the model will run for. Returns: results: dict of results from dispersion analysis. @@ -768,8 +880,9 @@ class FDTDGrid: # Time to analyse waveform - 4*pulse_width as using entire # time window can result in demanding FFT waveform.calculate_coefficients() - iterations = round_value(4 * waveform.chi / self.dt) - iterations = min(iterations, iterations) + # TODO: Check max_iterations should be calculated (original code didn't go on to use it) + max_iterations = round_value(4 * waveform.chi / self.dt) + iterations = min(iterations, max_iterations) waveformvalues = np.zeros(iterations) for iteration in range(iterations): waveformvalues[iteration] = waveform.calculate_value( diff --git a/gprMax/grid/mpi_grid.py b/gprMax/grid/mpi_grid.py index 95604684..f9f58f47 100644 --- a/gprMax/grid/mpi_grid.py +++ b/gprMax/grid/mpi_grid.py @@ -24,7 +24,7 @@ from typing import List, Optional, Tuple, TypeVar, Union import numpy as np import numpy.typing as npt from mpi4py import MPI -from numpy import ndarray +from numpy import empty, ndarray from gprMax import config from gprMax.cython.pml_build import pml_sum_er_mr @@ -115,11 +115,31 @@ class MPIGrid(FDTDGrid): self.size[Dim.Z] = value def is_coordinator(self) -> bool: + """Test if the current rank is the coordinator. + + Returns: + is_coordinator: True if `self.rank` equals + `self.COORDINATOR_RANK`. + """ return self.rank == self.COORDINATOR_RANK - def get_grid_coord_from_coordinate(self, coord: npt.NDArray) -> npt.NDArray[np.intc]: + def get_grid_coord_from_coordinate(self, coord: npt.NDArray[np.intc]) -> npt.NDArray[np.intc]: + """Get the MPI grid coordinate for a global grid coordinate. + + Args: + coord: Global grid coordinate. + + Returns: + grid_coord: Coordinate of the MPI rank containing the global + grid coordinate. + """ step_size = self.global_size // self.mpi_tasks overflow = self.global_size % self.mpi_tasks + + # The first n MPI ranks where n is the overflow, will have size + # step_size + 1. Additionally, step_size may be zero in some + # dimensions (e.g. in the 2D case) so we need to avoid division + # by zero. return np.where( (step_size + 1) * overflow >= coord, coord // (step_size + 1), @@ -127,12 +147,36 @@ class MPIGrid(FDTDGrid): ) def get_rank_from_coordinate(self, coord: npt.NDArray) -> int: + """Get the MPI rank for a global grid coordinate. + + A coordinate only exists on a single rank (halos are ignored). + + Args: + coord: Global grid coordinate. + + Returns: + rank: MPI rank containing the global grid coordinate. + """ grid_coord = self.get_grid_coord_from_coordinate(coord) return self.comm.Get_cart_rank(grid_coord.tolist()) def get_ranks_between_coordinates( self, start_coord: npt.NDArray, stop_coord: npt.NDArray ) -> List[int]: + """Get the MPI ranks for between two global grid coordinates. + + `stop_coord` must not be less than `start_coord` in any + dimension, however it can be equal. The returned ranks will + contain coordinates inclusive of both `start_coord` and + `stop_coord`. + + Args: + start_coord: Starting global grid coordinate. + stop_coord: End global grid coordinate. + + Returns: + ranks: List of MPI ranks + """ start = self.get_grid_coord_from_coordinate(start_coord) stop = self.get_grid_coord_from_coordinate(stop_coord) + 1 coord_to_rank = lambda c: self.comm.Get_cart_rank((start + c).tolist()) @@ -141,12 +185,45 @@ class MPIGrid(FDTDGrid): def global_to_local_coordinate( self, global_coord: npt.NDArray[np.intc] ) -> npt.NDArray[np.intc]: + """Convert a global grid coordinate to a local grid coordinate. + + The returned coordinate will be relative to the current MPI + rank's local grid. It may be negative, or greater than the size + of the local grid if the point lies outside the local grid. + + Args: + global_coord: Global grid coordinate. + + Returns: + local_coord: Local grid coordinate + """ return global_coord - self.lower_extent def local_to_global_coordinate(self, local_coord: npt.NDArray[np.intc]) -> npt.NDArray[np.intc]: + """Convert a local grid coordinate to a global grid coordinate. + + Args: + local_coord: Local grid coordinate + + Returns: + global_coord: Global grid coordinate + """ return local_coord + self.lower_extent def scatter_coord_objects(self, objects: List[CoordType]) -> List[CoordType]: + """Scatter coord objects to the correct MPI rank. + + Coord objects (sources and receivers) are scattered to the MPI + rank based on their location in the grid. The receiving MPI rank + converts the object locations to its own local grid. + + Args: + objects: Coord objects to be scattered. + + Returns: + scattered_objects: List of Coord objects belonging to the + current MPI rank. + """ if self.is_coordinator(): objects_by_rank: List[List[CoordType]] = [[] for _ in range(self.comm.size)] for o in objects: @@ -162,6 +239,20 @@ class MPIGrid(FDTDGrid): return objects def gather_coord_objects(self, objects: List[CoordType]) -> List[CoordType]: + """Scatter coord objects to the correct MPI rank. + + The sending MPI rank converts the object locations to the global + grid. The coord objects (sources and receivers) are all sent to + the coordinatoor rank. + + Args: + objects: Coord objects to be gathered. + + Returns: + gathered_objects: List of gathered coord objects if the + current rank is the coordinator. Otherwise, the original + list of objects is returned. + """ for o in objects: o.coord = self.local_to_global_coordinate(o.coord) gathered_objects: Optional[List[List[CoordType]]] = self.comm.gather( @@ -174,6 +265,13 @@ class MPIGrid(FDTDGrid): return objects def scatter_snapshots(self): + """Scatter snapshots to the correct MPI rank. + + Each snapshot is sent by the coordinator to the MPI ranks + containing the snapshot. A new communicator is created for each + snapshot, and each rank bounds the snapshot to within its own + local grid. + """ if self.is_coordinator(): snapshots_by_rank: List[List[Optional[Snapshot]]] = [[] for _ in range(self.comm.size)] for s in self.snapshots: @@ -184,6 +282,10 @@ class MPIGrid(FDTDGrid): if rank in ranks: snapshots_by_rank[rank].append(s) else: + # All ranks need the same number of 'snapshots' + # (which may be None) to ensure snapshot + # communicators are setup correctly and to avoid + # deadlock. snapshots_by_rank[rank].append(None) else: snapshots_by_rank = None @@ -218,6 +320,20 @@ class MPIGrid(FDTDGrid): self.snapshots = [s for s in snapshots if s is not None] def scatter_3d_array(self, array: npt.NDArray) -> npt.NDArray: + """Scatter a 3D array to each MPI rank + + Use to distribute a 3D array across MPI ranks. Each rank will + receive its own segment of the array including a negative halo, + but NOT a positive halo. + + Args: + array: Array to be scattered + + Returns: + scattered_array: Local extent of the array for the current + MPI rank. + """ + # TODO: Use Scatter instead of Bcast self.comm.Bcast(array, root=self.COORDINATOR_RANK) return array[ @@ -227,6 +343,21 @@ class MPIGrid(FDTDGrid): ].copy(order="C") def scatter_4d_array(self, array: npt.NDArray) -> npt.NDArray: + """Scatter a 4D array to each MPI rank + + Use to distribute a 4D array across MPI ranks. The first + dimension is ignored when partitioning the array. Each rank will + receive its own segment of the array including a negative halo, + but NOT a positive halo. + + Args: + array: Array to be scattered + + Returns: + scattered_array: Local extent of the array for the current + MPI rank. + """ + # TODO: Use Scatter instead of Bcast self.comm.Bcast(array, root=self.COORDINATOR_RANK) return array[ @@ -237,6 +368,21 @@ class MPIGrid(FDTDGrid): ].copy(order="C") def scatter_4d_array_with_positive_halo(self, array: npt.NDArray) -> npt.NDArray: + """Scatter a 4D array to each MPI rank + + Use to distribute a 4D array across MPI ranks. The first + dimension is ignored when partitioning the array. Each rank will + receive its own segment of the array including both a negative + and positive halo. + + Args: + array: Array to be scattered + + Returns: + scattered_array: Local extent of the array for the current + MPI rank. + """ + # TODO: Use Scatter instead of Bcast self.comm.Bcast(array, root=self.COORDINATOR_RANK) return array[ @@ -246,7 +392,12 @@ class MPIGrid(FDTDGrid): self.lower_extent[Dim.Z] : self.upper_extent[Dim.Z] + 1, ].copy(order="C") - def scatter_grid(self): + def distribute_grid(self): + """Distribute grid properties and objects to all MPI ranks. + + Global properties/objects are broadcast to all ranks whereas + local properties/objects are scattered to the relevant ranks. + """ self.materials = self.comm.bcast(self.materials, root=self.COORDINATOR_RANK) self.rxs = self.scatter_coord_objects(self.rxs) self.voltagesources = self.scatter_coord_objects(self.voltagesources) @@ -280,6 +431,8 @@ class MPIGrid(FDTDGrid): self.rigidH = self.scatter_4d_array(self.rigidH) def gather_grid_objects(self): + """Gather sources and receivers.""" + self.rxs = self.gather_coord_objects(self.rxs) self.voltagesources = self.gather_coord_objects(self.voltagesources) self.magneticdipoles = self.gather_coord_objects(self.magneticdipoles) @@ -287,6 +440,7 @@ class MPIGrid(FDTDGrid): self.transmissionlines = self.gather_coord_objects(self.transmissionlines) def initialise_geometry_arrays(self, use_local_size=False): + # TODO: Remove this when scatter geometry arrays rather than broadcast if use_local_size: super().initialise_geometry_arrays() else: @@ -296,6 +450,16 @@ class MPIGrid(FDTDGrid): self.ID = np.ones((6, *(self.global_size + 1)), dtype=np.uint32) def _halo_swap(self, array: ndarray, dim: Dim, dir: Dir): + """Perform a halo swap in the specifed dimension and direction. + + If no neighbour exists for the current rank in the specifed + dimension and direction, the halo swap is skipped. + + Args: + array: Array to perform the halo swap with. + dim: Dimension of halo to swap. + dir: Direction of halo to swap. + """ neighbour = self.neighbours[dim][dir] if neighbour != -1: self.comm.Sendrecv( @@ -309,6 +473,16 @@ class MPIGrid(FDTDGrid): ) def _halo_swap_by_dimension(self, array: ndarray, dim: Dim): + """Perform halo swaps in the specifed dimension. + + Perform a halo swaps in the positive and negative direction for + the specified dimension. The order of the swaps is determined by + the current rank's MPI grid coordinate to prevent deadlock. + + Args: + array: Array to perform the halo swaps with. + dim: Dimension of halos to swap. + """ if self.coords[dim] % 2 == 0: self._halo_swap(array, dim, Dir.NEG) self._halo_swap(array, dim, Dir.POS) @@ -317,21 +491,36 @@ class MPIGrid(FDTDGrid): self._halo_swap(array, dim, Dir.NEG) def _halo_swap_array(self, array: ndarray): + """Perform halo swaps for the specified array. + + Args: + array: Array to perform the halo swaps with. + """ self._halo_swap_by_dimension(array, Dim.X) self._halo_swap_by_dimension(array, Dim.Y) self._halo_swap_by_dimension(array, Dim.Z) def halo_swap_electric(self): + """Perform halo swaps for electric field arrays.""" + self._halo_swap_array(self.Ex) self._halo_swap_array(self.Ey) self._halo_swap_array(self.Ez) def halo_swap_magnetic(self): + """Perform halo swaps for magnetic field arrays.""" + self._halo_swap_array(self.Hx) self._halo_swap_array(self.Hy) self._halo_swap_array(self.Hz) def _construct_pml(self, pml_ID: str, thickness: int) -> MPIPML: + """Build instance of MPIPML and set the MPI communicator. + + Args: + pml_ID: Identifier of PML slab. + thickness: Thickness of PML slab in cells. + """ pml = super()._construct_pml(pml_ID, thickness, MPIPML) if pml.ID[0] == "x": pml.comm = self.x_comm @@ -344,6 +533,15 @@ class MPIGrid(FDTDGrid): return pml def _calculate_average_pml_material_properties(self, pml: MPIPML) -> Tuple[float, float]: + """Calculate average material properties for the provided PML. + + Args: + pml: PML to calculate the properties of. + + Returns: + averageer, averagemr: Average permittivity and permeability + in the PML slab. + """ # Arrays to hold values of permittivity and permeability (avoids # accessing Material class in Cython.) ers = np.zeros(len(self.materials)) @@ -387,15 +585,19 @@ class MPIGrid(FDTDGrid): return averageer, averagemr def build(self): + """Set local properties and objects, then build the grid.""" + if any(self.global_size + 1 < self.mpi_tasks): logger.error( - f"Too many MPI tasks requested ({self.mpi_tasks}) for grid of size {self.global_size + 1}. Make sure the number of MPI tasks in each dimension is less than the size of the grid." + f"Too many MPI tasks requested ({self.mpi_tasks}) for grid of size" + f" {self.global_size + 1}. Make sure the number of MPI tasks in each dimension is" + " less than the size of the grid." ) raise ValueError self.calculate_local_extents() self.set_halo_map() - self.scatter_grid() + self.distribute_grid() # TODO: Check PML is not thicker than the grid size @@ -414,9 +616,20 @@ class MPIGrid(FDTDGrid): super().build() def has_neighbour(self, dim: Dim, dir: Dir) -> bool: + """Test if the current rank has a specified neighbour. + + Args: + dim: Dimension of neighbour. + dir: Direction of neighbour. + Returns: + has_neighbour: True if the current rank has a neighbour in + the specified dimension and direction. + """ return self.neighbours[dim][dir] != -1 def set_halo_map(self): + """Create MPI DataTypes for field array halo exchanges.""" + size = (self.size + 1).tolist() for dim in Dim: @@ -443,6 +656,8 @@ class MPIGrid(FDTDGrid): self.recv_halo_map[dim][Dir.POS].Commit() def calculate_local_extents(self): + """Calculate size and extents of the local grid""" + self.size = self.global_size // self.mpi_tasks overflow = self.global_size % self.mpi_tasks @@ -465,5 +680,6 @@ class MPIGrid(FDTDGrid): self.upper_extent = self.lower_extent + self.size logger.debug( - f"Grid size: {self.size}, Lower extent: {self.lower_extent}, Upper extent: {self.upper_extent}" + f"Local grid size: {self.size}, Lower extent: {self.lower_extent}, Upper extent:" + f" {self.upper_extent}" )