deepextractor.training.train_fn =============================== .. py:module:: deepextractor.training.train_fn Module Contents --------------- .. py:function:: train_fn(loader, model, model_name, optimizer, loss_fn, scaler, device) Train the model for one epoch and return average losses. .. py:function:: train_fn_td(loader, model, optimizer, loss_fn, scaler, device, *, residual_channels: bool = False, residual_weight: float = 1.0, residual_mode: str = 'true', use_amp: bool = False) Train the two-detector time-domain separation model for one epoch. Expects the DataLoader to yield: data — (B, 2, T) H1+L1 strain (standard-scaled inputs) targets — (B, 4, T) [bg_H1, bg_L1, sig_H1, sig_L1] (whitened) :param residual_channels: If True, model outputs 6 channels; the extra 2 are a residual term that enforces input reconstruction. :param residual_weight: Weight applied to the residual loss term. :param residual_mode: How the residual loss is computed — "true" : residual target = data - (tgt_bg + tgt_sig) "sum" : loss on (pred_bg + pred_sig + pred_res) vs data "sum_detach" : same but bg+sig gradients are detached :param use_amp: Enable mixed-precision autocast. Disabled by default — Snellius training found AMP unstable with whitened targets. :returns: (avg_total, avg_bg, avg_sig) if residual_channels=False (avg_total, avg_bg, avg_sig, avg_res) if residual_channels=True .. py:function:: eval_fn_td(loader, model, device, *, residual_channels: bool = False, residual_weight: float = 1.0) -> tuple Evaluate the two-detector time-domain separation model on a validation set. Mirrors the signature of :func:`train_fn_td` but runs under ``torch.no_grad`` and does not update weights. Model is restored to train mode afterwards. :returns: (avg_total, avg_bg, avg_sig) if residual_channels=False (avg_total, avg_bg, avg_sig, avg_res) if residual_channels=True