Refactor Snapshot UserObject build process

这个提交包含在:
nmannall
2024-05-17 11:37:49 +01:00
父节点 98ee9b5328
当前提交 7a35bac24b
共有 3 个文件被更改,包括 26 次插入24 次删除

查看文件

@@ -1126,7 +1126,8 @@ class Snapshot(UserObjectMulti):
self.order = 9 self.order = 9
self.hash = "#snapshot" self.hash = "#snapshot"
def build(self, grid, uip): def build(self, model, uip):
grid = uip.grid
if isinstance(grid, SubGridBaseGrid): if isinstance(grid, SubGridBaseGrid):
logger.exception(f"{self.params_str()} do not add snapshots to subgrids.") logger.exception(f"{self.params_str()} do not add snapshots to subgrids.")
raise ValueError raise ValueError
@@ -1206,7 +1207,7 @@ class Snapshot(UserObjectMulti):
f"{self.params_str()} the step size should not be less than the spatial discretisation." f"{self.params_str()} the step size should not be less than the spatial discretisation."
) )
raise ValueError raise ValueError
if iterations <= 0 or iterations > grid.iterations: if iterations <= 0 or iterations > model.iterations:
logger.exception(f"{self.params_str()} time value is not valid.") logger.exception(f"{self.params_str()} time value is not valid.")
raise ValueError raise ValueError

查看文件

@@ -35,7 +35,8 @@ from gprMax.cython.yee_cell_build import build_electric_components, build_magnet
from gprMax.materials import Material, process_materials from gprMax.materials import Material, process_materials
from gprMax.pml import CFS, PML, build_pml, print_pml_info from gprMax.pml import CFS, PML, build_pml, print_pml_info
from gprMax.receivers import Rx from gprMax.receivers import Rx
from gprMax.sources import HertzianDipole, MagneticDipole, Source, VoltageSource from gprMax.snapshots import Snapshot
from gprMax.sources import HertzianDipole, MagneticDipole, Source, TransmissionLine, VoltageSource
# from gprMax.subgrids.grid import SubGridBaseGrid # from gprMax.subgrids.grid import SubGridBaseGrid
from gprMax.utilities.host_info import mem_check_build_all, mem_check_run_all from gprMax.utilities.host_info import mem_check_build_all, mem_check_run_all
@@ -85,11 +86,11 @@ class FDTDGrid:
self.voltagesources: List[VoltageSource] = [] self.voltagesources: List[VoltageSource] = []
self.hertziandipoles: List[HertzianDipole] = [] self.hertziandipoles: List[HertzianDipole] = []
self.magneticdipoles: List[MagneticDipole] = [] self.magneticdipoles: List[MagneticDipole] = []
self.transmissionlines = [] self.transmissionlines: List[TransmissionLine] = []
self.rxs: List[Rx] = [] self.rxs: List[Rx] = []
self.srcsteps: List[int] = [0, 0, 0] self.srcsteps: List[int] = [0, 0, 0]
self.rxsteps: List[int] = [0, 0, 0] self.rxsteps: List[int] = [0, 0, 0]
self.snapshots = [] self.snapshots: List[Snapshot] = []
@property @property
def dx(self) -> float: def dx(self) -> float:

查看文件

@@ -19,6 +19,7 @@
import logging import logging
import sys import sys
from pathlib import Path from pathlib import Path
from typing import Dict
import h5py import h5py
import numpy as np import numpy as np
@@ -26,7 +27,6 @@ from evtk.hl import imageToVTK
from tqdm import tqdm from tqdm import tqdm
import gprMax.config as config import gprMax.config as config
from gprMax.grid.fdtd_grid import FDTDGrid
from ._version import __version__ from ._version import __version__
from .cython.snapshots import calculate_snapshot_fields from .cython.snapshots import calculate_snapshot_fields
@@ -35,7 +35,7 @@ from .utilities.utilities import get_terminal_width, round_value
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def save_snapshots(grid: FDTDGrid): def save_snapshots(grid):
"""Saves snapshots to file(s). """Saves snapshots to file(s).
Args: Args:
@@ -89,19 +89,19 @@ class Snapshot:
def __init__( def __init__(
self, self,
xs=None, xs: int,
ys=None, ys: int,
zs=None, zs: int,
xf=None, xf: int,
yf=None, yf: int,
zf=None, zf: int,
dx=None, dx: int,
dy=None, dy: int,
dz=None, dz: int,
time=None, time: int,
filename=None, filename: str,
fileext=None, fileext: str,
outputs=None, outputs: Dict[str, bool],
): ):
""" """
Args: Args:
@@ -149,7 +149,7 @@ class Snapshot:
(1, 1, 1), dtype=config.sim_config.dtypes["float_or_double"] (1, 1, 1), dtype=config.sim_config.dtypes["float_or_double"]
) )
def store(self, G: FDTDGrid): def store(self, G):
"""Store (in memory) electric and magnetic field values for snapshot. """Store (in memory) electric and magnetic field values for snapshot.
Args: Args:
@@ -191,7 +191,7 @@ class Snapshot:
self.snapfields["Hz"], self.snapfields["Hz"],
) )
def write_file(self, pbar: tqdm, G: FDTDGrid): def write_file(self, pbar: tqdm, G):
"""Writes snapshot file either as VTK ImageData (.vti) format """Writes snapshot file either as VTK ImageData (.vti) format
or HDF5 format (.h5) files or HDF5 format (.h5) files
@@ -205,7 +205,7 @@ class Snapshot:
elif self.fileext == ".h5": elif self.fileext == ".h5":
self.write_hdf5(pbar, G) self.write_hdf5(pbar, G)
def write_vtk(self, pbar: tqdm, G: FDTDGrid): def write_vtk(self, pbar: tqdm, G):
"""Writes snapshot file in VTK ImageData (.vti) format. """Writes snapshot file in VTK ImageData (.vti) format.
Args: Args:
@@ -238,7 +238,7 @@ class Snapshot:
* np.dtype(config.sim_config.dtypes["float_or_double"]).itemsize * np.dtype(config.sim_config.dtypes["float_or_double"]).itemsize
) )
def write_hdf5(self, pbar: tqdm, G: FDTDGrid): def write_hdf5(self, pbar: tqdm, G):
"""Writes snapshot file in HDF5 (.h5) format. """Writes snapshot file in HDF5 (.h5) format.
Args: Args: