From 644bd53a4a737b007861b7c0a5367ca58dc5a756 Mon Sep 17 00:00:00 2001 From: nmannall Date: Wed, 15 May 2024 17:29:28 +0100 Subject: [PATCH] Make UserInput generic --- gprMax/user_inputs.py | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/gprMax/user_inputs.py b/gprMax/user_inputs.py index c7314038..5004e0cb 100644 --- a/gprMax/user_inputs.py +++ b/gprMax/user_inputs.py @@ -18,8 +18,13 @@ from __future__ import annotations import logging +from typing import Generic import numpy as np +from typing_extensions import TypeVar + +from gprMax.grid.fdtd_grid import FDTDGrid +from gprMax.subgrids.grid import SubGridBaseGrid from .utilities.utilities import round_value @@ -34,11 +39,13 @@ logger = logging.getLogger(__name__) encapulsated here. """ +GT = TypeVar("GT", bound=FDTDGrid, default=FDTDGrid) -class UserInput: + +class UserInput(Generic[GT]): """Handles (x, y, z) points supplied by the user.""" - def __init__(self, grid): + def __init__(self, grid: GT): self.grid = grid def point_within_bounds(self, p, cmd_str, name): @@ -73,7 +80,7 @@ class UserInput: return p * self.grid.dl -class MainGridUserInput(UserInput): +class MainGridUserInput(UserInput[GT]): """Handles (x, y, z) points supplied by the user in the main grid.""" def __init__(self, grid): @@ -127,7 +134,7 @@ class MainGridUserInput(UserInput): return super().discretise_point(p) * self.grid.dl -class SubgridUserInput(MainGridUserInput): +class SubgridUserInput(MainGridUserInput[SubGridBaseGrid]): """Handles (x, y, z) points supplied by the user in the subgrid. This class autotranslates points from main grid to subgrid equivalent (within IS). Useful if material traverse is not required.