diff --git a/7_network_train.py b/7_network_train.py new file mode 100644 index 0000000..4fc0534 --- /dev/null +++ b/7_network_train.py @@ -0,0 +1,86 @@ + +import torch +import torch.nn as nn +from torch.utils.data import DataLoader +from torch.utils.tensorboard import SummaryWriter +from torchvision.utils import make_grid, save_image +import time +import pandas as pd +from torchsummary import summary +from Network.MyDataset import MyDataset +from utils.train_val_lr import train, validate, adjust_learning_rate +from Network.Model import Model +from config import Network_train_Config as cfg +from config import Path_Config as pcfg + +# File paths +data_file = pcfg.dataset_path +label_file = pcfg.labelset_path +impulse_field_file = pcfg.field_impulse +impulse_sim_file = pcfg.sim_impulse + +if __name__ == '__main__': + BATCH_SIZE = cfg.BATCH_SIZE + LR = cfg.LR # Initial learning rate + EPOCHS = cfg.EPOCHS + val_loss_min = cfg.val_loss_min # Initial high value for validation loss tracking + + # Model initialization + model = Model( + inplanes=2, + outplanes=1, + layers=cfg.network_layers, + cycle_weight=getattr(cfg, "cycle_weight", 1.0), + cycle_loss='mse' + ).cuda() + optimizer = torch.optim.Adam(model.parameters(), lr=LR) + loss_func = nn.MSELoss(reduction='mean') + + # Create a CSV file to log training progress + df = pd.DataFrame(columns=['epoch', 'train Loss', 'val Loss', 'learning rate']) + df.to_csv(pcfg.train_val_loss, index=False) + + # Load dataset + dataset = MyDataset(data_file, label_file, impulse_field_file, impulse_sim_file, mode='train', check=False) + train_size = int(len(dataset) * cfg.dataset_Proportion) + val_size = len(dataset) - train_size + train_dataset, val_dataset = torch.utils.data.random_split(dataset, [train_size, val_size]) + + # Create data loaders + train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=cfg.num_workers, pin_memory=True) + val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=cfg.num_workers, pin_memory=True) + + # TensorBoard writer + writer = SummaryWriter(r'Log') + + print('Dataset loaded successfully.') + + for epoch in range(EPOCHS): + epoch_start_time = time.time() + print(f'Starting epoch {epoch}') + + # Adjust learning rate + LR = adjust_learning_rate(optimizer, epoch, LR) + + # Training and validation + train_loss = train(train_loader, model, loss_func, optimizer, epoch) + val_loss = validate(val_loader, model, loss_func) + + print(f'Validation Loss: {val_loss:.6f}') + + # Save best model + if val_loss < val_loss_min: + torch.save(model.state_dict(), pcfg.BEST_MODEL_PATH) + val_loss_min = val_loss + + # Periodically save model + if epoch % cfg.save_period == 0: + torch.save(model.state_dict(), pcfg.LATEST_MODEL_PATH) + + # Log training details + log_data = pd.DataFrame([[epoch, train_loss, val_loss, LR]]) + log_data.to_csv(pcfg.train_val_loss, mode='a', header=False, index=False) + + print(f'Epoch {epoch} completed in {time.time() - epoch_start_time:.2f} seconds.') + +