你已经派生过 gpr-sidl-inv
镜像自地址
https://gitee.com/sduem/gpr-sidl-inv.git
已同步 2025-09-19 01:03:50 +08:00
@@ -8,56 +8,80 @@ from torchvision.utils import make_grid, save_image
|
|||||||
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
|
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
|
||||||
from config import Network_train_Config as cfg
|
from config import Network_train_Config as cfg
|
||||||
|
|
||||||
|
|
||||||
# Training function
|
# Training function
|
||||||
def train(train_loader, model, loss_func, optimizer, epoch):
|
def train(train_loader, model, loss_func, optimizer, epoch):
|
||||||
model.train()
|
model.train()
|
||||||
total_loss = 0
|
total_loss_sum = 0.0
|
||||||
|
primary_loss_sum = 0.0
|
||||||
|
cycle_loss_sum = 0.0
|
||||||
batch_count = 0
|
batch_count = 0
|
||||||
|
|
||||||
for data, label in train_loader:
|
for data, label in train_loader:
|
||||||
data = data.cuda()
|
data = data.cuda(non_blocking=True).float()
|
||||||
label = label.cuda()
|
label = label.cuda(non_blocking=True).float()
|
||||||
|
|
||||||
output = model(data.type(torch.cuda.FloatTensor))
|
# Forward pass with cycle; both primary loss and cycle loss are handled in the model
|
||||||
loss = loss_func(output.float(), label.float())
|
pred, recon = model.forward_with_cycle(data)
|
||||||
|
losses = model.compute_loss(data, label, pred, recon) # returns dict with {'total','primary','cycle'}
|
||||||
|
loss = losses['total']
|
||||||
|
|
||||||
total_loss += loss.item()
|
optimizer.zero_grad(set_to_none=True)
|
||||||
optimizer.zero_grad()
|
|
||||||
loss.backward()
|
loss.backward()
|
||||||
optimizer.step()
|
optimizer.step()
|
||||||
|
|
||||||
|
total_loss_sum += loss.item()
|
||||||
|
primary_loss_sum += losses['primary'].item()
|
||||||
|
cycle_loss_sum += losses['cycle'].item()
|
||||||
batch_count += 1
|
batch_count += 1
|
||||||
|
|
||||||
avg_loss = total_loss / batch_count
|
avg_total = total_loss_sum / max(batch_count, 1)
|
||||||
print(f"Epoch {epoch}: Training Loss = {avg_loss:.6f}")
|
avg_primary = primary_loss_sum / max(batch_count, 1)
|
||||||
return avg_loss
|
avg_cycle = cycle_loss_sum / max(batch_count, 1)
|
||||||
|
|
||||||
|
# Print the three loss components (total, primary, cycle) for monitoring balance
|
||||||
|
print(f"Epoch {epoch}: Training -> total={avg_total:.6f} | primary={avg_primary:.6f} | cycle={avg_cycle:.6f}")
|
||||||
|
return avg_total # Keep compatibility with original script: return total loss
|
||||||
|
|
||||||
|
|
||||||
# Validation function
|
# Validation function
|
||||||
def validate(val_loader, model, loss_func):
|
def validate(val_loader, model, loss_func):
|
||||||
model.eval()
|
model.eval()
|
||||||
total_loss = 0
|
total_loss_sum = 0.0
|
||||||
|
primary_loss_sum = 0.0
|
||||||
|
cycle_loss_sum = 0.0
|
||||||
batch_count = 0
|
batch_count = 0
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
for data, label in val_loader:
|
for data, label in val_loader:
|
||||||
data = data.cuda()
|
data = data.cuda(non_blocking=True).float()
|
||||||
label = label.cuda()
|
label = label.cuda(non_blocking=True).float()
|
||||||
|
|
||||||
output = model(data.type(torch.cuda.FloatTensor))
|
pred, recon = model.forward_with_cycle(data)
|
||||||
loss = loss_func(output.float(), label.float())
|
losses = model.compute_loss(data, label, pred, recon)
|
||||||
|
|
||||||
total_loss += loss.item()
|
total_loss_sum += losses['total'].item()
|
||||||
|
primary_loss_sum += losses['primary'].item()
|
||||||
|
cycle_loss_sum += losses['cycle'].item()
|
||||||
batch_count += 1
|
batch_count += 1
|
||||||
|
|
||||||
avg_loss = total_loss / batch_count
|
avg_total = total_loss_sum / max(batch_count, 1)
|
||||||
return avg_loss
|
avg_primary = primary_loss_sum / max(batch_count, 1)
|
||||||
|
avg_cycle = cycle_loss_sum / max(batch_count, 1)
|
||||||
|
|
||||||
|
# Print validation losses
|
||||||
|
print(f"Validation : total={avg_total:.6f} | primary={avg_primary:.6f} | cycle={avg_cycle:.6f}")
|
||||||
|
return avg_total
|
||||||
|
|
||||||
|
|
||||||
# Learning rate scheduler
|
# Learning rate scheduler
|
||||||
def adjust_learning_rate(optimizer, epoch, start_lr):
|
def adjust_learning_rate(optimizer, epoch, start_lr):
|
||||||
"""Exponentially decays learning rate based on epoch and configured decay rate."""
|
"""Exponentially decay learning rate based on epoch and configured decay rate."""
|
||||||
lr = start_lr * (cfg.lr_decrease_rate ** epoch)
|
lr = start_lr * cfg.lr_decrease_rate
|
||||||
for param_group in optimizer.param_groups:
|
for param_group in optimizer.param_groups:
|
||||||
param_group['lr'] = lr
|
param_group['lr'] = lr
|
||||||
return lr
|
return lr
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
在新工单中引用
屏蔽一个用户