Source code for deepextractor.training.trainer

"""
CLI entry point for training DeepExtractor models.

Usage::

    deepextractor-train --model DeepExtractor_257 --data-dir data/pycbc_noise/spectrogram_domain_clean_glitch_129/

"""

import argparse
import logging
import os
from pathlib import Path

import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim.lr_scheduler import ReduceLROnPlateau

from deepextractor.models.architectures import (
    Autoencoder1D,
    Autoencoder2D,
    DnCNN1D,
    ModifiedAutoencoder2D,
    UNET1D,
    UNET2D,
)
from deepextractor.training.train_fn import train_fn
from deepextractor.utils.checkpoints import load_checkpoint, load_optimizer, save_checkpoint
from deepextractor.utils.io import check_accuracy, get_loaders

logging.basicConfig(level=logging.INFO)
[docs] logger = logging.getLogger(__name__)
# Registry of available model architectures
[docs] MODEL_REGISTRY = { "UNET1D": lambda: UNET1D(in_channels=1, out_channels=1), "DnCNN1D": DnCNN1D, "Autoencoder1D": lambda: Autoencoder1D(in_channels=1, out_channels=1), "UNET2D": lambda: UNET2D(in_channels=2, out_channels=2), "UNET2D_glitch_target": lambda: UNET2D(in_channels=2, out_channels=2), "Autoencoder2D": lambda: Autoencoder2D(in_channels=2, out_channels=2), "ModifiedAutoencoder2D": lambda: ModifiedAutoencoder2D(in_channels=2, out_channels=2), "DeepExtractor_65": lambda: UNET2D(in_channels=2, out_channels=2), "DeepExtractor_129": lambda: UNET2D(in_channels=2, out_channels=2), "DeepExtractor_257": lambda: UNET2D(in_channels=2, out_channels=2), }
[docs] def main(): parser = argparse.ArgumentParser( description="Train a DeepExtractor model", formatter_class=argparse.ArgumentDefaultsHelpFormatter, ) parser.add_argument( "--model", choices=list(MODEL_REGISTRY.keys()), default="DeepExtractor_257", help="Model architecture to train.", ) parser.add_argument("--batch-size", type=int, default=32) parser.add_argument("--epochs", type=int, default=150) parser.add_argument("--lr", type=float, default=1e-4) parser.add_argument("--num-workers", type=int, default=2) parser.add_argument( "--time-domain", action="store_true", help="Train on time-domain data instead of spectrograms.", ) parser.add_argument( "--data-dir", type=Path, required=True, help="Directory containing the training .npy arrays.", ) parser.add_argument( "--checkpoint-dir", type=Path, default=Path("checkpoints"), help="Directory to save model checkpoints.", ) parser.add_argument( "--loss-dir", type=Path, default=Path("losses"), help="Directory to save loss arrays.", ) parser.add_argument( "--transfer-learn", action="store_true", help="Resume from an existing checkpoint (transfer learning).", ) parser.add_argument( "--bilby-noise", action="store_true", help="Use bilby noise suffix in checkpoint filenames.", ) parser.add_argument( "--device", default=None, help="Device to use, e.g. 'cuda:0' or 'cpu'. Auto-detected if not set.", ) parser.add_argument( "--early-stopping-patience", type=int, default=9, help="Number of epochs without improvement before stopping.", ) args = parser.parse_args() # --- Device --- if args.device is not None: device = args.device elif torch.cuda.is_available(): device = f"cuda:{torch.cuda.device_count() - 1}" elif torch.backends.mps.is_available(): device = "mps" else: device = "cpu" # --- Model --- model_name = args.model model = MODEL_REGISTRY[model_name]().to(device) logger.info(f"Training {model_name} on {device}") # --- Directories --- noise_ext = "bilby_noise" if args.bilby_noise else "pycbc_noise" tl_ext = "transfer_learn" if args.transfer_learn else "base" model_checkpoint_dir = args.checkpoint_dir / f"{model_name}_checkpoints" model_loss_dir = args.loss_dir / f"{model_name}_{noise_ext}_{tl_ext}_losses" for d in [args.checkpoint_dir, model_checkpoint_dir, args.loss_dir, model_loss_dir]: os.makedirs(d, exist_ok=True) # --- Data loaders --- data_dir = args.data_dir if args.time_domain: train_loader, val_loader = get_loaders( str(data_dir / "glitch_train_scaled.npy"), str(data_dir / "background_train_scaled.npy"), str(data_dir / "glitch_val_scaled.npy"), str(data_dir / "background_val_scaled.npy"), args.batch_size, None, None, args.num_workers, True, time_domain=True, ) else: train_loader, val_loader = get_loaders( str(data_dir / "glitch_train_scaled_mag_phase.npy"), str(data_dir / "background_train_scaled_mag_phase.npy"), str(data_dir / "glitch_val_scaled_mag_phase.npy"), str(data_dir / "background_val_scaled_mag_phase.npy"), batch_size=args.batch_size, train_transform=None, val_transform=None, num_workers=args.num_workers, pin_memory=True, time_domain=False, ) # --- Optimizer and scheduler --- loss_fn = nn.MSELoss() optimizer = optim.Adam(model.parameters(), lr=args.lr) scheduler = ReduceLROnPlateau(optimizer, mode="min", factor=0.1, patience=4) start_epoch = 0 if args.transfer_learn: checkpoint_path = model_checkpoint_dir / "checkpoint_best.pth.tar" try: logger.info("Loading model checkpoint...") checkpoint = torch.load(str(checkpoint_path)) load_checkpoint(checkpoint, model) load_optimizer(checkpoint, optimizer) start_epoch = checkpoint.get("epoch", start_epoch) except Exception as e: logger.error(f"Failed to load checkpoint: {e}") return # --- Training loop --- train_losses, train_noise_losses, train_constraint_losses = [], [], [] val_losses, val_noise_losses, val_constraint_losses = [], [], [] scaler = torch.amp.GradScaler("cuda") if device.startswith("cuda") else torch.amp.GradScaler("cpu") early_stopping_counter = 0 best_val_loss = float("inf") for epoch in range(start_epoch, start_epoch + args.epochs): logger.info(f"Epoch {epoch + 1}/{args.epochs}") train_loss, train_noise_loss, train_constraint_loss = train_fn( train_loader, model, model_name, optimizer, loss_fn, scaler, device ) train_losses.append(train_loss) train_noise_losses.append(train_noise_loss) train_constraint_losses.append(train_constraint_loss) val_loss, val_noise_loss, val_constraint_loss = check_accuracy( val_loader, model, model_name, device=device ) val_losses.append(val_loss) val_noise_losses.append(val_noise_loss) val_constraint_losses.append(val_constraint_loss) scheduler.step(val_loss) current_lr = scheduler.optimizer.param_groups[0]["lr"] logger.info(f"Current learning rate: {current_lr}") if val_loss < best_val_loss: best_val_loss = val_loss early_stopping_counter = 0 checkpoint = { "state_dict": model.state_dict(), "optimizer": optimizer.state_dict(), "scheduler": scheduler.state_dict(), "epoch": epoch, } ckpt_filename = ( model_checkpoint_dir / f"checkpoint_best_{noise_ext}_{tl_ext}.pth.tar" ) save_checkpoint(checkpoint, str(ckpt_filename)) logger.info(f"Validation loss improved to {best_val_loss:.6}. Checkpoint saved.") else: early_stopping_counter += 1 logger.info( f"No improvement. Early stopping counter: " f"{early_stopping_counter}/{args.early_stopping_patience}" ) if early_stopping_counter >= args.early_stopping_patience: logger.info(f"Early stopping after {epoch + 1} epochs.") break if epoch % 10 == 0: _save_losses( model_loss_dir, start_epoch, epoch, train_losses, train_noise_losses, train_constraint_losses, val_losses, val_noise_losses, val_constraint_losses, ) _save_losses( model_loss_dir, start_epoch, epoch, train_losses, train_noise_losses, train_constraint_losses, val_losses, val_noise_losses, val_constraint_losses, ) logger.info("Training complete. Losses saved.")
def _save_losses( loss_dir, start_epoch, epoch, train_losses, train_noise_losses, train_constraint_losses, val_losses, val_noise_losses, val_constraint_losses, ): np.save(str(loss_dir / f"train_losses_epoch_{start_epoch}_to_{epoch}.npy"), np.array(train_losses)) np.save(str(loss_dir / f"train_noise_losses_{start_epoch}_to_{epoch}.npy"), np.array(train_noise_losses)) np.save(str(loss_dir / f"train_constraint_losses_{start_epoch}_to_{epoch}.npy"), np.array(train_constraint_losses)) np.save(str(loss_dir / f"val_losses_{start_epoch}_to_{epoch}.npy"), np.array(val_losses)) np.save(str(loss_dir / f"val_noise_losses_{start_epoch}_to_{epoch}.npy"), np.array(val_noise_losses)) np.save(str(loss_dir / f"val_constraint_losses_{start_epoch}_to_{epoch}.npy"), np.array(val_constraint_losses)) if __name__ == "__main__": main()