Source code for deepextractor.training.train_fn

import torch
import torch.nn as nn
from tqdm import tqdm

from deepextractor.utils.checkpoints import load_checkpoint, load_optimizer, save_checkpoint
from deepextractor.utils.io import check_accuracy, get_loaders
from deepextractor.utils.visualization import save_predictions_as_plots


[docs] def train_fn(loader, model, model_name, optimizer, loss_fn, scaler, device): """Train the model for one epoch and return average losses.""" loop = tqdm(loader, desc="Training on batch") epoch_loss = 0 epoch_noise_loss = 0 epoch_constraint_loss = 0 for batch_idx, (data, targets) in enumerate(loop): data = data.to(device=device) targets = targets.float().to(device=device) autocast_device = "cuda" if device.startswith("cuda") else "cpu" with torch.amp.autocast(autocast_device): predictions = model(data) if model_name == "UNET1D_diff": noise_pred = predictions[:, 0:1, :] residual_pred = predictions[:, 1:2, :] reconstructed = noise_pred + residual_pred constraint_loss = loss_fn(reconstructed, data) noise_loss = loss_fn(noise_pred, targets) loss = constraint_loss + noise_loss epoch_noise_loss += noise_loss.item() epoch_constraint_loss += constraint_loss.item() else: loss = loss_fn(predictions, targets) optimizer.zero_grad() scaler.scale(loss).backward() scaler.step(optimizer) scaler.update() epoch_loss += loss.item() if model_name == "UNET1D_diff": loop.set_postfix( total_loss=loss.item(), constraint_loss=constraint_loss.item(), noise_loss=noise_loss.item(), ) else: loop.set_postfix(loss=loss.item()) avg_loss = epoch_loss / len(loader) avg_noise_loss = epoch_noise_loss / len(loader) if model_name == "UNET1D_diff" else 0 avg_constraint_loss = ( epoch_constraint_loss / len(loader) if model_name == "UNET1D_diff" else 0 ) return avg_loss, avg_noise_loss, avg_constraint_loss