Fix Bscan models when running with MPI

这个提交包含在:
Nathan Mannall
2025-03-21 16:12:32 +00:00
父节点 b185f5506c
当前提交 3da372cc82
共有 3 个文件被更改,包括 105 次插入29 次删除

查看文件

@@ -21,7 +21,7 @@ import itertools
import logging
import sys
from collections import OrderedDict
from typing import Any, Iterable, List, Literal, Tuple, Union
from typing import Any, Iterable, List, Tuple, Union
import numpy as np
import numpy.typing as npt
@@ -424,6 +424,16 @@ class FDTDGrid:
logger.info("")
logger.info(f"Materials [{self.name}]:\n{materialstable.table}\n")
def update_sources_and_recievers(self):
"""Update position of sources and receivers."""
# Adjust position of simple sources and receivers if required
model_num = config.sim_config.current_model
if any(self.srcsteps != 0):
self.update_simple_source_positions(model_num)
if any(self.rxsteps != 0):
self.update_receiver_positions(model_num)
def _update_positions(
self, items: Iterable[Union[Source, Rx]], step_size: npt.NDArray[np.int32], step_number: int
) -> None:
@@ -438,31 +448,22 @@ class FDTDGrid:
ValueError: Raised if any of the items would be stepped
outside of the grid.
"""
if step_size[0] != 0 or step_size[1] != 0 or step_size[2] != 0:
if any(step_size > 0):
for item in items:
if step_number == 0:
if (
item.xcoord + step_size[0] * config.sim_config.model_end < 0
or item.xcoord + step_size[0] * config.sim_config.model_end > self.nx
or item.ycoord + step_size[1] * config.sim_config.model_end < 0
or item.ycoord + step_size[1] * config.sim_config.model_end > self.ny
or item.zcoord + step_size[2] * config.sim_config.model_end < 0
or item.zcoord + step_size[2] * config.sim_config.model_end > self.nz
):
raise ValueError
# Check item won't be stepped outside of the grid
end_coord = item.coord + step_size * config.sim_config.model_end
self.within_bounds(end_coord)
else:
item.coord = item.coordorigin + step_number * step_size
def update_simple_source_positions(
self, step_size: npt.NDArray[np.int32], step: int = 0
) -> None:
def update_simple_source_positions(self, step: int = 0) -> None:
"""Update the positions of sources in the grid.
Move hertzian dipole and magnetic dipole sources. Transmission
line sources and voltage sources will not be moved.
Args:
step_size: Number of grid cells to move the sources each
step.
step: Number of steps to move the sources by.
Raises:
@@ -471,18 +472,16 @@ class FDTDGrid:
"""
try:
self._update_positions(
itertools.chain(self.hertziandipoles, self.magneticdipoles), step_size, step
itertools.chain(self.hertziandipoles, self.magneticdipoles), self.srcsteps, step
)
except ValueError as e:
logger.exception("Source(s) will be stepped to a position outside the domain.")
raise ValueError from e
def update_receiver_positions(self, step_size: npt.NDArray[np.int32], step: int = 0) -> None:
def update_receiver_positions(self, step: int = 0) -> None:
"""Update the positions of receivers in the grid.
Args:
step_size: Number of grid cells to move the receivers each
step.
step: Number of steps to move the receivers by.
Raises:
@@ -490,7 +489,7 @@ class FDTDGrid:
outside of the grid.
"""
try:
self._update_positions(self.rxs, step_size, step)
self._update_positions(self.rxs, self.rxsteps, step)
except ValueError as e:
logger.exception("Receiver(s) will be stepped to a position outside the domain.")
raise ValueError from e

查看文件

@@ -24,7 +24,7 @@ from typing import List, Optional, Tuple, TypeVar, Union
import numpy as np
import numpy.typing as npt
from mpi4py import MPI
from numpy import empty, ndarray
from numpy import ndarray
from gprMax import config
from gprMax.cython.pml_build import pml_sum_er_mr
@@ -688,6 +688,88 @@ class MPIGrid(FDTDGrid):
super().build()
def update_sources_and_recievers(self):
"""Update position of sources and receivers.
If any sources or receivers have stepped outside of the local
grid, they will be moved to the correct MPI rank.
"""
super().update_sources_and_recievers()
# Check it is possible for sources and receivers to have moved
model_num = config.sim_config.current_model
if (all(self.srcsteps == 0) and all(self.rxsteps == 0)) or model_num == 0:
return
# Get items that are outside the local bounds of the grid
items_to_send = list(
itertools.filterfalse(
lambda x: self.within_bounds(x.coord),
itertools.chain(
self.voltagesources,
self.hertziandipoles,
self.magneticdipoles,
self.transmissionlines,
self.discreteplanewaves,
self.rxs,
),
)
)
# Map items being sent to the global coordinate space
for item in items_to_send:
item.coord = self.local_to_global_coordinate(item.coord)
send_count_by_rank = np.zeros(self.comm.size, dtype=np.int32)
# Send items to correct rank
for rank, items in itertools.groupby(
items_to_send, lambda x: self.get_rank_from_coordinate(x.coord)
):
self.comm.isend(list(items), rank)
send_count_by_rank[rank] += 1
# Communicate the number of messages sent to each rank
if self.is_coordinator():
self.comm.Reduce(MPI.IN_PLACE, [send_count_by_rank, MPI.INT32_T], op=MPI.SUM)
else:
self.comm.Reduce([send_count_by_rank, MPI.INT32_T], None, op=MPI.SUM)
# Get number of messages this rank will receive
messages_to_receive = np.zeros(1, dtype=np.int32)
if self.is_coordinator():
self.comm.Scatter([send_count_by_rank, MPI.INT32_T], [messages_to_receive, MPI.INT32_T])
else:
self.comm.Scatter(None, [messages_to_receive, MPI.INT32_T])
# Receive new items for this rank
for _ in range(messages_to_receive[0]):
new_items = self.comm.recv(None, MPI.ANY_SOURCE)
for item in new_items:
item.coord = self.global_to_local_coordinate(item.coord)
if isinstance(item, Rx):
self.add_receiver(item)
else:
self.add_source(item)
# If this rank sent any items, remove them from our source and
# receiver lists
if len(items_to_send) > 0:
# Map items sent back to the local coordinate space
for item in items_to_send:
item.coord = self.global_to_local_coordinate(item.coord)
filter_items = lambda items: list(
filter(lambda item: self.within_bounds(item.coord), items)
)
self.voltagesources = filter_items(self.voltagesources)
self.hertziandipoles = filter_items(self.hertziandipoles)
self.magneticdipoles = filter_items(self.magneticdipoles)
self.transmissionlines = filter_items(self.transmissionlines)
self.discreteplanewaves = filter_items(self.discreteplanewaves)
self.rxs = filter_items(self.rxs)
def has_neighbour(self, dim: Dim, dir: Dir) -> bool:
"""Test if the current rank has a specified neighbour.

查看文件

@@ -333,8 +333,6 @@ class Model:
def build(self):
"""Builds the Yee cells for a model."""
G = self.G
# Monitor memory usage
self.p = psutil.Process()
@@ -348,10 +346,7 @@ class Model:
f"Output directory: {config.get_model_config().output_file_path.parent.resolve()}\n"
)
# Adjust position of simple sources and receivers if required
model_num = config.sim_config.current_model
G.update_simple_source_positions(self.srcsteps, step=model_num)
G.update_receiver_positions(self.rxsteps, step=model_num)
self.G.update_sources_and_recievers()
self._output_geometry()