你已经派生过 gpr-sidl-inv
镜像自地址
https://gitee.com/sduem/gpr-sidl-inv.git
已同步 2025-08-03 10:56:50 +08:00
79
7_network_train.py
普通文件
79
7_network_train.py
普通文件
@@ -0,0 +1,79 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.utils.data import DataLoader
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
from torchvision.utils import make_grid, save_image
|
||||
import time
|
||||
import pandas as pd
|
||||
from torchsummary import summary
|
||||
from Network.MyDataset import MyDataset
|
||||
from utils.train_val_lr import train, validate, adjust_learning_rate
|
||||
from Network.Model import Model
|
||||
from config import Network_train_Config as cfg
|
||||
from config import Path_Config as pcfg
|
||||
|
||||
# File paths
|
||||
data_file = pcfg.dataset_path
|
||||
label_file = pcfg.labelset_path
|
||||
impulse_field_file = pcfg.field_impulse
|
||||
impulse_sim_file = pcfg.sim_impulse
|
||||
|
||||
if __name__ == '__main__':
|
||||
BATCH_SIZE = cfg.BATCH_SIZE
|
||||
LR = cfg.LR # Initial learning rate
|
||||
EPOCHS = cfg.EPOCHS
|
||||
val_loss_min = cfg.val_loss_min # Initial high value for validation loss tracking
|
||||
|
||||
# Model initialization
|
||||
model = Model(inplanes=2, outplanes=1, layers=cfg.network_layers).cuda()
|
||||
optimizer = torch.optim.Adam(model.parameters(), lr=LR)
|
||||
loss_func = nn.MSELoss(reduction='mean')
|
||||
|
||||
# Create a CSV file to log training progress
|
||||
df = pd.DataFrame(columns=['epoch', 'train Loss', 'val Loss', 'learning rate'])
|
||||
df.to_csv(pcfg.train_val_loss, index=False)
|
||||
|
||||
# Load dataset
|
||||
dataset = MyDataset(data_file, label_file, impulse_field_file, impulse_sim_file, mode='train', check=False)
|
||||
train_size = int(len(dataset) * cfg.dataset_Proportion)
|
||||
val_size = len(dataset) - train_size
|
||||
train_dataset, val_dataset = torch.utils.data.random_split(dataset, [train_size, val_size])
|
||||
|
||||
# Create data loaders
|
||||
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=cfg.num_workers, pin_memory=True)
|
||||
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=cfg.num_workers, pin_memory=True)
|
||||
|
||||
# TensorBoard writer
|
||||
writer = SummaryWriter(r'Log')
|
||||
|
||||
print('Dataset loaded successfully.')
|
||||
|
||||
for epoch in range(EPOCHS):
|
||||
epoch_start_time = time.time()
|
||||
print(f'Starting epoch {epoch}')
|
||||
|
||||
# Adjust learning rate
|
||||
LR = adjust_learning_rate(optimizer, epoch, LR)
|
||||
|
||||
# Training and validation
|
||||
train_loss = train(train_loader, model, loss_func, optimizer, epoch)
|
||||
val_loss = validate(val_loader, model, loss_func)
|
||||
|
||||
print(f'Validation Loss: {val_loss:.6f}')
|
||||
|
||||
# Save best model
|
||||
if val_loss < val_loss_min:
|
||||
torch.save(model.state_dict(), pcfg.BEST_MODEL_PATH)
|
||||
val_loss_min = val_loss
|
||||
|
||||
# Periodically save model
|
||||
if epoch % cfg.save_period == 0:
|
||||
torch.save(model.state_dict(), pcfg.LATEST_MODEL_PATH)
|
||||
|
||||
# Log training details
|
||||
log_data = pd.DataFrame([[epoch, train_loss, val_loss, LR]])
|
||||
log_data.to_csv(pcfg.train_val_loss, mode='a', header=False, index=False)
|
||||
|
||||
print(f'Epoch {epoch} completed in {time.time() - epoch_start_time:.2f} seconds.')
|
||||
|
||||
|
在新工单中引用
屏蔽一个用户