Make abstract Updates class generic and add grid attribute

这个提交包含在:
nmannall
2024-07-02 15:07:30 +01:00
父节点 0cde41cdca
当前提交 816833e254
共有 4 个文件被更改,包括 23 次插入15 次删除

查看文件

@@ -17,7 +17,6 @@
# along with gprMax. If not, see <http://www.gnu.org/licenses/>.
from importlib import import_module
from typing import Generic
from typing_extensions import TypeVar
@@ -25,14 +24,11 @@ from gprMax import config
from gprMax.cython.fields_updates_normal import update_electric as update_electric_cpu
from gprMax.cython.fields_updates_normal import update_magnetic as update_magnetic_cpu
from gprMax.fields_outputs import store_outputs as store_outputs_cpu
from gprMax.grid.fdtd_grid import FDTDGrid
from gprMax.updates.updates import Updates
from gprMax.updates.updates import GridType, Updates
from gprMax.utilities.utilities import timer
GridType = TypeVar("GridType", bound=FDTDGrid, default=FDTDGrid)
class CPUUpdates(Generic[GridType], Updates):
class CPUUpdates(Updates[GridType]):
"""Defines update functions for CPU-based solver."""
def __init__(self, G: GridType):
@@ -40,8 +36,7 @@ class CPUUpdates(Generic[GridType], Updates):
Args:
G: FDTDGrid class describing a grid in a model.
"""
self.grid = G
super().__init__(G)
def store_outputs(self, iteration):
"""Stores field component values for every receiver and transmission line."""

查看文件

@@ -40,7 +40,7 @@ from gprMax.utilities.utilities import round32
logger = logging.getLogger(__name__)
class CUDAUpdates(Updates):
class CUDAUpdates(Updates[CUDAGrid]):
"""Defines update functions for GPU-based (CUDA) solver."""
def __init__(self, G: CUDAGrid):
@@ -48,8 +48,7 @@ class CUDAUpdates(Updates):
Args:
G: CUDAGrid class describing a grid in a model.
"""
self.grid = G
super().__init__(G)
# Import PyCUDA modules
self.drv = import_module("pycuda.driver")

查看文件

@@ -38,7 +38,7 @@ from gprMax.updates.updates import Updates
logger = logging.getLogger(__name__)
class OpenCLUpdates(Updates):
class OpenCLUpdates(Updates[OpenCLGrid]):
"""Defines update functions for OpenCL-based solver."""
def __init__(self, G: OpenCLGrid):
@@ -46,8 +46,7 @@ class OpenCLUpdates(Updates):
Args:
G: OpenCLGrid class describing a grid in a model.
"""
self.grid = G
super().__init__(G)
# Import pyopencl module
self.cl = import_module("pyopencl")

查看文件

@@ -18,11 +18,26 @@
from abc import ABC, abstractmethod
from typing import Generic
from typing_extensions import TypeVar
from gprMax.grid.fdtd_grid import FDTDGrid
GridType = TypeVar("GridType", bound=FDTDGrid, default=FDTDGrid)
class Updates(ABC):
class Updates(Generic[GridType], ABC):
"""Defines update functions for a solver."""
def __init__(self, G: GridType):
"""
Args:
G: FDTDGrid class describing a grid in a model.
"""
self.grid = G
@abstractmethod
def store_outputs(self, iteration: int) -> None:
"""Stores field component values for every receiver and transmission line."""