Source code for deepextractor.evaluation.simulate

"""
Simulated evaluation of DeepExtractor models.

Usage::

    deepextractor-evaluate --model DeepExtractor_257 --checkpoint-dir checkpoints/ \\
        --data-dir data/ --output-dir evaluation/

"""

import argparse
import logging
import os
import pickle
import random

import numpy as np
import torch
from tqdm import tqdm

from deepextractor.generation.glitch_functions import (
    generate_cdvgan_glitch,
    generate_chirp,
    generate_gaussian_pulse,
    generate_gengli_glitch,
    generate_sine,
    generate_sine_gaussian,
    ringdown,
)
from deepextractor.models.architectures import (
    Autoencoder1D,
    Autoencoder2D,
    DnCNN1D,
    ModifiedAutoencoder2D,
    UNET1D,
    UNET2D,
)
from deepextractor.utils.checkpoints import CHECKPOINT_BILBY, load_torch_model
from deepextractor.utils.io import load_tf_model
from deepextractor.utils.signal import generate_gaussian_noise, whitened_snr_scaling

logging.basicConfig(level=logging.INFO)
[docs] logger = logging.getLogger(__name__)
[docs] SAMPLE_RATE = 4096
[docs] NYQUIST_FREQ = SAMPLE_RATE // 2
T_MIN, T_MAX = 0.125, 2
[docs] T = 2.0
[docs] T_INJ = T / 2
[docs] LENGTH = int(T * SAMPLE_RATE)
[docs] MODEL_REGISTRY = { "UNET1D": UNET1D(in_channels=1, out_channels=1), "UNET1D_glitch": UNET1D(in_channels=1, out_channels=1), "UNET1D_diff": UNET1D(in_channels=1, out_channels=2), "UNET1D_2channel": UNET1D(in_channels=1, out_channels=2), "DnCNN1D": DnCNN1D(), "Autoencoder1D": Autoencoder1D(in_channels=1, out_channels=1), "Autoencoder2D": Autoencoder2D(in_channels=2, out_channels=2), "ModifiedAutoencoder2D": ModifiedAutoencoder2D(in_channels=2, out_channels=2), "DeepExtractor_65": UNET2D(in_channels=2, out_channels=2), "DeepExtractor_129": UNET2D(in_channels=2, out_channels=2), "DeepExtractor_257": UNET2D(in_channels=2, out_channels=2), }
def _get_stft_params(model_name): if model_name in ("UNET2D_noise", "ModifiedAutoencoder2D"): n_fft = 256 win_length = n_fft // 2 hop_length = win_length // 2 elif model_name == "UNET2D_65_noise": n_fft = 256 // 2 win_length = n_fft hop_length = win_length // 2 else: n_fft = 256 * 2 win_length = n_fft // 8 hop_length = win_length // 2 window = torch.hann_window(win_length) return n_fft, hop_length, win_length, window from deepextractor.generation.generate_timeseries import SNR_SCALING_FACTOR_BILBY from deepextractor.utils.stft import apply_stft, apply_istft # noqa: F401
[docs] def prepare_data_for_stft(data, scaler, n_fft, hop_length, win_length, window): noisy_glitch_ts = np.asarray(data["noisy_glitch_ts"]) pure_noise_ts = np.asarray(data["pure_noise_ts"]) noisy_glitch_scaled = scaler.transform( noisy_glitch_ts.reshape(-1, 1) ).reshape(noisy_glitch_ts.shape) pure_noise_scaled = scaler.transform( pure_noise_ts.reshape(-1, 1) ).reshape(pure_noise_ts.shape) noisy_stft = apply_stft(noisy_glitch_scaled, n_fft, hop_length, win_length, window) return noisy_stft, noisy_glitch_scaled, pure_noise_scaled
[docs] def calculate_match(true_signal, predicted_signal, sample_rate=SAMPLE_RATE): from pycbc.filter.matchedfilter import match from pycbc.types import TimeSeries as TimeSeries_pycbc true_ts = TimeSeries_pycbc(true_signal, delta_t=1.0 / sample_rate) pred_ts = TimeSeries_pycbc(predicted_signal, delta_t=1.0 / sample_rate, dtype="double") return match(true_ts, pred_ts)[0]
[docs] def generate_glitch_data(signal_type, gaussian_noise_samples, signal_function_map, snr_min=7.5, snr_max=100, bilby_noise=False): glitches_ts, clean_glitch_subtract, noisy_glitch_ts, pure_noise_ts, snrs_inj = ( [], [], [], [], [] ) for noise_sample in tqdm(gaussian_noise_samples): snr_to_scale = np.random.uniform(snr_min, snr_max) background = noise_sample.copy() noisy_glitch = background.copy() if signal_type in ("chirp", "sine", "sine_gaussian", "gaussian_pulse", "ringdown"): duration = np.random.uniform(T_MIN, T_MAX) _, signal_injection = signal_function_map[signal_type](duration) else: _, signal_injection = signal_function_map[signal_type]() signal_injection = signal_injection.squeeze() if np.isnan(signal_injection).any(): continue len_glitch = len(signal_injection) id_start = int((T_INJ * SAMPLE_RATE / LENGTH) * len(noise_sample)) - len_glitch // 2 glitch = signal_injection - np.mean(signal_injection) effective_snr = snr_to_scale / SNR_SCALING_FACTOR_BILBY if bilby_noise else snr_to_scale glitch = whitened_snr_scaling(glitch, snr=effective_snr) noisy_glitch[id_start : id_start + len_glitch] += glitch clean_glitch = noisy_glitch - background glitches_ts.append(glitch) clean_glitch_subtract.append(clean_glitch) noisy_glitch_ts.append(noisy_glitch) pure_noise_ts.append(background) snrs_inj.append(snr_to_scale) return { "glitches_ts": glitches_ts, "clean_glitch_subtract": clean_glitch_subtract, "noisy_glitch_ts": noisy_glitch_ts, "pure_noise_ts": pure_noise_ts, "snr": snrs_inj, }
[docs] def generate_hybrid_glitch_data(gaussian_noise_samples, signal_function_map, snr_min=7.5, snr_max=100): hybrid_signals = list(signal_function_map.keys()) clean_glitch_subtract, noisy_glitch_ts, pure_noise_ts = [], [], [] for noise_sample in tqdm(gaussian_noise_samples): background = noise_sample.copy() noisy_glitch = background.copy() n_injs = np.random.randint(2, 8) for _ in range(n_injs): snr_to_scale = np.random.uniform(snr_min, snr_max) s_type = random.choice(hybrid_signals) if s_type in ("chirp", "sine", "sine_gaussian", "gaussian_pulse", "ringdown"): duration = np.random.uniform(T_MIN, 1) _, signal_injection = signal_function_map[s_type](duration) else: _, signal_injection = signal_function_map[s_type]() signal_injection = signal_injection.squeeze() len_glitch = len(signal_injection) id_start = int((T_INJ * SAMPLE_RATE / LENGTH) * len(background)) - len_glitch // 2 glitch = signal_injection - np.mean(signal_injection) glitch = whitened_snr_scaling(glitch, snr=snr_to_scale) shift_int = np.random.randint(-id_start, len(background) - id_start - len_glitch) noisy_glitch[id_start + shift_int : id_start + len_glitch + shift_int] += glitch clean_glitch = noisy_glitch - background clean_glitch_subtract.append(clean_glitch) noisy_glitch_ts.append(noisy_glitch) pure_noise_ts.append(background) return { "glitches_ts": [], "clean_glitch_subtract": clean_glitch_subtract, "noisy_glitch_ts": noisy_glitch_ts, "pure_noise_ts": pure_noise_ts, "snr": [], }
[docs] def evaluate_model(model_name, model_registry, scaler, glitch_data, output_path, checkpoint_dir, device, batch_size=8, checkpoint_filename=CHECKPOINT_BILBY): model = load_torch_model(model_name, model_registry, checkpoint_dir, device, checkpoint_filename=checkpoint_filename) n_fft, hop_length, win_length, window = _get_stft_params(model_name) model_data_dict = {} for signal_type, data in glitch_data.items(): noisy_stft, _, _ = prepare_data_for_stft( data, scaler, n_fft, hop_length, win_length, window ) noisy_glitch_ts = np.asarray(data["noisy_glitch_ts"]) pure_noise_ts = np.asarray(data["pure_noise_ts"]) metrics_dict = {"match_background": [], "match_glitch": [], "mismatch_glitch": []} extracted_signals, background_output = [], [] for m in range(0, len(noisy_stft), batch_size): batch = noisy_stft[m : m + batch_size].to(device) noisy_batch = noisy_glitch_ts[m : m + batch_size] pure_batch = pure_noise_ts[m : m + batch_size] clean_batch = data["clean_glitch_subtract"][m : m + batch_size] with torch.no_grad(): output_val = model(batch.squeeze()) output_istft = apply_istft(output_val, n_fft, hop_length, win_length, window) output_np = output_istft.cpu().numpy().squeeze() backgrounds_inv = scaler.inverse_transform( output_np.reshape(-1, output_np.shape[-1]) ).reshape(output_np.shape) diff = noisy_batch - backgrounds_inv for k in range(len(diff)): extracted_signals.append(diff[k]) background_output.append(backgrounds_inv[k]) match_bg = calculate_match(pure_batch[k], backgrounds_inv[k]) match_gl = calculate_match(clean_batch[k], diff[k]) metrics_dict["match_background"].append(match_bg) metrics_dict["match_glitch"].append(match_gl) metrics_dict["mismatch_glitch"].append((1 - match_gl) * 100) model_data_dict[signal_type] = { "metrics": metrics_dict, "time_series": { "extracted_glitches": [ts.tolist() for ts in extracted_signals], "background_outputs": [ts.tolist() for ts in background_output], }, } return model_data_dict
[docs] def main(): parser = argparse.ArgumentParser( description="Evaluate DeepExtractor models on simulated glitch data", formatter_class=argparse.ArgumentDefaultsHelpFormatter, ) parser.add_argument( "--model", nargs="+", default=["DeepExtractor_257"], help="One or more model names to evaluate.", ) parser.add_argument("--checkpoint-dir", type=str, default=None, help="Local checkpoint directory. Falls back to Hugging Face Hub if not set.") parser.add_argument("--checkpoint-filename", type=str, default=CHECKPOINT_BILBY, help="Checkpoint file name within the model subdirectory.") parser.add_argument("--assets-dir", type=str, default="assets/", help="Directory containing scaler .pkl files.") parser.add_argument("--scaler-path", type=str, default=None, help="Path to scaler .pkl. Defaults to <assets-dir>/scaler_bilby.pkl.") parser.add_argument("--data-dir", type=str, default="data/") parser.add_argument("--output-dir", type=str, default="evaluation/") parser.add_argument("--batch-size", type=int, default=8) parser.add_argument("--num-samples", type=int, default=512) parser.add_argument("--snr-min", type=float, default=7.5) parser.add_argument("--snr-max", type=float, default=100.0) parser.add_argument( "--device", default=None, help="Device to use. Auto-detected if not set.", ) args = parser.parse_args() device = args.device or ("cuda" if torch.cuda.is_available() else "cpu") os.makedirs(args.output_dir, exist_ok=True) # Load scaler — default to bilby scaler (matches simulated evaluation) scaler_path = args.scaler_path or os.path.join(args.assets_dir, "scaler_bilby.pkl") with open(scaler_path, "rb") as f: scaler = pickle.load(f) # Load CDVGAN generator (optional) try: generator = load_tf_model(args.data_dir, "cdvgan") except Exception as e: logger.warning(f"Could not load CDVGAN generator: {e}. CDVGAN signals will be unavailable.") generator = None signal_function_map = { "chirp": generate_chirp, "sine": generate_sine, "sine_gaussian": generate_sine_gaussian, "gaussian_pulse": generate_gaussian_pulse, "ringdown": ringdown, "gengli_H1": lambda: generate_gengli_glitch(ifo="H1"), "gengli_L1": lambda: generate_gengli_glitch(ifo="L1"), } if generator is not None: signal_function_map.update({ "cdvgan_blip": lambda: generate_cdvgan_glitch("blip", generator), "cdvgan_tomte": lambda: generate_cdvgan_glitch("tomte", generator), "cdvgan_bbh": lambda: generate_cdvgan_glitch("bbh", generator), "cdvgan_simplex": lambda: generate_cdvgan_glitch("simplex", generator), "cdvgan_uniform": lambda: generate_cdvgan_glitch("uniform", generator), }) # Generate noise samples mean, std_dev = 0, 50 gaussian_noise_samples = generate_gaussian_noise(mean, std_dev, args.num_samples, (LENGTH,)) # Generate glitch data per signal type glitch_data = {} for signal_type in signal_function_map: logger.info(f"Generating data for: {signal_type}") glitch_data[signal_type] = generate_glitch_data( signal_type, gaussian_noise_samples, signal_function_map, args.snr_min, args.snr_max, ) glitch_data["hybrid"] = generate_hybrid_glitch_data( gaussian_noise_samples, signal_function_map, args.snr_min, args.snr_max ) # Evaluate models data_dict = {"data": glitch_data, "model_outputs": {}} for model_name in args.model: logger.info(f"Evaluating model: {model_name}") model_data = evaluate_model( model_name, MODEL_REGISTRY, scaler, glitch_data, args.output_dir, args.checkpoint_dir, device, args.batch_size, checkpoint_filename=args.checkpoint_filename, ) data_dict["model_outputs"][model_name] = model_data out_file = os.path.join(args.output_dir, "simulation_results.pkl") with open(out_file, "wb") as f: pickle.dump(data_dict, f) logger.info(f"Evaluation complete. Results saved to {out_file}")
if __name__ == "__main__": main()