文件
gpr-sidl-inv/7_network_train.py
葛峻恺 699f32f283 program
Signed-off-by: 葛峻恺 <202115006@mail.sdu.edu.cn>
2025-04-07 12:17:39 +00:00

80 行
3.0 KiB
Python

import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from torchvision.utils import make_grid, save_image
import time
import pandas as pd
from torchsummary import summary
from Network.MyDataset import MyDataset
from utils.train_val_lr import train, validate, adjust_learning_rate
from Network.Model import Model
from config import Network_train_Config as cfg
from config import Path_Config as pcfg
# File paths
data_file = pcfg.dataset_path
label_file = pcfg.labelset_path
impulse_field_file = pcfg.field_impulse
impulse_sim_file = pcfg.sim_impulse
if __name__ == '__main__':
BATCH_SIZE = cfg.BATCH_SIZE
LR = cfg.LR # Initial learning rate
EPOCHS = cfg.EPOCHS
val_loss_min = cfg.val_loss_min # Initial high value for validation loss tracking
# Model initialization
model = Model(inplanes=2, outplanes=1, layers=cfg.network_layers).cuda()
optimizer = torch.optim.Adam(model.parameters(), lr=LR)
loss_func = nn.MSELoss(reduction='mean')
# Create a CSV file to log training progress
df = pd.DataFrame(columns=['epoch', 'train Loss', 'val Loss', 'learning rate'])
df.to_csv(pcfg.train_val_loss, index=False)
# Load dataset
dataset = MyDataset(data_file, label_file, impulse_field_file, impulse_sim_file, mode='train', check=False)
train_size = int(len(dataset) * cfg.dataset_Proportion)
val_size = len(dataset) - train_size
train_dataset, val_dataset = torch.utils.data.random_split(dataset, [train_size, val_size])
# Create data loaders
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=cfg.num_workers, pin_memory=True)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=cfg.num_workers, pin_memory=True)
# TensorBoard writer
writer = SummaryWriter(r'Log')
print('Dataset loaded successfully.')
for epoch in range(EPOCHS):
epoch_start_time = time.time()
print(f'Starting epoch {epoch}')
# Adjust learning rate
LR = adjust_learning_rate(optimizer, epoch, LR)
# Training and validation
train_loss = train(train_loader, model, loss_func, optimizer, epoch)
val_loss = validate(val_loader, model, loss_func)
print(f'Validation Loss: {val_loss:.6f}')
# Save best model
if val_loss < val_loss_min:
torch.save(model.state_dict(), pcfg.BEST_MODEL_PATH)
val_loss_min = val_loss
# Periodically save model
if epoch % cfg.save_period == 0:
torch.save(model.state_dict(), pcfg.LATEST_MODEL_PATH)
# Log training details
log_data = pd.DataFrame([[epoch, train_loss, val_loss, LR]])
log_data.to_csv(pcfg.train_val_loss, mode='a', header=False, index=False)
print(f'Epoch {epoch} completed in {time.time() - epoch_start_time:.2f} seconds.')