Add support for VTKHDF snapshots

这个提交包含在:
Nathan Mannall
2025-06-20 12:17:07 +01:00
父节点 3443071097
当前提交 652da89ad4
共有 27 个文件被更改,包括 95 次插入47 次删除

查看文件

@@ -24,7 +24,6 @@ from typing import Dict, Generic, List
import h5py
import numpy as np
from evtk.hl import imageToVTK
from mpi4py import MPI
from tqdm import tqdm
@@ -32,6 +31,7 @@ import gprMax.config as config
from gprMax.geometry_outputs.grid_view import GridType, GridView, MPIGridView
from gprMax.grid.mpi_grid import MPIGrid
from gprMax.utilities.mpi import Dim, Dir
from gprMax.vtkhdf_filehandlers.vtk_image_data import VtkImageData
from ._version import __version__
from .cython.snapshots import calculate_snapshot_fields
@@ -54,8 +54,7 @@ def save_snapshots(snapshots: List["Snapshot"]):
logger.info(f"Snapshot directory: {snapshotdir.resolve()}")
for i, snap in enumerate(snapshots):
fn = snapshotdir / snap.filename
snap.filename = fn.with_suffix(snap.fileext)
snap.filename = snapshotdir / snap.filename
pbar = tqdm(
total=snap.nbytes,
leave=True,
@@ -83,9 +82,9 @@ class Snapshot(Generic[GridType]):
"Hz": None,
}
# Snapshots can be output as VTK ImageData (.vti) format or
# Snapshots can be output as VTK ImageData (.vtkhdf) format or
# HDF5 format (.h5) files
fileexts = [".vti", ".h5"]
fileexts = [".vtkhdf", ".h5"]
# Dimensions of largest requested snapshot
nx_max = 0
@@ -124,12 +123,12 @@ class Snapshot(Generic[GridType]):
dx, dy, dz: ints for the spatial discretisation in cells.
time: int for the iteration number to take the snapshot on.
filename: string for the filename to save to.
fileext: optional string for the file extension.
outputs: optional dict of booleans for fields to use for snapshot.
fileext: string for the file extension.
outputs: dict of booleans for fields to use for snapshot.
"""
self.fileext = fileext
self.filename = Path(filename)
self.filename = Path(filename).with_suffix(fileext)
self.time = time
self.outputs = outputs
self.grid_view = self.GRID_VIEW_TYPE(grid, xs, ys, zs, xf, yf, zf, dx, dy, dz)
@@ -246,41 +245,26 @@ class Snapshot(Generic[GridType]):
G: FDTDGrid class describing a grid in a model.
"""
if self.fileext == ".vti":
if self.fileext == ".vtkhdf":
self.write_vtk(pbar)
elif self.fileext == ".h5":
self.write_hdf5(pbar)
def write_vtk(self, pbar: tqdm):
"""Writes snapshot file in VTK ImageData (.vti) format.
"""Writes snapshot file in VTK ImageData (.vtkhdf) format.
Args:
pbar: Progress bar class instance.
"""
celldata = {
k: self.snapfields[k]
for k in ["Ex", "Ey", "Ez", "Hx", "Hy", "Hz"]
if self.outputs.get(k)
}
origin = self.grid_view.start * self.grid.dl
spacing = self.grid_view.step * self.grid.dl
imageToVTK(
str(self.filename.with_suffix("")),
origin=tuple(origin),
spacing=tuple(spacing),
cellData=celldata,
)
pbar.update(
n=len(celldata)
* self.nx
* self.ny
* self.nz
* np.dtype(config.sim_config.dtypes["float_or_double"]).itemsize
)
with VtkImageData(self.filename, self.grid_view.size, origin, spacing) as f:
for key in ["Ex", "Ey", "Ez", "Hx", "Hy", "Hz"]:
if self.outputs[key]:
f.add_cell_data(key, self.snapfields[key])
pbar.update(n=self.snapfields[key].nbytes)
def write_hdf5(self, pbar: tqdm):
"""Writes snapshot file in HDF5 (.h5) format.
@@ -527,6 +511,25 @@ class MPISnapshot(Snapshot[MPIGrid]):
self.snapfields["Hz"],
)
def write_vtk(self, pbar: tqdm):
"""Writes snapshot file in VTK ImageData (.vtkhdf) format.
Args:
pbar: Progress bar class instance.
"""
assert isinstance(self.grid_view, self.GRID_VIEW_TYPE)
origin = self.grid_view.global_start * self.grid.dl
spacing = self.grid_view.step * self.grid.dl
with VtkImageData(
self.filename, self.grid_view.global_size, origin, spacing, comm=self.comm
) as f:
for key in ["Ex", "Ey", "Ez", "Hx", "Hy", "Hz"]:
if self.outputs.get(key):
f.add_cell_data(key, self.snapfields[key], self.grid_view.offset)
pbar.update(n=self.snapfields[key].nbytes)
def write_hdf5(self, pbar: tqdm):
"""Writes snapshot file in HDF5 (.h5) format.

查看文件

@@ -1,5 +1,5 @@
# Copyright (C) 2015-2025: The University of Edinburgh, United Kingdom
# Authors: Craig Warren, Antonis Giannopoulos, John Hartley,
# Authors: Craig Warren, Antonis Giannopoulos, John Hartley,
# and Nathan Mannall
#
# This file is part of gprMax.
@@ -24,7 +24,6 @@ import numpy as np
import numpy.typing as npt
from gprMax.grid.fdtd_grid import FDTDGrid
from gprMax.grid.mpi_grid import MPIGrid
from gprMax.model import Model
from gprMax.snapshots import Snapshot as SnapshotUser
from gprMax.subgrids.grid import SubGridBaseGrid
@@ -121,9 +120,7 @@ class Snapshot(OutputUserObject):
# correction has been applied.
while any(discretised_upper_bound < upper_bound):
try:
uip.point_within_bounds(
upper_bound, f"[{upper_bound[0]}, {upper_bound[1]}, {upper_bound[2]}]"
)
grid.within_bounds(upper_bound)
upper_bound_within_grid = True
except ValueError:
upper_bound_within_grid = False
@@ -214,12 +211,6 @@ class Snapshot(OutputUserObject):
f" Valid options are: {' '.join(SnapshotUser.fileexts)}."
)
# TODO: Allow VTKHDF files when they are implemented
if isinstance(grid, MPIGrid) and self.file_extension != ".h5":
raise ValueError(
f"{self.params_str()} currently only '.h5' snapshots are compatible with MPI."
)
if self.outputs is None:
outputs = dict.fromkeys(SnapshotUser.allowableoutputs, True)
else:
@@ -258,8 +249,7 @@ class Snapshot(OutputUserObject):
f" {dl[0]:g}m, {dl[1]:g}m, {dl[2]:g}m, at"
f" {snapshot.time * grid.dt:g} secs with field outputs"
f" {', '.join([k for k, v in outputs.items() if v])} "
f" and filename {snapshot.filename}{snapshot.fileext}"
" will be created."
f" and filename {snapshot.filename} will be created."
)

查看文件

@@ -67,7 +67,7 @@ class SnapshotMixin(GprMaxMixin):
Args:
snapshot: Name of the snapshot.
"""
return Path(f"{self.model}_snaps", snapshot).with_suffix(".h5")
return Path(f"{self.model}_snaps", snapshot)
@run_after("setup")
def add_snapshot_regression_checks(self):
@@ -82,7 +82,7 @@ class SnapshotMixin(GprMaxMixin):
for snapshot in self.snapshots:
snapshot_file = self.build_snapshot_filepath(snapshot)
reference_file = self.build_reference_filepath(snapshot)
reference_file = self.build_reference_filepath(snapshot, suffix=snapshot_file.suffix)
regression_check = SnapshotRegressionCheck(snapshot_file, reference_file)
self.regression_checks.append(regression_check)

查看文件

@@ -20,3 +20,18 @@
#snapshot: 0 0 0.025 0.100 0.100 0.026 0.01 0.01 0.01 2e-9 snapshot_z_25.h5
#snapshot: 0 0 0.055 0.100 0.100 0.056 0.01 0.01 0.01 2e-9 snapshot_z_55.h5
#snapshot: 0 0 0.055 0.100 0.100 0.086 0.01 0.01 0.01 2e-9 snapshot_z_85.h5
#snapshot: 0.005 0 0 0.006 0.100 0.100 0.01 0.01 0.01 2e-9 snapshot_x_05.vtkhdf
#snapshot: 0.035 0 0 0.036 0.100 0.100 0.01 0.01 0.01 2e-9 snapshot_x_35.vtkhdf
#snapshot: 0.065 0 0 0.066 0.100 0.100 0.01 0.01 0.01 2e-9 snapshot_x_65.vtkhdf
#snapshot: 0.095 0 0 0.096 0.100 0.100 0.01 0.01 0.01 2e-9 snapshot_x_95.vtkhdf
#snapshot: 0 0.015 0 0.100 0.016 0.100 0.01 0.01 0.01 2e-9 snapshot_y_15.vtkhdf
#snapshot: 0 0.040 0 0.100 0.050 0.100 0.01 0.01 0.01 2e-9 snapshot_y_40.vtkhdf
#snapshot: 0 0.045 0 0.100 0.046 0.100 0.01 0.01 0.01 2e-9 snapshot_y_45.vtkhdf
#snapshot: 0 0.050 0 0.100 0.051 0.100 0.01 0.01 0.01 2e-9 snapshot_y_50.vtkhdf
#snapshot: 0 0.075 0 0.100 0.076 0.100 0.01 0.01 0.01 2e-9 snapshot_y_75.vtkhdf
#snapshot: 0 0 0.025 0.100 0.100 0.026 0.01 0.01 0.01 2e-9 snapshot_z_25.vtkhdf
#snapshot: 0 0 0.055 0.100 0.100 0.056 0.01 0.01 0.01 2e-9 snapshot_z_55.vtkhdf
#snapshot: 0 0 0.055 0.100 0.100 0.086 0.01 0.01 0.01 2e-9 snapshot_z_85.vtkhdf

查看文件

@@ -10,3 +10,8 @@
#snapshot: 0 0 0 0.100 0.100 0.100 0.01 0.01 0.01 1e-9 snapshot_1.h5
#snapshot: 0 0 0 0.100 0.100 0.100 0.01 0.01 0.01 2e-9 snapshot_2.h5
#snapshot: 0 0 0 0.100 0.100 0.100 0.01 0.01 0.01 3e-9 snapshot_3.h5
#snapshot: 0 0 0 0.100 0.100 0.100 0.01 0.01 0.01 1 snapshot_0.vtkhdf
#snapshot: 0 0 0 0.100 0.100 0.100 0.01 0.01 0.01 1e-9 snapshot_1.vtkhdf
#snapshot: 0 0 0 0.100 0.100 0.100 0.01 0.01 0.01 2e-9 snapshot_2.vtkhdf
#snapshot: 0 0 0 0.100 0.100 0.100 0.01 0.01 0.01 3e-9 snapshot_3.vtkhdf

查看文件

@@ -10,3 +10,8 @@
#snapshot: 0 0 0 0.100 0.100 0.001 0.01 0.01 0.01 1e-9 snapshot_1.h5
#snapshot: 0 0 0 0.100 0.100 0.001 0.01 0.01 0.01 2e-9 snapshot_2.h5
#snapshot: 0 0 0 0.100 0.100 0.001 0.01 0.01 0.01 3e-9 snapshot_3.h5
#snapshot: 0 0 0 0.100 0.100 0.001 0.01 0.01 0.01 1 snapshot_0.vtkhdf
#snapshot: 0 0 0 0.100 0.100 0.001 0.01 0.01 0.01 1e-9 snapshot_1.vtkhdf
#snapshot: 0 0 0 0.100 0.100 0.001 0.01 0.01 0.01 2e-9 snapshot_2.vtkhdf
#snapshot: 0 0 0 0.100 0.100 0.001 0.01 0.01 0.01 3e-9 snapshot_3.vtkhdf

查看文件

@@ -10,7 +10,16 @@ class Test2DSnapshot(GprMaxSnapshotTest):
tags = {"test", "serial", "2d", "waveform", "hertzian_dipole", "snapshot"}
sourcesdir = "src/snapshot_tests"
model = parameter(["whole_domain_2d"])
snapshots = ["snapshot_0.h5", "snapshot_1.h5", "snapshot_2.h5", "snapshot_3.h5"]
snapshots = [
"snapshot_0.h5",
"snapshot_1.h5",
"snapshot_2.h5",
"snapshot_3.h5",
"snapshot_0.vtkhdf",
"snapshot_1.vtkhdf",
"snapshot_2.vtkhdf",
"snapshot_3.vtkhdf",
]
@rfm.simple_test
@@ -18,7 +27,16 @@ class TestSnapshot(GprMaxSnapshotTest):
tags = {"test", "serial", "2d", "waveform", "hertzian_dipole", "snapshot"}
sourcesdir = "src/snapshot_tests"
model = parameter(["whole_domain"])
snapshots = ["snapshot_0.h5", "snapshot_1.h5", "snapshot_2.h5", "snapshot_3.h5"]
snapshots = [
"snapshot_0.h5",
"snapshot_1.h5",
"snapshot_2.h5",
"snapshot_3.h5",
"snapshot_0.vtkhdf",
"snapshot_1.vtkhdf",
"snapshot_2.vtkhdf",
"snapshot_3.vtkhdf",
]
@rfm.simple_test
@@ -39,6 +57,18 @@ class Test2DSliceSnapshot(GprMaxSnapshotTest):
"snapshot_z_25.h5",
"snapshot_z_55.h5",
"snapshot_z_85.h5",
"snapshot_x_05.vtkhdf",
"snapshot_x_35.vtkhdf",
"snapshot_x_65.vtkhdf",
"snapshot_x_95.vtkhdf",
"snapshot_y_15.vtkhdf",
"snapshot_y_40.vtkhdf",
"snapshot_y_45.vtkhdf",
"snapshot_y_50.vtkhdf",
"snapshot_y_75.vtkhdf",
"snapshot_z_25.vtkhdf",
"snapshot_z_55.vtkhdf",
"snapshot_z_85.vtkhdf",
]