Update VoltageSource UserObject for parallel build

这个提交包含在:
nmannall
2025-01-20 15:14:33 +00:00
父节点 f597246556
当前提交 619707a9d1
共有 4 个文件被更改,包括 80 次插入90 次删除

查看文件

@@ -448,7 +448,7 @@ class FDTDGrid:
logger.exception("Receiver(s) will be stepped to a position outside the domain.")
raise ValueError from e
def within_bounds(self, p: Tuple[int, int, int]) -> bool:
def within_bounds(self, p: npt.NDArray[np.int32]) -> bool:
"""Check a point is within the grid.
Args:
@@ -496,7 +496,7 @@ class FDTDGrid:
p_r = (p[0] * self.dx, p[1] * self.dy, p[2] * self.dz)
return p_r
def within_pml(self, p: Tuple[int, int, int]) -> bool:
def within_pml(self, p: npt.NDArray[np.int32]) -> bool:
"""Check if the provided point is within a PML.
Args:

查看文件

@@ -783,7 +783,7 @@ class MPIGrid(FDTDGrid):
f" {self.lower_extent}, Upper extent: {self.upper_extent}"
)
def within_bounds(self, p: Tuple[int, int, int]) -> bool:
def within_bounds(self, p: npt.NDArray[np.int32]) -> bool:
"""Check a point is within the grid.
Args:
@@ -803,11 +803,11 @@ class MPIGrid(FDTDGrid):
if p[2] < 0 or p[2] > self.gz:
raise ValueError("z")
local_point = self.global_to_local_coordinate(np.array(p, dtype=np.int32))
local_point = self.global_to_local_coordinate(p)
return all(local_point >= self.negative_halo_offset) and all(local_point <= self.size)
def within_pml(self, p: Tuple[int, int, int]) -> bool:
def within_pml(self, p: npt.NDArray[np.int32]) -> bool:
"""Check if the provided point is within a PML.
Args:
@@ -816,13 +816,12 @@ class MPIGrid(FDTDGrid):
Returns:
within_pml: True if the point is within a PML.
"""
local_point = self.global_to_local_coordinate(np.array(p))
p = (local_point[0], local_point[1], local_point[2])
local_point = self.global_to_local_coordinate(p)
# within_pml check will only be valid if the point is also
# within the local grid
return (
super().within_pml(p)
super().within_pml(local_point)
and all(local_point >= self.negative_halo_offset)
and all(local_point <= self.size)
)

查看文件

@@ -49,7 +49,9 @@ class UserInput(Generic[GridType]):
def __init__(self, grid: GridType):
self.grid = grid
def point_within_bounds(self, p, cmd_str, name, ignore_error=False) -> bool:
def point_within_bounds(
self, p: npt.NDArray[np.int32], cmd_str: str, name: str = "", ignore_error=False
) -> bool:
try:
return self.grid.within_bounds(p)
except ValueError as err:
@@ -99,15 +101,15 @@ class MainGridUserInput(UserInput[GridType]):
self.point_within_bounds(p, cmd_str, name)
return p
def check_src_rx_point(self, p, cmd_str, name=""):
p = self.check_point(p, cmd_str, name)
def check_src_rx_point(self, p: npt.NDArray[np.int32], cmd_str: str, name: str = "") -> bool:
within_grid = self.point_within_bounds(p, cmd_str, name)
if self.grid.within_pml(p):
logger.warning(
f"'{cmd_str}' sources and receivers should not normally be positioned within the PML."
)
return p
return within_grid
def check_box_points(self, p1, p2, cmd_str):
p1 = self.check_point(p1, cmd_str, name="lower")

查看文件

@@ -20,7 +20,7 @@ import inspect
import logging
from os import PathLike
from pathlib import Path
from typing import Optional, Union
from typing import Optional, Tuple, Union
import numpy as np
import numpy.typing as npt
@@ -311,75 +311,67 @@ class VoltageSource(RotatableMixin, GridUserObject):
def hash(self):
return "#voltage_source"
def __init__(self, **kwargs):
super().__init__(**kwargs)
def __init__(
self,
p1: Tuple[float, float, float],
polarisation: str,
resistance: float,
waveform_id: str,
start: Optional[float] = None,
stop: Optional[float] = None,
):
super().__init__(
p1=p1, polarisation=polarisation, resistance=resistance, waveform_id=waveform_id
)
self.point = p1
self.polarisation = polarisation.lower()
self.resistance = resistance
self.waveform_id = waveform_id
self.start = start
self.stop = stop
def _do_rotate(self, grid: FDTDGrid):
"""Performs rotation."""
rot_pol_pts, self.kwargs["polarisation"] = rotate_polarisation(
self.kwargs["p1"], self.kwargs["polarisation"], self.axis, self.angle, grid
rot_pol_pts, self.polarisation = rotate_polarisation(
self.point, self.polarisation, self.axis, self.angle, grid
)
rot_pts = rotate_2point_object(rot_pol_pts, self.axis, self.angle, self.origin)
self.kwargs["p1"] = tuple(rot_pts[0, :])
self.point = tuple(rot_pts[0, :])
def build(self, grid: FDTDGrid):
try:
p1 = self.kwargs["p1"]
polarisation = self.kwargs["polarisation"].lower()
resistance = self.kwargs["resistance"]
waveform_id = self.kwargs["waveform_id"]
except KeyError:
logger.exception(self.params_str() + (" requires at least six parameters."))
raise
if self.do_rotate:
self._do_rotate(grid)
# Check polarity & position parameters
if polarisation not in ("x", "y", "z"):
logger.exception(self.params_str() + (" polarisation must be x, y, or z."))
raise ValueError
if "2D TMx" in config.get_model_config().mode and polarisation in [
"y",
"z",
]:
logger.exception(self.params_str() + (" polarisation must be x in 2D TMx mode."))
raise ValueError
elif "2D TMy" in config.get_model_config().mode and polarisation in [
"x",
"z",
]:
logger.exception(self.params_str() + (" polarisation must be y in 2D TMy mode."))
raise ValueError
elif "2D TMz" in config.get_model_config().mode and polarisation in [
"x",
"y",
]:
logger.exception(self.params_str() + (" polarisation must be z in 2D TMz mode."))
raise ValueError
uip = self._create_uip(grid)
xcoord, ycoord, zcoord = uip.check_src_rx_point(p1, self.params_str())
p2 = uip.round_to_grid_static_point(p1)
discretised_point = uip.discretise_point(self.point)
if not uip.check_src_rx_point(discretised_point, self.params_str()):
return
if resistance < 0:
logger.exception(
self.params_str() + (" requires a source resistance of zero or greater.")
# Check polarity & position parameters
if self.polarisation not in ("x", "y", "z"):
raise ValueError(f"{self.params_str()} polarisation must be x, y, or z.")
if "2D TMx" in config.get_model_config().mode and self.polarisation in ["y", "z"]:
raise ValueError(f"{self.params_str()} polarisation must be x in 2D TMx mode.")
elif "2D TMy" in config.get_model_config().mode and self.polarisation in ["x", "z"]:
raise ValueError(f"{self.params_str()} polarisation must be y in 2D TMy mode.")
elif "2D TMz" in config.get_model_config().mode and self.polarisation in ["x", "y"]:
raise ValueError(f"{self.params_str()} polarisation must be z in 2D TMz mode.")
if self.resistance < 0:
raise ValueError(
f"{self.params_str()} requires a source resistance of zero or greater."
)
raise ValueError
# Check if there is a waveformID in the waveforms list
if not any(x.ID == waveform_id for x in grid.waveforms):
logger.exception(
self.params_str() + (" there is no waveform with the identifier {waveform_id}.")
if not any(x.ID == self.waveform_id for x in grid.waveforms):
raise ValueError(
f"{self.params_str()} there is no waveform with the identifier {self.waveform_id}."
)
raise ValueError
v = VoltageSourceUser()
v.polarisation = polarisation
v.xcoord = xcoord
v.ycoord = ycoord
v.zcoord = zcoord
v.polarisation = self.polarisation
v.coord = discretised_point
v.ID = (
v.__class__.__name__
+ "("
@@ -390,39 +382,36 @@ class VoltageSource(RotatableMixin, GridUserObject):
+ str(v.zcoord)
+ ")"
)
v.resistance = resistance
v.waveform = grid.get_waveform_by_id(waveform_id)
v.resistance = self.resistance
v.waveform = grid.get_waveform_by_id(self.waveform_id)
try:
start = self.kwargs["start"]
stop = self.kwargs["stop"]
# Check source start & source remove time parameters
if start < 0:
logger.exception(
self.params_str()
+ (" delay of the initiation of the source should not be less than zero.")
)
raise ValueError
if stop < 0:
logger.exception(
self.params_str() + (" time to remove the source should not be less than zero.")
)
raise ValueError
if stop - start <= 0:
logger.exception(
self.params_str() + (" duration of the source should not be zero or less.")
)
raise ValueError
v.start = start
v.stop = min(stop, grid.timewindow)
startstop = f" start time {v.start:g} secs, finish time {v.stop:g} secs "
except KeyError:
if self.start is None or self.stop is None:
v.start = 0
v.stop = grid.timewindow
startstop = " "
else:
# Check source start & source remove time parameters
if self.start < 0:
raise ValueError(
f"{self.params_str()} delay of the initiation of the source should not be less"
" than zero."
)
if self.stop < 0:
raise ValueError(
f"{self.params_str()} time to remove the source should not be less than zero."
)
if self.stop - self.start <= 0:
raise ValueError(
f"{self.params_str()} duration of the source should not be zero or less."
)
v.start = self.start
v.stop = min(self.stop, grid.timewindow)
startstop = f" start time {v.start:g} secs, finish time {v.stop:g} secs "
v.calculate_waveform_values(grid.iterations, grid.dt)
p2 = uip.discretised_to_continuous(discretised_point)
logger.info(
f"{self.grid_name(grid)}Voltage source with polarity "
f"{v.polarisation} at {p2[0]:g}m, {p2[1]:g}m, {p2[2]:g}m, "