diff --git a/8_prediction.py b/8_prediction.py new file mode 100644 index 0000000..3f435e9 --- /dev/null +++ b/8_prediction.py @@ -0,0 +1,62 @@ +import os +import torch +from torch.utils.data import DataLoader +import numpy as np +import pandas as pd +import matplotlib.pyplot as plt +from Network.MyDataset import * +from utils.train_val_lr import train, validate, adjust_learning_rate +from Network.Model import * +from scipy.signal import butter, lfilter, sosfilt +from config import Network_prediction_Config as cfg +from config import Path_Config as pcfg +from utils.plot import plot_permittivity_constant + +# Set parameters +BATCH_SIZE = cfg.BATCH_SIZE +TEST_FILE = pcfg.PROCESSED_TEST_FILE +inversion_time_result_file=pcfg.inversion_time_result_file +inversion_time_result_img= pcfg.inversion_time_result_img +MODEL_PATH = pcfg.BEST_MODEL_PATH +impulse_field_file = pcfg.field_impulse +impulse_sim_file = pcfg.sim_impulse +smooth_window_size = cfg.smooth_window_size +max_permittivity=cfg.max_permittvity +initial_params = cfg.initial_params + + +# Load the pre-trained model +model = Model(inplanes=2, outplanes=1, layers=cfg.network_layers).cuda() +model.load_state_dict(torch.load(MODEL_PATH)) +model.eval() + +# Load dataset +dataset = MyDataset(TEST_FILE, TEST_FILE, impulse_field_file, impulse_sim_file, mode='apply', initial_params=initial_params) +test_loader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=cfg.num_workers, pin_memory=True) + +# Initialize results container +results = [] + +# Inference without gradient tracking +with torch.no_grad(): + for idx, (data_data, label) in enumerate(test_loader): + data_data = data_data.cuda() + + # Model inference + data_data_out = model(data_data.type(torch.cuda.FloatTensor)) + + # Move output to CPU and squeeze dimensions + data_data_out = data_data_out.cpu().numpy().squeeze() + results.append(data_data_out) + +# Concatenate and save final results +final_result = np.column_stack(results) + +# Apply smoothing filter (moving average) +final_result = np.apply_along_axis(lambda x: np.convolve(x, np.ones(smooth_window_size)/smooth_window_size, mode='valid'), axis=1, arr=final_result) +pd.DataFrame(final_result* max_permittivity).to_csv(inversion_time_result_file, index=False) +# Plot results +plot_permittivity_constant(final_result * max_permittivity, inversion_time_result_img, line_length=cfg.distance, time_length=cfg.time_window) + + +