文件
gpr-sidl-inv/utils/train_val_lr.py
葛峻恺 f99ce1d024 update utils/train_val_lr.py.
Signed-off-by: 葛峻恺 <202115006@mail.sdu.edu.cn>
2025-09-02 12:06:37 +00:00

88 行
3.0 KiB
Python

import os
import sys
import torch
from torchsummary import summary
from torchvision.utils import make_grid, save_image
# Add parent directory to path for config import
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
from config import Network_train_Config as cfg
# Training function
def train(train_loader, model, loss_func, optimizer, epoch):
model.train()
total_loss_sum = 0.0
primary_loss_sum = 0.0
cycle_loss_sum = 0.0
batch_count = 0
for data, label in train_loader:
data = data.cuda(non_blocking=True).float()
label = label.cuda(non_blocking=True).float()
# Forward pass with cycle; both primary loss and cycle loss are handled in the model
pred, recon = model.forward_with_cycle(data)
losses = model.compute_loss(data, label, pred, recon) # returns dict with {'total','primary','cycle'}
loss = losses['total']
optimizer.zero_grad(set_to_none=True)
loss.backward()
optimizer.step()
total_loss_sum += loss.item()
primary_loss_sum += losses['primary'].item()
cycle_loss_sum += losses['cycle'].item()
batch_count += 1
avg_total = total_loss_sum / max(batch_count, 1)
avg_primary = primary_loss_sum / max(batch_count, 1)
avg_cycle = cycle_loss_sum / max(batch_count, 1)
# Print the three loss components (total, primary, cycle) for monitoring balance
print(f"Epoch {epoch}: Training -> total={avg_total:.6f} | primary={avg_primary:.6f} | cycle={avg_cycle:.6f}")
return avg_total # Keep compatibility with original script: return total loss
# Validation function
def validate(val_loader, model, loss_func):
model.eval()
total_loss_sum = 0.0
primary_loss_sum = 0.0
cycle_loss_sum = 0.0
batch_count = 0
with torch.no_grad():
for data, label in val_loader:
data = data.cuda(non_blocking=True).float()
label = label.cuda(non_blocking=True).float()
pred, recon = model.forward_with_cycle(data)
losses = model.compute_loss(data, label, pred, recon)
total_loss_sum += losses['total'].item()
primary_loss_sum += losses['primary'].item()
cycle_loss_sum += losses['cycle'].item()
batch_count += 1
avg_total = total_loss_sum / max(batch_count, 1)
avg_primary = primary_loss_sum / max(batch_count, 1)
avg_cycle = cycle_loss_sum / max(batch_count, 1)
# Print validation losses
print(f"Validation : total={avg_total:.6f} | primary={avg_primary:.6f} | cycle={avg_cycle:.6f}")
return avg_total
# Learning rate scheduler
def adjust_learning_rate(optimizer, epoch, start_lr):
"""Exponentially decay learning rate based on epoch and configured decay rate."""
lr = start_lr * cfg.lr_decrease_rate
for param_group in optimizer.param_groups:
param_group['lr'] = lr
return lr