Refactor UserInputs to make inheritence simpler

这个提交包含在:
nmannall
2025-01-31 14:48:48 +00:00
父节点 e4370abefd
当前提交 1d0cd4b980

查看文件

@@ -73,20 +73,71 @@ class UserInput(Generic[GridType]):
def grid_upper_bound(self) -> list[int]:
return [self.grid.nx, self.grid.ny, self.grid.nz]
def discretise_point(self, p: Tuple[float, float, float]) -> npt.NDArray[np.int32]:
"""Gets the index of a continuous point with the grid."""
rv = np.vectorize(round_int, otypes=[np.int32])
return rv(p / self.grid.dl)
def discretise_static_point(self, point: Tuple[float, float, float]) -> npt.NDArray[np.int32]:
"""Get the nearest grid index to a continuous static point.
def round_to_grid(self, p: Tuple[float, float, float]) -> npt.NDArray[np.float64]:
"""Gets the nearest continuous point on the grid from a continuous point
in space.
For a static point, the point of the origin of the grid is
ignored. I.e. it is assumed to be at (0, 0, 0). There are no
checks of the validity of the point such as bound checking.
Args:
point: x, y, z coordinates of the point in space.
Returns:
discretised_point: x, y, z indices of the point on the grid.
"""
return self.discretise_point(p) * self.grid.dl
rv = np.vectorize(round_int, otypes=[np.int32])
return rv(point / self.grid.dl)
def discretised_to_continuous(self, p: npt.NDArray[np.int32]) -> npt.NDArray[np.float64]:
"""Returns a point given as indices to a continuous point in the real space."""
return p * self.grid.dl
def round_to_grid_static_point(
self, point: Tuple[float, float, float]
) -> npt.NDArray[np.float64]:
"""Round a continuous static point to the nearest point on the grid.
For a static point, the point of the origin of the grid is
ignored. I.e. it is assumed to be at (0, 0, 0). There are no
checks of the validity of the point such as bound checking.
Args:
point: x, y, z coordinates of the point in space.
Returns:
rounded_point: x, y, z coordinates of the nearest continuous
point on the grid.
"""
return self.discretise_static_point(point) * self.grid.dl
def discretise_point(self, point: Tuple[float, float, float]) -> npt.NDArray[np.int32]:
"""Get the nearest grid index to a continuous static point.
This function translates user points to the correct index for
building objects. Points will be mapped from the user coordinate
space to the local coordinate space of the grid. There are no
checks of the validity of the point such as bound checking.
Args:
point: x, y, z coordinates of the point in space.
Returns:
discretised_point: x, y, z indices of the point on the grid.
"""
return self.discretise_static_point(point)
def round_to_grid(self, point: Tuple[float, float, float]) -> npt.NDArray[np.float64]:
"""Round a continuous static point to the nearest point on the grid.
The point will be mapped from the user coordinate space to the
local coordinate space of the grid. There are no checks of the
validity of the point such as bound checking.
Args:
point: x, y, z coordinates of the point in space.
Returns:
rounded_point: x, y, z coordinates of the nearest continuous
point on the grid.
"""
return self.discretise_point(point) * self.grid.dl
class MainGridUserInput(UserInput[GridType]):
@@ -130,18 +181,6 @@ class MainGridUserInput(UserInput[GridType]):
return p1, p2, p3
def discretise_static_point(self, p):
"""Gets the index of a continuous point regardless of the point of
origin of the grid.
"""
return super().discretise_point(p)
def round_to_grid_static_point(self, p):
"""Gets the index of a continuous point regardless of the point of
origin of the grid.
"""
return super().discretise_point(p) * self.grid.dl
class SubgridUserInput(MainGridUserInput[SubGridBaseGrid]):
"""Handles (x, y, z) points supplied by the user in the subgrid.
@@ -168,20 +207,26 @@ class SubgridUserInput(MainGridUserInput[SubGridBaseGrid]):
return np.array([p1, p2, p3])
def discretise_point(self, p) -> npt.NDArray[np.int32]:
"""Discretises a point. Does not provide any checks. The user enters
coordinates relative to self.inner_bound. This function translate
the user point to the correct index for building objects.
def discretise_point(self, point: Tuple[float, float, float]) -> npt.NDArray[np.int32]:
"""Get the nearest grid index to a continuous static point.
This function translates user points to the correct index for
building objects. The user enters coordinates relative to
self.inner_bound which are mapped to the local coordinate space
of the grid. There are no checks of the validity of the point
such as bound checking.
Args:
point: x, y, z coordinates of the point in space relative to
self.inner_bound.
Returns:
discretised_point: x, y, z indices of the point on the grid.
"""
p = super().discretise_point(p)
p_t = self.translate_to_gap(p)
return p_t
def round_to_grid(self, p):
p_t = self.discretise_point(p)
p_m = p_t * self.grid.dl
return p_m
discretised_point = super().discretise_point(point)
discretised_point = self.translate_to_gap(discretised_point)
return discretised_point
def check_point(self, p, cmd_str, name=""):
p_t = super().check_point(p, cmd_str, name)
@@ -193,13 +238,3 @@ class SubgridUserInput(MainGridUserInput[SubGridBaseGrid]):
f"'{cmd_str}' this object traverses the Outer Surface. This is an advanced feature."
)
return p_t
def discretise_static_point(self, p):
"""Gets the index of a continuous point regardless of the point of
origin of the grid."""
return super().discretise_point(p)
def round_to_grid_static_point(self, p):
"""Gets the index of a continuous point regardless of the point of
origin of the grid."""
return super().discretise_point(p) * self.grid.dl