你已经派生过 gpr-sidl-inv
镜像自地址
https://gitee.com/sduem/gpr-sidl-inv.git
已同步 2025-08-02 18:36:51 +08:00
80 行
3.0 KiB
Python
80 行
3.0 KiB
Python
import torch
|
|
import torch.nn as nn
|
|
from torch.utils.data import DataLoader
|
|
from torch.utils.tensorboard import SummaryWriter
|
|
from torchvision.utils import make_grid, save_image
|
|
import time
|
|
import pandas as pd
|
|
from torchsummary import summary
|
|
from Network.MyDataset import MyDataset
|
|
from utils.train_val_lr import train, validate, adjust_learning_rate
|
|
from Network.Model import Model
|
|
from config import Network_train_Config as cfg
|
|
from config import Path_Config as pcfg
|
|
|
|
# File paths
|
|
data_file = pcfg.dataset_path
|
|
label_file = pcfg.labelset_path
|
|
impulse_field_file = pcfg.field_impulse
|
|
impulse_sim_file = pcfg.sim_impulse
|
|
|
|
if __name__ == '__main__':
|
|
BATCH_SIZE = cfg.BATCH_SIZE
|
|
LR = cfg.LR # Initial learning rate
|
|
EPOCHS = cfg.EPOCHS
|
|
val_loss_min = cfg.val_loss_min # Initial high value for validation loss tracking
|
|
|
|
# Model initialization
|
|
model = Model(inplanes=2, outplanes=1, layers=cfg.network_layers).cuda()
|
|
optimizer = torch.optim.Adam(model.parameters(), lr=LR)
|
|
loss_func = nn.MSELoss(reduction='mean')
|
|
|
|
# Create a CSV file to log training progress
|
|
df = pd.DataFrame(columns=['epoch', 'train Loss', 'val Loss', 'learning rate'])
|
|
df.to_csv(pcfg.train_val_loss, index=False)
|
|
|
|
# Load dataset
|
|
dataset = MyDataset(data_file, label_file, impulse_field_file, impulse_sim_file, mode='train', check=False)
|
|
train_size = int(len(dataset) * cfg.dataset_Proportion)
|
|
val_size = len(dataset) - train_size
|
|
train_dataset, val_dataset = torch.utils.data.random_split(dataset, [train_size, val_size])
|
|
|
|
# Create data loaders
|
|
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=cfg.num_workers, pin_memory=True)
|
|
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=cfg.num_workers, pin_memory=True)
|
|
|
|
# TensorBoard writer
|
|
writer = SummaryWriter(r'Log')
|
|
|
|
print('Dataset loaded successfully.')
|
|
|
|
for epoch in range(EPOCHS):
|
|
epoch_start_time = time.time()
|
|
print(f'Starting epoch {epoch}')
|
|
|
|
# Adjust learning rate
|
|
LR = adjust_learning_rate(optimizer, epoch, LR)
|
|
|
|
# Training and validation
|
|
train_loss = train(train_loader, model, loss_func, optimizer, epoch)
|
|
val_loss = validate(val_loader, model, loss_func)
|
|
|
|
print(f'Validation Loss: {val_loss:.6f}')
|
|
|
|
# Save best model
|
|
if val_loss < val_loss_min:
|
|
torch.save(model.state_dict(), pcfg.BEST_MODEL_PATH)
|
|
val_loss_min = val_loss
|
|
|
|
# Periodically save model
|
|
if epoch % cfg.save_period == 0:
|
|
torch.save(model.state_dict(), pcfg.LATEST_MODEL_PATH)
|
|
|
|
# Log training details
|
|
log_data = pd.DataFrame([[epoch, train_loss, val_loss, LR]])
|
|
log_data.to_csv(pcfg.train_val_loss, mode='a', header=False, index=False)
|
|
|
|
print(f'Epoch {epoch} completed in {time.time() - epoch_start_time:.2f} seconds.')
|
|
|
|
|