你已经派生过 gprMax
镜像自地址
https://gitee.com/sunhf/gprMax.git
已同步 2025-08-07 04:56:51 +08:00
Refactor existing model tests to use pytest
这个提交包含在:
@@ -19,4 +19,9 @@ numpy-stl
|
||||
terminaltables
|
||||
tqdm
|
||||
wheel
|
||||
pytest
|
||||
pytest-benchmark
|
||||
pytest-benchmark[histogram]
|
||||
pytest-mpi
|
||||
pytest-regressions
|
||||
git+https://github.com/craig-warren/PyEVTK.git
|
||||
|
@@ -0,0 +1,8 @@
|
||||
#title: Hertzian dipole in free-space
|
||||
#domain: 0.100 0.100 0.100
|
||||
#dx_dy_dz: 0.001 0.001 0.001
|
||||
#time_window: 3e-9
|
||||
|
||||
#waveform: gaussianprime 1 1e9 myWave
|
||||
#hertzian_dipole: z 0.050 0.050 0.050 myWave
|
||||
#rx: 0.070 0.070 0.070
|
@@ -0,0 +1,8 @@
|
||||
#title: 2D test Ex, Hy, Hz components
|
||||
#domain: 0.001 0.100 0.100
|
||||
#dx_dy_dz: 0.001 0.001 0.001
|
||||
#time_window: 3e-9
|
||||
|
||||
#waveform: gaussiandot 1 1e9 myWave
|
||||
#hertzian_dipole: x 0 0.050 0.050 myWave
|
||||
#rx: 0 0.070 0.070
|
二进制文件未显示。
@@ -0,0 +1,8 @@
|
||||
#title: 2D test Ey, Hx, Hz components
|
||||
#domain: 0.100 0.001 0.100
|
||||
#dx_dy_dz: 0.001 0.001 0.001
|
||||
#time_window: 3e-9
|
||||
|
||||
#waveform: gaussiandot 1 1e9 myWave
|
||||
#hertzian_dipole: y 0.050 0 0.050 myWave
|
||||
#rx: 0.070 0 0.070
|
二进制文件未显示。
@@ -0,0 +1,8 @@
|
||||
#title: 2D test Ez, Hx, Hy components
|
||||
#domain: 0.100 0.100 0.001
|
||||
#dx_dy_dz: 0.001 0.001 0.001
|
||||
#time_window: 3e-9
|
||||
|
||||
#waveform: gaussiandot 1 1e9 myWave
|
||||
#hertzian_dipole: z 0.050 0.050 0 myWave
|
||||
#rx: 0.070 0.070 0
|
二进制文件未显示。
@@ -0,0 +1,13 @@
|
||||
#title: A-scan from a metal cylinder buried in a dielectric half-space
|
||||
#domain: 0.240 0.210 0.002
|
||||
#dx_dy_dz: 0.002 0.002 0.002
|
||||
#time_window: 3e-9
|
||||
|
||||
#material: 6 0 1 0 half_space
|
||||
|
||||
#waveform: ricker 1 1.5e9 my_ricker
|
||||
#hertzian_dipole: z 0.100 0.170 0 my_ricker
|
||||
#rx: 0.140 0.170 0
|
||||
|
||||
#box: 0 0 0 0.240 0.170 0.002 half_space
|
||||
#cylinder: 0.120 0.080 0 0.120 0.080 0.002 0.010 pec
|
二进制文件未显示。
@@ -0,0 +1,12 @@
|
||||
#title: Hertzian dipole in water
|
||||
#domain: 0.100 0.100 0.100
|
||||
#dx_dy_dz: 0.001 0.001 0.001
|
||||
#time_window: 3e-9
|
||||
|
||||
#waveform: gaussiandot 1 1e9 myWave
|
||||
#hertzian_dipole: z 0.050 0.050 0.050 myWave
|
||||
#rx: 0.070 0.070 0.070
|
||||
|
||||
#material: 4.9 0 1 0 myWater
|
||||
#add_dispersion_debye: 1 75.2 9.231e-12 myWater
|
||||
#box: 0 0 0 0.100 0.100 0.100 myWater
|
二进制文件未显示。
@@ -0,0 +1,8 @@
|
||||
#title: Hertzian dipole in free-space
|
||||
#domain: 0.100 0.100 0.100
|
||||
#dx_dy_dz: 0.001 0.001 0.001
|
||||
#time_window: 3e-9
|
||||
|
||||
#waveform: gaussiandot 1 1e9 myWave
|
||||
#hertzian_dipole: z 0.050 0.050 0.050 myWave
|
||||
#rx: 0.070 0.070 0.070
|
二进制文件未显示。
@@ -0,0 +1,11 @@
|
||||
#title: Hertzian dipole over a half-space
|
||||
#domain: 0.100 0.100 0.100
|
||||
#dx_dy_dz: 0.001 0.001 0.001
|
||||
#time_window: 3e-9
|
||||
|
||||
#waveform: gaussiandot 1 1e9 myWave
|
||||
#hertzian_dipole: z 0.050 0.050 0.050 myWave
|
||||
#rx: 0.070 0.070 0.070
|
||||
|
||||
#material: 8 0 1 0 half_space
|
||||
#box: 0 0 0 0.100 0.100 0.050 half_space
|
二进制文件未显示。
@@ -0,0 +1,8 @@
|
||||
#title: Magnetic dipole in free-space
|
||||
#domain: 0.100 0.100 0.100
|
||||
#dx_dy_dz: 0.001 0.001 0.001
|
||||
#time_window: 3e-9
|
||||
|
||||
#waveform: gaussiandot 1 1e9 myWave
|
||||
#magnetic_dipole: z 0.050 0.050 0.050 myWave
|
||||
#rx: 0.070 0.070 0.070
|
二进制文件未显示。
135
tests/test_models.py
普通文件
135
tests/test_models.py
普通文件
@@ -0,0 +1,135 @@
|
||||
# Copyright (C) 2015-2023: The University of Edinburgh, United Kingdom
|
||||
# Authors: Craig Warren, Antonis Giannopoulos, and John Hartley
|
||||
#
|
||||
# This file is part of gprMax.
|
||||
#
|
||||
# gprMax is free software: you can redistribute it and/or modify
|
||||
# it under the terms of the GNU General Public License as published by
|
||||
# the Free Software Foundation, either version 3 of the License, or
|
||||
# (at your option) any later version.
|
||||
#
|
||||
# gprMax is distributed in the hope that it will be useful,
|
||||
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
# GNU General Public License for more details.
|
||||
#
|
||||
# You should have received a copy of the GNU General Public License
|
||||
# along with gprMax. If not, see <http://www.gnu.org/licenses/>.
|
||||
|
||||
import logging
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
import h5py
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
import gprMax
|
||||
from testing.analytical_solutions import hertzian_dipole_fs
|
||||
from tests.utilities.data import get_data_from_h5_file, calculate_diffs
|
||||
from tests.utilities.plotting import plot_dataset_comparison, plot_diffs
|
||||
|
||||
from gprMax.utilities.logging import logging_config
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
logging_config(name=__name__)
|
||||
|
||||
if sys.platform == "linux":
|
||||
plt.switch_backend("agg")
|
||||
|
||||
|
||||
"""Compare field outputs
|
||||
|
||||
Usage:
|
||||
cd gprMax
|
||||
pytest tests/test_models.py
|
||||
"""
|
||||
|
||||
# Specify directory containing basic models to test
|
||||
BASIC_MODELS_DIRECTORY = Path(__file__).parent / "data" / "models_basic"
|
||||
|
||||
# List of available basic test models
|
||||
BASIC_MODELS = [
|
||||
"2D_ExHyHz",
|
||||
"2D_EyHxHz",
|
||||
"2D_EzHxHy",
|
||||
"cylinder_Ascan_2D",
|
||||
"hertzian_dipole_fs",
|
||||
"hertzian_dipole_hs",
|
||||
"hertzian_dipole_dispersive",
|
||||
"magnetic_dipole_fs",
|
||||
]
|
||||
|
||||
# Specify directory containing analytical models to test
|
||||
ANALYTICAL_MODELS_DIRECTORY = Path(__file__).parent / "data" / "models_analytical"
|
||||
|
||||
# List of available analytical models
|
||||
ANALYTICAL_MODELS = ["hertzian_dipole_fs_analytical"]
|
||||
|
||||
FIELD_COMPONENTS_BASE_PATH = "/rxs/rx1/"
|
||||
|
||||
|
||||
def run_test(model_name, input_base, data_directory, analytical_func=None, gpu=None, opencl=None):
|
||||
input_filepath = input_base.with_suffix(".in")
|
||||
reference_filepath = Path(f"{input_base}_ref.h5")
|
||||
|
||||
output_base = data_directory / model_name
|
||||
output_filepath = output_base.with_suffix(".h5")
|
||||
|
||||
# Run model
|
||||
gprMax.run(inputfile=input_filepath, outputfile=output_filepath, gpu=gpu, opencl=opencl)
|
||||
|
||||
test_time, test_data = get_data_from_h5_file(output_filepath)
|
||||
|
||||
if analytical_func is not None:
|
||||
ref_time = test_time
|
||||
ref_data = analytical_func(output_filepath)
|
||||
else:
|
||||
ref_time, ref_data = get_data_from_h5_file(reference_filepath)
|
||||
|
||||
fig1 = plot_dataset_comparison(test_time, test_data, ref_time, ref_data, model_name)
|
||||
fig1.savefig(output_base.with_suffix(".png"), dpi=150, format="png", bbox_inches="tight", pad_inches=0.1)
|
||||
|
||||
# Required to correctly calculate diffs
|
||||
assert test_time.shape == ref_time.shape
|
||||
assert np.all(test_time == ref_time)
|
||||
assert test_data.shape == ref_data.shape
|
||||
|
||||
data_diffs = calculate_diffs(test_data, ref_data)
|
||||
max_diff = round(np.max(data_diffs), 2)
|
||||
|
||||
fig2 = plot_diffs(test_time, data_diffs)
|
||||
fig2.savefig(Path(f"{output_base}_diffs.png"), dpi=150, format="png", bbox_inches="tight", pad_inches=0.1)
|
||||
|
||||
logger.info(f"Output data folder: {data_directory}")
|
||||
|
||||
assert max_diff <= 0
|
||||
|
||||
|
||||
def calc_hertzian_dipole_fs_analytical_solution(filepath):
|
||||
with h5py.File(filepath, "r") as file:
|
||||
# Tx/Rx position to feed to analytical solution
|
||||
rx_pos = file[FIELD_COMPONENTS_BASE_PATH].attrs["Position"]
|
||||
tx_pos = file["/srcs/src1/"].attrs["Position"]
|
||||
rx_pos_relative = ((rx_pos[0] - tx_pos[0]), (rx_pos[1] - tx_pos[1]), (rx_pos[2] - tx_pos[2]))
|
||||
|
||||
# Analytical solution of a dipole in free space
|
||||
data = hertzian_dipole_fs(
|
||||
file.attrs["Iterations"], file.attrs["dt"], file.attrs["dx_dy_dz"], rx_pos_relative
|
||||
)
|
||||
return data
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model", BASIC_MODELS)
|
||||
def test_basic_models(model, datadir):
|
||||
|
||||
base_filepath = Path(BASIC_MODELS_DIRECTORY, model, model)
|
||||
run_test(model, base_filepath, datadir)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model", ANALYTICAL_MODELS)
|
||||
def test_analyitical_models(model, datadir):
|
||||
|
||||
base_filepath = Path(ANALYTICAL_MODELS_DIRECTORY, model)
|
||||
run_test(model, base_filepath, datadir, analytical_func=calc_hertzian_dipole_fs_analytical_solution)
|
53
tests/utilities/data.py
普通文件
53
tests/utilities/data.py
普通文件
@@ -0,0 +1,53 @@
|
||||
import logging
|
||||
|
||||
import h5py
|
||||
import numpy as np
|
||||
|
||||
from gprMax.utilities.logging import logging_config
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
logging_config(name=__name__)
|
||||
|
||||
|
||||
FIELD_COMPONENTS_BASE_PATH = "/rxs/rx1/"
|
||||
|
||||
|
||||
def get_data_from_h5_file(h5_filepath):
|
||||
with h5py.File(h5_filepath, "r") as h5_file:
|
||||
# Get available field output component names and datatype
|
||||
field_components = list(h5_file[FIELD_COMPONENTS_BASE_PATH].keys())
|
||||
dtype = h5_file[FIELD_COMPONENTS_BASE_PATH + field_components[0]].dtype
|
||||
shape = h5_file[FIELD_COMPONENTS_BASE_PATH + str(field_components[0])].shape
|
||||
|
||||
# Arrays for storing field data
|
||||
if len(shape) == 1:
|
||||
data = np.zeros((h5_file.attrs["Iterations"], len(field_components)), dtype=dtype)
|
||||
else: # Merged B-scan data
|
||||
data = np.zeros((h5_file.attrs["Iterations"], len(field_components), shape[1]), dtype=dtype)
|
||||
for index, field_component in enumerate(field_components):
|
||||
data[:, index] = h5_file[FIELD_COMPONENTS_BASE_PATH + str(field_component)]
|
||||
if np.any(np.isnan(data[:, index])):
|
||||
logger.exception("Data contains NaNs")
|
||||
raise ValueError
|
||||
|
||||
max_time = (h5_file.attrs["Iterations"] - 1) * h5_file.attrs["dt"] / 1e-9
|
||||
time = np.linspace(0, max_time, num=h5_file.attrs["Iterations"])
|
||||
|
||||
return time, data
|
||||
|
||||
|
||||
def calculate_diffs(test_data, ref_data):
|
||||
diffs = np.zeros(test_data.shape, dtype=np.float64)
|
||||
for i in range(test_data.shape[1]):
|
||||
maxi = np.amax(np.abs(ref_data[:, i]))
|
||||
diffs[:, i] = np.divide(
|
||||
np.abs(ref_data[:, i] - test_data[:, i]), maxi, out=np.zeros_like(ref_data[:, i]), where=maxi != 0
|
||||
) # Replace any division by zero with zero
|
||||
|
||||
# Calculate power (ignore warning from taking a log of any zero values)
|
||||
with np.errstate(divide="ignore"):
|
||||
diffs[:, i] = 20 * np.log10(diffs[:, i])
|
||||
# Replace any NaNs or Infs from zero division
|
||||
diffs[:, i][np.invert(np.isfinite(diffs[:, i]))] = 0
|
||||
|
||||
return diffs
|
85
tests/utilities/plotting.py
普通文件
85
tests/utilities/plotting.py
普通文件
@@ -0,0 +1,85 @@
|
||||
from matplotlib import pyplot as plt
|
||||
import numpy as np
|
||||
|
||||
|
||||
def _plot_data(subplots, time, data, label=None, colour="r", line_style="-"):
|
||||
for i in range(data.shape[1]):
|
||||
subplots[i].plot(time, data[:, i], colour, lw=2, ls=line_style, label=label)
|
||||
|
||||
|
||||
def plot_dataset_comparison(test_time, test_data, ref_time, ref_data, model_name):
|
||||
fig, ((ex, hx), (ey, hy), (ez, hz)) = plt.subplots(
|
||||
nrows=3,
|
||||
ncols=2,
|
||||
sharex=False,
|
||||
sharey="col",
|
||||
subplot_kw=dict(xlabel="Time [ns]"),
|
||||
figsize=(20, 10),
|
||||
facecolor="w",
|
||||
edgecolor="w",
|
||||
)
|
||||
|
||||
subplots = [ex, ey, ez, hx, hy, hz]
|
||||
_plot_data(subplots, test_time, test_data, model_name)
|
||||
_plot_data(subplots, ref_time, ref_data, f"{model_name} (Ref)", "g", "--")
|
||||
|
||||
ylabels = [
|
||||
"$E_x$, field strength [V/m]",
|
||||
"$H_x$, field strength [A/m]",
|
||||
"$E_y$, field strength [V/m]",
|
||||
"$H_y$, field strength [A/m]",
|
||||
"$E_z$, field strength [V/m]",
|
||||
"$H_z$, field strength [A/m]",
|
||||
]
|
||||
|
||||
x_max = max(np.max(test_time), np.max(ref_time))
|
||||
for i, ax in enumerate(fig.axes):
|
||||
ax.set_ylabel(ylabels[i])
|
||||
ax.set_xlim(0, x_max)
|
||||
ax.grid()
|
||||
ax.legend()
|
||||
|
||||
return fig
|
||||
|
||||
|
||||
def plot_diffs(time, diffs, plot_min=-160):
|
||||
"""Plots ...
|
||||
|
||||
Args:
|
||||
time:
|
||||
diffs:
|
||||
plot_min: minimum value of difference to plot (dB). Default: -160
|
||||
|
||||
Returns:
|
||||
plt: matplotlib plot object.
|
||||
"""
|
||||
fig, ((ex, hx), (ey, hy), (ez, hz)) = plt.subplots(
|
||||
nrows=3,
|
||||
ncols=2,
|
||||
sharex=False,
|
||||
sharey="col",
|
||||
subplot_kw=dict(xlabel="Time [ns]"),
|
||||
figsize=(20, 10),
|
||||
facecolor="w",
|
||||
edgecolor="w",
|
||||
)
|
||||
_plot_data([ex, ey, ez, hx, hy, hz], time, diffs)
|
||||
|
||||
ylabels = [
|
||||
"$E_x$, difference [dB]",
|
||||
"$H_x$, difference [dB]",
|
||||
"$E_y$, difference [dB]",
|
||||
"$H_y$, difference [dB]",
|
||||
"$E_z$, difference [dB]",
|
||||
"$H_z$, difference [dB]",
|
||||
]
|
||||
|
||||
x_max = np.max(time)
|
||||
y_max = np.max(diffs)
|
||||
for i, ax in enumerate(fig.axes):
|
||||
ax.set_ylabel(ylabels[i])
|
||||
ax.set_xlim(0, x_max)
|
||||
ax.set_ylim(plot_min, y_max)
|
||||
ax.grid()
|
||||
|
||||
return fig
|
在新工单中引用
屏蔽一个用户