Add has_neighbour helper method

这个提交包含在:
nmannall
2024-08-05 15:53:26 +01:00
父节点 33670b3bd1
当前提交 d4276e44f9

查看文件

@@ -375,6 +375,9 @@ class MPISnapshot(Snapshot):
self.size = np.ceil((self.stop - self.start) / self.step).astype(np.intc)
return super().initialise_snapfields()
def has_neighbour(self, dimension: Dim, direction: Dir):
return self.neighbours[dimension][direction] != -1
def store(self, G):
"""Store (in memory) electric and magnetic field values for snapshot.
@@ -439,12 +442,12 @@ class MPISnapshot(Snapshot):
blocking_requests: List[MPI.Request] = []
requests: List[MPI.Request] = []
if self.neighbours[Dim.X][Dir.NEG] != -1:
if self.has_neighbour(Dim.X, Dir.NEG):
requests += [
self.comm.Isend(Hxslice[0, :, :], self.neighbours[Dim.X][Dir.NEG], self.H_TAG),
self.comm.Isend(Ezslice[0, :, :], self.neighbours[Dim.X][Dir.NEG], self.EZ_TAG),
]
if self.neighbours[Dim.X][Dir.POS] != -1:
if self.has_neighbour(Dim.X, Dir.POS):
blocking_requests.append(
self.comm.Irecv(Ezxhalo, self.neighbours[Dim.X][Dir.POS], self.EZ_TAG),
)
@@ -452,7 +455,7 @@ class MPISnapshot(Snapshot):
self.comm.Irecv(Hxhalo, self.neighbours[Dim.X][Dir.POS], self.H_TAG),
self.comm.Irecv(Eyxhalo, self.neighbours[Dim.X][Dir.POS], self.EY_TAG),
]
if self.neighbours[Dim.Y][Dir.NEG] != -1:
if self.has_neighbour(Dim.Y, Dir.NEG):
requests += [
self.comm.Isend(
np.ascontiguousarray(Hyslice[:, 0, :]),
@@ -465,7 +468,7 @@ class MPISnapshot(Snapshot):
self.EX_TAG,
),
]
if self.neighbours[Dim.Y][Dir.POS] != -1:
if self.has_neighbour(Dim.Y, Dir.POS):
blocking_requests.append(
self.comm.Irecv(Exyhalo, self.neighbours[Dim.Y][Dir.POS], self.EX_TAG),
)
@@ -473,7 +476,7 @@ class MPISnapshot(Snapshot):
self.comm.Irecv(Hyhalo, self.neighbours[Dim.Y][Dir.POS], self.H_TAG),
self.comm.Irecv(Ezyhalo, self.neighbours[Dim.Y][Dir.POS], self.EZ_TAG),
]
if self.neighbours[Dim.Z][Dir.NEG] != -1:
if self.has_neighbour(Dim.Z, Dir.NEG):
requests += [
self.comm.Isend(
np.ascontiguousarray(Hzslice[:, :, 0]),
@@ -486,7 +489,7 @@ class MPISnapshot(Snapshot):
self.EY_TAG,
),
]
if self.neighbours[Dim.Z][Dir.POS] != -1:
if self.has_neighbour(Dim.Z, Dir.POS):
blocking_requests.append(
self.comm.Irecv(Eyzhalo, self.neighbours[Dim.Z][Dir.POS], self.EY_TAG),
)
@@ -500,18 +503,18 @@ class MPISnapshot(Snapshot):
logger.debug(f"Initial halo exchanges complete")
if self.neighbours[Dim.X][Dir.POS] != -1:
if self.has_neighbour(Dim.X, Dir.POS):
Ezslice = np.concatenate((Ezslice, Ezxhalo), axis=Dim.X)
if self.neighbours[Dim.Y][Dir.POS] != -1:
if self.has_neighbour(Dim.Y, Dir.POS):
Exslice = np.concatenate((Exslice, Exyhalo), axis=Dim.Y)
if self.neighbours[Dim.Z][Dir.POS] != -1:
if self.has_neighbour(Dim.Z, Dir.POS):
Eyslice = np.concatenate((Eyslice, Eyzhalo), axis=Dim.Z)
if self.neighbours[Dim.X][Dir.NEG] != -1:
if self.has_neighbour(Dim.X, Dir.NEG):
requests.append(
self.comm.Isend(Eyslice[0, :, :], self.neighbours[Dim.X][Dir.NEG], self.EY_TAG),
)
if self.neighbours[Dim.Y][Dir.NEG] != -1:
if self.has_neighbour(Dim.Y, Dir.NEG):
requests.append(
self.comm.Isend(
np.ascontiguousarray(Ezslice[:, 0, :]),
@@ -519,7 +522,7 @@ class MPISnapshot(Snapshot):
self.EZ_TAG,
),
)
if self.neighbours[Dim.Z][Dir.NEG] != -1:
if self.has_neighbour(Dim.Z, Dir.NEG):
requests.append(
self.comm.Isend(
np.ascontiguousarray(Exslice[:, :, 0]),
@@ -533,13 +536,13 @@ class MPISnapshot(Snapshot):
logger.debug(f"All halo exchanges complete")
if self.neighbours[Dim.X][Dir.POS] != -1:
if self.has_neighbour(Dim.X, Dir.POS):
Eyslice = np.concatenate((Eyslice, Eyxhalo), axis=Dim.X)
Hxslice = np.concatenate((Hxslice, Hxhalo), axis=Dim.X)
if self.neighbours[Dim.Y][Dir.POS] != -1:
if self.has_neighbour(Dim.Y, Dir.POS):
Ezslice = np.concatenate((Ezslice, Ezyhalo), axis=Dim.Y)
Hyslice = np.concatenate((Hyslice, Hyhalo), axis=Dim.Y)
if self.neighbours[Dim.Z][Dir.POS] != -1:
if self.has_neighbour(Dim.Z, Dir.POS):
Exslice = np.concatenate((Exslice, Exzhalo), axis=Dim.Z)
Hzslice = np.concatenate((Hzslice, Hzhalo), axis=Dim.Z)