import os import sys import torch from torchsummary import summary from torchvision.utils import make_grid, save_image # Add parent directory to path for config import sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) from config import Network_train_Config as cfg # Training function def train(train_loader, model, loss_func, optimizer, epoch): model.train() total_loss_sum = 0.0 primary_loss_sum = 0.0 cycle_loss_sum = 0.0 batch_count = 0 for data, label in train_loader: data = data.cuda(non_blocking=True).float() label = label.cuda(non_blocking=True).float() # Forward pass with cycle; both primary loss and cycle loss are handled in the model pred, recon = model.forward_with_cycle(data) losses = model.compute_loss(data, label, pred, recon) # returns dict with {'total','primary','cycle'} loss = losses['total'] optimizer.zero_grad(set_to_none=True) loss.backward() optimizer.step() total_loss_sum += loss.item() primary_loss_sum += losses['primary'].item() cycle_loss_sum += losses['cycle'].item() batch_count += 1 avg_total = total_loss_sum / max(batch_count, 1) avg_primary = primary_loss_sum / max(batch_count, 1) avg_cycle = cycle_loss_sum / max(batch_count, 1) # Print the three loss components (total, primary, cycle) for monitoring balance print(f"Epoch {epoch}: Training -> total={avg_total:.6f} | primary={avg_primary:.6f} | cycle={avg_cycle:.6f}") return avg_total # Keep compatibility with original script: return total loss # Validation function def validate(val_loader, model, loss_func): model.eval() total_loss_sum = 0.0 primary_loss_sum = 0.0 cycle_loss_sum = 0.0 batch_count = 0 with torch.no_grad(): for data, label in val_loader: data = data.cuda(non_blocking=True).float() label = label.cuda(non_blocking=True).float() pred, recon = model.forward_with_cycle(data) losses = model.compute_loss(data, label, pred, recon) total_loss_sum += losses['total'].item() primary_loss_sum += losses['primary'].item() cycle_loss_sum += losses['cycle'].item() batch_count += 1 avg_total = total_loss_sum / max(batch_count, 1) avg_primary = primary_loss_sum / max(batch_count, 1) avg_cycle = cycle_loss_sum / max(batch_count, 1) # Print validation losses print(f"Validation : total={avg_total:.6f} | primary={avg_primary:.6f} | cycle={avg_cycle:.6f}") return avg_total # Learning rate scheduler def adjust_learning_rate(optimizer, epoch, start_lr): """Exponentially decay learning rate based on epoch and configured decay rate.""" lr = start_lr * cfg.lr_decrease_rate for param_group in optimizer.param_groups: param_group['lr'] = lr return lr