deepextractor.training.train_fn¶
Module Contents¶
- deepextractor.training.train_fn.train_fn(loader, model, model_name, optimizer, loss_fn, scaler, device)[source]¶
Train the model for one epoch and return average losses.
- deepextractor.training.train_fn.train_fn_td(loader, model, optimizer, loss_fn, scaler, device, *, residual_channels: bool = False, residual_weight: float = 1.0, residual_mode: str = 'true', use_amp: bool = False)[source]¶
Train the two-detector time-domain separation model for one epoch.
- Expects the DataLoader to yield:
data — (B, 2, T) H1+L1 strain (standard-scaled inputs) targets — (B, 4, T) [bg_H1, bg_L1, sig_H1, sig_L1] (whitened)
- Parameters:
residual_channels – If True, model outputs 6 channels; the extra 2 are a residual term that enforces input reconstruction.
residual_weight – Weight applied to the residual loss term.
residual_mode – How the residual loss is computed — “true” : residual target = data - (tgt_bg + tgt_sig) “sum” : loss on (pred_bg + pred_sig + pred_res) vs data “sum_detach” : same but bg+sig gradients are detached
use_amp – Enable mixed-precision autocast. Disabled by default — Snellius training found AMP unstable with whitened targets.
- Returns:
(avg_total, avg_bg, avg_sig) if residual_channels=False (avg_total, avg_bg, avg_sig, avg_res) if residual_channels=True
- deepextractor.training.train_fn.eval_fn_td(loader, model, device, *, residual_channels: bool = False, residual_weight: float = 1.0) tuple[source]¶
Evaluate the two-detector time-domain separation model on a validation set.
Mirrors the signature of
train_fn_td()but runs undertorch.no_gradand does not update weights. Model is restored to train mode afterwards.- Returns:
(avg_total, avg_bg, avg_sig) if residual_channels=False (avg_total, avg_bg, avg_sig, avg_res) if residual_channels=True