diff --git a/utils/train_val_lr.py b/utils/train_val_lr.py index c7ec6dc..3f7ead1 100644 --- a/utils/train_val_lr.py +++ b/utils/train_val_lr.py @@ -8,56 +8,80 @@ from torchvision.utils import make_grid, save_image sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) from config import Network_train_Config as cfg + # Training function def train(train_loader, model, loss_func, optimizer, epoch): model.train() - total_loss = 0 + total_loss_sum = 0.0 + primary_loss_sum = 0.0 + cycle_loss_sum = 0.0 batch_count = 0 for data, label in train_loader: - data = data.cuda() - label = label.cuda() + data = data.cuda(non_blocking=True).float() + label = label.cuda(non_blocking=True).float() - output = model(data.type(torch.cuda.FloatTensor)) - loss = loss_func(output.float(), label.float()) + # Forward pass with cycle; both primary loss and cycle loss are handled in the model + pred, recon = model.forward_with_cycle(data) + losses = model.compute_loss(data, label, pred, recon) # returns dict with {'total','primary','cycle'} + loss = losses['total'] - total_loss += loss.item() - optimizer.zero_grad() + optimizer.zero_grad(set_to_none=True) loss.backward() optimizer.step() + + total_loss_sum += loss.item() + primary_loss_sum += losses['primary'].item() + cycle_loss_sum += losses['cycle'].item() batch_count += 1 - avg_loss = total_loss / batch_count - print(f"Epoch {epoch}: Training Loss = {avg_loss:.6f}") - return avg_loss + avg_total = total_loss_sum / max(batch_count, 1) + avg_primary = primary_loss_sum / max(batch_count, 1) + avg_cycle = cycle_loss_sum / max(batch_count, 1) + + # Print the three loss components (total, primary, cycle) for monitoring balance + print(f"Epoch {epoch}: Training -> total={avg_total:.6f} | primary={avg_primary:.6f} | cycle={avg_cycle:.6f}") + return avg_total # Keep compatibility with original script: return total loss + # Validation function def validate(val_loader, model, loss_func): model.eval() - total_loss = 0 + total_loss_sum = 0.0 + primary_loss_sum = 0.0 + cycle_loss_sum = 0.0 batch_count = 0 with torch.no_grad(): for data, label in val_loader: - data = data.cuda() - label = label.cuda() + data = data.cuda(non_blocking=True).float() + label = label.cuda(non_blocking=True).float() - output = model(data.type(torch.cuda.FloatTensor)) - loss = loss_func(output.float(), label.float()) + pred, recon = model.forward_with_cycle(data) + losses = model.compute_loss(data, label, pred, recon) - total_loss += loss.item() + total_loss_sum += losses['total'].item() + primary_loss_sum += losses['primary'].item() + cycle_loss_sum += losses['cycle'].item() batch_count += 1 - avg_loss = total_loss / batch_count - return avg_loss + avg_total = total_loss_sum / max(batch_count, 1) + avg_primary = primary_loss_sum / max(batch_count, 1) + avg_cycle = cycle_loss_sum / max(batch_count, 1) + + # Print validation losses + print(f"Validation : total={avg_total:.6f} | primary={avg_primary:.6f} | cycle={avg_cycle:.6f}") + return avg_total + # Learning rate scheduler def adjust_learning_rate(optimizer, epoch, start_lr): - """Exponentially decays learning rate based on epoch and configured decay rate.""" - lr = start_lr * (cfg.lr_decrease_rate ** epoch) + """Exponentially decay learning rate based on epoch and configured decay rate.""" + lr = start_lr * cfg.lr_decrease_rate for param_group in optimizer.param_groups: param_group['lr'] = lr return lr +