Add validation for snapshots to fix out of bounds memory access

这个提交包含在:
nmannall
2024-09-26 18:13:05 +01:00
父节点 aef7737299
当前提交 9dd0d1b255
共有 3 个文件被更改,包括 105 次插入36 次删除

查看文件

@@ -22,6 +22,7 @@ from abc import ABC, abstractmethod
from pathlib import Path from pathlib import Path
import numpy as np import numpy as np
import numpy.typing as npt
from scipy import interpolate from scipy import interpolate
import gprMax.config as config import gprMax.config as config
@@ -1139,6 +1140,12 @@ class Snapshot(UserObjectMulti):
self.order = 9 self.order = 9
self.hash = "#snapshot" self.hash = "#snapshot"
def _calculate_upper_bound(
self, start: npt.NDArray, step: npt.NDArray, size: npt.NDArray
) -> npt.NDArray:
# upper_bound = p2 + dl - ((snapshot_size - 1) % dl) - 1
return start + step * np.ceil(size / step)
def build(self, model, uip): def build(self, model, uip):
grid = uip.grid grid = uip.grid
@@ -1154,17 +1161,69 @@ class Snapshot(UserObjectMulti):
logger.exception(f"{self.params_str()} requires exactly 11 parameters.") logger.exception(f"{self.params_str()} requires exactly 11 parameters.")
raise raise
dl = np.array(uip.discretise_static_point(dl))
try: try:
p3 = uip.round_to_grid_static_point(p1)
p4 = uip.round_to_grid_static_point(p2)
p1, p2 = uip.check_box_points(p1, p2, self.params_str()) p1, p2 = uip.check_box_points(p1, p2, self.params_str())
except ValueError: except ValueError:
logger.exception(f"{self.params_str()} point is outside the domain.") logger.exception(f"{self.params_str()} point is outside the domain.")
raise raise
xs, ys, zs = p1
xf, yf, zf = p2
dx, dy, dz = uip.discretise_static_point(dl) p1 = np.array(p1)
p2 = np.array(p2)
snapshot_size = p2 - p1
# If p2 does not line up with the set discretisation, the actual
# maximum element accessed in the grid will be this upper bound.
upper_bound = self._calculate_upper_bound(p1, dl, snapshot_size)
# Each coordinate may need a different method to correct p2.
# Therefore, this check needs to be repeated after each
# correction has been applied.
while any(p2 < upper_bound):
# Ideally extend p2 up to the correct upper bound. This will
# not change the snapshot output.
if uip.check_point_within_bounds(upper_bound):
p2 = upper_bound
p2_continuous = uip.descretised_to_continuous(p2)
logger.warning(
f"{self.params_str()} upper bound not aligned with discretisation. Updating 'p2'"
f" to {p2_continuous}"
)
# If the snapshot size cannot be increased, the
# discretisation may need reducing. E.g. for snapshots of 2D
# models.
elif any(dl > snapshot_size):
dl = np.where(dl > snapshot_size, snapshot_size, dl)
upper_bound = self._calculate_upper_bound(p1, dl, snapshot_size)
dl_continuous = uip.descretised_to_continuous(dl)
logger.warning(
f"{self.params_str()} current bounds and discretisation would go outside"
f" domain. As discretisation is larger than the snapshot size in at least one"
f" dimension, limiting 'dl' to {dl_continuous}"
)
# Otherwise, limit p2 to the discretisation step below the
# current snapshot size. This will reduce the size of the
# snapshot by 1 in the effected dimension(s), but avoid out
# of memory access.
else:
p2 = np.where(uip.grid_upper_bound() < upper_bound, p2 - (snapshot_size % dl), p2)
snapshot_size = p2 - p1
upper_bound = self._calculate_upper_bound(p1, dl, snapshot_size)
p2_continuous = uip.descretised_to_continuous(p2)
logger.warning(
f"{self.params_str()} current bounds and discretisation would go outside"
f" domain. Limiting 'p2' to {p2_continuous}"
)
if any(dl < 0):
logger.exception(f"{self.params_str()} the step size should not be less than zero.")
raise ValueError
if any(dl < 1):
logger.exception(
f"{self.params_str()} the step size should not be less than the spatial discretisation."
)
raise ValueError
# If number of iterations given # If number of iterations given
try: try:
@@ -1182,6 +1241,10 @@ class Snapshot(UserObjectMulti):
logger.exception(f"{self.params_str()} time value must be greater than zero.") logger.exception(f"{self.params_str()} time value must be greater than zero.")
raise ValueError raise ValueError
if iterations <= 0 or iterations > model.iterations:
logger.exception(f"{self.params_str()} time value is not valid.")
raise ValueError
try: try:
fileext = self.kwargs["fileext"] fileext = self.kwargs["fileext"]
if fileext not in SnapshotUser.fileexts: if fileext not in SnapshotUser.fileexts:
@@ -1196,7 +1259,7 @@ class Snapshot(UserObjectMulti):
if isinstance(grid, MPIGrid) and fileext != ".h5": if isinstance(grid, MPIGrid) and fileext != ".h5":
logger.exception( logger.exception(
f"{self.params_str()} Currently only '.h5' snapshots are compatible with MPI." f"{self.params_str()} currently only '.h5' snapshots are compatible with MPI."
) )
raise ValueError raise ValueError
@@ -1219,23 +1282,15 @@ class Snapshot(UserObjectMulti):
# If outputs are not specified, use default # If outputs are not specified, use default
outputs = dict.fromkeys(SnapshotUser.allowableoutputs, True) outputs = dict.fromkeys(SnapshotUser.allowableoutputs, True)
if dx < 0 or dy < 0 or dz < 0:
logger.exception(f"{self.params_str()} the step size should not be less than zero.")
raise ValueError
if dx < 1 or dy < 1 or dz < 1:
logger.exception(
f"{self.params_str()} the step size should not be less than the spatial discretisation."
)
raise ValueError
if iterations <= 0 or iterations > model.iterations:
logger.exception(f"{self.params_str()} time value is not valid.")
raise ValueError
if isinstance(grid, MPIGrid): if isinstance(grid, MPIGrid):
snapshot_type = MPISnapshotUser snapshot_type = MPISnapshotUser
else: else:
snapshot_type = SnapshotUser snapshot_type = SnapshotUser
xs, ys, zs = p1
xf, yf, zf = p2
dx, dy, dz = dl
s = snapshot_type( s = snapshot_type(
xs, xs,
ys, ys,
@@ -1255,8 +1310,8 @@ class Snapshot(UserObjectMulti):
) )
logger.info( logger.info(
f"Snapshot from {p3[0]:g}m, {p3[1]:g}m, {p3[2]:g}m, to " f"Snapshot from {xs * grid.dx:g}m, {ys * grid.dy:g}m, {zs * grid.dz:g}m, to "
f"{p4[0]:g}m, {p4[1]:g}m, {p4[2]:g}m, discretisation " f"{xf * grid.dx:g}m, {yf * grid.dy:g}m, {zf * grid.dz:g}m, discretisation "
f"{dx * grid.dx:g}m, {dy * grid.dy:g}m, {dz * grid.dz:g}m, " f"{dx * grid.dx:g}m, {dy * grid.dy:g}m, {dz * grid.dz:g}m, "
f"at {s.time * grid.dt:g} secs with field outputs " f"at {s.time * grid.dt:g} secs with field outputs "
f"{', '.join([k for k, v in outputs.items() if v])} and " f"{', '.join([k for k, v in outputs.items() if v])} and "

查看文件

@@ -274,13 +274,13 @@ class MPIGrid(FDTDGrid):
""" """
if self.is_coordinator(): if self.is_coordinator():
snapshots_by_rank: List[List[Optional[Snapshot]]] = [[] for _ in range(self.comm.size)] snapshots_by_rank: List[List[Optional[Snapshot]]] = [[] for _ in range(self.comm.size)]
for s in self.snapshots: for snapshot in self.snapshots:
ranks = self.get_ranks_between_coordinates(s.start, s.stop + s.step) ranks = self.get_ranks_between_coordinates(snapshot.start, snapshot.stop)
for rank in range( for rank in range(
self.comm.size self.comm.size
): # TODO: Loop over ranks in snapshot, not all ranks ): # TODO: Loop over ranks in snapshot, not all ranks
if rank in ranks: if rank in ranks:
snapshots_by_rank[rank].append(s) snapshots_by_rank[rank].append(snapshot)
else: else:
# All ranks need the same number of 'snapshots' # All ranks need the same number of 'snapshots'
# (which may be None) to ensure snapshot # (which may be None) to ensure snapshot
@@ -294,28 +294,32 @@ class MPIGrid(FDTDGrid):
snapshots_by_rank, root=self.COORDINATOR_RANK snapshots_by_rank, root=self.COORDINATOR_RANK
) )
for s in snapshots: for snapshot in snapshots:
if s is None: if snapshot is None:
self.comm.Split(MPI.UNDEFINED) self.comm.Split(MPI.UNDEFINED)
else: else:
comm = self.comm.Split() comm = self.comm.Split()
assert isinstance(comm, MPI.Intracomm) assert isinstance(comm, MPI.Intracomm)
start = self.get_grid_coord_from_coordinate(s.start) start = self.get_grid_coord_from_coordinate(snapshot.start)
stop = self.get_grid_coord_from_coordinate(s.stop + s.step) + 1 stop = self.get_grid_coord_from_coordinate(snapshot.stop) + 1
s.comm = comm.Create_cart((stop - start).tolist()) snapshot.comm = comm.Create_cart((stop - start).tolist())
s.start = self.global_to_local_coordinate(s.start) snapshot.start = self.global_to_local_coordinate(snapshot.start)
# Calculate number of steps needed to bring the start # Calculate number of steps needed to bring the start
# into the local grid (and not in the negative halo) # into the local grid (and not in the negative halo)
s.offset = np.where( snapshot.offset = np.where(
s.start < self.negative_halo_offset, snapshot.start < self.negative_halo_offset,
np.abs((s.start - self.negative_halo_offset) // s.step), np.abs((snapshot.start - self.negative_halo_offset) // snapshot.step),
s.offset, snapshot.offset,
) )
s.start += s.step * s.offset snapshot.start += snapshot.step * snapshot.offset
s.stop = self.global_to_local_coordinate(s.stop) snapshot.stop = self.global_to_local_coordinate(snapshot.stop)
s.stop = np.where(s.stop > self.size, self.size, s.stop) snapshot.stop = np.where(
snapshot.stop > self.size,
self.size + ((snapshot.stop - self.size) % snapshot.step),
snapshot.stop,
)
self.snapshots = [s for s in snapshots if s is not None] self.snapshots = [s for s in snapshots if s is not None]

查看文件

@@ -64,6 +64,16 @@ class UserInput(Generic[GridType]):
logger.exception(s) logger.exception(s)
raise raise
def check_point_within_bounds(self, p) -> bool:
try:
self.grid.within_bounds(p)
return True
except ValueError:
return False
def grid_upper_bound(self) -> list[int]:
return [self.grid.nx, self.grid.ny, self.grid.nz]
def discretise_point(self, p): def discretise_point(self, p):
"""Gets the index of a continuous point with the grid.""" """Gets the index of a continuous point with the grid."""
rv = np.vectorize(round_value) rv = np.vectorize(round_value)