文件
gpr-sidl-inv/8_prediction.py
葛峻恺 699f32f283 program
Signed-off-by: 葛峻恺 <202115006@mail.sdu.edu.cn>
2025-04-07 12:17:39 +00:00

63 行
2.3 KiB
Python

import os
import torch
from torch.utils.data import DataLoader
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from Network.MyDataset import *
from utils.train_val_lr import train, validate, adjust_learning_rate
from Network.Model import *
from scipy.signal import butter, lfilter, sosfilt
from config import Network_prediction_Config as cfg
from config import Path_Config as pcfg
from utils.plot import plot_permittivity_constant
# Set parameters
BATCH_SIZE = cfg.BATCH_SIZE
TEST_FILE = pcfg.PROCESSED_TEST_FILE
inversion_time_result_file=pcfg.inversion_time_result_file
inversion_time_result_img= pcfg.inversion_time_result_img
MODEL_PATH = pcfg.LATEST_MODEL_PATH
impulse_field_file = pcfg.field_impulse
impulse_sim_file = pcfg.sim_impulse
smooth_window_size = cfg.smooth_window_size
max_permittivity=cfg.max_permittvity
initial_params = cfg.initial_params
# Load the pre-trained model
model = Model(inplanes=2, outplanes=1, layers=cfg.network_layers).cuda()
model.load_state_dict(torch.load(MODEL_PATH))
model.eval()
# Load dataset
dataset = MyDataset(TEST_FILE, TEST_FILE, impulse_field_file, impulse_sim_file, mode='apply', initial_params=initial_params)
test_loader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=cfg.num_workers, pin_memory=True)
# Initialize results container
results = []
# Inference without gradient tracking
with torch.no_grad():
for idx, (data_data, label) in enumerate(test_loader):
data_data = data_data.cuda()
# Model inference
data_data_out = model(data_data.type(torch.cuda.FloatTensor))
# Move output to CPU and squeeze dimensions
data_data_out = data_data_out.cpu().numpy().squeeze()
results.append(data_data_out)
# Concatenate and save final results
final_result = np.column_stack(results)
# Apply smoothing filter (moving average)
final_result = np.apply_along_axis(lambda x: np.convolve(x, np.ones(smooth_window_size)/smooth_window_size, mode='valid'), axis=1, arr=final_result)
pd.DataFrame(final_result* max_permittivity).to_csv(inversion_time_result_file, index=False)
# Plot results
plot_permittivity_constant(final_result * max_permittivity, inversion_time_result_img, line_length=cfg.distance, time_length=cfg.time_window)