From 5b05b2af61c19ba694d90c67d31fee8c1d26911d Mon Sep 17 00:00:00 2001 From: nmannall Date: Mon, 20 Jan 2025 14:36:31 +0000 Subject: [PATCH] Discretised points should be ndarrays not Tuples --- gprMax/model.py | 2 +- gprMax/mpi_model.py | 5 +++-- gprMax/user_inputs.py | 15 ++++++++------- gprMax/user_objects/cmds_singleuse.py | 5 ++--- 4 files changed, 14 insertions(+), 13 deletions(-) diff --git a/gprMax/model.py b/gprMax/model.py index b3481e1f..ae5a8878 100644 --- a/gprMax/model.py +++ b/gprMax/model.py @@ -170,7 +170,7 @@ class Model: return grid - def set_size(self, size: Tuple[int, int, int]): + def set_size(self, size: npt.NDArray[np.int32]): self.nx, self.ny, self.nz = size def build(self): diff --git a/gprMax/mpi_model.py b/gprMax/mpi_model.py index 73da3c28..953f659d 100644 --- a/gprMax/mpi_model.py +++ b/gprMax/mpi_model.py @@ -1,7 +1,8 @@ import logging -from typing import Optional, Tuple +from typing import Optional import numpy as np +import numpy.typing as npt from mpi4py import MPI from gprMax import config @@ -53,7 +54,7 @@ class MPIModel(Model): def is_coordinator(self): return self.rank == 0 - def set_size(self, size: Tuple[int, int, int]): + def set_size(self, size: npt.NDArray[np.int32]): super().set_size(size) self.G.calculate_local_extents() diff --git a/gprMax/user_inputs.py b/gprMax/user_inputs.py index 566d50a8..48e3b91c 100644 --- a/gprMax/user_inputs.py +++ b/gprMax/user_inputs.py @@ -21,12 +21,13 @@ import logging from typing import Generic, Tuple import numpy as np +import numpy.typing as npt from typing_extensions import TypeVar from gprMax.grid.fdtd_grid import FDTDGrid from gprMax.subgrids.grid import SubGridBaseGrid -from .utilities.utilities import round_value +from .utilities.utilities import round_int logger = logging.getLogger(__name__) @@ -70,18 +71,18 @@ class UserInput(Generic[GridType]): def grid_upper_bound(self) -> list[int]: return [self.grid.nx, self.grid.ny, self.grid.nz] - def discretise_point(self, p: Tuple[float, float, float]) -> Tuple[int, int, int]: + def discretise_point(self, p: Tuple[float, float, float]) -> npt.NDArray[np.int32]: """Gets the index of a continuous point with the grid.""" - rv = np.vectorize(round_value) + rv = np.vectorize(round_int, otypes=[np.int32]) return rv(p / self.grid.dl) - def round_to_grid(self, p): + def round_to_grid(self, p: Tuple[float, float, float]) -> npt.NDArray[np.float64]: """Gets the nearest continuous point on the grid from a continuous point in space. """ return self.discretise_point(p) * self.grid.dl - def descretised_to_continuous(self, p): + def descretised_to_continuous(self, p: npt.NDArray[np.int32]) -> npt.NDArray[np.float64]: """Returns a point given as indices to a continuous point in the real space.""" return p * self.grid.dl @@ -156,7 +157,7 @@ class SubgridUserInput(MainGridUserInput[SubGridBaseGrid]): self.outer_bound = np.subtract([grid.nx, grid.ny, grid.nz], self.inner_bound) - def translate_to_gap(self, p): + def translate_to_gap(self, p) -> npt.NDArray[np.int32]: """Translates the user input point to the real point in the subgrid.""" p1 = (p[0] - self.grid.i0 * self.grid.ratio) + self.grid.n_boundary_cells_x @@ -165,7 +166,7 @@ class SubgridUserInput(MainGridUserInput[SubGridBaseGrid]): return np.array([p1, p2, p3]) - def discretise_point(self, p): + def discretise_point(self, p) -> npt.NDArray[np.int32]: """Discretises a point. Does not provide any checks. The user enters coordinates relative to self.inner_bound. This function translate the user point to the correct index for building objects. diff --git a/gprMax/user_objects/cmds_singleuse.py b/gprMax/user_objects/cmds_singleuse.py index 5932b8e3..4c499045 100644 --- a/gprMax/user_objects/cmds_singleuse.py +++ b/gprMax/user_objects/cmds_singleuse.py @@ -125,7 +125,6 @@ class Domain(ModelUserObject): discretised_domain_size = uip.discretise_point(self.domain_size) - # TODO: Fix type hinting model.set_size(discretised_domain_size) if model.nx == 0 or model.ny == 0 or model.nz == 0: @@ -520,7 +519,7 @@ class SrcSteps(ModelUserObject): def build(self, model: Model): uip = self._create_uip(model.G) - model.srcsteps = np.array(uip.discretise_point(self.step_size), dtype=np.int32) + model.srcsteps = uip.discretise_point(self.step_size) logger.info( f"Simple sources will step {model.srcsteps[0] * model.dx:g}m, " @@ -557,7 +556,7 @@ class RxSteps(ModelUserObject): def build(self, model: Model): uip = self._create_uip(model.G) - model.rxsteps = np.array(uip.discretise_point(self.step_size), dtype=np.int32) + model.rxsteps = uip.discretise_point(self.step_size) logger.info( f"All receivers will step {model.rxsteps[0] * model.dx:g}m, "