你已经派生过 gpr-sidl-inv
镜像自地址
https://gitee.com/sduem/gpr-sidl-inv.git
已同步 2025-08-03 10:56:50 +08:00
63
utils/train_val_lr.py
普通文件
63
utils/train_val_lr.py
普通文件
@@ -0,0 +1,63 @@
|
||||
import os
|
||||
import sys
|
||||
import torch
|
||||
from torchsummary import summary
|
||||
from torchvision.utils import make_grid, save_image
|
||||
|
||||
# Add parent directory to path for config import
|
||||
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
|
||||
from config import Network_train_Config as cfg
|
||||
|
||||
# Training function
|
||||
def train(train_loader, model, loss_func, optimizer, epoch):
|
||||
model.train()
|
||||
total_loss = 0
|
||||
batch_count = 0
|
||||
|
||||
for data, label in train_loader:
|
||||
data = data.cuda()
|
||||
label = label.cuda()
|
||||
|
||||
output = model(data.type(torch.cuda.FloatTensor))
|
||||
loss = loss_func(output.float(), label.float())
|
||||
|
||||
total_loss += loss.item()
|
||||
optimizer.zero_grad()
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
batch_count += 1
|
||||
|
||||
avg_loss = total_loss / batch_count
|
||||
print(f"Epoch {epoch}: Training Loss = {avg_loss:.6f}")
|
||||
return avg_loss
|
||||
|
||||
# Validation function
|
||||
def validate(val_loader, model, loss_func):
|
||||
model.eval()
|
||||
total_loss = 0
|
||||
batch_count = 0
|
||||
|
||||
with torch.no_grad():
|
||||
for data, label in val_loader:
|
||||
data = data.cuda()
|
||||
label = label.cuda()
|
||||
|
||||
output = model(data.type(torch.cuda.FloatTensor))
|
||||
loss = loss_func(output.float(), label.float())
|
||||
|
||||
total_loss += loss.item()
|
||||
batch_count += 1
|
||||
|
||||
avg_loss = total_loss / batch_count
|
||||
return avg_loss
|
||||
|
||||
# Learning rate scheduler
|
||||
def adjust_learning_rate(optimizer, epoch, start_lr):
|
||||
"""Exponentially decays learning rate based on epoch and configured decay rate."""
|
||||
lr = start_lr * (cfg.lr_decrease_rate ** epoch)
|
||||
for param_group in optimizer.param_groups:
|
||||
param_group['lr'] = lr
|
||||
return lr
|
||||
|
||||
|
||||
|
在新工单中引用
屏蔽一个用户