From 619707a9d18745904c3a45b14893ccb1644eb0ef Mon Sep 17 00:00:00 2001 From: nmannall Date: Mon, 20 Jan 2025 15:14:33 +0000 Subject: [PATCH] Update VoltageSource UserObject for parallel build --- gprMax/grid/fdtd_grid.py | 4 +- gprMax/grid/mpi_grid.py | 11 +- gprMax/user_inputs.py | 10 +- gprMax/user_objects/cmds_multiuse.py | 145 +++++++++++++-------------- 4 files changed, 80 insertions(+), 90 deletions(-) diff --git a/gprMax/grid/fdtd_grid.py b/gprMax/grid/fdtd_grid.py index 331c3ebf..92cef7b7 100644 --- a/gprMax/grid/fdtd_grid.py +++ b/gprMax/grid/fdtd_grid.py @@ -448,7 +448,7 @@ class FDTDGrid: logger.exception("Receiver(s) will be stepped to a position outside the domain.") raise ValueError from e - def within_bounds(self, p: Tuple[int, int, int]) -> bool: + def within_bounds(self, p: npt.NDArray[np.int32]) -> bool: """Check a point is within the grid. Args: @@ -496,7 +496,7 @@ class FDTDGrid: p_r = (p[0] * self.dx, p[1] * self.dy, p[2] * self.dz) return p_r - def within_pml(self, p: Tuple[int, int, int]) -> bool: + def within_pml(self, p: npt.NDArray[np.int32]) -> bool: """Check if the provided point is within a PML. Args: diff --git a/gprMax/grid/mpi_grid.py b/gprMax/grid/mpi_grid.py index c3e3d1a8..c22efecb 100644 --- a/gprMax/grid/mpi_grid.py +++ b/gprMax/grid/mpi_grid.py @@ -783,7 +783,7 @@ class MPIGrid(FDTDGrid): f" {self.lower_extent}, Upper extent: {self.upper_extent}" ) - def within_bounds(self, p: Tuple[int, int, int]) -> bool: + def within_bounds(self, p: npt.NDArray[np.int32]) -> bool: """Check a point is within the grid. Args: @@ -803,11 +803,11 @@ class MPIGrid(FDTDGrid): if p[2] < 0 or p[2] > self.gz: raise ValueError("z") - local_point = self.global_to_local_coordinate(np.array(p, dtype=np.int32)) + local_point = self.global_to_local_coordinate(p) return all(local_point >= self.negative_halo_offset) and all(local_point <= self.size) - def within_pml(self, p: Tuple[int, int, int]) -> bool: + def within_pml(self, p: npt.NDArray[np.int32]) -> bool: """Check if the provided point is within a PML. Args: @@ -816,13 +816,12 @@ class MPIGrid(FDTDGrid): Returns: within_pml: True if the point is within a PML. """ - local_point = self.global_to_local_coordinate(np.array(p)) - p = (local_point[0], local_point[1], local_point[2]) + local_point = self.global_to_local_coordinate(p) # within_pml check will only be valid if the point is also # within the local grid return ( - super().within_pml(p) + super().within_pml(local_point) and all(local_point >= self.negative_halo_offset) and all(local_point <= self.size) ) diff --git a/gprMax/user_inputs.py b/gprMax/user_inputs.py index bd53c6fc..e387117e 100644 --- a/gprMax/user_inputs.py +++ b/gprMax/user_inputs.py @@ -49,7 +49,9 @@ class UserInput(Generic[GridType]): def __init__(self, grid: GridType): self.grid = grid - def point_within_bounds(self, p, cmd_str, name, ignore_error=False) -> bool: + def point_within_bounds( + self, p: npt.NDArray[np.int32], cmd_str: str, name: str = "", ignore_error=False + ) -> bool: try: return self.grid.within_bounds(p) except ValueError as err: @@ -99,15 +101,15 @@ class MainGridUserInput(UserInput[GridType]): self.point_within_bounds(p, cmd_str, name) return p - def check_src_rx_point(self, p, cmd_str, name=""): - p = self.check_point(p, cmd_str, name) + def check_src_rx_point(self, p: npt.NDArray[np.int32], cmd_str: str, name: str = "") -> bool: + within_grid = self.point_within_bounds(p, cmd_str, name) if self.grid.within_pml(p): logger.warning( f"'{cmd_str}' sources and receivers should not normally be positioned within the PML." ) - return p + return within_grid def check_box_points(self, p1, p2, cmd_str): p1 = self.check_point(p1, cmd_str, name="lower") diff --git a/gprMax/user_objects/cmds_multiuse.py b/gprMax/user_objects/cmds_multiuse.py index 48940433..bc1e9570 100644 --- a/gprMax/user_objects/cmds_multiuse.py +++ b/gprMax/user_objects/cmds_multiuse.py @@ -20,7 +20,7 @@ import inspect import logging from os import PathLike from pathlib import Path -from typing import Optional, Union +from typing import Optional, Tuple, Union import numpy as np import numpy.typing as npt @@ -311,75 +311,67 @@ class VoltageSource(RotatableMixin, GridUserObject): def hash(self): return "#voltage_source" - def __init__(self, **kwargs): - super().__init__(**kwargs) + def __init__( + self, + p1: Tuple[float, float, float], + polarisation: str, + resistance: float, + waveform_id: str, + start: Optional[float] = None, + stop: Optional[float] = None, + ): + super().__init__( + p1=p1, polarisation=polarisation, resistance=resistance, waveform_id=waveform_id + ) + + self.point = p1 + self.polarisation = polarisation.lower() + self.resistance = resistance + self.waveform_id = waveform_id + self.start = start + self.stop = stop def _do_rotate(self, grid: FDTDGrid): """Performs rotation.""" - rot_pol_pts, self.kwargs["polarisation"] = rotate_polarisation( - self.kwargs["p1"], self.kwargs["polarisation"], self.axis, self.angle, grid + rot_pol_pts, self.polarisation = rotate_polarisation( + self.point, self.polarisation, self.axis, self.angle, grid ) rot_pts = rotate_2point_object(rot_pol_pts, self.axis, self.angle, self.origin) - self.kwargs["p1"] = tuple(rot_pts[0, :]) + self.point = tuple(rot_pts[0, :]) def build(self, grid: FDTDGrid): - try: - p1 = self.kwargs["p1"] - polarisation = self.kwargs["polarisation"].lower() - resistance = self.kwargs["resistance"] - waveform_id = self.kwargs["waveform_id"] - except KeyError: - logger.exception(self.params_str() + (" requires at least six parameters.")) - raise - if self.do_rotate: self._do_rotate(grid) - # Check polarity & position parameters - if polarisation not in ("x", "y", "z"): - logger.exception(self.params_str() + (" polarisation must be x, y, or z.")) - raise ValueError - if "2D TMx" in config.get_model_config().mode and polarisation in [ - "y", - "z", - ]: - logger.exception(self.params_str() + (" polarisation must be x in 2D TMx mode.")) - raise ValueError - elif "2D TMy" in config.get_model_config().mode and polarisation in [ - "x", - "z", - ]: - logger.exception(self.params_str() + (" polarisation must be y in 2D TMy mode.")) - raise ValueError - elif "2D TMz" in config.get_model_config().mode and polarisation in [ - "x", - "y", - ]: - logger.exception(self.params_str() + (" polarisation must be z in 2D TMz mode.")) - raise ValueError - uip = self._create_uip(grid) - xcoord, ycoord, zcoord = uip.check_src_rx_point(p1, self.params_str()) - p2 = uip.round_to_grid_static_point(p1) + discretised_point = uip.discretise_point(self.point) + if not uip.check_src_rx_point(discretised_point, self.params_str()): + return - if resistance < 0: - logger.exception( - self.params_str() + (" requires a source resistance of zero or greater.") + # Check polarity & position parameters + if self.polarisation not in ("x", "y", "z"): + raise ValueError(f"{self.params_str()} polarisation must be x, y, or z.") + if "2D TMx" in config.get_model_config().mode and self.polarisation in ["y", "z"]: + raise ValueError(f"{self.params_str()} polarisation must be x in 2D TMx mode.") + elif "2D TMy" in config.get_model_config().mode and self.polarisation in ["x", "z"]: + raise ValueError(f"{self.params_str()} polarisation must be y in 2D TMy mode.") + elif "2D TMz" in config.get_model_config().mode and self.polarisation in ["x", "y"]: + raise ValueError(f"{self.params_str()} polarisation must be z in 2D TMz mode.") + + if self.resistance < 0: + raise ValueError( + f"{self.params_str()} requires a source resistance of zero or greater." ) - raise ValueError # Check if there is a waveformID in the waveforms list - if not any(x.ID == waveform_id for x in grid.waveforms): - logger.exception( - self.params_str() + (" there is no waveform with the identifier {waveform_id}.") + if not any(x.ID == self.waveform_id for x in grid.waveforms): + raise ValueError( + f"{self.params_str()} there is no waveform with the identifier {self.waveform_id}." ) - raise ValueError v = VoltageSourceUser() - v.polarisation = polarisation - v.xcoord = xcoord - v.ycoord = ycoord - v.zcoord = zcoord + v.polarisation = self.polarisation + v.coord = discretised_point v.ID = ( v.__class__.__name__ + "(" @@ -390,39 +382,36 @@ class VoltageSource(RotatableMixin, GridUserObject): + str(v.zcoord) + ")" ) - v.resistance = resistance - v.waveform = grid.get_waveform_by_id(waveform_id) + v.resistance = self.resistance + v.waveform = grid.get_waveform_by_id(self.waveform_id) - try: - start = self.kwargs["start"] - stop = self.kwargs["stop"] - # Check source start & source remove time parameters - if start < 0: - logger.exception( - self.params_str() - + (" delay of the initiation of the source should not be less than zero.") - ) - raise ValueError - if stop < 0: - logger.exception( - self.params_str() + (" time to remove the source should not be less than zero.") - ) - raise ValueError - if stop - start <= 0: - logger.exception( - self.params_str() + (" duration of the source should not be zero or less.") - ) - raise ValueError - v.start = start - v.stop = min(stop, grid.timewindow) - startstop = f" start time {v.start:g} secs, finish time {v.stop:g} secs " - except KeyError: + if self.start is None or self.stop is None: v.start = 0 v.stop = grid.timewindow startstop = " " + else: + # Check source start & source remove time parameters + if self.start < 0: + raise ValueError( + f"{self.params_str()} delay of the initiation of the source should not be less" + " than zero." + ) + if self.stop < 0: + raise ValueError( + f"{self.params_str()} time to remove the source should not be less than zero." + ) + if self.stop - self.start <= 0: + raise ValueError( + f"{self.params_str()} duration of the source should not be zero or less." + ) + v.start = self.start + v.stop = min(self.stop, grid.timewindow) + startstop = f" start time {v.start:g} secs, finish time {v.stop:g} secs " v.calculate_waveform_values(grid.iterations, grid.dt) + p2 = uip.discretised_to_continuous(discretised_point) + logger.info( f"{self.grid_name(grid)}Voltage source with polarity " f"{v.polarisation} at {p2[0]:g}m, {p2[1]:g}m, {p2[2]:g}m, "