Source code for deepextractor.utils.visualization
import os
import matplotlib.ticker as ticker
import matplotlib.pyplot as plt
import numpy as np
import torch
from gwpy.timeseries import TimeSeries
[docs]
def save_predictions_as_plots(loader, model, folder="saved_predictions/", device="cuda"):
"""Save model prediction vs target plots for each sample in the loader."""
model.eval()
os.makedirs(folder, exist_ok=True)
for idx, (x, y) in enumerate(loader):
x = x.to(device=device)
with torch.no_grad():
preds = model(x).cpu().numpy()
for i, pred in enumerate(preds):
plt.figure(figsize=(10, 4))
plt.plot(pred.squeeze(), label="Prediction", color="b")
plt.plot(y[i].cpu().numpy().squeeze(), label="Target", color="r", linestyle="--")
plt.title(f"Time Series Prediction vs Target {idx}_{i}")
plt.xlabel("Time Step")
plt.ylabel("Value")
plt.legend()
plt.tight_layout()
plt.savefig(f"{folder}/plot_{idx}_{i}.png")
plt.close()
model.train()
[docs]
def plot_examples(
Difference_ts,
clean_glitch_subtract,
snrs,
signal_type,
PLOTS_PATH,
indices_to_plot,
noisy=False,
):
"""Plot up to 3 example time series and save to disk."""
plt.figure(figsize=(18, 5))
for i, idx in enumerate(indices_to_plot):
plt.subplot(1, 3, i + 1)
plt.plot(Difference_ts[idx], label="Difference_ts", color="red", alpha=0.7)
plt.plot(
clean_glitch_subtract[idx],
label="Clean Glitch Subtract",
color="blue",
alpha=0.5,
)
plt.title(f"Example {i + 1} for {signal_type} with SNR={np.round(snrs[idx], 2)}")
plt.xlabel("Time")
plt.ylabel("Amplitude")
plt.legend()
plt.grid(True)
plt.tight_layout()
suffix = "noisy_example" if noisy else "example"
plt.savefig(os.path.join(PLOTS_PATH, f"{signal_type}_{suffix}"))
plt.close()
[docs]
def plot_q_transform(data, srate=4096.0, crop=None, whiten=True, ax=None, colourbar=True, qrange = [4, 64], frange = [10, 1200]):
"""
Plot the Q-transform of a time series using gwpy.
Parameters
----------
data : array-like
Input time-domain data.
srate : float
Sample rate in Hz.
crop : tuple or list, optional
``(center_time, duration)`` window in seconds for the Q-transform.
whiten : bool
If True, apply whitening before the Q-transform.
ax : matplotlib.axes.Axes, optional
Axes on which to plot. A new figure is created if not provided.
colourbar : bool
If True, add a colorbar to the plot.
"""
data = TimeSeries(data, sample_rate=srate)
q_scan = data.q_transform(
qrange=qrange,
frange=[10, 1200],
tres=0.002,
fres=0.5,
whiten=whiten,
)
if isinstance(crop, (list, tuple)):
t_center, dur = crop
t_center = t_center + data.t0.value
q_scan = q_scan.crop(t_center - dur / 2, t_center + dur / 2)
xticklabels = np.linspace(0, 2, 5)
if ax is None:
fig, ax = plt.subplots(dpi=120)
im = ax.imshow(q_scan, aspect="auto", extent=[0, 2, 10, 1290])
ax.set_yscale("log", base=2)
ax.set_xscale("linear")
if isinstance(crop, (list, tuple)):
ax.set_xticks(xticklabels)
ax.set_xticklabels(xticklabels)
ax.set_ylabel("Frequency (Hz)", fontsize=14)
ax.set_xlabel("Time (s)", labelpad=0.1, fontsize=14)
ax.yaxis.set_major_formatter(ticker.ScalarFormatter())
ax.tick_params(axis="both", which="major", labelsize=14)
im.set_clim(0, 25.5)
if colourbar:
cb = ax.figure.colorbar(im, ax=ax, label="Normalized energy", pad=0.01)
cb.ax.tick_params(labelsize=18)
cb.set_label("Normalized energy", fontsize=24)