import os import torch from torch.utils.data import DataLoader import numpy as np import pandas as pd import matplotlib.pyplot as plt from Network.MyDataset import * from utils.train_val_lr import train, validate, adjust_learning_rate from Network.Model import * from scipy.signal import butter, lfilter, sosfilt from config import Network_prediction_Config as cfg from config import Path_Config as pcfg from utils.plot import plot_permittivity_constant # Set parameters BATCH_SIZE = cfg.BATCH_SIZE TEST_FILE = pcfg.PROCESSED_TEST_FILE inversion_time_result_file=pcfg.inversion_time_result_file inversion_time_result_img= pcfg.inversion_time_result_img MODEL_PATH = pcfg.LATEST_MODEL_PATH impulse_field_file = pcfg.field_impulse impulse_sim_file = pcfg.sim_impulse smooth_window_size = cfg.smooth_window_size max_permittivity=cfg.max_permittvity initial_params = cfg.initial_params # Load the pre-trained model model = Model(inplanes=2, outplanes=1, layers=cfg.network_layers).cuda() model.load_state_dict(torch.load(MODEL_PATH)) model.eval() # Load dataset dataset = MyDataset(TEST_FILE, TEST_FILE, impulse_field_file, impulse_sim_file, mode='apply', initial_params=initial_params) test_loader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=cfg.num_workers, pin_memory=True) # Initialize results container results = [] # Inference without gradient tracking with torch.no_grad(): for idx, (data_data, label) in enumerate(test_loader): data_data = data_data.cuda() # Model inference data_data_out = model(data_data.type(torch.cuda.FloatTensor)) # Move output to CPU and squeeze dimensions data_data_out = data_data_out.cpu().numpy().squeeze() results.append(data_data_out) # Concatenate and save final results final_result = np.column_stack(results) # Apply smoothing filter (moving average) final_result = np.apply_along_axis(lambda x: np.convolve(x, np.ones(smooth_window_size)/smooth_window_size, mode='valid'), axis=1, arr=final_result) pd.DataFrame(final_result* max_permittivity).to_csv(inversion_time_result_file, index=False) # Plot results plot_permittivity_constant(final_result * max_permittivity, inversion_time_result_img, line_length=cfg.distance, time_length=cfg.time_window)