Pass iteration to CUDA and OpenCL update methods

这个提交包含在:
nmannall
2024-05-31 17:18:35 +01:00
父节点 a68b7e0583
当前提交 d8005d35d8
共有 2 个文件被更改,包括 12 次插入12 次删除

查看文件

@@ -354,12 +354,12 @@ class CUDAUpdates(Updates):
self.drv.memcpy_htod(updatecoeffsE, self.grid.updatecoeffsE)
self.drv.memcpy_htod(updatecoeffsH, self.grid.updatecoeffsH)
def store_outputs(self):
def store_outputs(self, iteration):
"""Stores field component values for every receiver."""
if self.grid.rxs:
self.store_outputs_dev(
np.int32(len(self.grid.rxs)),
np.int32(self.grid.iteration),
np.int32(iteration),
self.rxcoords_dev.gpudata,
self.rxs_dev.gpudata,
self.grid.Ex_dev.gpudata,
@@ -442,12 +442,12 @@ class CUDAUpdates(Updates):
for pml in self.grid.pmls["slabs"]:
pml.update_magnetic()
def update_magnetic_sources(self):
def update_magnetic_sources(self, iteration):
"""Updates magnetic field components from sources."""
if self.grid.magneticdipoles:
self.update_magnetic_dipole_dev(
np.int32(len(self.grid.magneticdipoles)),
np.int32(self.grid.iteration),
np.int32(iteration),
config.sim_config.dtypes["float_or_double"](self.grid.dx),
config.sim_config.dtypes["float_or_double"](self.grid.dy),
config.sim_config.dtypes["float_or_double"](self.grid.dz),
@@ -509,14 +509,14 @@ class CUDAUpdates(Updates):
for pml in self.grid.pmls["slabs"]:
pml.update_electric()
def update_electric_sources(self):
def update_electric_sources(self, iteration):
"""Updates electric field components from sources -
update any Hertzian dipole sources last.
"""
if self.grid.voltagesources:
self.update_voltage_source_dev(
np.int32(len(self.grid.voltagesources)),
np.int32(self.grid.iteration),
np.int32(iteration),
config.sim_config.dtypes["float_or_double"](self.grid.dx),
config.sim_config.dtypes["float_or_double"](self.grid.dy),
config.sim_config.dtypes["float_or_double"](self.grid.dz),

查看文件

@@ -363,12 +363,12 @@ class OpenCLUpdates(Updates):
options=config.sim_config.devices["compiler_opts"],
)
def store_outputs(self):
def store_outputs(self, iteration):
"""Stores field component values for every receiver."""
if self.grid.rxs:
self.store_outputs_dev(
np.int32(len(self.grid.rxs)),
np.int32(self.grid.iteration),
np.int32(iteration),
self.rxcoords_dev,
self.rxs_dev,
self.grid.Ex_dev,
@@ -446,12 +446,12 @@ class OpenCLUpdates(Updates):
for pml in self.grid.pmls["slabs"]:
pml.update_magnetic()
def update_magnetic_sources(self):
def update_magnetic_sources(self, iteration):
"""Updates magnetic field components from sources."""
if self.grid.magneticdipoles:
self.update_magnetic_dipole_dev(
np.int32(len(self.grid.magneticdipoles)),
np.int32(self.grid.iteration),
np.int32(iteration),
config.sim_config.dtypes["float_or_double"](self.grid.dx),
config.sim_config.dtypes["float_or_double"](self.grid.dy),
config.sim_config.dtypes["float_or_double"](self.grid.dz),
@@ -507,14 +507,14 @@ class OpenCLUpdates(Updates):
for pml in self.grid.pmls["slabs"]:
pml.update_electric()
def update_electric_sources(self):
def update_electric_sources(self, iteration):
"""Updates electric field components from sources -
update any Hertzian dipole sources last.
"""
if self.grid.voltagesources:
self.update_voltage_source_dev(
np.int32(len(self.grid.voltagesources)),
np.int32(self.grid.iteration),
np.int32(iteration),
config.sim_config.dtypes["float_or_double"](self.grid.dx),
config.sim_config.dtypes["float_or_double"](self.grid.dy),
config.sim_config.dtypes["float_or_double"](self.grid.dz),