文件
gpr-sidl-inv/utils/train_val_lr.py
葛峻恺 8f4f8347de program
Signed-off-by: 葛峻恺 <202115006@mail.sdu.edu.cn>
2025-04-07 12:18:37 +00:00

64 行
1.7 KiB
Python

import os
import sys
import torch
from torchsummary import summary
from torchvision.utils import make_grid, save_image
# Add parent directory to path for config import
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
from config import Network_train_Config as cfg
# Training function
def train(train_loader, model, loss_func, optimizer, epoch):
model.train()
total_loss = 0
batch_count = 0
for data, label in train_loader:
data = data.cuda()
label = label.cuda()
output = model(data.type(torch.cuda.FloatTensor))
loss = loss_func(output.float(), label.float())
total_loss += loss.item()
optimizer.zero_grad()
loss.backward()
optimizer.step()
batch_count += 1
avg_loss = total_loss / batch_count
print(f"Epoch {epoch}: Training Loss = {avg_loss:.6f}")
return avg_loss
# Validation function
def validate(val_loader, model, loss_func):
model.eval()
total_loss = 0
batch_count = 0
with torch.no_grad():
for data, label in val_loader:
data = data.cuda()
label = label.cuda()
output = model(data.type(torch.cuda.FloatTensor))
loss = loss_func(output.float(), label.float())
total_loss += loss.item()
batch_count += 1
avg_loss = total_loss / batch_count
return avg_loss
# Learning rate scheduler
def adjust_learning_rate(optimizer, epoch, start_lr):
"""Exponentially decays learning rate based on epoch and configured decay rate."""
lr = start_lr * (cfg.lr_decrease_rate ** epoch)
for param_group in optimizer.param_groups:
param_group['lr'] = lr
return lr