deepextractor.model

High-level DeepExtractor model wrapper for inference.

Module Contents

deepextractor.model.logger[source]
class deepextractor.model.DeepExtractorModel(checkpoint: str = 'DeepExtractor_257', checkpoint_filename: str = CHECKPOINT_BILBY, checkpoint_dir: str | None = None, device: str | torch.device | None = None, scaler_path: str | None = None, n_fft: int = 512, win_length: int = 64, hop_length: int = 32)[source]

High-level wrapper around a pretrained DeepExtractor UNET2D model.

Bundles the PyTorch model, StandardScaler, and STFT parameters into a single object so callers don’t need to manage them separately.

Parameters:
  • checkpoint (str) – Model name / checkpoint key. Defaults to "DeepExtractor_257".

  • checkpoint_filename (str) – Checkpoint file name within the model subdirectory on HuggingFace Hub or local checkpoint_dir. Defaults to CHECKPOINT_BILBY.

  • checkpoint_dir (str | None) – Local directory to search for checkpoint files before falling back to HuggingFace Hub. Pass None to always use the Hub.

  • device (str | torch.device | None) – Compute device. Auto-detects CUDA if available when None.

  • scaler_path (str | None) – Path to the scaler .pkl file. Defaults to the bundled assets/scaler_bilby.pkl.

  • n_fft (int) – STFT FFT size. Default 512.

  • win_length (int) – STFT window length. Default 64.

  • hop_length (int) – STFT hop length. Default 32.

device[source]
n_fft = 512[source]
win_length = 64[source]
hop_length = 32[source]
background(noisy_input: numpy.ndarray) numpy.ndarray[source]

Estimate the background (noise-only) component.

Parameters:

noisy_input (np.ndarray) – 1-D array of shape (T,) or 2-D batch of shape (N, T).

Returns:

Background estimate, same shape as noisy_input.

Return type:

np.ndarray

reconstruct(noisy_input: numpy.ndarray) numpy.ndarray[source]

Extract the transient signal by subtracting the predicted background.

Parameters:

noisy_input (np.ndarray) – 1-D array of shape (T,) or 2-D batch of shape (N, T).

Returns:

Reconstructed signal, same shape as noisy_input.

Return type:

np.ndarray