import os import sys import torch from torchsummary import summary from torchvision.utils import make_grid, save_image # Add parent directory to path for config import sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) from config import Network_train_Config as cfg # Training function def train(train_loader, model, loss_func, optimizer, epoch): model.train() total_loss = 0 batch_count = 0 for data, label in train_loader: data = data.cuda() label = label.cuda() output = model(data.type(torch.cuda.FloatTensor)) loss = loss_func(output.float(), label.float()) total_loss += loss.item() optimizer.zero_grad() loss.backward() optimizer.step() batch_count += 1 avg_loss = total_loss / batch_count print(f"Epoch {epoch}: Training Loss = {avg_loss:.6f}") return avg_loss # Validation function def validate(val_loader, model, loss_func): model.eval() total_loss = 0 batch_count = 0 with torch.no_grad(): for data, label in val_loader: data = data.cuda() label = label.cuda() output = model(data.type(torch.cuda.FloatTensor)) loss = loss_func(output.float(), label.float()) total_loss += loss.item() batch_count += 1 avg_loss = total_loss / batch_count return avg_loss # Learning rate scheduler def adjust_learning_rate(optimizer, epoch, start_lr): """Exponentially decays learning rate based on epoch and configured decay rate.""" lr = start_lr * (cfg.lr_decrease_rate ** epoch) for param_group in optimizer.param_groups: param_group['lr'] = lr return lr