Make UserInput generic

这个提交包含在:
nmannall
2024-05-15 17:29:28 +01:00
父节点 a31be536d6
当前提交 644bd53a4a

查看文件

@@ -18,8 +18,13 @@
from __future__ import annotations from __future__ import annotations
import logging import logging
from typing import Generic
import numpy as np import numpy as np
from typing_extensions import TypeVar
from gprMax.grid.fdtd_grid import FDTDGrid
from gprMax.subgrids.grid import SubGridBaseGrid
from .utilities.utilities import round_value from .utilities.utilities import round_value
@@ -34,11 +39,13 @@ logger = logging.getLogger(__name__)
encapulsated here. encapulsated here.
""" """
GT = TypeVar("GT", bound=FDTDGrid, default=FDTDGrid)
class UserInput:
class UserInput(Generic[GT]):
"""Handles (x, y, z) points supplied by the user.""" """Handles (x, y, z) points supplied by the user."""
def __init__(self, grid): def __init__(self, grid: GT):
self.grid = grid self.grid = grid
def point_within_bounds(self, p, cmd_str, name): def point_within_bounds(self, p, cmd_str, name):
@@ -73,7 +80,7 @@ class UserInput:
return p * self.grid.dl return p * self.grid.dl
class MainGridUserInput(UserInput): class MainGridUserInput(UserInput[GT]):
"""Handles (x, y, z) points supplied by the user in the main grid.""" """Handles (x, y, z) points supplied by the user in the main grid."""
def __init__(self, grid): def __init__(self, grid):
@@ -127,7 +134,7 @@ class MainGridUserInput(UserInput):
return super().discretise_point(p) * self.grid.dl return super().discretise_point(p) * self.grid.dl
class SubgridUserInput(MainGridUserInput): class SubgridUserInput(MainGridUserInput[SubGridBaseGrid]):
"""Handles (x, y, z) points supplied by the user in the subgrid. """Handles (x, y, z) points supplied by the user in the subgrid.
This class autotranslates points from main grid to subgrid equivalent This class autotranslates points from main grid to subgrid equivalent
(within IS). Useful if material traverse is not required. (within IS). Useful if material traverse is not required.