deepextractor.training.train_fn

Module Contents

deepextractor.training.train_fn.train_fn(loader, model, model_name, optimizer, loss_fn, scaler, device)[source]

Train the model for one epoch and return average losses.

deepextractor.training.train_fn.train_fn_td(loader, model, optimizer, loss_fn, scaler, device, *, residual_channels: bool = False, residual_weight: float = 1.0, residual_mode: str = 'true', use_amp: bool = False)[source]

Train the two-detector time-domain separation model for one epoch.

Expects the DataLoader to yield:

data — (B, 2, T) H1+L1 strain (standard-scaled inputs) targets — (B, 4, T) [bg_H1, bg_L1, sig_H1, sig_L1] (whitened)

Parameters:
  • residual_channels – If True, model outputs 6 channels; the extra 2 are a residual term that enforces input reconstruction.

  • residual_weight – Weight applied to the residual loss term.

  • residual_mode – How the residual loss is computed — “true” : residual target = data - (tgt_bg + tgt_sig) “sum” : loss on (pred_bg + pred_sig + pred_res) vs data “sum_detach” : same but bg+sig gradients are detached

  • use_amp – Enable mixed-precision autocast. Disabled by default — Snellius training found AMP unstable with whitened targets.

Returns:

(avg_total, avg_bg, avg_sig) if residual_channels=False (avg_total, avg_bg, avg_sig, avg_res) if residual_channels=True

deepextractor.training.train_fn.eval_fn_td(loader, model, device, *, residual_channels: bool = False, residual_weight: float = 1.0) tuple[source]

Evaluate the two-detector time-domain separation model on a validation set.

Mirrors the signature of train_fn_td() but runs under torch.no_grad and does not update weights. Model is restored to train mode afterwards.

Returns:

(avg_total, avg_bg, avg_sig) if residual_channels=False (avg_total, avg_bg, avg_sig, avg_res) if residual_channels=True