Add MPIUserObject

这个提交包含在:
nmannall
2025-01-31 17:22:04 +00:00
父节点 19eb17bf59
当前提交 828ed75429
共有 5 个文件被更改,包括 60 次插入46 次删除

查看文件

@@ -155,16 +155,6 @@ class MPIGrid(FDTDGrid):
if self.has_neighbour(Dim.Z, Dir.POS):
self.pmls["thickness"]["zmax"] = 0
def add_source(self, source: Source):
source.coord = self.global_to_local_coordinate(source.coord)
source.coordorigin = self.global_to_local_coordinate(source.coordorigin)
return super().add_source(source)
def add_receiver(self, receiver: Rx):
receiver.coord = self.global_to_local_coordinate(receiver.coord)
receiver.coordorigin = self.global_to_local_coordinate(receiver.coordorigin)
return super().add_receiver(receiver)
def is_coordinator(self) -> bool:
"""Test if the current rank is the coordinator.
@@ -524,14 +514,12 @@ class MPIGrid(FDTDGrid):
"""
self.scatter_snapshots()
if not self.is_coordinator():
# TODO: When scatter arrays properly, should initialise these to the local grid size
self.initialise_geometry_arrays()
self.ID = self.scatter_4d_array_with_positive_halo(self.ID)
self.solid = self.scatter_3d_array(self.solid)
self.rigidE = self.scatter_4d_array(self.rigidE)
self.rigidH = self.scatter_4d_array(self.rigidH)
# self._halo_swap_array(self.ID[0])
# self._halo_swap_array(self.ID[1])
# self._halo_swap_array(self.ID[2])
# self._halo_swap_array(self.ID[3])
# self._halo_swap_array(self.ID[4])
# self._halo_swap_array(self.ID[5])
def gather_grid_objects(self):
"""Gather sources and receivers."""
@@ -542,16 +530,6 @@ class MPIGrid(FDTDGrid):
self.hertziandipoles = self.gather_coord_objects(self.hertziandipoles)
self.transmissionlines = self.gather_coord_objects(self.transmissionlines)
def initialise_geometry_arrays(self, use_local_size=False):
# TODO: Remove this when scatter geometry arrays rather than broadcast
if use_local_size:
super().initialise_geometry_arrays()
else:
self.solid = np.ones(self.global_size, dtype=np.uint32)
self.rigidE = np.zeros((12, *self.global_size), dtype=np.int8)
self.rigidH = np.zeros((6, *self.global_size), dtype=np.int8)
self.ID = np.ones((6, *(self.global_size + 1)), dtype=np.uint32)
def _halo_swap(self, array: ndarray, dim: Dim, dir: Dir):
"""Perform a halo swap in the specifed dimension and direction.
@@ -786,11 +764,11 @@ class MPIGrid(FDTDGrid):
f" {self.lower_extent}, Upper extent: {self.upper_extent}"
)
def within_bounds(self, p: npt.NDArray[np.int32]) -> bool:
"""Check a point is within the grid.
def within_bounds(self, local_point: npt.NDArray[np.int32]) -> bool:
"""Check a local point is within the grid.
Args:
p: Point to check.
local_point: Point to check.
Returns:
within_bounds: True if the point is within the local grid
@@ -799,14 +777,18 @@ class MPIGrid(FDTDGrid):
Raises:
ValueError: Raised if the point is outside the global grid.
"""
if p[0] < 0 or p[0] > self.gx:
raise ValueError("x")
if p[1] < 0 or p[1] > self.gy:
raise ValueError("y")
if p[2] < 0 or p[2] > self.gz:
raise ValueError("z")
local_point = self.global_to_local_coordinate(p)
gx, gy, gz = self.local_to_global_coordinate(local_point)
print(local_point)
print(gx, gy, gz)
if gx < 0 or gx > self.gx:
raise ValueError("x")
if gy < 0 or gy > self.gy:
raise ValueError("y")
if gz < 0 or gz > self.gz:
raise ValueError("z")
return all(local_point >= self.negative_halo_offset) and all(local_point <= self.size)

查看文件

@@ -25,6 +25,7 @@ import numpy.typing as npt
from typing_extensions import TypeVar
from gprMax.grid.fdtd_grid import FDTDGrid
from gprMax.grid.mpi_grid import MPIGrid
from gprMax.subgrids.grid import SubGridBaseGrid
from .utilities.utilities import round_int
@@ -40,7 +41,7 @@ logger = logging.getLogger(__name__)
encapulsated here.
"""
GridType = TypeVar("GridType", bound=FDTDGrid, default=FDTDGrid)
GridType = TypeVar("GridType", bound=FDTDGrid)
class UserInput(Generic[GridType]):
@@ -182,6 +183,32 @@ class MainGridUserInput(UserInput[GridType]):
return p1, p2, p3
class MPIUserInput(MainGridUserInput[MPIGrid]):
"""Handles (x, y, z) points supplied by the user for MPI grids.
This class autotranslates points from the global coordinate system
to the grid's local coordinate system.
"""
def discretise_point(self, point: Tuple[float, float, float]) -> npt.NDArray[np.int32]:
"""Get the nearest grid index to a continuous static point.
This function translates user points to the correct index for
building objects. Points will be mapped from the global
coordinate space to the local coordinate space of the grid.
There are no checks of the validity of the point such as bound
checking.
Args:
point: x, y, z coordinates of the point in space.
Returns:
discretised_point: x, y, z indices of the point on the grid.
"""
discretised_point = super().discretise_point(point)
return self.grid.global_to_local_coordinate(discretised_point)
class SubgridUserInput(MainGridUserInput[SubGridBaseGrid]):
"""Handles (x, y, z) points supplied by the user in the subgrid.
This class autotranslates points from main grid to subgrid equivalent

查看文件

@@ -951,7 +951,9 @@ class Rx(RotatableMixin, GridUserObject):
r.coordorigin = coord
if self.id is None:
r.ID = f"{r.__class__.__name__}({str(r.xcoord)},{str(r.ycoord)},{str(r.zcoord)})"
uip = self._create_uip(grid)
x, y, z = uip.discretise_static_point(self.point)
r.ID = f"Rx({x},{y},{z})"
else:
r.ID = self.id
@@ -1033,7 +1035,7 @@ class RxArray(GridUserObject):
uip = self._create_uip(grid)
discretised_lower_point = uip.discretise_point(self.lower_point)
discretised_upper_point = uip.discretise_point(self.upper_point)
discretised_dl = uip.discretise_point(self.dl)
discretised_dl = uip.discretise_static_point(self.dl)
uip.check_src_rx_point(discretised_lower_point, self.params_str(), "lower")
uip.check_src_rx_point(discretised_upper_point, self.params_str(), "upper")

查看文件

@@ -123,7 +123,7 @@ class Domain(ModelUserObject):
def build(self, model: Model):
uip = self._create_uip(model.G)
discretised_domain_size = uip.discretise_point(self.domain_size)
discretised_domain_size = uip.discretise_static_point(self.domain_size)
model.set_size(discretised_domain_size)
@@ -519,7 +519,7 @@ class SrcSteps(ModelUserObject):
def build(self, model: Model):
uip = self._create_uip(model.G)
model.srcsteps = uip.discretise_point(self.step_size)
model.srcsteps = uip.discretise_static_point(self.step_size)
logger.info(
f"Simple sources will step {model.srcsteps[0] * model.dx:g}m, "
@@ -556,7 +556,7 @@ class RxSteps(ModelUserObject):
def build(self, model: Model):
uip = self._create_uip(model.G)
model.rxsteps = uip.discretise_point(self.step_size)
model.rxsteps = uip.discretise_static_point(self.step_size)
logger.info(
f"All receivers will step {model.rxsteps[0] * model.dx:g}m, "

查看文件

@@ -3,9 +3,10 @@ from typing import List, Union
from gprMax import config
from gprMax.grid.fdtd_grid import FDTDGrid
from gprMax.grid.mpi_grid import MPIGrid
from gprMax.model import Model
from gprMax.subgrids.grid import SubGridBaseGrid
from gprMax.user_inputs import MainGridUserInput, SubgridUserInput
from gprMax.user_inputs import MainGridUserInput, MPIUserInput, SubgridUserInput
class UserObject(ABC):
@@ -58,7 +59,7 @@ class UserObject(ABC):
"""Readable string of parameters given to object."""
return f"{self.hash}: {str(self.kwargs)}"
def _create_uip(self, grid: FDTDGrid) -> Union[SubgridUserInput, MainGridUserInput]:
def _create_uip(self, grid: FDTDGrid) -> MainGridUserInput:
"""Returns a point checker class based on the grid supplied.
Args:
@@ -77,6 +78,8 @@ class UserObject(ABC):
and self.autotranslate
):
return SubgridUserInput(grid)
elif isinstance(grid, MPIGrid):
return MPIUserInput(grid)
else:
return MainGridUserInput(grid)