Add validation for snapshots to fix out of bounds memory access

这个提交包含在:
nmannall
2024-09-26 18:13:05 +01:00
父节点 aef7737299
当前提交 9dd0d1b255
共有 3 个文件被更改,包括 105 次插入36 次删除

查看文件

@@ -22,6 +22,7 @@ from abc import ABC, abstractmethod
from pathlib import Path
import numpy as np
import numpy.typing as npt
from scipy import interpolate
import gprMax.config as config
@@ -1139,6 +1140,12 @@ class Snapshot(UserObjectMulti):
self.order = 9
self.hash = "#snapshot"
def _calculate_upper_bound(
self, start: npt.NDArray, step: npt.NDArray, size: npt.NDArray
) -> npt.NDArray:
# upper_bound = p2 + dl - ((snapshot_size - 1) % dl) - 1
return start + step * np.ceil(size / step)
def build(self, model, uip):
grid = uip.grid
@@ -1154,17 +1161,69 @@ class Snapshot(UserObjectMulti):
logger.exception(f"{self.params_str()} requires exactly 11 parameters.")
raise
dl = np.array(uip.discretise_static_point(dl))
try:
p3 = uip.round_to_grid_static_point(p1)
p4 = uip.round_to_grid_static_point(p2)
p1, p2 = uip.check_box_points(p1, p2, self.params_str())
except ValueError:
logger.exception(f"{self.params_str()} point is outside the domain.")
raise
xs, ys, zs = p1
xf, yf, zf = p2
dx, dy, dz = uip.discretise_static_point(dl)
p1 = np.array(p1)
p2 = np.array(p2)
snapshot_size = p2 - p1
# If p2 does not line up with the set discretisation, the actual
# maximum element accessed in the grid will be this upper bound.
upper_bound = self._calculate_upper_bound(p1, dl, snapshot_size)
# Each coordinate may need a different method to correct p2.
# Therefore, this check needs to be repeated after each
# correction has been applied.
while any(p2 < upper_bound):
# Ideally extend p2 up to the correct upper bound. This will
# not change the snapshot output.
if uip.check_point_within_bounds(upper_bound):
p2 = upper_bound
p2_continuous = uip.descretised_to_continuous(p2)
logger.warning(
f"{self.params_str()} upper bound not aligned with discretisation. Updating 'p2'"
f" to {p2_continuous}"
)
# If the snapshot size cannot be increased, the
# discretisation may need reducing. E.g. for snapshots of 2D
# models.
elif any(dl > snapshot_size):
dl = np.where(dl > snapshot_size, snapshot_size, dl)
upper_bound = self._calculate_upper_bound(p1, dl, snapshot_size)
dl_continuous = uip.descretised_to_continuous(dl)
logger.warning(
f"{self.params_str()} current bounds and discretisation would go outside"
f" domain. As discretisation is larger than the snapshot size in at least one"
f" dimension, limiting 'dl' to {dl_continuous}"
)
# Otherwise, limit p2 to the discretisation step below the
# current snapshot size. This will reduce the size of the
# snapshot by 1 in the effected dimension(s), but avoid out
# of memory access.
else:
p2 = np.where(uip.grid_upper_bound() < upper_bound, p2 - (snapshot_size % dl), p2)
snapshot_size = p2 - p1
upper_bound = self._calculate_upper_bound(p1, dl, snapshot_size)
p2_continuous = uip.descretised_to_continuous(p2)
logger.warning(
f"{self.params_str()} current bounds and discretisation would go outside"
f" domain. Limiting 'p2' to {p2_continuous}"
)
if any(dl < 0):
logger.exception(f"{self.params_str()} the step size should not be less than zero.")
raise ValueError
if any(dl < 1):
logger.exception(
f"{self.params_str()} the step size should not be less than the spatial discretisation."
)
raise ValueError
# If number of iterations given
try:
@@ -1182,6 +1241,10 @@ class Snapshot(UserObjectMulti):
logger.exception(f"{self.params_str()} time value must be greater than zero.")
raise ValueError
if iterations <= 0 or iterations > model.iterations:
logger.exception(f"{self.params_str()} time value is not valid.")
raise ValueError
try:
fileext = self.kwargs["fileext"]
if fileext not in SnapshotUser.fileexts:
@@ -1196,7 +1259,7 @@ class Snapshot(UserObjectMulti):
if isinstance(grid, MPIGrid) and fileext != ".h5":
logger.exception(
f"{self.params_str()} Currently only '.h5' snapshots are compatible with MPI."
f"{self.params_str()} currently only '.h5' snapshots are compatible with MPI."
)
raise ValueError
@@ -1219,23 +1282,15 @@ class Snapshot(UserObjectMulti):
# If outputs are not specified, use default
outputs = dict.fromkeys(SnapshotUser.allowableoutputs, True)
if dx < 0 or dy < 0 or dz < 0:
logger.exception(f"{self.params_str()} the step size should not be less than zero.")
raise ValueError
if dx < 1 or dy < 1 or dz < 1:
logger.exception(
f"{self.params_str()} the step size should not be less than the spatial discretisation."
)
raise ValueError
if iterations <= 0 or iterations > model.iterations:
logger.exception(f"{self.params_str()} time value is not valid.")
raise ValueError
if isinstance(grid, MPIGrid):
snapshot_type = MPISnapshotUser
else:
snapshot_type = SnapshotUser
xs, ys, zs = p1
xf, yf, zf = p2
dx, dy, dz = dl
s = snapshot_type(
xs,
ys,
@@ -1255,8 +1310,8 @@ class Snapshot(UserObjectMulti):
)
logger.info(
f"Snapshot from {p3[0]:g}m, {p3[1]:g}m, {p3[2]:g}m, to "
f"{p4[0]:g}m, {p4[1]:g}m, {p4[2]:g}m, discretisation "
f"Snapshot from {xs * grid.dx:g}m, {ys * grid.dy:g}m, {zs * grid.dz:g}m, to "
f"{xf * grid.dx:g}m, {yf * grid.dy:g}m, {zf * grid.dz:g}m, discretisation "
f"{dx * grid.dx:g}m, {dy * grid.dy:g}m, {dz * grid.dz:g}m, "
f"at {s.time * grid.dt:g} secs with field outputs "
f"{', '.join([k for k, v in outputs.items() if v])} and "

查看文件

@@ -274,13 +274,13 @@ class MPIGrid(FDTDGrid):
"""
if self.is_coordinator():
snapshots_by_rank: List[List[Optional[Snapshot]]] = [[] for _ in range(self.comm.size)]
for s in self.snapshots:
ranks = self.get_ranks_between_coordinates(s.start, s.stop + s.step)
for snapshot in self.snapshots:
ranks = self.get_ranks_between_coordinates(snapshot.start, snapshot.stop)
for rank in range(
self.comm.size
): # TODO: Loop over ranks in snapshot, not all ranks
if rank in ranks:
snapshots_by_rank[rank].append(s)
snapshots_by_rank[rank].append(snapshot)
else:
# All ranks need the same number of 'snapshots'
# (which may be None) to ensure snapshot
@@ -294,28 +294,32 @@ class MPIGrid(FDTDGrid):
snapshots_by_rank, root=self.COORDINATOR_RANK
)
for s in snapshots:
if s is None:
for snapshot in snapshots:
if snapshot is None:
self.comm.Split(MPI.UNDEFINED)
else:
comm = self.comm.Split()
assert isinstance(comm, MPI.Intracomm)
start = self.get_grid_coord_from_coordinate(s.start)
stop = self.get_grid_coord_from_coordinate(s.stop + s.step) + 1
s.comm = comm.Create_cart((stop - start).tolist())
start = self.get_grid_coord_from_coordinate(snapshot.start)
stop = self.get_grid_coord_from_coordinate(snapshot.stop) + 1
snapshot.comm = comm.Create_cart((stop - start).tolist())
s.start = self.global_to_local_coordinate(s.start)
snapshot.start = self.global_to_local_coordinate(snapshot.start)
# Calculate number of steps needed to bring the start
# into the local grid (and not in the negative halo)
s.offset = np.where(
s.start < self.negative_halo_offset,
np.abs((s.start - self.negative_halo_offset) // s.step),
s.offset,
snapshot.offset = np.where(
snapshot.start < self.negative_halo_offset,
np.abs((snapshot.start - self.negative_halo_offset) // snapshot.step),
snapshot.offset,
)
s.start += s.step * s.offset
snapshot.start += snapshot.step * snapshot.offset
s.stop = self.global_to_local_coordinate(s.stop)
s.stop = np.where(s.stop > self.size, self.size, s.stop)
snapshot.stop = self.global_to_local_coordinate(snapshot.stop)
snapshot.stop = np.where(
snapshot.stop > self.size,
self.size + ((snapshot.stop - self.size) % snapshot.step),
snapshot.stop,
)
self.snapshots = [s for s in snapshots if s is not None]

查看文件

@@ -64,6 +64,16 @@ class UserInput(Generic[GridType]):
logger.exception(s)
raise
def check_point_within_bounds(self, p) -> bool:
try:
self.grid.within_bounds(p)
return True
except ValueError:
return False
def grid_upper_bound(self) -> list[int]:
return [self.grid.nx, self.grid.ny, self.grid.nz]
def discretise_point(self, p):
"""Gets the index of a continuous point with the grid."""
rv = np.vectorize(round_value)