Fix subgrid updater needing iteration count

SubgridUpdater now keep track of the number of iterations. Override
update methods to pass the current iteration and increment it in the
correct place. Pass the Model to create_updates as need access to the
main grid as well as subgrids.
这个提交包含在:
nmannall
2024-05-21 12:01:29 +01:00
父节点 6033d19e74
当前提交 bf386a15b4
共有 4 个文件被更改,包括 36 次插入16 次删除

查看文件

@@ -111,7 +111,7 @@ class Context:
model.build()
if not config.sim_config.geometry_only:
solver = create_solver(model.G)
solver = create_solver(model)
model.solve(solver)
del solver

查看文件

@@ -17,6 +17,7 @@
# along with gprMax. If not, see <http://www.gnu.org/licenses/>.
import gprMax.config as config
from gprMax.model import Model
from .grid.cuda_grid import CUDAGrid
from .grid.fdtd_grid import FDTDGrid
@@ -75,7 +76,7 @@ class Solver:
self.updates.cleanup()
def create_solver(grid: FDTDGrid) -> Solver:
def create_solver(model: Model) -> Solver:
"""Create configured solver object.
N.B. A large range of different functions exist to advance the time
@@ -87,14 +88,14 @@ def create_solver(grid: FDTDGrid) -> Solver:
substitution at runtime.
Args:
G: FDTDGrid class describing a grid in a model.
model: model containing the main grid and subgrids.
Returns:
solver: Solver object.
"""
grid = model.G
if config.sim_config.general["subgrid"]:
updates = create_subgrid_updates(grid)
updates = create_subgrid_updates(model)
if config.get_model_config().materials["maxpoles"] != 0:
# Set dispersive update functions for both SubgridUpdates and
# SubgridUpdaters subclasses

查看文件

@@ -18,6 +18,10 @@
import logging
from gprMax.grid.fdtd_grid import FDTDGrid
from gprMax.model import Model
from gprMax.subgrids.grid import SubGridBaseGrid
from ..updates.cpu_updates import CPUUpdates
from .precursor_nodes import PrecursorNodes, PrecursorNodesFiltered
from .subgrid_hsg import SubGridHSG
@@ -25,24 +29,24 @@ from .subgrid_hsg import SubGridHSG
logger = logging.getLogger(__name__)
def create_updates(G):
def create_updates(model: Model):
"""Return the solver for the given subgrids."""
updaters = []
for sg in G.subgrids:
for sg in model.subgrids:
sg_type = type(sg)
if sg_type == SubGridHSG and sg.filter:
precursors = PrecursorNodesFiltered(G, sg)
precursors = PrecursorNodesFiltered(model.G, sg)
elif sg_type == SubGridHSG:
precursors = PrecursorNodes(G, sg)
precursors = PrecursorNodes(model.G, sg)
else:
logger.exception(f"{str(sg)} is not a subgrid type")
raise ValueError
sgu = SubgridUpdater(sg, precursors, G)
sgu = SubgridUpdater(sg, precursors, model.G)
updaters.append(sgu)
updates = SubgridUpdates(G, updaters)
updates = SubgridUpdates(model.G, updaters)
return updates
@@ -64,13 +68,13 @@ class SubgridUpdates(CPUUpdates):
sg_updater.hsg_2()
class SubgridUpdater(CPUUpdates):
class SubgridUpdater(CPUUpdates[SubGridBaseGrid]):
"""Handles updating the electric and magnetic fields of an HSG subgrid.
The IS, OS, subgrid region and the electric/magnetic sources are updated
using the precursor regions.
"""
def __init__(self, subgrid, precursors, G):
def __init__(self, subgrid: SubGridBaseGrid, precursors: PrecursorNodes, G: FDTDGrid):
"""
Args:
subgrid: SubGrid3d instance to be updated.
@@ -82,7 +86,17 @@ class SubgridUpdater(CPUUpdates):
super().__init__(subgrid)
self.precursors = precursors
self.G = G
self.source_iteration = 0
self.iteration = 0
def store_outputs(self):
return super().store_outputs(self.iteration)
def update_electric_sources(self):
super().update_electric_sources(self.iteration)
self.iteration += 1
def update_magnetic_sources(self):
return super().update_magnetic_sources(self.iteration)
def hsg_1(self):
"""First half of the subgrid update. Takes the time step up to the main

查看文件

@@ -17,6 +17,9 @@
# along with gprMax. If not, see <http://www.gnu.org/licenses/>.
from importlib import import_module
from typing import Generic
from typing_extensions import TypeVar
from gprMax import config
from gprMax.cython.fields_updates_normal import update_electric as update_electric_cpu
@@ -26,11 +29,13 @@ from gprMax.grid.fdtd_grid import FDTDGrid
from gprMax.updates.updates import Updates
from gprMax.utilities.utilities import timer
GridType = TypeVar("GridType", bound=FDTDGrid, default=FDTDGrid)
class CPUUpdates(Updates):
class CPUUpdates(Generic[GridType], Updates):
"""Defines update functions for CPU-based solver."""
def __init__(self, G: FDTDGrid):
def __init__(self, G: GridType):
"""
Args:
G: FDTDGrid class describing a grid in a model.