From b5d9f1ebc4df28ed2b74a5a7f40a4a71d36e84ad Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=91=9B=E5=B3=BB=E6=81=BA?= <202115006@mail.sdu.edu.cn> Date: Tue, 2 Sep 2025 12:02:51 +0000 Subject: [PATCH] =?UTF-8?q?=E5=88=A0=E9=99=A4=E6=96=87=E4=BB=B6=207=5Fnetw?= =?UTF-8?q?ork=5Ftrain.py?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- 7_network_train.py | 79 ---------------------------------------------- 1 file changed, 79 deletions(-) delete mode 100644 7_network_train.py diff --git a/7_network_train.py b/7_network_train.py deleted file mode 100644 index 8ebf28e..0000000 --- a/7_network_train.py +++ /dev/null @@ -1,79 +0,0 @@ -import torch -import torch.nn as nn -from torch.utils.data import DataLoader -from torch.utils.tensorboard import SummaryWriter -from torchvision.utils import make_grid, save_image -import time -import pandas as pd -from torchsummary import summary -from Network.MyDataset import MyDataset -from utils.train_val_lr import train, validate, adjust_learning_rate -from Network.Model import Model -from config import Network_train_Config as cfg -from config import Path_Config as pcfg - -# File paths -data_file = pcfg.dataset_path -label_file = pcfg.labelset_path -impulse_field_file = pcfg.field_impulse -impulse_sim_file = pcfg.sim_impulse - -if __name__ == '__main__': - BATCH_SIZE = cfg.BATCH_SIZE - LR = cfg.LR # Initial learning rate - EPOCHS = cfg.EPOCHS - val_loss_min = cfg.val_loss_min # Initial high value for validation loss tracking - - # Model initialization - model = Model(inplanes=2, outplanes=1, layers=cfg.network_layers).cuda() - optimizer = torch.optim.Adam(model.parameters(), lr=LR) - loss_func = nn.MSELoss(reduction='mean') - - # Create a CSV file to log training progress - df = pd.DataFrame(columns=['epoch', 'train Loss', 'val Loss', 'learning rate']) - df.to_csv(pcfg.train_val_loss, index=False) - - # Load dataset - dataset = MyDataset(data_file, label_file, impulse_field_file, impulse_sim_file, mode='train', check=False) - train_size = int(len(dataset) * cfg.dataset_Proportion) - val_size = len(dataset) - train_size - train_dataset, val_dataset = torch.utils.data.random_split(dataset, [train_size, val_size]) - - # Create data loaders - train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=cfg.num_workers, pin_memory=True) - val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=cfg.num_workers, pin_memory=True) - - # TensorBoard writer - writer = SummaryWriter(r'Log') - - print('Dataset loaded successfully.') - - for epoch in range(EPOCHS): - epoch_start_time = time.time() - print(f'Starting epoch {epoch}') - - # Adjust learning rate - LR = adjust_learning_rate(optimizer, epoch, LR) - - # Training and validation - train_loss = train(train_loader, model, loss_func, optimizer, epoch) - val_loss = validate(val_loader, model, loss_func) - - print(f'Validation Loss: {val_loss:.6f}') - - # Save best model - if val_loss < val_loss_min: - torch.save(model.state_dict(), pcfg.BEST_MODEL_PATH) - val_loss_min = val_loss - - # Periodically save model - if epoch % cfg.save_period == 0: - torch.save(model.state_dict(), pcfg.LATEST_MODEL_PATH) - - # Log training details - log_data = pd.DataFrame([[epoch, train_loss, val_loss, LR]]) - log_data.to_csv(pcfg.train_val_loss, mode='a', header=False, index=False) - - print(f'Epoch {epoch} completed in {time.time() - epoch_start_time:.2f} seconds.') - -