Refactor Snapshot UserObject build process

这个提交包含在:
nmannall
2024-05-17 11:37:49 +01:00
父节点 98ee9b5328
当前提交 7a35bac24b
共有 3 个文件被更改,包括 26 次插入24 次删除

查看文件

@@ -1126,7 +1126,8 @@ class Snapshot(UserObjectMulti):
self.order = 9
self.hash = "#snapshot"
def build(self, grid, uip):
def build(self, model, uip):
grid = uip.grid
if isinstance(grid, SubGridBaseGrid):
logger.exception(f"{self.params_str()} do not add snapshots to subgrids.")
raise ValueError
@@ -1206,7 +1207,7 @@ class Snapshot(UserObjectMulti):
f"{self.params_str()} the step size should not be less than the spatial discretisation."
)
raise ValueError
if iterations <= 0 or iterations > grid.iterations:
if iterations <= 0 or iterations > model.iterations:
logger.exception(f"{self.params_str()} time value is not valid.")
raise ValueError

查看文件

@@ -35,7 +35,8 @@ from gprMax.cython.yee_cell_build import build_electric_components, build_magnet
from gprMax.materials import Material, process_materials
from gprMax.pml import CFS, PML, build_pml, print_pml_info
from gprMax.receivers import Rx
from gprMax.sources import HertzianDipole, MagneticDipole, Source, VoltageSource
from gprMax.snapshots import Snapshot
from gprMax.sources import HertzianDipole, MagneticDipole, Source, TransmissionLine, VoltageSource
# from gprMax.subgrids.grid import SubGridBaseGrid
from gprMax.utilities.host_info import mem_check_build_all, mem_check_run_all
@@ -85,11 +86,11 @@ class FDTDGrid:
self.voltagesources: List[VoltageSource] = []
self.hertziandipoles: List[HertzianDipole] = []
self.magneticdipoles: List[MagneticDipole] = []
self.transmissionlines = []
self.transmissionlines: List[TransmissionLine] = []
self.rxs: List[Rx] = []
self.srcsteps: List[int] = [0, 0, 0]
self.rxsteps: List[int] = [0, 0, 0]
self.snapshots = []
self.snapshots: List[Snapshot] = []
@property
def dx(self) -> float:

查看文件

@@ -19,6 +19,7 @@
import logging
import sys
from pathlib import Path
from typing import Dict
import h5py
import numpy as np
@@ -26,7 +27,6 @@ from evtk.hl import imageToVTK
from tqdm import tqdm
import gprMax.config as config
from gprMax.grid.fdtd_grid import FDTDGrid
from ._version import __version__
from .cython.snapshots import calculate_snapshot_fields
@@ -35,7 +35,7 @@ from .utilities.utilities import get_terminal_width, round_value
logger = logging.getLogger(__name__)
def save_snapshots(grid: FDTDGrid):
def save_snapshots(grid):
"""Saves snapshots to file(s).
Args:
@@ -89,19 +89,19 @@ class Snapshot:
def __init__(
self,
xs=None,
ys=None,
zs=None,
xf=None,
yf=None,
zf=None,
dx=None,
dy=None,
dz=None,
time=None,
filename=None,
fileext=None,
outputs=None,
xs: int,
ys: int,
zs: int,
xf: int,
yf: int,
zf: int,
dx: int,
dy: int,
dz: int,
time: int,
filename: str,
fileext: str,
outputs: Dict[str, bool],
):
"""
Args:
@@ -149,7 +149,7 @@ class Snapshot:
(1, 1, 1), dtype=config.sim_config.dtypes["float_or_double"]
)
def store(self, G: FDTDGrid):
def store(self, G):
"""Store (in memory) electric and magnetic field values for snapshot.
Args:
@@ -191,7 +191,7 @@ class Snapshot:
self.snapfields["Hz"],
)
def write_file(self, pbar: tqdm, G: FDTDGrid):
def write_file(self, pbar: tqdm, G):
"""Writes snapshot file either as VTK ImageData (.vti) format
or HDF5 format (.h5) files
@@ -205,7 +205,7 @@ class Snapshot:
elif self.fileext == ".h5":
self.write_hdf5(pbar, G)
def write_vtk(self, pbar: tqdm, G: FDTDGrid):
def write_vtk(self, pbar: tqdm, G):
"""Writes snapshot file in VTK ImageData (.vti) format.
Args:
@@ -238,7 +238,7 @@ class Snapshot:
* np.dtype(config.sim_config.dtypes["float_or_double"]).itemsize
)
def write_hdf5(self, pbar: tqdm, G: FDTDGrid):
def write_hdf5(self, pbar: tqdm, G):
"""Writes snapshot file in HDF5 (.h5) format.
Args: