diff --git a/gprMax/vtkhdf.py b/gprMax/vtkhdf.py index f9403de9..0a2e6278 100644 --- a/gprMax/vtkhdf.py +++ b/gprMax/vtkhdf.py @@ -225,6 +225,8 @@ class VtkImageData(VtkHdfFile): SPACING_ATTR = "Spacing" WHOLE_EXTENT_ATTR = "WholeExtent" + DIMENSIONS = 3 + @property def TYPE(self) -> Literal["ImageData"]: return "ImageData" @@ -236,26 +238,33 @@ class VtkImageData(VtkHdfFile): origin: Optional[npt.NDArray[np.single]] = None, spacing: Optional[npt.NDArray[np.single]] = None, direction: Optional[npt.NDArray[np.single]] = None, - comm: Optional[MPI.Cartcomm] = None, + comm: Optional[MPI.Comm] = None, ) -> None: super().__init__(filename, comm) + if len(shape) == 0: + raise ValueError(f"Shape must not be empty.") + if len(shape) > self.DIMENSIONS: + raise ValueError(f"Shape must not have more than {self.DIMENSIONS} dimensions.") + elif len(shape) < self.DIMENSIONS: + shape = np.concatenate((shape, np.ones(self.DIMENSIONS - len(shape), dtype=np.intc))) + self.shape = shape - self.points_shape = shape + 1 - whole_extent = np.zeros(2 * len(self.shape), dtype=np.intc) + + whole_extent = np.zeros(2 * self.DIMENSIONS, dtype=np.intc) whole_extent[1::2] = self.shape self._set_root_attribute(self.WHOLE_EXTENT_ATTR, whole_extent) if origin is None: - origin = np.zeros(len(self.shape), dtype=np.single) + origin = np.zeros(self.DIMENSIONS, dtype=np.single) self.set_origin(origin) if spacing is None: - spacing = np.ones(len(self.shape), dtype=np.single) + spacing = np.ones(self.DIMENSIONS, dtype=np.single) self.set_spacing(spacing) if direction is None: - direction = np.diag(np.ones(len(self.shape), dtype=np.single)).flatten() + direction = np.diag(np.ones(self.DIMENSIONS, dtype=np.single)).flatten() self.set_direction(direction) @property @@ -275,30 +284,39 @@ class VtkImageData(VtkHdfFile): return self._get_root_attribute(self.DIRECTION_ATTR) def set_origin(self, origin: npt.NDArray[np.single]): + if len(origin) != self.DIMENSIONS: + raise ValueError(f"Origin attribute must have {self.DIMENSIONS} dimensions.") self._set_root_attribute(self.ORIGIN_ATTR, origin) def set_spacing(self, spacing: npt.NDArray[np.single]): + if len(spacing) != self.DIMENSIONS: + raise ValueError(f"Spacing attribute must have {self.DIMENSIONS} dimensions.") self._set_root_attribute(self.SPACING_ATTR, spacing) def set_direction(self, direction: npt.NDArray[np.single]): + if len(direction) != self.DIMENSIONS * self.DIMENSIONS: + raise ValueError( + f"Direction attribute must have {self.DIMENSIONS * self.DIMENSIONS} dimensions." + ) self._set_root_attribute(self.DIRECTION_ATTR, direction) def add_point_data( self, name: str, data: npt.NDArray, offset: Optional[npt.NDArray[np.intc]] = None ): - if offset is None and any(data.shape != self.points_shape): # type: ignore + points_shape = self.shape + 1 + if offset is None and any(data.shape != points_shape): # type: ignore raise ValueError( - f"If no offset is specified, data.shape {data.shape} must match the shape of the" - f" VtkImageData point datasets {self.points_shape}" + "If no offset is specified, data.shape must be one larger in each dimension than" + f" this vtkImageData object. {data.shape} != {points_shape}" ) - return super().add_point_data(name, data, self.points_shape, offset) + return super().add_point_data(name, data, points_shape, offset) def add_cell_data( self, name: str, data: npt.NDArray, offset: Optional[npt.NDArray[np.intc]] = None ): if offset is None and any(data.shape != self.shape): # type: ignore raise ValueError( - f"If no offset is specified, data.shape {data.shape} must match the shape of the" - f" VtkImageData {self.shape}" + "If no offset is specified, data.shape must match the dimensions of this" + f" vtkImageData object. {data.shape} != {self.shape}" ) return super().add_cell_data(name, data, self.shape, offset)