import torch import torch.nn as nn from torch.utils.data import DataLoader from torch.utils.tensorboard import SummaryWriter from torchvision.utils import make_grid, save_image import time import pandas as pd from torchsummary import summary from Network.MyDataset import MyDataset from utils.train_val_lr import train, validate, adjust_learning_rate from Network.Model import Model from config import Network_train_Config as cfg from config import Path_Config as pcfg # File paths data_file = pcfg.dataset_path label_file = pcfg.labelset_path impulse_field_file = pcfg.field_impulse impulse_sim_file = pcfg.sim_impulse if __name__ == '__main__': BATCH_SIZE = cfg.BATCH_SIZE LR = cfg.LR # Initial learning rate EPOCHS = cfg.EPOCHS val_loss_min = cfg.val_loss_min # Initial high value for validation loss tracking # Model initialization model = Model(inplanes=2, outplanes=1, layers=cfg.network_layers).cuda() optimizer = torch.optim.Adam(model.parameters(), lr=LR) loss_func = nn.MSELoss(reduction='mean') # Create a CSV file to log training progress df = pd.DataFrame(columns=['epoch', 'train Loss', 'val Loss', 'learning rate']) df.to_csv(pcfg.train_val_loss, index=False) # Load dataset dataset = MyDataset(data_file, label_file, impulse_field_file, impulse_sim_file, mode='train', check=False) train_size = int(len(dataset) * cfg.dataset_Proportion) val_size = len(dataset) - train_size train_dataset, val_dataset = torch.utils.data.random_split(dataset, [train_size, val_size]) # Create data loaders train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=cfg.num_workers, pin_memory=True) val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=cfg.num_workers, pin_memory=True) # TensorBoard writer writer = SummaryWriter(r'Log') print('Dataset loaded successfully.') for epoch in range(EPOCHS): epoch_start_time = time.time() print(f'Starting epoch {epoch}') # Adjust learning rate LR = adjust_learning_rate(optimizer, epoch, LR) # Training and validation train_loss = train(train_loader, model, loss_func, optimizer, epoch) val_loss = validate(val_loader, model, loss_func) print(f'Validation Loss: {val_loss:.6f}') # Save best model if val_loss < val_loss_min: torch.save(model.state_dict(), pcfg.BEST_MODEL_PATH) val_loss_min = val_loss # Periodically save model if epoch % cfg.save_period == 0: torch.save(model.state_dict(), pcfg.LATEST_MODEL_PATH) # Log training details log_data = pd.DataFrame([[epoch, train_loss, val_loss, LR]]) log_data.to_csv(pcfg.train_val_loss, mode='a', header=False, index=False) print(f'Epoch {epoch} completed in {time.time() - epoch_start_time:.2f} seconds.')