Fix Bscan models when running with MPI

这个提交包含在:
Nathan Mannall
2025-03-21 16:12:32 +00:00
父节点 b185f5506c
当前提交 3da372cc82
共有 3 个文件被更改,包括 105 次插入29 次删除

查看文件

@@ -21,7 +21,7 @@ import itertools
import logging import logging
import sys import sys
from collections import OrderedDict from collections import OrderedDict
from typing import Any, Iterable, List, Literal, Tuple, Union from typing import Any, Iterable, List, Tuple, Union
import numpy as np import numpy as np
import numpy.typing as npt import numpy.typing as npt
@@ -424,6 +424,16 @@ class FDTDGrid:
logger.info("") logger.info("")
logger.info(f"Materials [{self.name}]:\n{materialstable.table}\n") logger.info(f"Materials [{self.name}]:\n{materialstable.table}\n")
def update_sources_and_recievers(self):
"""Update position of sources and receivers."""
# Adjust position of simple sources and receivers if required
model_num = config.sim_config.current_model
if any(self.srcsteps != 0):
self.update_simple_source_positions(model_num)
if any(self.rxsteps != 0):
self.update_receiver_positions(model_num)
def _update_positions( def _update_positions(
self, items: Iterable[Union[Source, Rx]], step_size: npt.NDArray[np.int32], step_number: int self, items: Iterable[Union[Source, Rx]], step_size: npt.NDArray[np.int32], step_number: int
) -> None: ) -> None:
@@ -438,31 +448,22 @@ class FDTDGrid:
ValueError: Raised if any of the items would be stepped ValueError: Raised if any of the items would be stepped
outside of the grid. outside of the grid.
""" """
if step_size[0] != 0 or step_size[1] != 0 or step_size[2] != 0: if any(step_size > 0):
for item in items: for item in items:
if step_number == 0: if step_number == 0:
if ( # Check item won't be stepped outside of the grid
item.xcoord + step_size[0] * config.sim_config.model_end < 0 end_coord = item.coord + step_size * config.sim_config.model_end
or item.xcoord + step_size[0] * config.sim_config.model_end > self.nx self.within_bounds(end_coord)
or item.ycoord + step_size[1] * config.sim_config.model_end < 0 else:
or item.ycoord + step_size[1] * config.sim_config.model_end > self.ny item.coord = item.coordorigin + step_number * step_size
or item.zcoord + step_size[2] * config.sim_config.model_end < 0
or item.zcoord + step_size[2] * config.sim_config.model_end > self.nz
):
raise ValueError
item.coord = item.coordorigin + step_number * step_size
def update_simple_source_positions( def update_simple_source_positions(self, step: int = 0) -> None:
self, step_size: npt.NDArray[np.int32], step: int = 0
) -> None:
"""Update the positions of sources in the grid. """Update the positions of sources in the grid.
Move hertzian dipole and magnetic dipole sources. Transmission Move hertzian dipole and magnetic dipole sources. Transmission
line sources and voltage sources will not be moved. line sources and voltage sources will not be moved.
Args: Args:
step_size: Number of grid cells to move the sources each
step.
step: Number of steps to move the sources by. step: Number of steps to move the sources by.
Raises: Raises:
@@ -471,18 +472,16 @@ class FDTDGrid:
""" """
try: try:
self._update_positions( self._update_positions(
itertools.chain(self.hertziandipoles, self.magneticdipoles), step_size, step itertools.chain(self.hertziandipoles, self.magneticdipoles), self.srcsteps, step
) )
except ValueError as e: except ValueError as e:
logger.exception("Source(s) will be stepped to a position outside the domain.") logger.exception("Source(s) will be stepped to a position outside the domain.")
raise ValueError from e raise ValueError from e
def update_receiver_positions(self, step_size: npt.NDArray[np.int32], step: int = 0) -> None: def update_receiver_positions(self, step: int = 0) -> None:
"""Update the positions of receivers in the grid. """Update the positions of receivers in the grid.
Args: Args:
step_size: Number of grid cells to move the receivers each
step.
step: Number of steps to move the receivers by. step: Number of steps to move the receivers by.
Raises: Raises:
@@ -490,7 +489,7 @@ class FDTDGrid:
outside of the grid. outside of the grid.
""" """
try: try:
self._update_positions(self.rxs, step_size, step) self._update_positions(self.rxs, self.rxsteps, step)
except ValueError as e: except ValueError as e:
logger.exception("Receiver(s) will be stepped to a position outside the domain.") logger.exception("Receiver(s) will be stepped to a position outside the domain.")
raise ValueError from e raise ValueError from e

查看文件

@@ -24,7 +24,7 @@ from typing import List, Optional, Tuple, TypeVar, Union
import numpy as np import numpy as np
import numpy.typing as npt import numpy.typing as npt
from mpi4py import MPI from mpi4py import MPI
from numpy import empty, ndarray from numpy import ndarray
from gprMax import config from gprMax import config
from gprMax.cython.pml_build import pml_sum_er_mr from gprMax.cython.pml_build import pml_sum_er_mr
@@ -688,6 +688,88 @@ class MPIGrid(FDTDGrid):
super().build() super().build()
def update_sources_and_recievers(self):
"""Update position of sources and receivers.
If any sources or receivers have stepped outside of the local
grid, they will be moved to the correct MPI rank.
"""
super().update_sources_and_recievers()
# Check it is possible for sources and receivers to have moved
model_num = config.sim_config.current_model
if (all(self.srcsteps == 0) and all(self.rxsteps == 0)) or model_num == 0:
return
# Get items that are outside the local bounds of the grid
items_to_send = list(
itertools.filterfalse(
lambda x: self.within_bounds(x.coord),
itertools.chain(
self.voltagesources,
self.hertziandipoles,
self.magneticdipoles,
self.transmissionlines,
self.discreteplanewaves,
self.rxs,
),
)
)
# Map items being sent to the global coordinate space
for item in items_to_send:
item.coord = self.local_to_global_coordinate(item.coord)
send_count_by_rank = np.zeros(self.comm.size, dtype=np.int32)
# Send items to correct rank
for rank, items in itertools.groupby(
items_to_send, lambda x: self.get_rank_from_coordinate(x.coord)
):
self.comm.isend(list(items), rank)
send_count_by_rank[rank] += 1
# Communicate the number of messages sent to each rank
if self.is_coordinator():
self.comm.Reduce(MPI.IN_PLACE, [send_count_by_rank, MPI.INT32_T], op=MPI.SUM)
else:
self.comm.Reduce([send_count_by_rank, MPI.INT32_T], None, op=MPI.SUM)
# Get number of messages this rank will receive
messages_to_receive = np.zeros(1, dtype=np.int32)
if self.is_coordinator():
self.comm.Scatter([send_count_by_rank, MPI.INT32_T], [messages_to_receive, MPI.INT32_T])
else:
self.comm.Scatter(None, [messages_to_receive, MPI.INT32_T])
# Receive new items for this rank
for _ in range(messages_to_receive[0]):
new_items = self.comm.recv(None, MPI.ANY_SOURCE)
for item in new_items:
item.coord = self.global_to_local_coordinate(item.coord)
if isinstance(item, Rx):
self.add_receiver(item)
else:
self.add_source(item)
# If this rank sent any items, remove them from our source and
# receiver lists
if len(items_to_send) > 0:
# Map items sent back to the local coordinate space
for item in items_to_send:
item.coord = self.global_to_local_coordinate(item.coord)
filter_items = lambda items: list(
filter(lambda item: self.within_bounds(item.coord), items)
)
self.voltagesources = filter_items(self.voltagesources)
self.hertziandipoles = filter_items(self.hertziandipoles)
self.magneticdipoles = filter_items(self.magneticdipoles)
self.transmissionlines = filter_items(self.transmissionlines)
self.discreteplanewaves = filter_items(self.discreteplanewaves)
self.rxs = filter_items(self.rxs)
def has_neighbour(self, dim: Dim, dir: Dir) -> bool: def has_neighbour(self, dim: Dim, dir: Dir) -> bool:
"""Test if the current rank has a specified neighbour. """Test if the current rank has a specified neighbour.

查看文件

@@ -333,8 +333,6 @@ class Model:
def build(self): def build(self):
"""Builds the Yee cells for a model.""" """Builds the Yee cells for a model."""
G = self.G
# Monitor memory usage # Monitor memory usage
self.p = psutil.Process() self.p = psutil.Process()
@@ -348,10 +346,7 @@ class Model:
f"Output directory: {config.get_model_config().output_file_path.parent.resolve()}\n" f"Output directory: {config.get_model_config().output_file_path.parent.resolve()}\n"
) )
# Adjust position of simple sources and receivers if required self.G.update_sources_and_recievers()
model_num = config.sim_config.current_model
G.update_simple_source_positions(self.srcsteps, step=model_num)
G.update_receiver_positions(self.rxsteps, step=model_num)
self._output_geometry() self._output_geometry()