From 7d0321715f111f8522ac362e263213c7e674b89a Mon Sep 17 00:00:00 2001 From: nmannall Date: Thu, 18 Jul 2024 16:59:51 +0100 Subject: [PATCH] Fix snapshots always defaulting to .vti files --- gprMax/cmds_multiuse.py | 22 ++++++++++++++++------ gprMax/hash_cmds_multiuse.py | 9 +++++++-- 2 files changed, 23 insertions(+), 8 deletions(-) diff --git a/gprMax/cmds_multiuse.py b/gprMax/cmds_multiuse.py index c4f82663..6e835532 100644 --- a/gprMax/cmds_multiuse.py +++ b/gprMax/cmds_multiuse.py @@ -43,6 +43,7 @@ from .materials import PeplinskiSoil as PeplinskiSoilUser from .materials import RangeMaterial as RangeMaterialUser from .pml import CFS, CFSParameter from .receivers import Rx as RxUser +from .snapshots import MPISnapshot as MPISnapshotUser from .snapshots import Snapshot as SnapshotUser from .sources import HertzianDipole as HertzianDipoleUser from .sources import MagneticDipole as MagneticDipoleUser @@ -1140,11 +1141,7 @@ class Snapshot(UserObjectMulti): def build(self, model, uip): grid = uip.grid - if isinstance(grid, MPIGrid): - logger.exception( - f"{self.params_str()} Snapshots are not currently compatible with MPI." - ) - raise ValueError + if isinstance(grid, SubGridBaseGrid): logger.exception(f"{self.params_str()} do not add snapshots to subgrids.") raise ValueError @@ -1197,6 +1194,12 @@ class Snapshot(UserObjectMulti): except KeyError: fileext = SnapshotUser.fileexts[0] + if isinstance(grid, MPIGrid) and fileext != ".h5": + logger.exception( + f"{self.params_str()} Currently only '.h5' snapshots are compatible with MPI." + ) + raise ValueError + try: tmp = self.kwargs["outputs"] outputs = dict.fromkeys(SnapshotUser.allowableoutputs, False) @@ -1228,7 +1231,12 @@ class Snapshot(UserObjectMulti): logger.exception(f"{self.params_str()} time value is not valid.") raise ValueError - s = SnapshotUser( + if isinstance(grid, MPIGrid): + snapshot_type = MPISnapshotUser + else: + snapshot_type = SnapshotUser + + s = snapshot_type( xs, ys, zs, @@ -1242,6 +1250,8 @@ class Snapshot(UserObjectMulti): filename, fileext=fileext, outputs=outputs, + grid_dl=grid.dl, + grid_dt=grid.dt, ) logger.info( diff --git a/gprMax/hash_cmds_multiuse.py b/gprMax/hash_cmds_multiuse.py index a6ee5612..c933b7c4 100644 --- a/gprMax/hash_cmds_multiuse.py +++ b/gprMax/hash_cmds_multiuse.py @@ -279,14 +279,19 @@ def process_multicmds(multicmds): p2 = (float(tmp[3]), float(tmp[4]), float(tmp[5])) dl = (float(tmp[6]), float(tmp[7]), float(tmp[8])) filename = tmp[10] + fileext = "." + filename.split(".")[-1] try: iterations = int(tmp[9]) - snapshot = Snapshot(p1=p1, p2=p2, dl=dl, iterations=iterations, filename=filename) + snapshot = Snapshot( + p1=p1, p2=p2, dl=dl, iterations=iterations, filename=filename, fileext=fileext + ) except ValueError: time = float(tmp[9]) - snapshot = Snapshot(p1=p1, p2=p2, dl=dl, time=time, filename=filename) + snapshot = Snapshot( + p1=p1, p2=p2, dl=dl, time=time, filename=filename, fileext=fileext + ) scene_objects.append(snapshot)