Combine point discretisation and checking in uip objects

这个提交包含在:
nmannall
2025-02-07 15:35:08 +00:00
父节点 14b1f7e3d7
当前提交 340f56c155
共有 2 个文件被更改,包括 70 次插入37 次删除

查看文件

@@ -147,40 +147,61 @@ class MainGridUserInput(UserInput[GridType]):
def __init__(self, grid):
super().__init__(grid)
def check_point(self, p, cmd_str, name=""):
def check_point(
self, point: Tuple[float, float, float], cmd_str: str, name: str = ""
) -> Tuple[bool, npt.NDArray[np.int32]]:
"""Discretises point and check its within the domain"""
p = self.discretise_point(p)
self.point_within_bounds(p, cmd_str, name)
return p
discretised_point = self.discretise_point(point)
within_bounds = self.point_within_bounds(discretised_point, cmd_str, name)
return within_bounds, discretised_point
def check_src_rx_point(self, p: npt.NDArray[np.int32], cmd_str: str, name: str = "") -> bool:
within_grid = self.point_within_bounds(p, cmd_str, name)
def check_src_rx_point(
self, point: Tuple[float, float, float], cmd_str: str, name: str = ""
) -> Tuple[bool, npt.NDArray[np.int32]]:
within_bounds, discretised_point = self.check_point(point, cmd_str, name)
if self.grid.within_pml(p):
if self.grid.within_pml(discretised_point):
logger.warning(
f"'{cmd_str}' sources and receivers should not normally be positioned within the PML."
)
return within_grid
return within_bounds, discretised_point
def check_box_points(self, p1, p2, cmd_str):
p1 = self.check_point(p1, cmd_str, name="lower")
p2 = self.check_point(p2, cmd_str, name="upper")
def _check_2d_points(
self, p1: Tuple[float, float, float], p2: Tuple[float, float, float], cmd_str: str
) -> Tuple[bool, npt.NDArray[np.int32], npt.NDArray[np.int32]]:
lower_within_grid, lower_point = self.check_point(p1, cmd_str, "lower")
upper_within_grid, upper_point = self.check_point(p2, cmd_str, "upper")
if np.greater(p1, p2).any():
logger.exception(
if np.greater(lower_point, upper_point).any():
raise ValueError(
f"'{cmd_str}' the lower coordinates should be less than the upper coordinates."
)
raise ValueError
return p1, p2
return lower_within_grid and upper_within_grid, lower_point, upper_point
def check_tri_points(self, p1, p2, p3, cmd_str):
p1 = self.check_point(p1, cmd_str, name="vertex_1")
p2 = self.check_point(p2, cmd_str, name="vertex_2")
p3 = self.check_point(p3, cmd_str, name="vertex_3")
def check_box_points(
self, p1: Tuple[float, float, float], p2: Tuple[float, float, float], cmd_str: str
) -> Tuple[bool, npt.NDArray[np.int32], npt.NDArray[np.int32]]:
return self._check_2d_points(p1, p2, cmd_str)
return p1, p2, p3
def check_tri_points(
self,
p1: Tuple[float, float, float],
p2: Tuple[float, float, float],
p3: Tuple[float, float, float],
cmd_str: str,
) -> Tuple[bool, npt.NDArray[np.int32], npt.NDArray[np.int32], npt.NDArray[np.int32]]:
p1_within_grid, p1_checked = self.check_point(p1, cmd_str, name="vertex_1")
p2_within_grid, p2_checked = self.check_point(p2, cmd_str, name="vertex_2")
p3_within_grid, p3_checked = self.check_point(p3, cmd_str, name="vertex_3")
return (
p1_within_grid and p2_within_grid and p3_within_grid,
p1_checked,
p2_checked,
p3_checked,
)
class MPIUserInput(MainGridUserInput[MPIGrid]):
@@ -208,6 +229,17 @@ class MPIUserInput(MainGridUserInput[MPIGrid]):
discretised_point = super().discretise_point(point)
return self.grid.global_to_local_coordinate(discretised_point)
def check_box_points(
self, p1: Tuple[float, float, float], p2: Tuple[float, float, float], cmd_str: str
) -> Tuple[bool, npt.NDArray[np.int32], npt.NDArray[np.int32]]:
_, lower_point, upper_point = super().check_box_points(p1, p2, cmd_str)
# Restrict points to the bounds of the local grid
lower_point = np.where(lower_point < 0, 0, lower_point)
upper_point = np.where(upper_point > self.grid.size, self.grid.size, upper_point)
return all(lower_point < upper_point), lower_point, upper_point
class SubgridUserInput(MainGridUserInput[SubGridBaseGrid]):
"""Handles (x, y, z) points supplied by the user in the subgrid.

查看文件

@@ -433,9 +433,9 @@ class VoltageSource(RotatableMixin, GridUserObject):
# Check the position of the voltage source
uip = self._create_uip(grid)
discretised_point = uip.discretise_point(self.point)
point_within_grid, discretised_point = uip.check_src_rx_point(self.point, self.params_str())
if uip.check_src_rx_point(discretised_point, self.params_str()):
if point_within_grid:
self._validate_parameters(grid)
voltage_source = self._create_voltage_source(grid, discretised_point)
grid.add_source(voltage_source)
@@ -583,9 +583,9 @@ class HertzianDipole(RotatableMixin, GridUserObject):
# Check the position of the hertzian dipole
uip = self._create_uip(grid)
discretised_point = uip.discretise_point(self.point)
point_within_grid, discretised_point = uip.check_src_rx_point(self.point, self.params_str())
if uip.check_src_rx_point(discretised_point, self.params_str()):
if point_within_grid:
self._validate_parameters(grid)
hertzian_dipole = self._create_hertzian_dipole(grid, discretised_point)
grid.add_source(hertzian_dipole)
@@ -638,9 +638,9 @@ class MagneticDipole(RotatableMixin, GridUserObject):
# Check the position of the magnetic dipole
uip = self._create_uip(grid)
discretised_point = uip.discretise_point(self.point)
point_within_grid, discretised_point = uip.check_src_rx_point(self.point, self.params_str())
if uip.check_src_rx_point(discretised_point, self.params_str()):
if point_within_grid:
self._validate_parameters(grid)
magnetic_dipole = self._create_magnetic_dipole(grid, discretised_point)
grid.add_source(magnetic_dipole)
@@ -793,9 +793,9 @@ class TransmissionLine(RotatableMixin, GridUserObject):
# Check the position of the voltage source
uip = self._create_uip(grid)
discretised_point = uip.discretise_point(self.point)
point_within_grid, discretised_point = uip.check_src_rx_point(self.point, self.params_str())
if uip.check_src_rx_point(discretised_point, self.params_str()):
if point_within_grid:
self._validate_parameters(grid)
transmission_line = self._create_transmission_line(grid, discretised_point)
grid.add_source(transmission_line)
@@ -988,9 +988,9 @@ class Rx(RotatableMixin, GridUserObject):
# Check position of the receiver
uip = self._create_uip(grid)
discretised_point = uip.discretise_point(self.point)
point_within_grid, discretised_point = uip.check_src_rx_point(self.point, self.params_str())
if uip.check_src_rx_point(discretised_point, self.params_str()):
if point_within_grid:
receiver = self._create_receiver(grid, discretised_point)
grid.add_receiver(receiver)
@@ -1033,13 +1033,14 @@ class RxArray(GridUserObject):
def build(self, grid: FDTDGrid):
uip = self._create_uip(grid)
discretised_lower_point = uip.discretise_point(self.lower_point)
discretised_upper_point = uip.discretise_point(self.upper_point)
_, discretised_lower_point = uip.check_src_rx_point(
self.lower_point, self.params_str(), "lower"
)
_, discretised_upper_point = uip.check_src_rx_point(
self.lower_point, self.params_str(), "upper"
)
discretised_dl = uip.discretise_static_point(self.dl)
uip.check_src_rx_point(discretised_lower_point, self.params_str(), "lower")
uip.check_src_rx_point(discretised_upper_point, self.params_str(), "upper")
if any(discretised_lower_point > discretised_upper_point):
raise ValueError(
f"{self.params_str()} the lower coordinates should be less than the upper coordinates."
@@ -1068,7 +1069,7 @@ class RxArray(GridUserObject):
for x in range(xs, xf + grid.dx, dx):
for y in range(ys, yf + grid.dy, dy):
for z in range(zs, zf + grid.dz, dz):
receiver = Rx((x, y, x))
receiver = Rx((x, y, z))
receiver.build(grid)
@@ -1128,7 +1129,7 @@ class Snapshot(GridUserObject):
dl = uip.discretise_static_point(dl)
try:
p1, p2 = uip.check_box_points(p1, p2, self.params_str())
_, p1, p2 = uip.check_box_points(p1, p2, self.params_str())
except ValueError:
logger.exception(f"{self.params_str()} point is outside the domain.")
raise