Fix bug creating MPI Cart comm ending at a grid boundary

这个提交包含在:
Nathan Mannall
2025-03-06 10:53:30 +00:00
父节点 5669002757
当前提交 4cd9f01f71
共有 2 个文件被更改,包括 18 次插入12 次删除

查看文件

@@ -139,13 +139,16 @@ class MPIGrid(FDTDGrid):
return self.rank == self.COORDINATOR_RANK
def create_sub_communicator(
self, start: npt.NDArray[np.int32], stop: npt.NDArray[np.int32]
self, local_start: npt.NDArray[np.int32], local_stop: npt.NDArray[np.int32]
) -> Optional[MPI.Cartcomm]:
if self.local_bounds_overlap_grid(start, stop):
if self.local_bounds_overlap_grid(local_start, local_stop):
comm = self.comm.Split()
assert isinstance(comm, MPI.Intracomm)
start_grid_coord = self.get_grid_coord_from_local_coordinate(start)
stop_grid_coord = self.get_grid_coord_from_local_coordinate(stop) + 1
start_grid_coord = self.get_grid_coord_from_local_coordinate(local_start)
# Subtract 1 from local_stop as the upper extent is
# exclusive meaning the last coordinate included in the sub
# communicator is actually (local_stop - 1).
stop_grid_coord = self.get_grid_coord_from_local_coordinate(local_stop - 1) + 1
comm = comm.Create_cart((stop_grid_coord - start_grid_coord).tolist())
return comm
else:

查看文件

@@ -320,14 +320,20 @@ class MPIGridView(GridView[MPIGrid]):
comm = grid.comm.Split()
assert isinstance(comm, MPI.Intracomm)
start_grid_coord = grid.get_grid_coord_from_local_coordinate(self.start)
stop_grid_coord = grid.get_grid_coord_from_local_coordinate(self.stop) + 1
self.comm = comm.Create_cart((stop_grid_coord - start_grid_coord).tolist())
# Calculate start, stop and size for the global grid view
self.global_start = self.grid.local_to_global_coordinate(self.start)
self.global_stop = self.grid.local_to_global_coordinate(self.stop)
self.global_size = self.size
# Calculate start for the local grid
self.global_start = self.grid.local_to_global_coordinate(self.start)
# Create new cartesean communicator by finding MPI grid coords
# for the start and end of the grid view.
# Subtract 1 from global_stop as the upper extent is exclusive
# meaning the last coordinate included in the grid view is
# actually (global_stop - 1).
start_grid_coord = grid.get_grid_coord_from_coordinate(self.global_start)
stop_grid_coord = grid.get_grid_coord_from_coordinate(self.global_stop - 1) + 1
self.comm = comm.Create_cart((stop_grid_coord - start_grid_coord).tolist())
self.has_negative_neighbour = self.start < self.grid.negative_halo_offset
@@ -340,9 +346,6 @@ class MPIGridView(GridView[MPIGrid]):
self.start,
)
# Calculate stop for the local grid
self.global_stop = self.grid.local_to_global_coordinate(self.stop)
self.has_positive_neighbour = self.stop > self.grid.size
# Limit stop such that it is at most one step beyond the max