Add docstrings to all MPIGrid and FDTDGrid methods

这个提交包含在:
nmannall
2024-08-16 14:47:45 +01:00
父节点 89873a5963
当前提交 c288e8eb7c
共有 2 个文件被更改,包括 377 次插入48 次删除

查看文件

@@ -138,6 +138,8 @@ class FDTDGrid:
self.dl[2] = value
def build(self) -> None:
"""Build the grid."""
# Set default CFS parameter for PMLs if not user provided
if not self.pmls["cfs"]:
self.pmls["cfs"] = [CFS()]
@@ -158,6 +160,8 @@ class FDTDGrid:
self._build_materials()
def _build_pmls(self) -> None:
"""Construct and calculate material properties of the PMLs."""
pbar = tqdm(
total=sum(1 for value in self.pmls["thickness"].values() if value > 0),
desc=f"Building PML boundaries [{self.name}]",
@@ -170,7 +174,8 @@ class FDTDGrid:
pml = self._construct_pml(pml_id, thickness)
averageer, averagemr = self._calculate_average_pml_material_properties(pml)
logger.debug(
f"PML {pml.ID}: Average permittivity = {averageer}, Average permeability = {averagemr}"
f"PML {pml.ID}: Average permittivity = {averageer}, Average permeability ="
f" {averagemr}"
)
pml.calculate_update_coeffs(averageer, averagemr)
self.pmls["slabs"].append(pml)
@@ -180,14 +185,15 @@ class FDTDGrid:
PmlType = TypeVar("PmlType", bound=PML)
def _construct_pml(self, pml_ID: str, thickness: int, pml_type: type[PmlType] = PML) -> PmlType:
"""Builds instances of the PML and calculates the initial parameters and
coefficients including setting profile (based on underlying material
er and mr from solid array).
"""Build PML instance of the specified ID, thickness and type.
Constructs a PML of the specified type and thickness. Properties
of the PML are set based on the provided identifier.
Args:
G: FDTDGrid class describing a grid in a model.
pml_ID: string identifier of PML slab.
thickness: int with thickness of PML slab in cells.
pml_ID: Identifier of PML slab.
thickness: Thickness of PML slab in cells.
pml_type: PML class to construct.
"""
if pml_ID == "x0":
pml = pml_type(
@@ -267,6 +273,15 @@ class FDTDGrid:
return pml
def _calculate_average_pml_material_properties(self, pml: PML) -> Tuple[float, float]:
"""Calculate average material properties for the provided PML.
Args:
pml: PML to calculate the properties of.
Returns:
averageer, averagemr: Average permittivity and permeability
in the PML slab.
"""
# Arrays to hold values of permittivity and permeability (avoids accessing
# Material class in Cython.)
ers = np.zeros(len(self.materials))
@@ -294,8 +309,11 @@ class FDTDGrid:
return pml_average_er_mr(n1, n2, config.get_model_config().ompthreads, solid, ers, mrs)
def _build_components(self) -> None:
# Build the model, i.e. set the material properties (ID) for every edge
# of every Yee cell
"""Build electric and magnetic components of the grid.
Set the material properties (stored in the ID array) for every
edge of every Yee cell.
"""
pbar = tqdm(
total=2,
desc=f"Building Yee cells [{self.name}]",
@@ -310,6 +328,7 @@ class FDTDGrid:
pbar.close()
def _tm_grid_update(self) -> None:
"""Add PEC boundaries to invariant if in 2D mode."""
if config.get_model_config().mode == "2D TMx":
self.tmx()
elif config.get_model_config().mode == "2D TMy":
@@ -318,14 +337,21 @@ class FDTDGrid:
self.tmz()
def _create_voltage_source_materials(self):
"""Create materials for voltage sources.
Process any voltage sources (that have resistance) to create a
new material at the source location.
"""
# Process any voltage sources (that have resistance) to create a new
# material at the source location
for voltagesource in self.voltagesources:
voltagesource.create_material(self)
def _build_materials(self) -> None:
# Process complete list of materials - calculate update coefficients,
# store in arrays, and build text list of materials/properties
"""Calculate properties of materials in the grid.
Log a summary of the material properties.
"""
materialsdata = process_materials(self)
# materialstable = SingleTable(materialsdata)
materialstable = AsciiTable(materialsdata)
@@ -338,6 +364,17 @@ class FDTDGrid:
def _update_positions(
self, items: Iterable[Union[Source, Rx]], step_size: List[int], step_number: int
) -> None:
"""Update the grid positions of the provided items.
Args:
items: Sources and receivers to update.
step_size: Number of grid cells to move the items each step.
step_number: Number of steps to move the items by.
Raises:
ValueError: Raised if any of the items would be stepped
outside of the grid.
"""
if step_size[0] != 0 or step_size[1] != 0 or step_size[2] != 0:
for item in items:
if step_number == 0:
@@ -355,6 +392,20 @@ class FDTDGrid:
item.zcoord = item.zcoordorigin + step_number * step_size[2]
def update_simple_source_positions(self, step_size: List[int], step: int = 0) -> None:
"""Update the positions of sources in the grid.
Move hertzian dipole and magnetic dipole sources. Transmission
line sources and voltage sources will not be moved.
Args:
step_size: Number of grid cells to move the sources each
step.
step: Number of steps to move the sources by.
Raises:
ValueError: Raised if any of the sources would be stepped
outside of the grid.
"""
try:
self._update_positions(
itertools.chain(self.hertziandipoles, self.magneticdipoles), step_size, step
@@ -364,13 +415,35 @@ class FDTDGrid:
raise ValueError from e
def update_receiver_positions(self, step_size: List[int], step: int = 0) -> None:
"""Update the positions of receivers in the grid.
Args:
step_size: Number of grid cells to move the receivers each
step.
step: Number of steps to move the receivers by.
Raises:
ValueError: Raised if any of the receivers would be stepped
outside of the grid.
"""
try:
self._update_positions(self.rxs, step_size, step)
except ValueError as e:
logger.exception("Receiver(s) will be stepped to a position outside the domain.")
raise ValueError from e
def within_bounds(self, p):
IntPoint = Tuple[int, int, int]
FloatPoint = Tuple[float, float, float]
def within_bounds(self, p: IntPoint):
"""Check a point is within the grid.
Args:
p: Point to check.
Raises:
ValueError: Raised if the point is outside the grid.
"""
if p[0] < 0 or p[0] > self.nx:
raise ValueError("x")
if p[1] < 0 or p[1] > self.ny:
@@ -378,39 +451,71 @@ class FDTDGrid:
if p[2] < 0 or p[2] > self.nz:
raise ValueError("z")
def discretise_point(self, p):
def discretise_point(self, p: FloatPoint) -> IntPoint:
"""Calculate the nearest grid cell to the given point.
Args:
p: Point to discretise.
Returns:
x, y, z: Discretised point.
"""
x = round_value(float(p[0]) / self.dx)
y = round_value(float(p[1]) / self.dy)
z = round_value(float(p[2]) / self.dz)
return (x, y, z)
def round_to_grid(self, p):
def round_to_grid(self, p: FloatPoint) -> FloatPoint:
"""Round the provided point to the nearest grid cell.
Args:
p: Point to round.
Returns:
p_r: Rounded point.
"""
p = self.discretise_point(p)
p_r = (p[0] * self.dx, p[1] * self.dy, p[2] * self.dz)
return p_r
def within_pml(self, p):
if (
def within_pml(self, p: IntPoint) -> bool:
"""Check if the provided point is within a PML.
Args:
p: Point to check.
Returns:
within_pml: True if the point is within a PML.
"""
return (
p[0] < self.pmls["thickness"]["x0"]
or p[0] > self.nx - self.pmls["thickness"]["xmax"]
or p[1] < self.pmls["thickness"]["y0"]
or p[1] > self.ny - self.pmls["thickness"]["ymax"]
or p[2] < self.pmls["thickness"]["z0"]
or p[2] > self.nz - self.pmls["thickness"]["zmax"]
):
return True
else:
return False
)
def get_waveform_by_id(self, waveform_id: str) -> Waveform:
"""Get waveform with the specified ID.
Args:
waveform_id: ID of the waveform.
Returns:
waveform: Requested waveform
"""
return next(waveform for waveform in self.waveforms if waveform.ID == waveform_id)
def initialise_geometry_arrays(self):
"""Initialise an array for volumetric material IDs (solid);
boolean arrays for specifying whether materials can have dielectric
"""Initialise arrays to store geometry properties.
Initialise an array for volumetric material IDs (solid); boolean
arrays for specifying whether materials can have dielectric
smoothing (rigid); and an array for cell edge IDs (ID).
Solid and ID arrays are initialised to free_space (one);
rigid arrays to allow dielectric smoothing (zero).
Solid and ID arrays are initialised to free_space (one); rigid
arrays to allow dielectric smoothing (zero).
"""
self.solid = np.ones((self.nx, self.ny, self.nz), dtype=np.uint32)
self.rigidE = np.zeros((12, self.nx, self.ny, self.nz), dtype=np.int8)
@@ -566,8 +671,10 @@ class FDTDGrid:
return mem_use
def mem_est_fractals(self):
"""Estimates the amount of memory (RAM) required to build any objects
which use the FractalVolume/FractalSurface classes.
"""Calculate the memory required to build fractal objects.
Estimates the amount of memory (RAM) required to build any
objects which use the FractalVolume/FractalSurface classes.
Returns:
mem_use: int of memory (bytes).
@@ -693,20 +800,23 @@ class FDTDGrid:
return Iz
def dispersion_analysis(self, iterations: int):
# Check to see if numerical dispersion might be a problem
"""Check to see if numerical dispersion might be a problem.
Raises:
ValueError: Raised if a problem is encountered.
"""
results = self._dispersion_analysis(iterations)
if results["error"]:
logger.warning(
f"Numerical dispersion analysis [{self.name}] "
f"not carried out as {results['error']}"
f"Numerical dispersion analysis [{self.name}] not carried out as {results['error']}"
)
elif results["N"] < config.get_model_config().numdispersion["mingridsampling"]:
logger.exception(
f"\nNon-physical wave propagation in [{self.name}] "
f"detected. Material '{results['material'].ID}' "
f"has wavelength sampled by {results['N']} cells, "
f"less than required minimum for physical wave "
f"propagation. Maximum significant frequency "
"less than required minimum for physical wave "
"propagation. Maximum significant frequency "
f"estimated as {results['maxfreq']:g}Hz"
)
raise ValueError
@@ -717,29 +827,31 @@ class FDTDGrid:
):
logger.warning(
f"[{self.name}] has potentially significant "
f"numerical dispersion. Estimated largest physical "
"numerical dispersion. Estimated largest physical "
f"phase-velocity error is {results['deltavp']:.2f}% "
f"in material '{results['material'].ID}' whose "
f"wavelength sampled by {results['N']} cells. "
f"Maximum significant frequency estimated as "
"Maximum significant frequency estimated as "
f"{results['maxfreq']:g}Hz\n"
)
elif results["deltavp"]:
logger.info(
f"Numerical dispersion analysis [{self.name}]: "
f"estimated largest physical phase-velocity error is "
"estimated largest physical phase-velocity error is "
f"{results['deltavp']:.2f}% in material '{results['material'].ID}' "
f"whose wavelength sampled by {results['N']} cells. "
f"Maximum significant frequency estimated as "
"Maximum significant frequency estimated as "
f"{results['maxfreq']:g}Hz\n"
)
def _dispersion_analysis(self, iterations: int):
"""Analysis of numerical dispersion (Taflove et al, 2005, p112) -
worse case of maximum frequency and minimum wavelength
def _dispersion_analysis(self, iterations: int) -> dict[str, Any]:
"""Run dispersion analysis.
Analysis of numerical dispersion (Taflove et al, 2005, p112) -
worse case of maximum frequency and minimum wavelength.
Args:
G: FDTDGrid class describing a grid in a model.
iterations: Number of iterations the model will run for.
Returns:
results: dict of results from dispersion analysis.
@@ -768,8 +880,9 @@ class FDTDGrid:
# Time to analyse waveform - 4*pulse_width as using entire
# time window can result in demanding FFT
waveform.calculate_coefficients()
iterations = round_value(4 * waveform.chi / self.dt)
iterations = min(iterations, iterations)
# TODO: Check max_iterations should be calculated (original code didn't go on to use it)
max_iterations = round_value(4 * waveform.chi / self.dt)
iterations = min(iterations, max_iterations)
waveformvalues = np.zeros(iterations)
for iteration in range(iterations):
waveformvalues[iteration] = waveform.calculate_value(

查看文件

@@ -24,7 +24,7 @@ from typing import List, Optional, Tuple, TypeVar, Union
import numpy as np
import numpy.typing as npt
from mpi4py import MPI
from numpy import ndarray
from numpy import empty, ndarray
from gprMax import config
from gprMax.cython.pml_build import pml_sum_er_mr
@@ -115,11 +115,31 @@ class MPIGrid(FDTDGrid):
self.size[Dim.Z] = value
def is_coordinator(self) -> bool:
"""Test if the current rank is the coordinator.
Returns:
is_coordinator: True if `self.rank` equals
`self.COORDINATOR_RANK`.
"""
return self.rank == self.COORDINATOR_RANK
def get_grid_coord_from_coordinate(self, coord: npt.NDArray) -> npt.NDArray[np.intc]:
def get_grid_coord_from_coordinate(self, coord: npt.NDArray[np.intc]) -> npt.NDArray[np.intc]:
"""Get the MPI grid coordinate for a global grid coordinate.
Args:
coord: Global grid coordinate.
Returns:
grid_coord: Coordinate of the MPI rank containing the global
grid coordinate.
"""
step_size = self.global_size // self.mpi_tasks
overflow = self.global_size % self.mpi_tasks
# The first n MPI ranks where n is the overflow, will have size
# step_size + 1. Additionally, step_size may be zero in some
# dimensions (e.g. in the 2D case) so we need to avoid division
# by zero.
return np.where(
(step_size + 1) * overflow >= coord,
coord // (step_size + 1),
@@ -127,12 +147,36 @@ class MPIGrid(FDTDGrid):
)
def get_rank_from_coordinate(self, coord: npt.NDArray) -> int:
"""Get the MPI rank for a global grid coordinate.
A coordinate only exists on a single rank (halos are ignored).
Args:
coord: Global grid coordinate.
Returns:
rank: MPI rank containing the global grid coordinate.
"""
grid_coord = self.get_grid_coord_from_coordinate(coord)
return self.comm.Get_cart_rank(grid_coord.tolist())
def get_ranks_between_coordinates(
self, start_coord: npt.NDArray, stop_coord: npt.NDArray
) -> List[int]:
"""Get the MPI ranks for between two global grid coordinates.
`stop_coord` must not be less than `start_coord` in any
dimension, however it can be equal. The returned ranks will
contain coordinates inclusive of both `start_coord` and
`stop_coord`.
Args:
start_coord: Starting global grid coordinate.
stop_coord: End global grid coordinate.
Returns:
ranks: List of MPI ranks
"""
start = self.get_grid_coord_from_coordinate(start_coord)
stop = self.get_grid_coord_from_coordinate(stop_coord) + 1
coord_to_rank = lambda c: self.comm.Get_cart_rank((start + c).tolist())
@@ -141,12 +185,45 @@ class MPIGrid(FDTDGrid):
def global_to_local_coordinate(
self, global_coord: npt.NDArray[np.intc]
) -> npt.NDArray[np.intc]:
"""Convert a global grid coordinate to a local grid coordinate.
The returned coordinate will be relative to the current MPI
rank's local grid. It may be negative, or greater than the size
of the local grid if the point lies outside the local grid.
Args:
global_coord: Global grid coordinate.
Returns:
local_coord: Local grid coordinate
"""
return global_coord - self.lower_extent
def local_to_global_coordinate(self, local_coord: npt.NDArray[np.intc]) -> npt.NDArray[np.intc]:
"""Convert a local grid coordinate to a global grid coordinate.
Args:
local_coord: Local grid coordinate
Returns:
global_coord: Global grid coordinate
"""
return local_coord + self.lower_extent
def scatter_coord_objects(self, objects: List[CoordType]) -> List[CoordType]:
"""Scatter coord objects to the correct MPI rank.
Coord objects (sources and receivers) are scattered to the MPI
rank based on their location in the grid. The receiving MPI rank
converts the object locations to its own local grid.
Args:
objects: Coord objects to be scattered.
Returns:
scattered_objects: List of Coord objects belonging to the
current MPI rank.
"""
if self.is_coordinator():
objects_by_rank: List[List[CoordType]] = [[] for _ in range(self.comm.size)]
for o in objects:
@@ -162,6 +239,20 @@ class MPIGrid(FDTDGrid):
return objects
def gather_coord_objects(self, objects: List[CoordType]) -> List[CoordType]:
"""Scatter coord objects to the correct MPI rank.
The sending MPI rank converts the object locations to the global
grid. The coord objects (sources and receivers) are all sent to
the coordinatoor rank.
Args:
objects: Coord objects to be gathered.
Returns:
gathered_objects: List of gathered coord objects if the
current rank is the coordinator. Otherwise, the original
list of objects is returned.
"""
for o in objects:
o.coord = self.local_to_global_coordinate(o.coord)
gathered_objects: Optional[List[List[CoordType]]] = self.comm.gather(
@@ -174,6 +265,13 @@ class MPIGrid(FDTDGrid):
return objects
def scatter_snapshots(self):
"""Scatter snapshots to the correct MPI rank.
Each snapshot is sent by the coordinator to the MPI ranks
containing the snapshot. A new communicator is created for each
snapshot, and each rank bounds the snapshot to within its own
local grid.
"""
if self.is_coordinator():
snapshots_by_rank: List[List[Optional[Snapshot]]] = [[] for _ in range(self.comm.size)]
for s in self.snapshots:
@@ -184,6 +282,10 @@ class MPIGrid(FDTDGrid):
if rank in ranks:
snapshots_by_rank[rank].append(s)
else:
# All ranks need the same number of 'snapshots'
# (which may be None) to ensure snapshot
# communicators are setup correctly and to avoid
# deadlock.
snapshots_by_rank[rank].append(None)
else:
snapshots_by_rank = None
@@ -218,6 +320,20 @@ class MPIGrid(FDTDGrid):
self.snapshots = [s for s in snapshots if s is not None]
def scatter_3d_array(self, array: npt.NDArray) -> npt.NDArray:
"""Scatter a 3D array to each MPI rank
Use to distribute a 3D array across MPI ranks. Each rank will
receive its own segment of the array including a negative halo,
but NOT a positive halo.
Args:
array: Array to be scattered
Returns:
scattered_array: Local extent of the array for the current
MPI rank.
"""
# TODO: Use Scatter instead of Bcast
self.comm.Bcast(array, root=self.COORDINATOR_RANK)
return array[
@@ -227,6 +343,21 @@ class MPIGrid(FDTDGrid):
].copy(order="C")
def scatter_4d_array(self, array: npt.NDArray) -> npt.NDArray:
"""Scatter a 4D array to each MPI rank
Use to distribute a 4D array across MPI ranks. The first
dimension is ignored when partitioning the array. Each rank will
receive its own segment of the array including a negative halo,
but NOT a positive halo.
Args:
array: Array to be scattered
Returns:
scattered_array: Local extent of the array for the current
MPI rank.
"""
# TODO: Use Scatter instead of Bcast
self.comm.Bcast(array, root=self.COORDINATOR_RANK)
return array[
@@ -237,6 +368,21 @@ class MPIGrid(FDTDGrid):
].copy(order="C")
def scatter_4d_array_with_positive_halo(self, array: npt.NDArray) -> npt.NDArray:
"""Scatter a 4D array to each MPI rank
Use to distribute a 4D array across MPI ranks. The first
dimension is ignored when partitioning the array. Each rank will
receive its own segment of the array including both a negative
and positive halo.
Args:
array: Array to be scattered
Returns:
scattered_array: Local extent of the array for the current
MPI rank.
"""
# TODO: Use Scatter instead of Bcast
self.comm.Bcast(array, root=self.COORDINATOR_RANK)
return array[
@@ -246,7 +392,12 @@ class MPIGrid(FDTDGrid):
self.lower_extent[Dim.Z] : self.upper_extent[Dim.Z] + 1,
].copy(order="C")
def scatter_grid(self):
def distribute_grid(self):
"""Distribute grid properties and objects to all MPI ranks.
Global properties/objects are broadcast to all ranks whereas
local properties/objects are scattered to the relevant ranks.
"""
self.materials = self.comm.bcast(self.materials, root=self.COORDINATOR_RANK)
self.rxs = self.scatter_coord_objects(self.rxs)
self.voltagesources = self.scatter_coord_objects(self.voltagesources)
@@ -280,6 +431,8 @@ class MPIGrid(FDTDGrid):
self.rigidH = self.scatter_4d_array(self.rigidH)
def gather_grid_objects(self):
"""Gather sources and receivers."""
self.rxs = self.gather_coord_objects(self.rxs)
self.voltagesources = self.gather_coord_objects(self.voltagesources)
self.magneticdipoles = self.gather_coord_objects(self.magneticdipoles)
@@ -287,6 +440,7 @@ class MPIGrid(FDTDGrid):
self.transmissionlines = self.gather_coord_objects(self.transmissionlines)
def initialise_geometry_arrays(self, use_local_size=False):
# TODO: Remove this when scatter geometry arrays rather than broadcast
if use_local_size:
super().initialise_geometry_arrays()
else:
@@ -296,6 +450,16 @@ class MPIGrid(FDTDGrid):
self.ID = np.ones((6, *(self.global_size + 1)), dtype=np.uint32)
def _halo_swap(self, array: ndarray, dim: Dim, dir: Dir):
"""Perform a halo swap in the specifed dimension and direction.
If no neighbour exists for the current rank in the specifed
dimension and direction, the halo swap is skipped.
Args:
array: Array to perform the halo swap with.
dim: Dimension of halo to swap.
dir: Direction of halo to swap.
"""
neighbour = self.neighbours[dim][dir]
if neighbour != -1:
self.comm.Sendrecv(
@@ -309,6 +473,16 @@ class MPIGrid(FDTDGrid):
)
def _halo_swap_by_dimension(self, array: ndarray, dim: Dim):
"""Perform halo swaps in the specifed dimension.
Perform a halo swaps in the positive and negative direction for
the specified dimension. The order of the swaps is determined by
the current rank's MPI grid coordinate to prevent deadlock.
Args:
array: Array to perform the halo swaps with.
dim: Dimension of halos to swap.
"""
if self.coords[dim] % 2 == 0:
self._halo_swap(array, dim, Dir.NEG)
self._halo_swap(array, dim, Dir.POS)
@@ -317,21 +491,36 @@ class MPIGrid(FDTDGrid):
self._halo_swap(array, dim, Dir.NEG)
def _halo_swap_array(self, array: ndarray):
"""Perform halo swaps for the specified array.
Args:
array: Array to perform the halo swaps with.
"""
self._halo_swap_by_dimension(array, Dim.X)
self._halo_swap_by_dimension(array, Dim.Y)
self._halo_swap_by_dimension(array, Dim.Z)
def halo_swap_electric(self):
"""Perform halo swaps for electric field arrays."""
self._halo_swap_array(self.Ex)
self._halo_swap_array(self.Ey)
self._halo_swap_array(self.Ez)
def halo_swap_magnetic(self):
"""Perform halo swaps for magnetic field arrays."""
self._halo_swap_array(self.Hx)
self._halo_swap_array(self.Hy)
self._halo_swap_array(self.Hz)
def _construct_pml(self, pml_ID: str, thickness: int) -> MPIPML:
"""Build instance of MPIPML and set the MPI communicator.
Args:
pml_ID: Identifier of PML slab.
thickness: Thickness of PML slab in cells.
"""
pml = super()._construct_pml(pml_ID, thickness, MPIPML)
if pml.ID[0] == "x":
pml.comm = self.x_comm
@@ -344,6 +533,15 @@ class MPIGrid(FDTDGrid):
return pml
def _calculate_average_pml_material_properties(self, pml: MPIPML) -> Tuple[float, float]:
"""Calculate average material properties for the provided PML.
Args:
pml: PML to calculate the properties of.
Returns:
averageer, averagemr: Average permittivity and permeability
in the PML slab.
"""
# Arrays to hold values of permittivity and permeability (avoids
# accessing Material class in Cython.)
ers = np.zeros(len(self.materials))
@@ -387,15 +585,19 @@ class MPIGrid(FDTDGrid):
return averageer, averagemr
def build(self):
"""Set local properties and objects, then build the grid."""
if any(self.global_size + 1 < self.mpi_tasks):
logger.error(
f"Too many MPI tasks requested ({self.mpi_tasks}) for grid of size {self.global_size + 1}. Make sure the number of MPI tasks in each dimension is less than the size of the grid."
f"Too many MPI tasks requested ({self.mpi_tasks}) for grid of size"
f" {self.global_size + 1}. Make sure the number of MPI tasks in each dimension is"
" less than the size of the grid."
)
raise ValueError
self.calculate_local_extents()
self.set_halo_map()
self.scatter_grid()
self.distribute_grid()
# TODO: Check PML is not thicker than the grid size
@@ -414,9 +616,20 @@ class MPIGrid(FDTDGrid):
super().build()
def has_neighbour(self, dim: Dim, dir: Dir) -> bool:
"""Test if the current rank has a specified neighbour.
Args:
dim: Dimension of neighbour.
dir: Direction of neighbour.
Returns:
has_neighbour: True if the current rank has a neighbour in
the specified dimension and direction.
"""
return self.neighbours[dim][dir] != -1
def set_halo_map(self):
"""Create MPI DataTypes for field array halo exchanges."""
size = (self.size + 1).tolist()
for dim in Dim:
@@ -443,6 +656,8 @@ class MPIGrid(FDTDGrid):
self.recv_halo_map[dim][Dir.POS].Commit()
def calculate_local_extents(self):
"""Calculate size and extents of the local grid"""
self.size = self.global_size // self.mpi_tasks
overflow = self.global_size % self.mpi_tasks
@@ -465,5 +680,6 @@ class MPIGrid(FDTDGrid):
self.upper_extent = self.lower_extent + self.size
logger.debug(
f"Grid size: {self.size}, Lower extent: {self.lower_extent}, Upper extent: {self.upper_extent}"
f"Local grid size: {self.size}, Lower extent: {self.lower_extent}, Upper extent:"
f" {self.upper_extent}"
)