diff --git a/gprMax/user_inputs.py b/gprMax/user_inputs.py index c7314038..5004e0cb 100644 --- a/gprMax/user_inputs.py +++ b/gprMax/user_inputs.py @@ -18,8 +18,13 @@ from __future__ import annotations import logging +from typing import Generic import numpy as np +from typing_extensions import TypeVar + +from gprMax.grid.fdtd_grid import FDTDGrid +from gprMax.subgrids.grid import SubGridBaseGrid from .utilities.utilities import round_value @@ -34,11 +39,13 @@ logger = logging.getLogger(__name__) encapulsated here. """ +GT = TypeVar("GT", bound=FDTDGrid, default=FDTDGrid) -class UserInput: + +class UserInput(Generic[GT]): """Handles (x, y, z) points supplied by the user.""" - def __init__(self, grid): + def __init__(self, grid: GT): self.grid = grid def point_within_bounds(self, p, cmd_str, name): @@ -73,7 +80,7 @@ class UserInput: return p * self.grid.dl -class MainGridUserInput(UserInput): +class MainGridUserInput(UserInput[GT]): """Handles (x, y, z) points supplied by the user in the main grid.""" def __init__(self, grid): @@ -127,7 +134,7 @@ class MainGridUserInput(UserInput): return super().discretise_point(p) * self.grid.dl -class SubgridUserInput(MainGridUserInput): +class SubgridUserInput(MainGridUserInput[SubGridBaseGrid]): """Handles (x, y, z) points supplied by the user in the subgrid. This class autotranslates points from main grid to subgrid equivalent (within IS). Useful if material traverse is not required.