Move CPUUpdates to seperate file

这个提交包含在:
nmannall
2024-02-09 12:54:11 +00:00
父节点 39c7253f3c
当前提交 0ff41843a8
共有 5 个文件被更改,包括 199 次插入193 次删除

查看文件

@@ -20,7 +20,8 @@ import gprMax.config as config
from .grid import CUDAGrid, FDTDGrid, OpenCLGrid
from .subgrids.updates import create_updates as create_subgrid_updates
from .updates.updates import CPUUpdates, CUDAUpdates, OpenCLUpdates
from .updates.cpu_updates import CPUUpdates
from .updates.updates import CUDAUpdates, OpenCLUpdates
def create_G():

查看文件

@@ -18,7 +18,7 @@
import logging
from ..updates.updates import CPUUpdates
from ..updates.cpu_updates import CPUUpdates
from .precursor_nodes import PrecursorNodes, PrecursorNodesFiltered
from .subgrid_hsg import SubGridHSG

查看文件

@@ -0,0 +1,191 @@
from importlib import import_module
from gprMax import config
from ..cython.fields_updates_normal import update_electric as update_electric_cpu
from ..cython.fields_updates_normal import update_magnetic as update_magnetic_cpu
from ..fields_outputs import store_outputs as store_outputs_cpu
from ..utilities.utilities import timer
class CPUUpdates:
"""Defines update functions for CPU-based solver."""
def __init__(self, G):
"""
Args:
G: FDTDGrid class describing a grid in a model.
"""
self.grid = G
def store_outputs(self):
"""Stores field component values for every receiver and transmission line."""
store_outputs_cpu(self.grid)
def store_snapshots(self, iteration):
"""Stores any snapshots.
Args:
iteration: int for iteration number.
"""
for snap in self.grid.snapshots:
if snap.time == iteration + 1:
snap.store(self.grid)
def update_magnetic(self):
"""Updates magnetic field components."""
update_magnetic_cpu(
self.grid.nx,
self.grid.ny,
self.grid.nz,
config.get_model_config().ompthreads,
self.grid.updatecoeffsH,
self.grid.ID,
self.grid.Ex,
self.grid.Ey,
self.grid.Ez,
self.grid.Hx,
self.grid.Hy,
self.grid.Hz,
)
def update_magnetic_pml(self):
"""Updates magnetic field components with the PML correction."""
for pml in self.grid.pmls["slabs"]:
pml.update_magnetic()
def update_magnetic_sources(self):
"""Updates magnetic field components from sources."""
for source in self.grid.transmissionlines + self.grid.magneticdipoles:
source.update_magnetic(
self.grid.iteration,
self.grid.updatecoeffsH,
self.grid.ID,
self.grid.Hx,
self.grid.Hy,
self.grid.Hz,
self.grid,
)
def update_electric_a(self):
"""Updates electric field components."""
# All materials are non-dispersive so do standard update.
if config.get_model_config().materials["maxpoles"] == 0:
update_electric_cpu(
self.grid.nx,
self.grid.ny,
self.grid.nz,
config.get_model_config().ompthreads,
self.grid.updatecoeffsE,
self.grid.ID,
self.grid.Ex,
self.grid.Ey,
self.grid.Ez,
self.grid.Hx,
self.grid.Hy,
self.grid.Hz,
)
# If there are any dispersive materials do 1st part of dispersive update
# (it is split into two parts as it requires present and updated electric field values).
else:
self.dispersive_update_a(
self.grid.nx,
self.grid.ny,
self.grid.nz,
config.get_model_config().ompthreads,
config.get_model_config().materials["maxpoles"],
self.grid.updatecoeffsE,
self.grid.updatecoeffsdispersive,
self.grid.ID,
self.grid.Tx,
self.grid.Ty,
self.grid.Tz,
self.grid.Ex,
self.grid.Ey,
self.grid.Ez,
self.grid.Hx,
self.grid.Hy,
self.grid.Hz,
)
def update_electric_pml(self):
"""Updates electric field components with the PML correction."""
for pml in self.grid.pmls["slabs"]:
pml.update_electric()
def update_electric_sources(self):
"""Updates electric field components from sources -
update any Hertzian dipole sources last.
"""
for source in self.grid.voltagesources + self.grid.transmissionlines + self.grid.hertziandipoles:
source.update_electric(
self.grid.iteration,
self.grid.updatecoeffsE,
self.grid.ID,
self.grid.Ex,
self.grid.Ey,
self.grid.Ez,
self.grid,
)
self.grid.iteration += 1
def update_electric_b(self):
"""If there are any dispersive materials do 2nd part of dispersive
update - it is split into two parts as it requires present and
updated electric field values. Therefore it can only be completely
updated after the electric field has been updated by the PML and
source updates.
"""
if config.get_model_config().materials["maxpoles"] > 0:
self.dispersive_update_b(
self.grid.nx,
self.grid.ny,
self.grid.nz,
config.get_model_config().ompthreads,
config.get_model_config().materials["maxpoles"],
self.grid.updatecoeffsdispersive,
self.grid.ID,
self.grid.Tx,
self.grid.Ty,
self.grid.Tz,
self.grid.Ex,
self.grid.Ey,
self.grid.Ez,
)
def set_dispersive_updates(self):
"""Sets dispersive update functions."""
poles = "multi" if config.get_model_config().materials["maxpoles"] > 1 else "1"
precision = "float" if config.sim_config.general["precision"] == "single" else "double"
dispersion = (
"complex"
if config.get_model_config().materials["dispersivedtype"] == config.sim_config.dtypes["complex"]
else "real"
)
update_f = "update_electric_dispersive_{}pole_{}_{}_{}"
disp_a = update_f.format(poles, "A", precision, dispersion)
disp_b = update_f.format(poles, "B", precision, dispersion)
disp_a_f = getattr(import_module("gprMax.cython.fields_updates_dispersive"), disp_a)
disp_b_f = getattr(import_module("gprMax.cython.fields_updates_dispersive"), disp_b)
self.dispersive_update_a = disp_a_f
self.dispersive_update_b = disp_b_f
def time_start(self):
"""Starts timer used to calculate solving time for model."""
self.timestart = timer()
def calculate_solve_time(self):
"""Calculates solving time for model."""
return timer() - self.timestart
def finalise(self):
pass
def cleanup(self):
pass

查看文件

@@ -26,200 +26,14 @@ from jinja2 import Environment, PackageLoader
import gprMax.config as config
from ..cuda_opencl import knl_fields_updates, knl_snapshots, knl_source_updates, knl_store_outputs
from ..cython.fields_updates_normal import update_electric as update_electric_cpu
from ..cython.fields_updates_normal import update_magnetic as update_magnetic_cpu
from ..fields_outputs import store_outputs as store_outputs_cpu
from ..receivers import dtoh_rx_array, htod_rx_arrays
from ..snapshots import Snapshot, dtoh_snapshot_array, htod_snapshot_array
from ..sources import htod_src_arrays
from ..utilities.utilities import round32, timer
from ..utilities.utilities import round32
logger = logging.getLogger(__name__)
class CPUUpdates:
"""Defines update functions for CPU-based solver."""
def __init__(self, G):
"""
Args:
G: FDTDGrid class describing a grid in a model.
"""
self.grid = G
def store_outputs(self):
"""Stores field component values for every receiver and transmission line."""
store_outputs_cpu(self.grid)
def store_snapshots(self, iteration):
"""Stores any snapshots.
Args:
iteration: int for iteration number.
"""
for snap in self.grid.snapshots:
if snap.time == iteration + 1:
snap.store(self.grid)
def update_magnetic(self):
"""Updates magnetic field components."""
update_magnetic_cpu(
self.grid.nx,
self.grid.ny,
self.grid.nz,
config.get_model_config().ompthreads,
self.grid.updatecoeffsH,
self.grid.ID,
self.grid.Ex,
self.grid.Ey,
self.grid.Ez,
self.grid.Hx,
self.grid.Hy,
self.grid.Hz,
)
def update_magnetic_pml(self):
"""Updates magnetic field components with the PML correction."""
for pml in self.grid.pmls["slabs"]:
pml.update_magnetic()
def update_magnetic_sources(self):
"""Updates magnetic field components from sources."""
for source in self.grid.transmissionlines + self.grid.magneticdipoles:
source.update_magnetic(
self.grid.iteration,
self.grid.updatecoeffsH,
self.grid.ID,
self.grid.Hx,
self.grid.Hy,
self.grid.Hz,
self.grid,
)
def update_electric_a(self):
"""Updates electric field components."""
# All materials are non-dispersive so do standard update.
if config.get_model_config().materials["maxpoles"] == 0:
update_electric_cpu(
self.grid.nx,
self.grid.ny,
self.grid.nz,
config.get_model_config().ompthreads,
self.grid.updatecoeffsE,
self.grid.ID,
self.grid.Ex,
self.grid.Ey,
self.grid.Ez,
self.grid.Hx,
self.grid.Hy,
self.grid.Hz,
)
# If there are any dispersive materials do 1st part of dispersive update
# (it is split into two parts as it requires present and updated electric field values).
else:
self.dispersive_update_a(
self.grid.nx,
self.grid.ny,
self.grid.nz,
config.get_model_config().ompthreads,
config.get_model_config().materials["maxpoles"],
self.grid.updatecoeffsE,
self.grid.updatecoeffsdispersive,
self.grid.ID,
self.grid.Tx,
self.grid.Ty,
self.grid.Tz,
self.grid.Ex,
self.grid.Ey,
self.grid.Ez,
self.grid.Hx,
self.grid.Hy,
self.grid.Hz,
)
def update_electric_pml(self):
"""Updates electric field components with the PML correction."""
for pml in self.grid.pmls["slabs"]:
pml.update_electric()
def update_electric_sources(self):
"""Updates electric field components from sources -
update any Hertzian dipole sources last.
"""
for source in self.grid.voltagesources + self.grid.transmissionlines + self.grid.hertziandipoles:
source.update_electric(
self.grid.iteration,
self.grid.updatecoeffsE,
self.grid.ID,
self.grid.Ex,
self.grid.Ey,
self.grid.Ez,
self.grid,
)
self.grid.iteration += 1
def update_electric_b(self):
"""If there are any dispersive materials do 2nd part of dispersive
update - it is split into two parts as it requires present and
updated electric field values. Therefore it can only be completely
updated after the electric field has been updated by the PML and
source updates.
"""
if config.get_model_config().materials["maxpoles"] > 0:
self.dispersive_update_b(
self.grid.nx,
self.grid.ny,
self.grid.nz,
config.get_model_config().ompthreads,
config.get_model_config().materials["maxpoles"],
self.grid.updatecoeffsdispersive,
self.grid.ID,
self.grid.Tx,
self.grid.Ty,
self.grid.Tz,
self.grid.Ex,
self.grid.Ey,
self.grid.Ez,
)
def set_dispersive_updates(self):
"""Sets dispersive update functions."""
poles = "multi" if config.get_model_config().materials["maxpoles"] > 1 else "1"
precision = "float" if config.sim_config.general["precision"] == "single" else "double"
dispersion = (
"complex"
if config.get_model_config().materials["dispersivedtype"] == config.sim_config.dtypes["complex"]
else "real"
)
update_f = "update_electric_dispersive_{}pole_{}_{}_{}"
disp_a = update_f.format(poles, "A", precision, dispersion)
disp_b = update_f.format(poles, "B", precision, dispersion)
disp_a_f = getattr(import_module("gprMax.cython.fields_updates_dispersive"), disp_a)
disp_b_f = getattr(import_module("gprMax.cython.fields_updates_dispersive"), disp_b)
self.dispersive_update_a = disp_a_f
self.dispersive_update_b = disp_b_f
def time_start(self):
"""Starts timer used to calculate solving time for model."""
self.timestart = timer()
def calculate_solve_time(self):
"""Calculates solving time for model."""
return timer() - self.timestart
def finalise(self):
pass
def cleanup(self):
pass
class CUDAUpdates:
"""Defines update functions for GPU-based (CUDA) solver."""

查看文件

@@ -8,7 +8,7 @@ from gprMax.grid import FDTDGrid
from gprMax.materials import create_built_in_materials
from gprMax.model_build_run import GridBuilder
from gprMax.pml import CFS
from gprMax.updates.updates import CPUUpdates
from gprMax.updates.cpu_updates import CPUUpdates
def build_grid(nx, ny, nz, dl=0.001, dt=3e-9):
@@ -50,7 +50,7 @@ def config_mock(monkeypatch):
monkeypatch.setattr(config, "get_model_config", _mock_model_config)
def test_update_magnetic_cpu(config_mock):
def test_update_magnetic(config_mock):
grid = build_grid(100, 100, 100)
expected_value = np.zeros((101, 101, 101))
@@ -72,7 +72,7 @@ def test_update_magnetic_cpu(config_mock):
assert np.equal(pml.EPhi2, 0).all()
def test_update_magnetic_pml_cpu(config_mock):
def test_update_magnetic_pml(config_mock):
grid = build_grid(100, 100, 100)
grid_expected_value = np.zeros((101, 101, 101))
@@ -94,7 +94,7 @@ def test_update_magnetic_pml_cpu(config_mock):
assert np.equal(pml.EPhi2, 0).all()
def test_update_electric_pml_cpu(config_mock):
def test_update_electric_pml(config_mock):
grid = build_grid(100, 100, 100)
grid_expected_value = np.zeros((101, 101, 101))