import os import sys import torch from torchsummary import summary from torchvision.utils import make_grid, save_image import numpy as np import matplotlib.pyplot as plt import scipy.ndimage # Add parent directory to path for config import sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) from config import Network_train_Config as cfg def plot_BSCAN_data(data, path, line_length=100, time_length=200, ratio=1): """ Plot the inverted permittivity constant map and adjust the colormap range based on the ratio parameter. If ratio < 1, values exceeding ratio * max(abs(data)) will be saturated at the colormap maximum. :param data: 2D NumPy array (N, M), where N is time/depth and M is survey line direction :param path: Path to save the output image :param line_length: Survey line length in meters (default: 400m) :param time_length: Time range in nanoseconds (default: 200ns) :param ratio: Scaling factor for colormap range (default: 1) """ num_points, num_lines = data.shape # Compute the maximum absolute value for normalization max_abs = np.max(np.abs(data)) vmin = -ratio * max_abs vmax = ratio * max_abs # Set font style plt.rcParams.update({'font.family': 'Times New Roman', 'font.size': 20}) # Plot the permittivity image plt.figure(figsize=(10, 4)) im = plt.imshow(data, aspect='auto', cmap='gray', extent=[0, line_length, time_length, 0], vmin=vmin, vmax=vmax) # Configure axis labels and ticks plt.xlabel('Distance (m)', fontsize=20) plt.xticks([0, 20, 40, 60, 80, 100, line_length]) plt.ylabel('Time (ns)', fontsize=20) plt.yticks([0, 50, 100, 150, 200, time_length]) # Customize tick and border appearance plt.tick_params(axis='both', direction='in', width=1) for spine in plt.gca().spines.values(): spine.set_linewidth(1) # Remove grid lines, adjust layout, and save the figure plt.grid(False) plt.tight_layout() plt.savefig(path, dpi=300) plt.show() def plot_permittivity_constant(data, path, line_length=100, time_length=200): """ Plot the inverted permittivity constant. :param data: 2D NumPy array, shape (N, M), where N is depth (or time) and M is distance along the survey line. :param path: Path to save the output image. :param line_length: Survey line length (meters), default is 100m. :param time_length: Time range (nanoseconds), default is 200ns. """ num_points, num_lines = data.shape # Configure font settings plt.rcParams.update({'font.family': 'Times New Roman', 'font.size': 20}) # Plot the permittivity map plt.figure(figsize=(12, 4)) im = plt.imshow(data, aspect='auto', cmap='rainbow_r', extent=[0, 100, 200, 0], vmin=9, vmax=26) # Set axis labels and ticks plt.xlabel('Distance (m)', fontsize=20) plt.xticks([0, 20, 40, 60, 80, 100]) plt.ylabel('Time (ns)', fontsize=20) plt.yticks([0, 50, 100, 150, 200]) # Add colorbar cbar = plt.colorbar(im) cbar.set_label('Permittivity', fontsize=20) # Adjust axis formatting plt.tick_params(axis='both', direction='in', width=1) for spine in plt.gca().spines.values(): spine.set_linewidth(1) plt.grid(False) plt.tight_layout() plt.savefig(path, dpi=300) plt.show() def plot_depth_permittivity_constant(data, path, line_length=100, time_length=200): """ Plot the inverted permittivity constant as a 2D colormap. :param data: 2D NumPy array (depth/time, distance), representing permittivity values. :param path: File path to save the output image. :param line_length: Length of the survey line in meters (default: 100m). :param time_length: Time range in nanoseconds (default: 200ns). """ num_points, num_lines = data.shape # Set font properties plt.rcParams.update({'font.family': 'Times New Roman', 'font.size': 20}) # Create the plot plt.figure(figsize=(12, 4)) im = plt.imshow(data, aspect='auto', cmap='rainbow_r', extent=[0, line_length, time_length, 0], vmin=8, vmax=30) # Label axes and set ticks plt.xlabel('Distance (m)', fontsize=20) plt.xticks([0, 20, 40, 60, 80, 100]) plt.ylabel('Time (ns)', fontsize=20) plt.yticks([0, 2, 4, 6, 8]) # Add colorbar cbar = plt.colorbar(im) cbar.set_label('Permittivity', fontsize=20) # Format axis appearance plt.tick_params(axis='both', direction='in', width=1) for spine in plt.gca().spines.values(): spine.set_linewidth(1) plt.grid(False) plt.tight_layout() plt.savefig(path, dpi=300) plt.show()