你已经派生过 gprMax
镜像自地址
https://gitee.com/sunhf/gprMax.git
已同步 2025-08-07 04:56:51 +08:00
Move CPUUpdates to seperate file
这个提交包含在:
@@ -20,7 +20,8 @@ import gprMax.config as config
|
||||
|
||||
from .grid import CUDAGrid, FDTDGrid, OpenCLGrid
|
||||
from .subgrids.updates import create_updates as create_subgrid_updates
|
||||
from .updates.updates import CPUUpdates, CUDAUpdates, OpenCLUpdates
|
||||
from .updates.cpu_updates import CPUUpdates
|
||||
from .updates.updates import CUDAUpdates, OpenCLUpdates
|
||||
|
||||
|
||||
def create_G():
|
||||
|
@@ -18,7 +18,7 @@
|
||||
|
||||
import logging
|
||||
|
||||
from ..updates.updates import CPUUpdates
|
||||
from ..updates.cpu_updates import CPUUpdates
|
||||
from .precursor_nodes import PrecursorNodes, PrecursorNodesFiltered
|
||||
from .subgrid_hsg import SubGridHSG
|
||||
|
||||
|
191
gprMax/updates/cpu_updates.py
普通文件
191
gprMax/updates/cpu_updates.py
普通文件
@@ -0,0 +1,191 @@
|
||||
from importlib import import_module
|
||||
|
||||
from gprMax import config
|
||||
|
||||
from ..cython.fields_updates_normal import update_electric as update_electric_cpu
|
||||
from ..cython.fields_updates_normal import update_magnetic as update_magnetic_cpu
|
||||
from ..fields_outputs import store_outputs as store_outputs_cpu
|
||||
from ..utilities.utilities import timer
|
||||
|
||||
|
||||
class CPUUpdates:
|
||||
"""Defines update functions for CPU-based solver."""
|
||||
|
||||
def __init__(self, G):
|
||||
"""
|
||||
Args:
|
||||
G: FDTDGrid class describing a grid in a model.
|
||||
"""
|
||||
|
||||
self.grid = G
|
||||
|
||||
def store_outputs(self):
|
||||
"""Stores field component values for every receiver and transmission line."""
|
||||
store_outputs_cpu(self.grid)
|
||||
|
||||
def store_snapshots(self, iteration):
|
||||
"""Stores any snapshots.
|
||||
|
||||
Args:
|
||||
iteration: int for iteration number.
|
||||
"""
|
||||
for snap in self.grid.snapshots:
|
||||
if snap.time == iteration + 1:
|
||||
snap.store(self.grid)
|
||||
|
||||
def update_magnetic(self):
|
||||
"""Updates magnetic field components."""
|
||||
update_magnetic_cpu(
|
||||
self.grid.nx,
|
||||
self.grid.ny,
|
||||
self.grid.nz,
|
||||
config.get_model_config().ompthreads,
|
||||
self.grid.updatecoeffsH,
|
||||
self.grid.ID,
|
||||
self.grid.Ex,
|
||||
self.grid.Ey,
|
||||
self.grid.Ez,
|
||||
self.grid.Hx,
|
||||
self.grid.Hy,
|
||||
self.grid.Hz,
|
||||
)
|
||||
|
||||
def update_magnetic_pml(self):
|
||||
"""Updates magnetic field components with the PML correction."""
|
||||
for pml in self.grid.pmls["slabs"]:
|
||||
pml.update_magnetic()
|
||||
|
||||
def update_magnetic_sources(self):
|
||||
"""Updates magnetic field components from sources."""
|
||||
for source in self.grid.transmissionlines + self.grid.magneticdipoles:
|
||||
source.update_magnetic(
|
||||
self.grid.iteration,
|
||||
self.grid.updatecoeffsH,
|
||||
self.grid.ID,
|
||||
self.grid.Hx,
|
||||
self.grid.Hy,
|
||||
self.grid.Hz,
|
||||
self.grid,
|
||||
)
|
||||
|
||||
def update_electric_a(self):
|
||||
"""Updates electric field components."""
|
||||
# All materials are non-dispersive so do standard update.
|
||||
if config.get_model_config().materials["maxpoles"] == 0:
|
||||
update_electric_cpu(
|
||||
self.grid.nx,
|
||||
self.grid.ny,
|
||||
self.grid.nz,
|
||||
config.get_model_config().ompthreads,
|
||||
self.grid.updatecoeffsE,
|
||||
self.grid.ID,
|
||||
self.grid.Ex,
|
||||
self.grid.Ey,
|
||||
self.grid.Ez,
|
||||
self.grid.Hx,
|
||||
self.grid.Hy,
|
||||
self.grid.Hz,
|
||||
)
|
||||
|
||||
# If there are any dispersive materials do 1st part of dispersive update
|
||||
# (it is split into two parts as it requires present and updated electric field values).
|
||||
else:
|
||||
self.dispersive_update_a(
|
||||
self.grid.nx,
|
||||
self.grid.ny,
|
||||
self.grid.nz,
|
||||
config.get_model_config().ompthreads,
|
||||
config.get_model_config().materials["maxpoles"],
|
||||
self.grid.updatecoeffsE,
|
||||
self.grid.updatecoeffsdispersive,
|
||||
self.grid.ID,
|
||||
self.grid.Tx,
|
||||
self.grid.Ty,
|
||||
self.grid.Tz,
|
||||
self.grid.Ex,
|
||||
self.grid.Ey,
|
||||
self.grid.Ez,
|
||||
self.grid.Hx,
|
||||
self.grid.Hy,
|
||||
self.grid.Hz,
|
||||
)
|
||||
|
||||
def update_electric_pml(self):
|
||||
"""Updates electric field components with the PML correction."""
|
||||
for pml in self.grid.pmls["slabs"]:
|
||||
pml.update_electric()
|
||||
|
||||
def update_electric_sources(self):
|
||||
"""Updates electric field components from sources -
|
||||
update any Hertzian dipole sources last.
|
||||
"""
|
||||
for source in self.grid.voltagesources + self.grid.transmissionlines + self.grid.hertziandipoles:
|
||||
source.update_electric(
|
||||
self.grid.iteration,
|
||||
self.grid.updatecoeffsE,
|
||||
self.grid.ID,
|
||||
self.grid.Ex,
|
||||
self.grid.Ey,
|
||||
self.grid.Ez,
|
||||
self.grid,
|
||||
)
|
||||
self.grid.iteration += 1
|
||||
|
||||
def update_electric_b(self):
|
||||
"""If there are any dispersive materials do 2nd part of dispersive
|
||||
update - it is split into two parts as it requires present and
|
||||
updated electric field values. Therefore it can only be completely
|
||||
updated after the electric field has been updated by the PML and
|
||||
source updates.
|
||||
"""
|
||||
if config.get_model_config().materials["maxpoles"] > 0:
|
||||
self.dispersive_update_b(
|
||||
self.grid.nx,
|
||||
self.grid.ny,
|
||||
self.grid.nz,
|
||||
config.get_model_config().ompthreads,
|
||||
config.get_model_config().materials["maxpoles"],
|
||||
self.grid.updatecoeffsdispersive,
|
||||
self.grid.ID,
|
||||
self.grid.Tx,
|
||||
self.grid.Ty,
|
||||
self.grid.Tz,
|
||||
self.grid.Ex,
|
||||
self.grid.Ey,
|
||||
self.grid.Ez,
|
||||
)
|
||||
|
||||
def set_dispersive_updates(self):
|
||||
"""Sets dispersive update functions."""
|
||||
|
||||
poles = "multi" if config.get_model_config().materials["maxpoles"] > 1 else "1"
|
||||
precision = "float" if config.sim_config.general["precision"] == "single" else "double"
|
||||
dispersion = (
|
||||
"complex"
|
||||
if config.get_model_config().materials["dispersivedtype"] == config.sim_config.dtypes["complex"]
|
||||
else "real"
|
||||
)
|
||||
|
||||
update_f = "update_electric_dispersive_{}pole_{}_{}_{}"
|
||||
disp_a = update_f.format(poles, "A", precision, dispersion)
|
||||
disp_b = update_f.format(poles, "B", precision, dispersion)
|
||||
|
||||
disp_a_f = getattr(import_module("gprMax.cython.fields_updates_dispersive"), disp_a)
|
||||
disp_b_f = getattr(import_module("gprMax.cython.fields_updates_dispersive"), disp_b)
|
||||
|
||||
self.dispersive_update_a = disp_a_f
|
||||
self.dispersive_update_b = disp_b_f
|
||||
|
||||
def time_start(self):
|
||||
"""Starts timer used to calculate solving time for model."""
|
||||
self.timestart = timer()
|
||||
|
||||
def calculate_solve_time(self):
|
||||
"""Calculates solving time for model."""
|
||||
return timer() - self.timestart
|
||||
|
||||
def finalise(self):
|
||||
pass
|
||||
|
||||
def cleanup(self):
|
||||
pass
|
@@ -26,200 +26,14 @@ from jinja2 import Environment, PackageLoader
|
||||
import gprMax.config as config
|
||||
|
||||
from ..cuda_opencl import knl_fields_updates, knl_snapshots, knl_source_updates, knl_store_outputs
|
||||
from ..cython.fields_updates_normal import update_electric as update_electric_cpu
|
||||
from ..cython.fields_updates_normal import update_magnetic as update_magnetic_cpu
|
||||
from ..fields_outputs import store_outputs as store_outputs_cpu
|
||||
from ..receivers import dtoh_rx_array, htod_rx_arrays
|
||||
from ..snapshots import Snapshot, dtoh_snapshot_array, htod_snapshot_array
|
||||
from ..sources import htod_src_arrays
|
||||
from ..utilities.utilities import round32, timer
|
||||
from ..utilities.utilities import round32
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class CPUUpdates:
|
||||
"""Defines update functions for CPU-based solver."""
|
||||
|
||||
def __init__(self, G):
|
||||
"""
|
||||
Args:
|
||||
G: FDTDGrid class describing a grid in a model.
|
||||
"""
|
||||
|
||||
self.grid = G
|
||||
|
||||
def store_outputs(self):
|
||||
"""Stores field component values for every receiver and transmission line."""
|
||||
store_outputs_cpu(self.grid)
|
||||
|
||||
def store_snapshots(self, iteration):
|
||||
"""Stores any snapshots.
|
||||
|
||||
Args:
|
||||
iteration: int for iteration number.
|
||||
"""
|
||||
for snap in self.grid.snapshots:
|
||||
if snap.time == iteration + 1:
|
||||
snap.store(self.grid)
|
||||
|
||||
def update_magnetic(self):
|
||||
"""Updates magnetic field components."""
|
||||
update_magnetic_cpu(
|
||||
self.grid.nx,
|
||||
self.grid.ny,
|
||||
self.grid.nz,
|
||||
config.get_model_config().ompthreads,
|
||||
self.grid.updatecoeffsH,
|
||||
self.grid.ID,
|
||||
self.grid.Ex,
|
||||
self.grid.Ey,
|
||||
self.grid.Ez,
|
||||
self.grid.Hx,
|
||||
self.grid.Hy,
|
||||
self.grid.Hz,
|
||||
)
|
||||
|
||||
def update_magnetic_pml(self):
|
||||
"""Updates magnetic field components with the PML correction."""
|
||||
for pml in self.grid.pmls["slabs"]:
|
||||
pml.update_magnetic()
|
||||
|
||||
def update_magnetic_sources(self):
|
||||
"""Updates magnetic field components from sources."""
|
||||
for source in self.grid.transmissionlines + self.grid.magneticdipoles:
|
||||
source.update_magnetic(
|
||||
self.grid.iteration,
|
||||
self.grid.updatecoeffsH,
|
||||
self.grid.ID,
|
||||
self.grid.Hx,
|
||||
self.grid.Hy,
|
||||
self.grid.Hz,
|
||||
self.grid,
|
||||
)
|
||||
|
||||
def update_electric_a(self):
|
||||
"""Updates electric field components."""
|
||||
# All materials are non-dispersive so do standard update.
|
||||
if config.get_model_config().materials["maxpoles"] == 0:
|
||||
update_electric_cpu(
|
||||
self.grid.nx,
|
||||
self.grid.ny,
|
||||
self.grid.nz,
|
||||
config.get_model_config().ompthreads,
|
||||
self.grid.updatecoeffsE,
|
||||
self.grid.ID,
|
||||
self.grid.Ex,
|
||||
self.grid.Ey,
|
||||
self.grid.Ez,
|
||||
self.grid.Hx,
|
||||
self.grid.Hy,
|
||||
self.grid.Hz,
|
||||
)
|
||||
|
||||
# If there are any dispersive materials do 1st part of dispersive update
|
||||
# (it is split into two parts as it requires present and updated electric field values).
|
||||
else:
|
||||
self.dispersive_update_a(
|
||||
self.grid.nx,
|
||||
self.grid.ny,
|
||||
self.grid.nz,
|
||||
config.get_model_config().ompthreads,
|
||||
config.get_model_config().materials["maxpoles"],
|
||||
self.grid.updatecoeffsE,
|
||||
self.grid.updatecoeffsdispersive,
|
||||
self.grid.ID,
|
||||
self.grid.Tx,
|
||||
self.grid.Ty,
|
||||
self.grid.Tz,
|
||||
self.grid.Ex,
|
||||
self.grid.Ey,
|
||||
self.grid.Ez,
|
||||
self.grid.Hx,
|
||||
self.grid.Hy,
|
||||
self.grid.Hz,
|
||||
)
|
||||
|
||||
def update_electric_pml(self):
|
||||
"""Updates electric field components with the PML correction."""
|
||||
for pml in self.grid.pmls["slabs"]:
|
||||
pml.update_electric()
|
||||
|
||||
def update_electric_sources(self):
|
||||
"""Updates electric field components from sources -
|
||||
update any Hertzian dipole sources last.
|
||||
"""
|
||||
for source in self.grid.voltagesources + self.grid.transmissionlines + self.grid.hertziandipoles:
|
||||
source.update_electric(
|
||||
self.grid.iteration,
|
||||
self.grid.updatecoeffsE,
|
||||
self.grid.ID,
|
||||
self.grid.Ex,
|
||||
self.grid.Ey,
|
||||
self.grid.Ez,
|
||||
self.grid,
|
||||
)
|
||||
self.grid.iteration += 1
|
||||
|
||||
def update_electric_b(self):
|
||||
"""If there are any dispersive materials do 2nd part of dispersive
|
||||
update - it is split into two parts as it requires present and
|
||||
updated electric field values. Therefore it can only be completely
|
||||
updated after the electric field has been updated by the PML and
|
||||
source updates.
|
||||
"""
|
||||
if config.get_model_config().materials["maxpoles"] > 0:
|
||||
self.dispersive_update_b(
|
||||
self.grid.nx,
|
||||
self.grid.ny,
|
||||
self.grid.nz,
|
||||
config.get_model_config().ompthreads,
|
||||
config.get_model_config().materials["maxpoles"],
|
||||
self.grid.updatecoeffsdispersive,
|
||||
self.grid.ID,
|
||||
self.grid.Tx,
|
||||
self.grid.Ty,
|
||||
self.grid.Tz,
|
||||
self.grid.Ex,
|
||||
self.grid.Ey,
|
||||
self.grid.Ez,
|
||||
)
|
||||
|
||||
def set_dispersive_updates(self):
|
||||
"""Sets dispersive update functions."""
|
||||
|
||||
poles = "multi" if config.get_model_config().materials["maxpoles"] > 1 else "1"
|
||||
precision = "float" if config.sim_config.general["precision"] == "single" else "double"
|
||||
dispersion = (
|
||||
"complex"
|
||||
if config.get_model_config().materials["dispersivedtype"] == config.sim_config.dtypes["complex"]
|
||||
else "real"
|
||||
)
|
||||
|
||||
update_f = "update_electric_dispersive_{}pole_{}_{}_{}"
|
||||
disp_a = update_f.format(poles, "A", precision, dispersion)
|
||||
disp_b = update_f.format(poles, "B", precision, dispersion)
|
||||
|
||||
disp_a_f = getattr(import_module("gprMax.cython.fields_updates_dispersive"), disp_a)
|
||||
disp_b_f = getattr(import_module("gprMax.cython.fields_updates_dispersive"), disp_b)
|
||||
|
||||
self.dispersive_update_a = disp_a_f
|
||||
self.dispersive_update_b = disp_b_f
|
||||
|
||||
def time_start(self):
|
||||
"""Starts timer used to calculate solving time for model."""
|
||||
self.timestart = timer()
|
||||
|
||||
def calculate_solve_time(self):
|
||||
"""Calculates solving time for model."""
|
||||
return timer() - self.timestart
|
||||
|
||||
def finalise(self):
|
||||
pass
|
||||
|
||||
def cleanup(self):
|
||||
pass
|
||||
|
||||
|
||||
class CUDAUpdates:
|
||||
"""Defines update functions for GPU-based (CUDA) solver."""
|
||||
|
||||
|
@@ -8,7 +8,7 @@ from gprMax.grid import FDTDGrid
|
||||
from gprMax.materials import create_built_in_materials
|
||||
from gprMax.model_build_run import GridBuilder
|
||||
from gprMax.pml import CFS
|
||||
from gprMax.updates.updates import CPUUpdates
|
||||
from gprMax.updates.cpu_updates import CPUUpdates
|
||||
|
||||
|
||||
def build_grid(nx, ny, nz, dl=0.001, dt=3e-9):
|
||||
@@ -50,7 +50,7 @@ def config_mock(monkeypatch):
|
||||
monkeypatch.setattr(config, "get_model_config", _mock_model_config)
|
||||
|
||||
|
||||
def test_update_magnetic_cpu(config_mock):
|
||||
def test_update_magnetic(config_mock):
|
||||
grid = build_grid(100, 100, 100)
|
||||
|
||||
expected_value = np.zeros((101, 101, 101))
|
||||
@@ -72,7 +72,7 @@ def test_update_magnetic_cpu(config_mock):
|
||||
assert np.equal(pml.EPhi2, 0).all()
|
||||
|
||||
|
||||
def test_update_magnetic_pml_cpu(config_mock):
|
||||
def test_update_magnetic_pml(config_mock):
|
||||
grid = build_grid(100, 100, 100)
|
||||
|
||||
grid_expected_value = np.zeros((101, 101, 101))
|
||||
@@ -94,7 +94,7 @@ def test_update_magnetic_pml_cpu(config_mock):
|
||||
assert np.equal(pml.EPhi2, 0).all()
|
||||
|
||||
|
||||
def test_update_electric_pml_cpu(config_mock):
|
||||
def test_update_electric_pml(config_mock):
|
||||
grid = build_grid(100, 100, 100)
|
||||
|
||||
grid_expected_value = np.zeros((101, 101, 101))
|
在新工单中引用
屏蔽一个用户