Source code for deepextractor.api

"""
Top-level convenience functions for DeepExtractor inference.

For one-shot use. For repeated inference on many signals, instantiate
:class:`DeepExtractorModel` or :class:`DeepExtractorSeparator` directly
to amortise the model load cost.
"""

from typing import Union
from pathlib import Path

import numpy as np

from deepextractor.model import DeepExtractorModel, DeepExtractorSeparator, SeparationResult
from deepextractor.utils.checkpoints import CHECKPOINT_BILBY


[docs] def reconstruct( noisy_input: np.ndarray, checkpoint: str = "DeepExtractor_257", checkpoint_filename: str = CHECKPOINT_BILBY, checkpoint_dir: str | None = None, device: str | None = None, scaler_path: str | None = None, ) -> np.ndarray: """ Extract the transient signal from a noisy gravitational-wave strain. Loads a DeepExtractor model, runs inference, and returns the reconstructed signal. For repeated calls, prefer instantiating :class:`DeepExtractorModel` directly to avoid reloading weights on each call. Parameters ---------- noisy_input : np.ndarray 1-D array of shape ``(T,)`` or 2-D batch of shape ``(N, T)``. checkpoint : str Model name. Default ``"DeepExtractor_257"``. checkpoint_filename : str Checkpoint filename. Defaults to the bilby-noise checkpoint. checkpoint_dir : str | None Local checkpoint directory. Falls back to HuggingFace Hub if None. device : str | None Torch device string. Auto-detected if None. scaler_path : str | None Path to scaler .pkl. Uses bundled asset if None. Returns ------- np.ndarray Reconstructed signal, same shape as ``noisy_input``. """ model = DeepExtractorModel( checkpoint=checkpoint, checkpoint_filename=checkpoint_filename, checkpoint_dir=checkpoint_dir, device=device, scaler_path=scaler_path, ) return model.reconstruct(noisy_input)
# `extract` and `reconstruct` are synonyms at the API level.
[docs] extract = reconstruct
[docs] def separate( h1: np.ndarray, l1: np.ndarray, checkpoint_path: Union[str, Path], scaler=None, device: str | None = None, model_kwargs: dict | None = None, ) -> SeparationResult: """Separate H1 and L1 strain into signal and background in the time domain. Loads a :class:`DeepExtractorSeparator`, runs inference, and returns the separated components. For repeated calls, instantiate :class:`DeepExtractorSeparator` directly to avoid reloading weights. Parameters ---------- h1 : np.ndarray H1 strain. Shape ``(T,)`` or ``(N, T)``. l1 : np.ndarray L1 strain. Same shape as ``h1``. checkpoint_path : str | Path Path to the ``.pth.tar`` checkpoint saved during training. scaler : ChannelStandardScaler | str | Path | None Per-channel input scaler. Pass a fitted :class:`~deepextractor.data.ChannelStandardScaler` instance, a path to a pickled scaler, or ``None`` to skip scaling. device : str | None Torch device string. Auto-detected if None. model_kwargs : dict | None Override keyword arguments forwarded to :class:`UNET1D_LSTM_ATT`. Returns ------- SeparationResult Named tuple with fields ``h1_signal``, ``l1_signal``, ``h1_background``, ``l1_background``. """ separator = DeepExtractorSeparator( checkpoint_path=checkpoint_path, scaler=scaler, device=device, model_kwargs=model_kwargs, ) return separator.separate(h1, l1)