你已经派生过 gprMax
镜像自地址
https://gitee.com/sunhf/gprMax.git
已同步 2025-08-07 15:10:13 +08:00
Add regression tests for basic and analytical models
这个提交包含在:
@@ -17,6 +17,7 @@
|
||||
# along with gprMax. If not, see <http://www.gnu.org/licenses/>.
|
||||
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
@@ -70,6 +71,24 @@ ANALYTICAL_MODELS = ["hertzian_dipole_fs_analytical"]
|
||||
FIELD_COMPONENTS_BASE_PATH = "/rxs/rx1/"
|
||||
|
||||
|
||||
def create_ascan_comparison_plots(test_time, test_data, ref_time, ref_data, model_name, output_base):
|
||||
fig1 = plot_dataset_comparison(test_time, test_data, ref_time, ref_data, model_name)
|
||||
fig1.savefig(output_base.with_suffix(".png"), dpi=150, format="png", bbox_inches="tight", pad_inches=0.1)
|
||||
|
||||
# Required to correctly calculate diffs
|
||||
assert test_time.shape == ref_time.shape
|
||||
assert np.all(test_time == ref_time)
|
||||
assert test_data.shape == ref_data.shape
|
||||
|
||||
data_diffs = calculate_diffs(test_data, ref_data)
|
||||
|
||||
fig2 = plot_diffs(test_time, data_diffs)
|
||||
fig2.savefig(Path(f"{output_base}_diffs.png"), dpi=150, format="png", bbox_inches="tight", pad_inches=0.1)
|
||||
|
||||
logger.info(f"Output data folder: {output_base.parent}")
|
||||
|
||||
|
||||
|
||||
def run_test(model_name, input_base, data_directory, analytical_func=None, gpu=None, opencl=None):
|
||||
input_filepath = input_base.with_suffix(".in")
|
||||
reference_filepath = Path(f"{input_base}_ref.h5")
|
||||
@@ -88,25 +107,39 @@ def run_test(model_name, input_base, data_directory, analytical_func=None, gpu=N
|
||||
else:
|
||||
ref_time, ref_data = get_data_from_h5_file(reference_filepath)
|
||||
|
||||
fig1 = plot_dataset_comparison(test_time, test_data, ref_time, ref_data, model_name)
|
||||
fig1.savefig(output_base.with_suffix(".png"), dpi=150, format="png", bbox_inches="tight", pad_inches=0.1)
|
||||
|
||||
# Required to correctly calculate diffs
|
||||
assert test_time.shape == ref_time.shape
|
||||
assert np.all(test_time == ref_time)
|
||||
assert test_data.shape == ref_data.shape
|
||||
|
||||
create_ascan_comparison_plots(test_time, test_data, ref_time, ref_data, model_name, output_base)
|
||||
|
||||
data_diffs = calculate_diffs(test_data, ref_data)
|
||||
max_diff = round(np.max(data_diffs), 2)
|
||||
|
||||
fig2 = plot_diffs(test_time, data_diffs)
|
||||
fig2.savefig(Path(f"{output_base}_diffs.png"), dpi=150, format="png", bbox_inches="tight", pad_inches=0.1)
|
||||
|
||||
logger.info(f"Output data folder: {data_directory}")
|
||||
|
||||
assert max_diff <= 0
|
||||
|
||||
|
||||
def run_regression_test(request, ndarrays_regression, model_name, input_base, data_directory, gpu=None, opencl=None):
|
||||
input_filepath = input_base.with_suffix(".in")
|
||||
|
||||
output_dir = data_directory / request.node.name
|
||||
output_dir.mkdir(exist_ok=True)
|
||||
output_base = output_dir / model_name
|
||||
output_filepath = output_base.with_suffix(".h5")
|
||||
reference_filepath = output_base.with_suffix(".npz")
|
||||
|
||||
# Run model
|
||||
gprMax.run(inputfile=input_filepath, outputfile=output_filepath, gpu=gpu, opencl=opencl)
|
||||
|
||||
test_time, test_data = get_data_from_h5_file(output_filepath)
|
||||
|
||||
# May not exist if first time running the regression test
|
||||
if os.path.exists(reference_filepath):
|
||||
reference_file = np.load(reference_filepath)
|
||||
|
||||
ref_time = reference_file["time"]
|
||||
ref_data = reference_file["data"]
|
||||
|
||||
create_ascan_comparison_plots(test_time, test_data, ref_time, ref_data, model_name, output_base)
|
||||
|
||||
ndarrays_regression.check({"time": test_time, "data": test_data}, basename=os.path.relpath(output_base, data_directory))
|
||||
|
||||
|
||||
def calc_hertzian_dipole_fs_analytical_solution(filepath):
|
||||
with h5py.File(filepath, "r") as file:
|
||||
# Tx/Rx position to feed to analytical solution
|
||||
@@ -129,7 +162,21 @@ def test_basic_models(model, datadir):
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model", ANALYTICAL_MODELS)
|
||||
def test_analyitical_models(model, datadir):
|
||||
def test_analyitical_models(datadir, model):
|
||||
|
||||
base_filepath = Path(ANALYTICAL_MODELS_DIRECTORY, model)
|
||||
run_test(model, base_filepath, datadir, analytical_func=calc_hertzian_dipole_fs_analytical_solution)
|
||||
run_test(model, base_filepath, datadir, analytical_func=calc_hertzian_dipole_fs_analytical_solution)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model", BASIC_MODELS)
|
||||
def test_basic_models_regression(request, ndarrays_regression, datadir, model):
|
||||
|
||||
base_filepath = Path(BASIC_MODELS_DIRECTORY, model, model)
|
||||
run_regression_test(request, ndarrays_regression, model, base_filepath, datadir)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model", ANALYTICAL_MODELS)
|
||||
def test_analytical_models_regression(request, ndarrays_regression, datadir, model):
|
||||
|
||||
base_filepath = Path(ANALYTICAL_MODELS_DIRECTORY, model)
|
||||
run_regression_test(request, ndarrays_regression, model, base_filepath, datadir)
|
二进制文件未显示。
二进制文件未显示。
二进制文件未显示。
二进制文件未显示。
二进制文件未显示。
二进制文件未显示。
二进制文件未显示。
二进制文件未显示。
在新工单中引用
屏蔽一个用户