Fix snapshots always defaulting to .vti files

这个提交包含在:
nmannall
2024-07-18 16:59:51 +01:00
父节点 8af024b496
当前提交 7d0321715f
共有 2 个文件被更改,包括 23 次插入8 次删除

查看文件

@@ -43,6 +43,7 @@ from .materials import PeplinskiSoil as PeplinskiSoilUser
from .materials import RangeMaterial as RangeMaterialUser from .materials import RangeMaterial as RangeMaterialUser
from .pml import CFS, CFSParameter from .pml import CFS, CFSParameter
from .receivers import Rx as RxUser from .receivers import Rx as RxUser
from .snapshots import MPISnapshot as MPISnapshotUser
from .snapshots import Snapshot as SnapshotUser from .snapshots import Snapshot as SnapshotUser
from .sources import HertzianDipole as HertzianDipoleUser from .sources import HertzianDipole as HertzianDipoleUser
from .sources import MagneticDipole as MagneticDipoleUser from .sources import MagneticDipole as MagneticDipoleUser
@@ -1140,11 +1141,7 @@ class Snapshot(UserObjectMulti):
def build(self, model, uip): def build(self, model, uip):
grid = uip.grid grid = uip.grid
if isinstance(grid, MPIGrid):
logger.exception(
f"{self.params_str()} Snapshots are not currently compatible with MPI."
)
raise ValueError
if isinstance(grid, SubGridBaseGrid): if isinstance(grid, SubGridBaseGrid):
logger.exception(f"{self.params_str()} do not add snapshots to subgrids.") logger.exception(f"{self.params_str()} do not add snapshots to subgrids.")
raise ValueError raise ValueError
@@ -1197,6 +1194,12 @@ class Snapshot(UserObjectMulti):
except KeyError: except KeyError:
fileext = SnapshotUser.fileexts[0] fileext = SnapshotUser.fileexts[0]
if isinstance(grid, MPIGrid) and fileext != ".h5":
logger.exception(
f"{self.params_str()} Currently only '.h5' snapshots are compatible with MPI."
)
raise ValueError
try: try:
tmp = self.kwargs["outputs"] tmp = self.kwargs["outputs"]
outputs = dict.fromkeys(SnapshotUser.allowableoutputs, False) outputs = dict.fromkeys(SnapshotUser.allowableoutputs, False)
@@ -1228,7 +1231,12 @@ class Snapshot(UserObjectMulti):
logger.exception(f"{self.params_str()} time value is not valid.") logger.exception(f"{self.params_str()} time value is not valid.")
raise ValueError raise ValueError
s = SnapshotUser( if isinstance(grid, MPIGrid):
snapshot_type = MPISnapshotUser
else:
snapshot_type = SnapshotUser
s = snapshot_type(
xs, xs,
ys, ys,
zs, zs,
@@ -1242,6 +1250,8 @@ class Snapshot(UserObjectMulti):
filename, filename,
fileext=fileext, fileext=fileext,
outputs=outputs, outputs=outputs,
grid_dl=grid.dl,
grid_dt=grid.dt,
) )
logger.info( logger.info(

查看文件

@@ -279,14 +279,19 @@ def process_multicmds(multicmds):
p2 = (float(tmp[3]), float(tmp[4]), float(tmp[5])) p2 = (float(tmp[3]), float(tmp[4]), float(tmp[5]))
dl = (float(tmp[6]), float(tmp[7]), float(tmp[8])) dl = (float(tmp[6]), float(tmp[7]), float(tmp[8]))
filename = tmp[10] filename = tmp[10]
fileext = "." + filename.split(".")[-1]
try: try:
iterations = int(tmp[9]) iterations = int(tmp[9])
snapshot = Snapshot(p1=p1, p2=p2, dl=dl, iterations=iterations, filename=filename) snapshot = Snapshot(
p1=p1, p2=p2, dl=dl, iterations=iterations, filename=filename, fileext=fileext
)
except ValueError: except ValueError:
time = float(tmp[9]) time = float(tmp[9])
snapshot = Snapshot(p1=p1, p2=p2, dl=dl, time=time, filename=filename) snapshot = Snapshot(
p1=p1, p2=p2, dl=dl, time=time, filename=filename, fileext=fileext
)
scene_objects.append(snapshot) scene_objects.append(snapshot)