你已经派生过 gprMax
镜像自地址
https://gitee.com/sunhf/gprMax.git
已同步 2025-08-08 07:24:19 +08:00
Make abstract Updates class generic and add grid attribute
这个提交包含在:
@@ -17,7 +17,6 @@
|
|||||||
# along with gprMax. If not, see <http://www.gnu.org/licenses/>.
|
# along with gprMax. If not, see <http://www.gnu.org/licenses/>.
|
||||||
|
|
||||||
from importlib import import_module
|
from importlib import import_module
|
||||||
from typing import Generic
|
|
||||||
|
|
||||||
from typing_extensions import TypeVar
|
from typing_extensions import TypeVar
|
||||||
|
|
||||||
@@ -25,14 +24,11 @@ from gprMax import config
|
|||||||
from gprMax.cython.fields_updates_normal import update_electric as update_electric_cpu
|
from gprMax.cython.fields_updates_normal import update_electric as update_electric_cpu
|
||||||
from gprMax.cython.fields_updates_normal import update_magnetic as update_magnetic_cpu
|
from gprMax.cython.fields_updates_normal import update_magnetic as update_magnetic_cpu
|
||||||
from gprMax.fields_outputs import store_outputs as store_outputs_cpu
|
from gprMax.fields_outputs import store_outputs as store_outputs_cpu
|
||||||
from gprMax.grid.fdtd_grid import FDTDGrid
|
from gprMax.updates.updates import GridType, Updates
|
||||||
from gprMax.updates.updates import Updates
|
|
||||||
from gprMax.utilities.utilities import timer
|
from gprMax.utilities.utilities import timer
|
||||||
|
|
||||||
GridType = TypeVar("GridType", bound=FDTDGrid, default=FDTDGrid)
|
|
||||||
|
|
||||||
|
class CPUUpdates(Updates[GridType]):
|
||||||
class CPUUpdates(Generic[GridType], Updates):
|
|
||||||
"""Defines update functions for CPU-based solver."""
|
"""Defines update functions for CPU-based solver."""
|
||||||
|
|
||||||
def __init__(self, G: GridType):
|
def __init__(self, G: GridType):
|
||||||
@@ -40,8 +36,7 @@ class CPUUpdates(Generic[GridType], Updates):
|
|||||||
Args:
|
Args:
|
||||||
G: FDTDGrid class describing a grid in a model.
|
G: FDTDGrid class describing a grid in a model.
|
||||||
"""
|
"""
|
||||||
|
super().__init__(G)
|
||||||
self.grid = G
|
|
||||||
|
|
||||||
def store_outputs(self, iteration):
|
def store_outputs(self, iteration):
|
||||||
"""Stores field component values for every receiver and transmission line."""
|
"""Stores field component values for every receiver and transmission line."""
|
||||||
|
@@ -40,7 +40,7 @@ from gprMax.utilities.utilities import round32
|
|||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class CUDAUpdates(Updates):
|
class CUDAUpdates(Updates[CUDAGrid]):
|
||||||
"""Defines update functions for GPU-based (CUDA) solver."""
|
"""Defines update functions for GPU-based (CUDA) solver."""
|
||||||
|
|
||||||
def __init__(self, G: CUDAGrid):
|
def __init__(self, G: CUDAGrid):
|
||||||
@@ -48,8 +48,7 @@ class CUDAUpdates(Updates):
|
|||||||
Args:
|
Args:
|
||||||
G: CUDAGrid class describing a grid in a model.
|
G: CUDAGrid class describing a grid in a model.
|
||||||
"""
|
"""
|
||||||
|
super().__init__(G)
|
||||||
self.grid = G
|
|
||||||
|
|
||||||
# Import PyCUDA modules
|
# Import PyCUDA modules
|
||||||
self.drv = import_module("pycuda.driver")
|
self.drv = import_module("pycuda.driver")
|
||||||
|
@@ -38,7 +38,7 @@ from gprMax.updates.updates import Updates
|
|||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class OpenCLUpdates(Updates):
|
class OpenCLUpdates(Updates[OpenCLGrid]):
|
||||||
"""Defines update functions for OpenCL-based solver."""
|
"""Defines update functions for OpenCL-based solver."""
|
||||||
|
|
||||||
def __init__(self, G: OpenCLGrid):
|
def __init__(self, G: OpenCLGrid):
|
||||||
@@ -46,8 +46,7 @@ class OpenCLUpdates(Updates):
|
|||||||
Args:
|
Args:
|
||||||
G: OpenCLGrid class describing a grid in a model.
|
G: OpenCLGrid class describing a grid in a model.
|
||||||
"""
|
"""
|
||||||
|
super().__init__(G)
|
||||||
self.grid = G
|
|
||||||
|
|
||||||
# Import pyopencl module
|
# Import pyopencl module
|
||||||
self.cl = import_module("pyopencl")
|
self.cl = import_module("pyopencl")
|
||||||
|
@@ -18,11 +18,26 @@
|
|||||||
|
|
||||||
|
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
|
from typing import Generic
|
||||||
|
|
||||||
|
from typing_extensions import TypeVar
|
||||||
|
|
||||||
|
from gprMax.grid.fdtd_grid import FDTDGrid
|
||||||
|
|
||||||
|
GridType = TypeVar("GridType", bound=FDTDGrid, default=FDTDGrid)
|
||||||
|
|
||||||
|
|
||||||
class Updates(ABC):
|
class Updates(Generic[GridType], ABC):
|
||||||
"""Defines update functions for a solver."""
|
"""Defines update functions for a solver."""
|
||||||
|
|
||||||
|
def __init__(self, G: GridType):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
G: FDTDGrid class describing a grid in a model.
|
||||||
|
"""
|
||||||
|
|
||||||
|
self.grid = G
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def store_outputs(self, iteration: int) -> None:
|
def store_outputs(self, iteration: int) -> None:
|
||||||
"""Stores field component values for every receiver and transmission line."""
|
"""Stores field component values for every receiver and transmission line."""
|
||||||
|
在新工单中引用
屏蔽一个用户