Discretised points should be ndarrays not Tuples

这个提交包含在:
nmannall
2025-01-20 14:36:31 +00:00
父节点 b982e37195
当前提交 5b05b2af61
共有 4 个文件被更改,包括 14 次插入13 次删除

查看文件

@@ -170,7 +170,7 @@ class Model:
return grid return grid
def set_size(self, size: Tuple[int, int, int]): def set_size(self, size: npt.NDArray[np.int32]):
self.nx, self.ny, self.nz = size self.nx, self.ny, self.nz = size
def build(self): def build(self):

查看文件

@@ -1,7 +1,8 @@
import logging import logging
from typing import Optional, Tuple from typing import Optional
import numpy as np import numpy as np
import numpy.typing as npt
from mpi4py import MPI from mpi4py import MPI
from gprMax import config from gprMax import config
@@ -53,7 +54,7 @@ class MPIModel(Model):
def is_coordinator(self): def is_coordinator(self):
return self.rank == 0 return self.rank == 0
def set_size(self, size: Tuple[int, int, int]): def set_size(self, size: npt.NDArray[np.int32]):
super().set_size(size) super().set_size(size)
self.G.calculate_local_extents() self.G.calculate_local_extents()

查看文件

@@ -21,12 +21,13 @@ import logging
from typing import Generic, Tuple from typing import Generic, Tuple
import numpy as np import numpy as np
import numpy.typing as npt
from typing_extensions import TypeVar from typing_extensions import TypeVar
from gprMax.grid.fdtd_grid import FDTDGrid from gprMax.grid.fdtd_grid import FDTDGrid
from gprMax.subgrids.grid import SubGridBaseGrid from gprMax.subgrids.grid import SubGridBaseGrid
from .utilities.utilities import round_value from .utilities.utilities import round_int
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -70,18 +71,18 @@ class UserInput(Generic[GridType]):
def grid_upper_bound(self) -> list[int]: def grid_upper_bound(self) -> list[int]:
return [self.grid.nx, self.grid.ny, self.grid.nz] return [self.grid.nx, self.grid.ny, self.grid.nz]
def discretise_point(self, p: Tuple[float, float, float]) -> Tuple[int, int, int]: def discretise_point(self, p: Tuple[float, float, float]) -> npt.NDArray[np.int32]:
"""Gets the index of a continuous point with the grid.""" """Gets the index of a continuous point with the grid."""
rv = np.vectorize(round_value) rv = np.vectorize(round_int, otypes=[np.int32])
return rv(p / self.grid.dl) return rv(p / self.grid.dl)
def round_to_grid(self, p): def round_to_grid(self, p: Tuple[float, float, float]) -> npt.NDArray[np.float64]:
"""Gets the nearest continuous point on the grid from a continuous point """Gets the nearest continuous point on the grid from a continuous point
in space. in space.
""" """
return self.discretise_point(p) * self.grid.dl return self.discretise_point(p) * self.grid.dl
def descretised_to_continuous(self, p): def descretised_to_continuous(self, p: npt.NDArray[np.int32]) -> npt.NDArray[np.float64]:
"""Returns a point given as indices to a continuous point in the real space.""" """Returns a point given as indices to a continuous point in the real space."""
return p * self.grid.dl return p * self.grid.dl
@@ -156,7 +157,7 @@ class SubgridUserInput(MainGridUserInput[SubGridBaseGrid]):
self.outer_bound = np.subtract([grid.nx, grid.ny, grid.nz], self.inner_bound) self.outer_bound = np.subtract([grid.nx, grid.ny, grid.nz], self.inner_bound)
def translate_to_gap(self, p): def translate_to_gap(self, p) -> npt.NDArray[np.int32]:
"""Translates the user input point to the real point in the subgrid.""" """Translates the user input point to the real point in the subgrid."""
p1 = (p[0] - self.grid.i0 * self.grid.ratio) + self.grid.n_boundary_cells_x p1 = (p[0] - self.grid.i0 * self.grid.ratio) + self.grid.n_boundary_cells_x
@@ -165,7 +166,7 @@ class SubgridUserInput(MainGridUserInput[SubGridBaseGrid]):
return np.array([p1, p2, p3]) return np.array([p1, p2, p3])
def discretise_point(self, p): def discretise_point(self, p) -> npt.NDArray[np.int32]:
"""Discretises a point. Does not provide any checks. The user enters """Discretises a point. Does not provide any checks. The user enters
coordinates relative to self.inner_bound. This function translate coordinates relative to self.inner_bound. This function translate
the user point to the correct index for building objects. the user point to the correct index for building objects.

查看文件

@@ -125,7 +125,6 @@ class Domain(ModelUserObject):
discretised_domain_size = uip.discretise_point(self.domain_size) discretised_domain_size = uip.discretise_point(self.domain_size)
# TODO: Fix type hinting
model.set_size(discretised_domain_size) model.set_size(discretised_domain_size)
if model.nx == 0 or model.ny == 0 or model.nz == 0: if model.nx == 0 or model.ny == 0 or model.nz == 0:
@@ -520,7 +519,7 @@ class SrcSteps(ModelUserObject):
def build(self, model: Model): def build(self, model: Model):
uip = self._create_uip(model.G) uip = self._create_uip(model.G)
model.srcsteps = np.array(uip.discretise_point(self.step_size), dtype=np.int32) model.srcsteps = uip.discretise_point(self.step_size)
logger.info( logger.info(
f"Simple sources will step {model.srcsteps[0] * model.dx:g}m, " f"Simple sources will step {model.srcsteps[0] * model.dx:g}m, "
@@ -557,7 +556,7 @@ class RxSteps(ModelUserObject):
def build(self, model: Model): def build(self, model: Model):
uip = self._create_uip(model.G) uip = self._create_uip(model.G)
model.rxsteps = np.array(uip.discretise_point(self.step_size), dtype=np.int32) model.rxsteps = uip.discretise_point(self.step_size)
logger.info( logger.info(
f"All receivers will step {model.rxsteps[0] * model.dx:g}m, " f"All receivers will step {model.rxsteps[0] * model.dx:g}m, "