From f99ce1d024e8050ba6e274504f77c21c7136a178 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=91=9B=E5=B3=BB=E6=81=BA?= <202115006@mail.sdu.edu.cn> Date: Tue, 2 Sep 2025 12:06:37 +0000 Subject: [PATCH] update utils/train_val_lr.py. MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: 葛峻恺 <202115006@mail.sdu.edu.cn> --- utils/train_val_lr.py | 64 +++++++++++++++++++++++++++++-------------- 1 file changed, 44 insertions(+), 20 deletions(-) diff --git a/utils/train_val_lr.py b/utils/train_val_lr.py index c7ec6dc..3f7ead1 100644 --- a/utils/train_val_lr.py +++ b/utils/train_val_lr.py @@ -8,56 +8,80 @@ from torchvision.utils import make_grid, save_image sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) from config import Network_train_Config as cfg + # Training function def train(train_loader, model, loss_func, optimizer, epoch): model.train() - total_loss = 0 + total_loss_sum = 0.0 + primary_loss_sum = 0.0 + cycle_loss_sum = 0.0 batch_count = 0 for data, label in train_loader: - data = data.cuda() - label = label.cuda() + data = data.cuda(non_blocking=True).float() + label = label.cuda(non_blocking=True).float() - output = model(data.type(torch.cuda.FloatTensor)) - loss = loss_func(output.float(), label.float()) + # Forward pass with cycle; both primary loss and cycle loss are handled in the model + pred, recon = model.forward_with_cycle(data) + losses = model.compute_loss(data, label, pred, recon) # returns dict with {'total','primary','cycle'} + loss = losses['total'] - total_loss += loss.item() - optimizer.zero_grad() + optimizer.zero_grad(set_to_none=True) loss.backward() optimizer.step() + + total_loss_sum += loss.item() + primary_loss_sum += losses['primary'].item() + cycle_loss_sum += losses['cycle'].item() batch_count += 1 - avg_loss = total_loss / batch_count - print(f"Epoch {epoch}: Training Loss = {avg_loss:.6f}") - return avg_loss + avg_total = total_loss_sum / max(batch_count, 1) + avg_primary = primary_loss_sum / max(batch_count, 1) + avg_cycle = cycle_loss_sum / max(batch_count, 1) + + # Print the three loss components (total, primary, cycle) for monitoring balance + print(f"Epoch {epoch}: Training -> total={avg_total:.6f} | primary={avg_primary:.6f} | cycle={avg_cycle:.6f}") + return avg_total # Keep compatibility with original script: return total loss + # Validation function def validate(val_loader, model, loss_func): model.eval() - total_loss = 0 + total_loss_sum = 0.0 + primary_loss_sum = 0.0 + cycle_loss_sum = 0.0 batch_count = 0 with torch.no_grad(): for data, label in val_loader: - data = data.cuda() - label = label.cuda() + data = data.cuda(non_blocking=True).float() + label = label.cuda(non_blocking=True).float() - output = model(data.type(torch.cuda.FloatTensor)) - loss = loss_func(output.float(), label.float()) + pred, recon = model.forward_with_cycle(data) + losses = model.compute_loss(data, label, pred, recon) - total_loss += loss.item() + total_loss_sum += losses['total'].item() + primary_loss_sum += losses['primary'].item() + cycle_loss_sum += losses['cycle'].item() batch_count += 1 - avg_loss = total_loss / batch_count - return avg_loss + avg_total = total_loss_sum / max(batch_count, 1) + avg_primary = primary_loss_sum / max(batch_count, 1) + avg_cycle = cycle_loss_sum / max(batch_count, 1) + + # Print validation losses + print(f"Validation : total={avg_total:.6f} | primary={avg_primary:.6f} | cycle={avg_cycle:.6f}") + return avg_total + # Learning rate scheduler def adjust_learning_rate(optimizer, epoch, start_lr): - """Exponentially decays learning rate based on epoch and configured decay rate.""" - lr = start_lr * (cfg.lr_decrease_rate ** epoch) + """Exponentially decay learning rate based on epoch and configured decay rate.""" + lr = start_lr * cfg.lr_decrease_rate for param_group in optimizer.param_groups: param_group['lr'] = lr return lr +