你已经派生过 gprMax
镜像自地址
https://gitee.com/sunhf/gprMax.git
已同步 2025-08-07 23:14:03 +08:00
Add reference data to output plot and plot diffs
这个提交包含在:
@@ -1,7 +1,9 @@
|
||||
import logging
|
||||
from typing import Tuple
|
||||
|
||||
import h5py
|
||||
import numpy as np
|
||||
import numpy.typing as npt
|
||||
|
||||
from gprMax.utilities.logging import logging_config
|
||||
|
||||
@@ -12,7 +14,7 @@ logging_config(name=__name__)
|
||||
FIELD_COMPONENTS_BASE_PATH = "/rxs/rx1/"
|
||||
|
||||
|
||||
def get_data_from_h5_file(h5_filepath):
|
||||
def get_data_from_h5_file(h5_filepath: str) -> Tuple[npt.NDArray, npt.NDArray]:
|
||||
with h5py.File(h5_filepath, "r") as h5_file:
|
||||
# Get available field output component names and datatype
|
||||
field_components = list(h5_file[FIELD_COMPONENTS_BASE_PATH].keys())
|
||||
@@ -23,7 +25,9 @@ def get_data_from_h5_file(h5_filepath):
|
||||
if len(shape) == 1:
|
||||
data = np.zeros((h5_file.attrs["Iterations"], len(field_components)), dtype=dtype)
|
||||
else: # Merged B-scan data
|
||||
data = np.zeros((h5_file.attrs["Iterations"], len(field_components), shape[1]), dtype=dtype)
|
||||
data = np.zeros(
|
||||
(h5_file.attrs["Iterations"], len(field_components), shape[1]), dtype=dtype
|
||||
)
|
||||
for index, field_component in enumerate(field_components):
|
||||
data[:, index] = h5_file[FIELD_COMPONENTS_BASE_PATH + str(field_component)]
|
||||
if np.any(np.isnan(data[:, index])):
|
||||
@@ -36,18 +40,21 @@ def get_data_from_h5_file(h5_filepath):
|
||||
return time, data
|
||||
|
||||
|
||||
def calculate_diffs(test_data, ref_data):
|
||||
def calculate_diffs(test_data: npt.NDArray, ref_data: npt.NDArray) -> npt.NDArray:
|
||||
diffs = np.zeros(test_data.shape, dtype=np.float64)
|
||||
for i in range(test_data.shape[1]):
|
||||
maxi = np.amax(np.abs(ref_data[:, i]))
|
||||
diffs[:, i] = np.divide(
|
||||
np.abs(ref_data[:, i] - test_data[:, i]), maxi, out=np.zeros_like(ref_data[:, i]), where=maxi != 0
|
||||
np.abs(ref_data[:, i] - test_data[:, i]),
|
||||
maxi,
|
||||
out=np.zeros_like(ref_data[:, i]),
|
||||
where=maxi != 0,
|
||||
) # Replace any division by zero with zero
|
||||
|
||||
# Calculate power (ignore warning from taking a log of any zero values)
|
||||
with np.errstate(divide="ignore"):
|
||||
diffs[:, i] = 20 * np.log10(diffs[:, i])
|
||||
# Replace any NaNs or Infs from zero division
|
||||
diffs[:, i][np.invert(np.isfinite(diffs[:, i]))] = 0
|
||||
# diffs[:, i][np.invert(np.isfinite(diffs[:, i]))] = 0
|
||||
|
||||
return diffs
|
||||
|
@@ -1,6 +1,10 @@
|
||||
import argparse
|
||||
|
||||
import numpy as np
|
||||
from matplotlib import pyplot as plt
|
||||
|
||||
from reframe_tests.utilities.data import calculate_diffs, get_data_from_h5_file
|
||||
|
||||
|
||||
def _plot_data(subplots, time, data, label=None, colour="r", line_style="-"):
|
||||
for i in range(data.shape[1]):
|
||||
@@ -76,6 +80,10 @@ def plot_diffs(time, diffs, plot_min=-160):
|
||||
|
||||
x_max = np.max(time)
|
||||
y_max = np.max(diffs)
|
||||
|
||||
if not np.isfinite(y_max):
|
||||
y_max = 0
|
||||
|
||||
for i, ax in enumerate(fig.axes):
|
||||
ax.set_ylabel(ylabels[i])
|
||||
ax.set_xlim(0, x_max)
|
||||
@@ -83,3 +91,24 @@ def plot_diffs(time, diffs, plot_min=-160):
|
||||
ax.grid()
|
||||
|
||||
return fig
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
||||
parser.add_argument("input_file", help="Path to input file")
|
||||
parser.add_argument("reference_file", help="Path to reference file")
|
||||
parser.add_argument("-model-name", "-name", "-n", help="Name of the model", default="model")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
input_time, input_data = get_data_from_h5_file(args.input_file)
|
||||
ref_time, ref_data = get_data_from_h5_file(args.reference_file)
|
||||
|
||||
figure = plot_dataset_comparison(input_time, input_data, ref_time, ref_data, args.model_name)
|
||||
figure.tight_layout(h_pad=3, w_pad=4, pad=2)
|
||||
figure.savefig(f"{args.model_name}.pdf", dpi=300)
|
||||
|
||||
diffs = calculate_diffs(input_data, ref_data)
|
||||
figure = plot_diffs(input_time, diffs)
|
||||
figure.tight_layout(h_pad=3, w_pad=4, pad=2)
|
||||
figure.savefig(f"{args.model_name}_diffs.pdf", dpi=300)
|
||||
|
在新工单中引用
屏蔽一个用户