diff --git a/gprMax/grid/fdtd_grid.py b/gprMax/grid/fdtd_grid.py index b73de81e..c1211357 100644 --- a/gprMax/grid/fdtd_grid.py +++ b/gprMax/grid/fdtd_grid.py @@ -21,7 +21,7 @@ import itertools import logging import sys from collections import OrderedDict -from typing import Any, Iterable, List, Literal, Tuple, Union +from typing import Any, Iterable, List, Tuple, Union import numpy as np import numpy.typing as npt @@ -424,6 +424,16 @@ class FDTDGrid: logger.info("") logger.info(f"Materials [{self.name}]:\n{materialstable.table}\n") + def update_sources_and_recievers(self): + """Update position of sources and receivers.""" + + # Adjust position of simple sources and receivers if required + model_num = config.sim_config.current_model + if any(self.srcsteps != 0): + self.update_simple_source_positions(model_num) + if any(self.rxsteps != 0): + self.update_receiver_positions(model_num) + def _update_positions( self, items: Iterable[Union[Source, Rx]], step_size: npt.NDArray[np.int32], step_number: int ) -> None: @@ -438,31 +448,22 @@ class FDTDGrid: ValueError: Raised if any of the items would be stepped outside of the grid. """ - if step_size[0] != 0 or step_size[1] != 0 or step_size[2] != 0: + if any(step_size > 0): for item in items: if step_number == 0: - if ( - item.xcoord + step_size[0] * config.sim_config.model_end < 0 - or item.xcoord + step_size[0] * config.sim_config.model_end > self.nx - or item.ycoord + step_size[1] * config.sim_config.model_end < 0 - or item.ycoord + step_size[1] * config.sim_config.model_end > self.ny - or item.zcoord + step_size[2] * config.sim_config.model_end < 0 - or item.zcoord + step_size[2] * config.sim_config.model_end > self.nz - ): - raise ValueError - item.coord = item.coordorigin + step_number * step_size + # Check item won't be stepped outside of the grid + end_coord = item.coord + step_size * config.sim_config.model_end + self.within_bounds(end_coord) + else: + item.coord = item.coordorigin + step_number * step_size - def update_simple_source_positions( - self, step_size: npt.NDArray[np.int32], step: int = 0 - ) -> None: + def update_simple_source_positions(self, step: int = 0) -> None: """Update the positions of sources in the grid. Move hertzian dipole and magnetic dipole sources. Transmission line sources and voltage sources will not be moved. Args: - step_size: Number of grid cells to move the sources each - step. step: Number of steps to move the sources by. Raises: @@ -471,18 +472,16 @@ class FDTDGrid: """ try: self._update_positions( - itertools.chain(self.hertziandipoles, self.magneticdipoles), step_size, step + itertools.chain(self.hertziandipoles, self.magneticdipoles), self.srcsteps, step ) except ValueError as e: logger.exception("Source(s) will be stepped to a position outside the domain.") raise ValueError from e - def update_receiver_positions(self, step_size: npt.NDArray[np.int32], step: int = 0) -> None: + def update_receiver_positions(self, step: int = 0) -> None: """Update the positions of receivers in the grid. Args: - step_size: Number of grid cells to move the receivers each - step. step: Number of steps to move the receivers by. Raises: @@ -490,7 +489,7 @@ class FDTDGrid: outside of the grid. """ try: - self._update_positions(self.rxs, step_size, step) + self._update_positions(self.rxs, self.rxsteps, step) except ValueError as e: logger.exception("Receiver(s) will be stepped to a position outside the domain.") raise ValueError from e diff --git a/gprMax/grid/mpi_grid.py b/gprMax/grid/mpi_grid.py index da9664e7..2a83fa51 100644 --- a/gprMax/grid/mpi_grid.py +++ b/gprMax/grid/mpi_grid.py @@ -24,7 +24,7 @@ from typing import List, Optional, Tuple, TypeVar, Union import numpy as np import numpy.typing as npt from mpi4py import MPI -from numpy import empty, ndarray +from numpy import ndarray from gprMax import config from gprMax.cython.pml_build import pml_sum_er_mr @@ -688,6 +688,88 @@ class MPIGrid(FDTDGrid): super().build() + def update_sources_and_recievers(self): + """Update position of sources and receivers. + + If any sources or receivers have stepped outside of the local + grid, they will be moved to the correct MPI rank. + """ + super().update_sources_and_recievers() + + # Check it is possible for sources and receivers to have moved + model_num = config.sim_config.current_model + if (all(self.srcsteps == 0) and all(self.rxsteps == 0)) or model_num == 0: + return + + # Get items that are outside the local bounds of the grid + items_to_send = list( + itertools.filterfalse( + lambda x: self.within_bounds(x.coord), + itertools.chain( + self.voltagesources, + self.hertziandipoles, + self.magneticdipoles, + self.transmissionlines, + self.discreteplanewaves, + self.rxs, + ), + ) + ) + + # Map items being sent to the global coordinate space + for item in items_to_send: + item.coord = self.local_to_global_coordinate(item.coord) + + send_count_by_rank = np.zeros(self.comm.size, dtype=np.int32) + + # Send items to correct rank + for rank, items in itertools.groupby( + items_to_send, lambda x: self.get_rank_from_coordinate(x.coord) + ): + self.comm.isend(list(items), rank) + send_count_by_rank[rank] += 1 + + # Communicate the number of messages sent to each rank + if self.is_coordinator(): + self.comm.Reduce(MPI.IN_PLACE, [send_count_by_rank, MPI.INT32_T], op=MPI.SUM) + else: + self.comm.Reduce([send_count_by_rank, MPI.INT32_T], None, op=MPI.SUM) + + # Get number of messages this rank will receive + messages_to_receive = np.zeros(1, dtype=np.int32) + if self.is_coordinator(): + self.comm.Scatter([send_count_by_rank, MPI.INT32_T], [messages_to_receive, MPI.INT32_T]) + else: + self.comm.Scatter(None, [messages_to_receive, MPI.INT32_T]) + + # Receive new items for this rank + for _ in range(messages_to_receive[0]): + new_items = self.comm.recv(None, MPI.ANY_SOURCE) + for item in new_items: + item.coord = self.global_to_local_coordinate(item.coord) + if isinstance(item, Rx): + self.add_receiver(item) + else: + self.add_source(item) + + # If this rank sent any items, remove them from our source and + # receiver lists + if len(items_to_send) > 0: + # Map items sent back to the local coordinate space + for item in items_to_send: + item.coord = self.global_to_local_coordinate(item.coord) + + filter_items = lambda items: list( + filter(lambda item: self.within_bounds(item.coord), items) + ) + + self.voltagesources = filter_items(self.voltagesources) + self.hertziandipoles = filter_items(self.hertziandipoles) + self.magneticdipoles = filter_items(self.magneticdipoles) + self.transmissionlines = filter_items(self.transmissionlines) + self.discreteplanewaves = filter_items(self.discreteplanewaves) + self.rxs = filter_items(self.rxs) + def has_neighbour(self, dim: Dim, dir: Dir) -> bool: """Test if the current rank has a specified neighbour. diff --git a/gprMax/model.py b/gprMax/model.py index 219012b5..29f1ef3a 100644 --- a/gprMax/model.py +++ b/gprMax/model.py @@ -333,8 +333,6 @@ class Model: def build(self): """Builds the Yee cells for a model.""" - G = self.G - # Monitor memory usage self.p = psutil.Process() @@ -348,10 +346,7 @@ class Model: f"Output directory: {config.get_model_config().output_file_path.parent.resolve()}\n" ) - # Adjust position of simple sources and receivers if required - model_num = config.sim_config.current_model - G.update_simple_source_positions(self.srcsteps, step=model_num) - G.update_receiver_positions(self.rxsteps, step=model_num) + self.G.update_sources_and_recievers() self._output_geometry()