你已经派生过 gprMax
镜像自地址
https://gitee.com/sunhf/gprMax.git
已同步 2025-08-07 15:10:13 +08:00
Add MPIUserObject
这个提交包含在:
@@ -155,16 +155,6 @@ class MPIGrid(FDTDGrid):
|
||||
if self.has_neighbour(Dim.Z, Dir.POS):
|
||||
self.pmls["thickness"]["zmax"] = 0
|
||||
|
||||
def add_source(self, source: Source):
|
||||
source.coord = self.global_to_local_coordinate(source.coord)
|
||||
source.coordorigin = self.global_to_local_coordinate(source.coordorigin)
|
||||
return super().add_source(source)
|
||||
|
||||
def add_receiver(self, receiver: Rx):
|
||||
receiver.coord = self.global_to_local_coordinate(receiver.coord)
|
||||
receiver.coordorigin = self.global_to_local_coordinate(receiver.coordorigin)
|
||||
return super().add_receiver(receiver)
|
||||
|
||||
def is_coordinator(self) -> bool:
|
||||
"""Test if the current rank is the coordinator.
|
||||
|
||||
@@ -524,14 +514,12 @@ class MPIGrid(FDTDGrid):
|
||||
"""
|
||||
self.scatter_snapshots()
|
||||
|
||||
if not self.is_coordinator():
|
||||
# TODO: When scatter arrays properly, should initialise these to the local grid size
|
||||
self.initialise_geometry_arrays()
|
||||
|
||||
self.ID = self.scatter_4d_array_with_positive_halo(self.ID)
|
||||
self.solid = self.scatter_3d_array(self.solid)
|
||||
self.rigidE = self.scatter_4d_array(self.rigidE)
|
||||
self.rigidH = self.scatter_4d_array(self.rigidH)
|
||||
# self._halo_swap_array(self.ID[0])
|
||||
# self._halo_swap_array(self.ID[1])
|
||||
# self._halo_swap_array(self.ID[2])
|
||||
# self._halo_swap_array(self.ID[3])
|
||||
# self._halo_swap_array(self.ID[4])
|
||||
# self._halo_swap_array(self.ID[5])
|
||||
|
||||
def gather_grid_objects(self):
|
||||
"""Gather sources and receivers."""
|
||||
@@ -542,16 +530,6 @@ class MPIGrid(FDTDGrid):
|
||||
self.hertziandipoles = self.gather_coord_objects(self.hertziandipoles)
|
||||
self.transmissionlines = self.gather_coord_objects(self.transmissionlines)
|
||||
|
||||
def initialise_geometry_arrays(self, use_local_size=False):
|
||||
# TODO: Remove this when scatter geometry arrays rather than broadcast
|
||||
if use_local_size:
|
||||
super().initialise_geometry_arrays()
|
||||
else:
|
||||
self.solid = np.ones(self.global_size, dtype=np.uint32)
|
||||
self.rigidE = np.zeros((12, *self.global_size), dtype=np.int8)
|
||||
self.rigidH = np.zeros((6, *self.global_size), dtype=np.int8)
|
||||
self.ID = np.ones((6, *(self.global_size + 1)), dtype=np.uint32)
|
||||
|
||||
def _halo_swap(self, array: ndarray, dim: Dim, dir: Dir):
|
||||
"""Perform a halo swap in the specifed dimension and direction.
|
||||
|
||||
@@ -786,11 +764,11 @@ class MPIGrid(FDTDGrid):
|
||||
f" {self.lower_extent}, Upper extent: {self.upper_extent}"
|
||||
)
|
||||
|
||||
def within_bounds(self, p: npt.NDArray[np.int32]) -> bool:
|
||||
"""Check a point is within the grid.
|
||||
def within_bounds(self, local_point: npt.NDArray[np.int32]) -> bool:
|
||||
"""Check a local point is within the grid.
|
||||
|
||||
Args:
|
||||
p: Point to check.
|
||||
local_point: Point to check.
|
||||
|
||||
Returns:
|
||||
within_bounds: True if the point is within the local grid
|
||||
@@ -799,14 +777,18 @@ class MPIGrid(FDTDGrid):
|
||||
Raises:
|
||||
ValueError: Raised if the point is outside the global grid.
|
||||
"""
|
||||
if p[0] < 0 or p[0] > self.gx:
|
||||
raise ValueError("x")
|
||||
if p[1] < 0 or p[1] > self.gy:
|
||||
raise ValueError("y")
|
||||
if p[2] < 0 or p[2] > self.gz:
|
||||
raise ValueError("z")
|
||||
|
||||
local_point = self.global_to_local_coordinate(p)
|
||||
gx, gy, gz = self.local_to_global_coordinate(local_point)
|
||||
|
||||
print(local_point)
|
||||
print(gx, gy, gz)
|
||||
|
||||
if gx < 0 or gx > self.gx:
|
||||
raise ValueError("x")
|
||||
if gy < 0 or gy > self.gy:
|
||||
raise ValueError("y")
|
||||
if gz < 0 or gz > self.gz:
|
||||
raise ValueError("z")
|
||||
|
||||
return all(local_point >= self.negative_halo_offset) and all(local_point <= self.size)
|
||||
|
||||
|
@@ -25,6 +25,7 @@ import numpy.typing as npt
|
||||
from typing_extensions import TypeVar
|
||||
|
||||
from gprMax.grid.fdtd_grid import FDTDGrid
|
||||
from gprMax.grid.mpi_grid import MPIGrid
|
||||
from gprMax.subgrids.grid import SubGridBaseGrid
|
||||
|
||||
from .utilities.utilities import round_int
|
||||
@@ -40,7 +41,7 @@ logger = logging.getLogger(__name__)
|
||||
encapulsated here.
|
||||
"""
|
||||
|
||||
GridType = TypeVar("GridType", bound=FDTDGrid, default=FDTDGrid)
|
||||
GridType = TypeVar("GridType", bound=FDTDGrid)
|
||||
|
||||
|
||||
class UserInput(Generic[GridType]):
|
||||
@@ -182,6 +183,32 @@ class MainGridUserInput(UserInput[GridType]):
|
||||
return p1, p2, p3
|
||||
|
||||
|
||||
class MPIUserInput(MainGridUserInput[MPIGrid]):
|
||||
"""Handles (x, y, z) points supplied by the user for MPI grids.
|
||||
|
||||
This class autotranslates points from the global coordinate system
|
||||
to the grid's local coordinate system.
|
||||
"""
|
||||
|
||||
def discretise_point(self, point: Tuple[float, float, float]) -> npt.NDArray[np.int32]:
|
||||
"""Get the nearest grid index to a continuous static point.
|
||||
|
||||
This function translates user points to the correct index for
|
||||
building objects. Points will be mapped from the global
|
||||
coordinate space to the local coordinate space of the grid.
|
||||
There are no checks of the validity of the point such as bound
|
||||
checking.
|
||||
|
||||
Args:
|
||||
point: x, y, z coordinates of the point in space.
|
||||
|
||||
Returns:
|
||||
discretised_point: x, y, z indices of the point on the grid.
|
||||
"""
|
||||
discretised_point = super().discretise_point(point)
|
||||
return self.grid.global_to_local_coordinate(discretised_point)
|
||||
|
||||
|
||||
class SubgridUserInput(MainGridUserInput[SubGridBaseGrid]):
|
||||
"""Handles (x, y, z) points supplied by the user in the subgrid.
|
||||
This class autotranslates points from main grid to subgrid equivalent
|
||||
|
@@ -951,7 +951,9 @@ class Rx(RotatableMixin, GridUserObject):
|
||||
r.coordorigin = coord
|
||||
|
||||
if self.id is None:
|
||||
r.ID = f"{r.__class__.__name__}({str(r.xcoord)},{str(r.ycoord)},{str(r.zcoord)})"
|
||||
uip = self._create_uip(grid)
|
||||
x, y, z = uip.discretise_static_point(self.point)
|
||||
r.ID = f"Rx({x},{y},{z})"
|
||||
else:
|
||||
r.ID = self.id
|
||||
|
||||
@@ -1033,7 +1035,7 @@ class RxArray(GridUserObject):
|
||||
uip = self._create_uip(grid)
|
||||
discretised_lower_point = uip.discretise_point(self.lower_point)
|
||||
discretised_upper_point = uip.discretise_point(self.upper_point)
|
||||
discretised_dl = uip.discretise_point(self.dl)
|
||||
discretised_dl = uip.discretise_static_point(self.dl)
|
||||
|
||||
uip.check_src_rx_point(discretised_lower_point, self.params_str(), "lower")
|
||||
uip.check_src_rx_point(discretised_upper_point, self.params_str(), "upper")
|
||||
|
@@ -123,7 +123,7 @@ class Domain(ModelUserObject):
|
||||
def build(self, model: Model):
|
||||
uip = self._create_uip(model.G)
|
||||
|
||||
discretised_domain_size = uip.discretise_point(self.domain_size)
|
||||
discretised_domain_size = uip.discretise_static_point(self.domain_size)
|
||||
|
||||
model.set_size(discretised_domain_size)
|
||||
|
||||
@@ -519,7 +519,7 @@ class SrcSteps(ModelUserObject):
|
||||
|
||||
def build(self, model: Model):
|
||||
uip = self._create_uip(model.G)
|
||||
model.srcsteps = uip.discretise_point(self.step_size)
|
||||
model.srcsteps = uip.discretise_static_point(self.step_size)
|
||||
|
||||
logger.info(
|
||||
f"Simple sources will step {model.srcsteps[0] * model.dx:g}m, "
|
||||
@@ -556,7 +556,7 @@ class RxSteps(ModelUserObject):
|
||||
|
||||
def build(self, model: Model):
|
||||
uip = self._create_uip(model.G)
|
||||
model.rxsteps = uip.discretise_point(self.step_size)
|
||||
model.rxsteps = uip.discretise_static_point(self.step_size)
|
||||
|
||||
logger.info(
|
||||
f"All receivers will step {model.rxsteps[0] * model.dx:g}m, "
|
||||
|
@@ -3,9 +3,10 @@ from typing import List, Union
|
||||
|
||||
from gprMax import config
|
||||
from gprMax.grid.fdtd_grid import FDTDGrid
|
||||
from gprMax.grid.mpi_grid import MPIGrid
|
||||
from gprMax.model import Model
|
||||
from gprMax.subgrids.grid import SubGridBaseGrid
|
||||
from gprMax.user_inputs import MainGridUserInput, SubgridUserInput
|
||||
from gprMax.user_inputs import MainGridUserInput, MPIUserInput, SubgridUserInput
|
||||
|
||||
|
||||
class UserObject(ABC):
|
||||
@@ -58,7 +59,7 @@ class UserObject(ABC):
|
||||
"""Readable string of parameters given to object."""
|
||||
return f"{self.hash}: {str(self.kwargs)}"
|
||||
|
||||
def _create_uip(self, grid: FDTDGrid) -> Union[SubgridUserInput, MainGridUserInput]:
|
||||
def _create_uip(self, grid: FDTDGrid) -> MainGridUserInput:
|
||||
"""Returns a point checker class based on the grid supplied.
|
||||
|
||||
Args:
|
||||
@@ -77,6 +78,8 @@ class UserObject(ABC):
|
||||
and self.autotranslate
|
||||
):
|
||||
return SubgridUserInput(grid)
|
||||
elif isinstance(grid, MPIGrid):
|
||||
return MPIUserInput(grid)
|
||||
else:
|
||||
return MainGridUserInput(grid)
|
||||
|
||||
|
在新工单中引用
屏蔽一个用户