你已经派生过 gpr-sidl-inv
镜像自地址
https://gitee.com/sduem/gpr-sidl-inv.git
已同步 2025-08-03 18:56:51 +08:00
62
8_prediction.py
普通文件
62
8_prediction.py
普通文件
@@ -0,0 +1,62 @@
|
||||
import os
|
||||
import torch
|
||||
from torch.utils.data import DataLoader
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import matplotlib.pyplot as plt
|
||||
from Network.MyDataset import *
|
||||
from utils.train_val_lr import train, validate, adjust_learning_rate
|
||||
from Network.Model import *
|
||||
from scipy.signal import butter, lfilter, sosfilt
|
||||
from config import Network_prediction_Config as cfg
|
||||
from config import Path_Config as pcfg
|
||||
from utils.plot import plot_permittivity_constant
|
||||
|
||||
# Set parameters
|
||||
BATCH_SIZE = cfg.BATCH_SIZE
|
||||
TEST_FILE = pcfg.PROCESSED_TEST_FILE
|
||||
inversion_time_result_file=pcfg.inversion_time_result_file
|
||||
inversion_time_result_img= pcfg.inversion_time_result_img
|
||||
MODEL_PATH = pcfg.LATEST_MODEL_PATH
|
||||
impulse_field_file = pcfg.field_impulse
|
||||
impulse_sim_file = pcfg.sim_impulse
|
||||
smooth_window_size = cfg.smooth_window_size
|
||||
max_permittivity=cfg.max_permittvity
|
||||
initial_params = cfg.initial_params
|
||||
|
||||
|
||||
# Load the pre-trained model
|
||||
model = Model(inplanes=2, outplanes=1, layers=cfg.network_layers).cuda()
|
||||
model.load_state_dict(torch.load(MODEL_PATH))
|
||||
model.eval()
|
||||
|
||||
# Load dataset
|
||||
dataset = MyDataset(TEST_FILE, TEST_FILE, impulse_field_file, impulse_sim_file, mode='apply', initial_params=initial_params)
|
||||
test_loader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=cfg.num_workers, pin_memory=True)
|
||||
|
||||
# Initialize results container
|
||||
results = []
|
||||
|
||||
# Inference without gradient tracking
|
||||
with torch.no_grad():
|
||||
for idx, (data_data, label) in enumerate(test_loader):
|
||||
data_data = data_data.cuda()
|
||||
|
||||
# Model inference
|
||||
data_data_out = model(data_data.type(torch.cuda.FloatTensor))
|
||||
|
||||
# Move output to CPU and squeeze dimensions
|
||||
data_data_out = data_data_out.cpu().numpy().squeeze()
|
||||
results.append(data_data_out)
|
||||
|
||||
# Concatenate and save final results
|
||||
final_result = np.column_stack(results)
|
||||
|
||||
# Apply smoothing filter (moving average)
|
||||
final_result = np.apply_along_axis(lambda x: np.convolve(x, np.ones(smooth_window_size)/smooth_window_size, mode='valid'), axis=1, arr=final_result)
|
||||
pd.DataFrame(final_result* max_permittivity).to_csv(inversion_time_result_file, index=False)
|
||||
# Plot results
|
||||
plot_permittivity_constant(final_result * max_permittivity, inversion_time_result_img, line_length=cfg.distance, time_length=cfg.time_window)
|
||||
|
||||
|
||||
|
在新工单中引用
屏蔽一个用户