deepextractor.model¶
High-level DeepExtractor model wrapper for inference.
Module Contents¶
- class deepextractor.model.DeepExtractorModel(checkpoint: str = 'DeepExtractor_257', checkpoint_filename: str = CHECKPOINT_BILBY, checkpoint_dir: str | None = None, device: str | torch.device | None = None, scaler_path: str | None = None, n_fft: int = 512, win_length: int = 64, hop_length: int = 32)[source]¶
High-level wrapper around a pretrained DeepExtractor UNET2D model.
Bundles the PyTorch model, StandardScaler, and STFT parameters into a single object so callers don’t need to manage them separately.
- Parameters:
checkpoint (str) – Model name / checkpoint key. Defaults to
"DeepExtractor_257".checkpoint_filename (str) – Checkpoint file name within the model subdirectory on HuggingFace Hub or local
checkpoint_dir. Defaults toCHECKPOINT_BILBY.checkpoint_dir (str | None) – Local directory to search for checkpoint files before falling back to HuggingFace Hub. Pass
Noneto always use the Hub.device (str | torch.device | None) – Compute device. Auto-detects CUDA if available when
None.scaler_path (str | None) – Path to the scaler
.pklfile. Defaults to the bundledassets/scaler_bilby.pkl.n_fft (int) – STFT FFT size. Default 512.
win_length (int) – STFT window length. Default 64.
hop_length (int) – STFT hop length. Default 32.
- background(noisy_input: numpy.ndarray) numpy.ndarray[source]¶
Estimate the background (noise-only) component.
- Parameters:
noisy_input (np.ndarray) – 1-D array of shape
(T,)or 2-D batch of shape(N, T).- Returns:
Background estimate, same shape as
noisy_input.- Return type:
np.ndarray
- reconstruct(noisy_input: numpy.ndarray) numpy.ndarray[source]¶
Extract the transient signal by subtracting the predicted background.
- Parameters:
noisy_input (np.ndarray) – 1-D array of shape
(T,)or 2-D batch of shape(N, T).- Returns:
Reconstructed signal, same shape as
noisy_input.- Return type:
np.ndarray