Update VoltageSource UserObject for parallel build

这个提交包含在:
nmannall
2025-01-20 15:14:33 +00:00
父节点 f597246556
当前提交 619707a9d1
共有 4 个文件被更改,包括 80 次插入90 次删除

查看文件

@@ -448,7 +448,7 @@ class FDTDGrid:
logger.exception("Receiver(s) will be stepped to a position outside the domain.") logger.exception("Receiver(s) will be stepped to a position outside the domain.")
raise ValueError from e raise ValueError from e
def within_bounds(self, p: Tuple[int, int, int]) -> bool: def within_bounds(self, p: npt.NDArray[np.int32]) -> bool:
"""Check a point is within the grid. """Check a point is within the grid.
Args: Args:
@@ -496,7 +496,7 @@ class FDTDGrid:
p_r = (p[0] * self.dx, p[1] * self.dy, p[2] * self.dz) p_r = (p[0] * self.dx, p[1] * self.dy, p[2] * self.dz)
return p_r return p_r
def within_pml(self, p: Tuple[int, int, int]) -> bool: def within_pml(self, p: npt.NDArray[np.int32]) -> bool:
"""Check if the provided point is within a PML. """Check if the provided point is within a PML.
Args: Args:

查看文件

@@ -783,7 +783,7 @@ class MPIGrid(FDTDGrid):
f" {self.lower_extent}, Upper extent: {self.upper_extent}" f" {self.lower_extent}, Upper extent: {self.upper_extent}"
) )
def within_bounds(self, p: Tuple[int, int, int]) -> bool: def within_bounds(self, p: npt.NDArray[np.int32]) -> bool:
"""Check a point is within the grid. """Check a point is within the grid.
Args: Args:
@@ -803,11 +803,11 @@ class MPIGrid(FDTDGrid):
if p[2] < 0 or p[2] > self.gz: if p[2] < 0 or p[2] > self.gz:
raise ValueError("z") raise ValueError("z")
local_point = self.global_to_local_coordinate(np.array(p, dtype=np.int32)) local_point = self.global_to_local_coordinate(p)
return all(local_point >= self.negative_halo_offset) and all(local_point <= self.size) return all(local_point >= self.negative_halo_offset) and all(local_point <= self.size)
def within_pml(self, p: Tuple[int, int, int]) -> bool: def within_pml(self, p: npt.NDArray[np.int32]) -> bool:
"""Check if the provided point is within a PML. """Check if the provided point is within a PML.
Args: Args:
@@ -816,13 +816,12 @@ class MPIGrid(FDTDGrid):
Returns: Returns:
within_pml: True if the point is within a PML. within_pml: True if the point is within a PML.
""" """
local_point = self.global_to_local_coordinate(np.array(p)) local_point = self.global_to_local_coordinate(p)
p = (local_point[0], local_point[1], local_point[2])
# within_pml check will only be valid if the point is also # within_pml check will only be valid if the point is also
# within the local grid # within the local grid
return ( return (
super().within_pml(p) super().within_pml(local_point)
and all(local_point >= self.negative_halo_offset) and all(local_point >= self.negative_halo_offset)
and all(local_point <= self.size) and all(local_point <= self.size)
) )

查看文件

@@ -49,7 +49,9 @@ class UserInput(Generic[GridType]):
def __init__(self, grid: GridType): def __init__(self, grid: GridType):
self.grid = grid self.grid = grid
def point_within_bounds(self, p, cmd_str, name, ignore_error=False) -> bool: def point_within_bounds(
self, p: npt.NDArray[np.int32], cmd_str: str, name: str = "", ignore_error=False
) -> bool:
try: try:
return self.grid.within_bounds(p) return self.grid.within_bounds(p)
except ValueError as err: except ValueError as err:
@@ -99,15 +101,15 @@ class MainGridUserInput(UserInput[GridType]):
self.point_within_bounds(p, cmd_str, name) self.point_within_bounds(p, cmd_str, name)
return p return p
def check_src_rx_point(self, p, cmd_str, name=""): def check_src_rx_point(self, p: npt.NDArray[np.int32], cmd_str: str, name: str = "") -> bool:
p = self.check_point(p, cmd_str, name) within_grid = self.point_within_bounds(p, cmd_str, name)
if self.grid.within_pml(p): if self.grid.within_pml(p):
logger.warning( logger.warning(
f"'{cmd_str}' sources and receivers should not normally be positioned within the PML." f"'{cmd_str}' sources and receivers should not normally be positioned within the PML."
) )
return p return within_grid
def check_box_points(self, p1, p2, cmd_str): def check_box_points(self, p1, p2, cmd_str):
p1 = self.check_point(p1, cmd_str, name="lower") p1 = self.check_point(p1, cmd_str, name="lower")

查看文件

@@ -20,7 +20,7 @@ import inspect
import logging import logging
from os import PathLike from os import PathLike
from pathlib import Path from pathlib import Path
from typing import Optional, Union from typing import Optional, Tuple, Union
import numpy as np import numpy as np
import numpy.typing as npt import numpy.typing as npt
@@ -311,75 +311,67 @@ class VoltageSource(RotatableMixin, GridUserObject):
def hash(self): def hash(self):
return "#voltage_source" return "#voltage_source"
def __init__(self, **kwargs): def __init__(
super().__init__(**kwargs) self,
p1: Tuple[float, float, float],
polarisation: str,
resistance: float,
waveform_id: str,
start: Optional[float] = None,
stop: Optional[float] = None,
):
super().__init__(
p1=p1, polarisation=polarisation, resistance=resistance, waveform_id=waveform_id
)
self.point = p1
self.polarisation = polarisation.lower()
self.resistance = resistance
self.waveform_id = waveform_id
self.start = start
self.stop = stop
def _do_rotate(self, grid: FDTDGrid): def _do_rotate(self, grid: FDTDGrid):
"""Performs rotation.""" """Performs rotation."""
rot_pol_pts, self.kwargs["polarisation"] = rotate_polarisation( rot_pol_pts, self.polarisation = rotate_polarisation(
self.kwargs["p1"], self.kwargs["polarisation"], self.axis, self.angle, grid self.point, self.polarisation, self.axis, self.angle, grid
) )
rot_pts = rotate_2point_object(rot_pol_pts, self.axis, self.angle, self.origin) rot_pts = rotate_2point_object(rot_pol_pts, self.axis, self.angle, self.origin)
self.kwargs["p1"] = tuple(rot_pts[0, :]) self.point = tuple(rot_pts[0, :])
def build(self, grid: FDTDGrid): def build(self, grid: FDTDGrid):
try:
p1 = self.kwargs["p1"]
polarisation = self.kwargs["polarisation"].lower()
resistance = self.kwargs["resistance"]
waveform_id = self.kwargs["waveform_id"]
except KeyError:
logger.exception(self.params_str() + (" requires at least six parameters."))
raise
if self.do_rotate: if self.do_rotate:
self._do_rotate(grid) self._do_rotate(grid)
# Check polarity & position parameters
if polarisation not in ("x", "y", "z"):
logger.exception(self.params_str() + (" polarisation must be x, y, or z."))
raise ValueError
if "2D TMx" in config.get_model_config().mode and polarisation in [
"y",
"z",
]:
logger.exception(self.params_str() + (" polarisation must be x in 2D TMx mode."))
raise ValueError
elif "2D TMy" in config.get_model_config().mode and polarisation in [
"x",
"z",
]:
logger.exception(self.params_str() + (" polarisation must be y in 2D TMy mode."))
raise ValueError
elif "2D TMz" in config.get_model_config().mode and polarisation in [
"x",
"y",
]:
logger.exception(self.params_str() + (" polarisation must be z in 2D TMz mode."))
raise ValueError
uip = self._create_uip(grid) uip = self._create_uip(grid)
xcoord, ycoord, zcoord = uip.check_src_rx_point(p1, self.params_str()) discretised_point = uip.discretise_point(self.point)
p2 = uip.round_to_grid_static_point(p1) if not uip.check_src_rx_point(discretised_point, self.params_str()):
return
if resistance < 0: # Check polarity & position parameters
logger.exception( if self.polarisation not in ("x", "y", "z"):
self.params_str() + (" requires a source resistance of zero or greater.") raise ValueError(f"{self.params_str()} polarisation must be x, y, or z.")
if "2D TMx" in config.get_model_config().mode and self.polarisation in ["y", "z"]:
raise ValueError(f"{self.params_str()} polarisation must be x in 2D TMx mode.")
elif "2D TMy" in config.get_model_config().mode and self.polarisation in ["x", "z"]:
raise ValueError(f"{self.params_str()} polarisation must be y in 2D TMy mode.")
elif "2D TMz" in config.get_model_config().mode and self.polarisation in ["x", "y"]:
raise ValueError(f"{self.params_str()} polarisation must be z in 2D TMz mode.")
if self.resistance < 0:
raise ValueError(
f"{self.params_str()} requires a source resistance of zero or greater."
) )
raise ValueError
# Check if there is a waveformID in the waveforms list # Check if there is a waveformID in the waveforms list
if not any(x.ID == waveform_id for x in grid.waveforms): if not any(x.ID == self.waveform_id for x in grid.waveforms):
logger.exception( raise ValueError(
self.params_str() + (" there is no waveform with the identifier {waveform_id}.") f"{self.params_str()} there is no waveform with the identifier {self.waveform_id}."
) )
raise ValueError
v = VoltageSourceUser() v = VoltageSourceUser()
v.polarisation = polarisation v.polarisation = self.polarisation
v.xcoord = xcoord v.coord = discretised_point
v.ycoord = ycoord
v.zcoord = zcoord
v.ID = ( v.ID = (
v.__class__.__name__ v.__class__.__name__
+ "(" + "("
@@ -390,39 +382,36 @@ class VoltageSource(RotatableMixin, GridUserObject):
+ str(v.zcoord) + str(v.zcoord)
+ ")" + ")"
) )
v.resistance = resistance v.resistance = self.resistance
v.waveform = grid.get_waveform_by_id(waveform_id) v.waveform = grid.get_waveform_by_id(self.waveform_id)
try: if self.start is None or self.stop is None:
start = self.kwargs["start"]
stop = self.kwargs["stop"]
# Check source start & source remove time parameters
if start < 0:
logger.exception(
self.params_str()
+ (" delay of the initiation of the source should not be less than zero.")
)
raise ValueError
if stop < 0:
logger.exception(
self.params_str() + (" time to remove the source should not be less than zero.")
)
raise ValueError
if stop - start <= 0:
logger.exception(
self.params_str() + (" duration of the source should not be zero or less.")
)
raise ValueError
v.start = start
v.stop = min(stop, grid.timewindow)
startstop = f" start time {v.start:g} secs, finish time {v.stop:g} secs "
except KeyError:
v.start = 0 v.start = 0
v.stop = grid.timewindow v.stop = grid.timewindow
startstop = " " startstop = " "
else:
# Check source start & source remove time parameters
if self.start < 0:
raise ValueError(
f"{self.params_str()} delay of the initiation of the source should not be less"
" than zero."
)
if self.stop < 0:
raise ValueError(
f"{self.params_str()} time to remove the source should not be less than zero."
)
if self.stop - self.start <= 0:
raise ValueError(
f"{self.params_str()} duration of the source should not be zero or less."
)
v.start = self.start
v.stop = min(self.stop, grid.timewindow)
startstop = f" start time {v.start:g} secs, finish time {v.stop:g} secs "
v.calculate_waveform_values(grid.iterations, grid.dt) v.calculate_waveform_values(grid.iterations, grid.dt)
p2 = uip.discretised_to_continuous(discretised_point)
logger.info( logger.info(
f"{self.grid_name(grid)}Voltage source with polarity " f"{self.grid_name(grid)}Voltage source with polarity "
f"{v.polarisation} at {p2[0]:g}m, {p2[1]:g}m, {p2[2]:g}m, " f"{v.polarisation} at {p2[0]:g}m, {p2[1]:g}m, {p2[2]:g}m, "