你已经派生过 gprMax
镜像自地址
https://gitee.com/sunhf/gprMax.git
已同步 2025-08-08 07:24:19 +08:00
Make abstract Updates class generic and add grid attribute
这个提交包含在:
@@ -17,7 +17,6 @@
|
||||
# along with gprMax. If not, see <http://www.gnu.org/licenses/>.
|
||||
|
||||
from importlib import import_module
|
||||
from typing import Generic
|
||||
|
||||
from typing_extensions import TypeVar
|
||||
|
||||
@@ -25,14 +24,11 @@ from gprMax import config
|
||||
from gprMax.cython.fields_updates_normal import update_electric as update_electric_cpu
|
||||
from gprMax.cython.fields_updates_normal import update_magnetic as update_magnetic_cpu
|
||||
from gprMax.fields_outputs import store_outputs as store_outputs_cpu
|
||||
from gprMax.grid.fdtd_grid import FDTDGrid
|
||||
from gprMax.updates.updates import Updates
|
||||
from gprMax.updates.updates import GridType, Updates
|
||||
from gprMax.utilities.utilities import timer
|
||||
|
||||
GridType = TypeVar("GridType", bound=FDTDGrid, default=FDTDGrid)
|
||||
|
||||
|
||||
class CPUUpdates(Generic[GridType], Updates):
|
||||
class CPUUpdates(Updates[GridType]):
|
||||
"""Defines update functions for CPU-based solver."""
|
||||
|
||||
def __init__(self, G: GridType):
|
||||
@@ -40,8 +36,7 @@ class CPUUpdates(Generic[GridType], Updates):
|
||||
Args:
|
||||
G: FDTDGrid class describing a grid in a model.
|
||||
"""
|
||||
|
||||
self.grid = G
|
||||
super().__init__(G)
|
||||
|
||||
def store_outputs(self, iteration):
|
||||
"""Stores field component values for every receiver and transmission line."""
|
||||
|
@@ -40,7 +40,7 @@ from gprMax.utilities.utilities import round32
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class CUDAUpdates(Updates):
|
||||
class CUDAUpdates(Updates[CUDAGrid]):
|
||||
"""Defines update functions for GPU-based (CUDA) solver."""
|
||||
|
||||
def __init__(self, G: CUDAGrid):
|
||||
@@ -48,8 +48,7 @@ class CUDAUpdates(Updates):
|
||||
Args:
|
||||
G: CUDAGrid class describing a grid in a model.
|
||||
"""
|
||||
|
||||
self.grid = G
|
||||
super().__init__(G)
|
||||
|
||||
# Import PyCUDA modules
|
||||
self.drv = import_module("pycuda.driver")
|
||||
|
@@ -38,7 +38,7 @@ from gprMax.updates.updates import Updates
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class OpenCLUpdates(Updates):
|
||||
class OpenCLUpdates(Updates[OpenCLGrid]):
|
||||
"""Defines update functions for OpenCL-based solver."""
|
||||
|
||||
def __init__(self, G: OpenCLGrid):
|
||||
@@ -46,8 +46,7 @@ class OpenCLUpdates(Updates):
|
||||
Args:
|
||||
G: OpenCLGrid class describing a grid in a model.
|
||||
"""
|
||||
|
||||
self.grid = G
|
||||
super().__init__(G)
|
||||
|
||||
# Import pyopencl module
|
||||
self.cl = import_module("pyopencl")
|
||||
|
@@ -18,11 +18,26 @@
|
||||
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Generic
|
||||
|
||||
from typing_extensions import TypeVar
|
||||
|
||||
from gprMax.grid.fdtd_grid import FDTDGrid
|
||||
|
||||
GridType = TypeVar("GridType", bound=FDTDGrid, default=FDTDGrid)
|
||||
|
||||
|
||||
class Updates(ABC):
|
||||
class Updates(Generic[GridType], ABC):
|
||||
"""Defines update functions for a solver."""
|
||||
|
||||
def __init__(self, G: GridType):
|
||||
"""
|
||||
Args:
|
||||
G: FDTDGrid class describing a grid in a model.
|
||||
"""
|
||||
|
||||
self.grid = G
|
||||
|
||||
@abstractmethod
|
||||
def store_outputs(self, iteration: int) -> None:
|
||||
"""Stores field component values for every receiver and transmission line."""
|
||||
|
在新工单中引用
屏蔽一个用户