Source code for deepextractor.utils.io

import os
import logging

import numpy as np
import torch
import torch.nn as nn
from gwpy.timeseries import TimeSeries
from torch.utils.data import DataLoader


[docs] logger = logging.getLogger(__name__)
[docs] def get_loaders( train_dir, train_target_dir, val_dir, val_target_dir, batch_size, train_transform=False, val_transform=False, num_workers=4, pin_memory=True, time_domain=True, ): """Return train and validation DataLoaders.""" from deepextractor.data.datasets import SpectrogramDataset, TimeSeriesDataset if time_domain: train_ds = TimeSeriesDataset( input_npy=train_dir, target_npy=train_target_dir, transform=train_transform, ) val_ds = TimeSeriesDataset( input_npy=val_dir, target_npy=val_target_dir, transform=val_transform, ) else: print("Got spec datasets") train_ds = SpectrogramDataset( input_npy=train_dir, target_npy=train_target_dir, transform=train_transform, ) val_ds = SpectrogramDataset( input_npy=val_dir, target_npy=val_target_dir, transform=val_transform, ) train_loader = DataLoader( train_ds, batch_size=batch_size, num_workers=num_workers, pin_memory=pin_memory, shuffle=True, ) val_loader = DataLoader( val_ds, batch_size=batch_size, num_workers=num_workers, pin_memory=pin_memory, shuffle=False, ) return train_loader, val_loader
[docs] def check_accuracy(loader, model, model_name, device="cuda"): """Compute MSE loss on the validation set and return average losses.""" mse_loss_fn = nn.MSELoss() total_mse_loss = 0 total_noise_loss = 0 total_constraint_loss = 0 num_samples = 0 model.eval() with torch.no_grad(): for x, y in loader: x = x.to(device) y = y.to(device) predictions = model(x) if model_name == "UNET1D_diff": noise_pred = predictions[:, 0:1, :] residual_pred = predictions[:, 1:2, :] reconstructed = noise_pred + residual_pred noise_loss = mse_loss_fn(noise_pred, y) constraint_loss = mse_loss_fn(reconstructed, x) total_noise_loss += noise_loss.item() * x.size(0) total_constraint_loss += constraint_loss.item() * x.size(0) total_loss = noise_loss + constraint_loss else: total_loss = mse_loss_fn(predictions, y) total_mse_loss += total_loss.item() * x.size(0) num_samples += x.size(0) model.train() avg_mse_loss = total_mse_loss / num_samples avg_noise_loss = total_noise_loss / num_samples if model_name == "UNET1D_diff" else None avg_constraint_loss = ( total_constraint_loss / num_samples if model_name == "UNET1D_diff" else None ) if model_name == "UNET1D_diff": print( f"Validation Losses - Total: {avg_mse_loss:.6f}, " f"Noise: {avg_noise_loss:.6f}, Constraint: {avg_constraint_loss:.6f}" ) else: print(f"Validation Loss: {avg_mse_loss:.6f}") return avg_mse_loss, avg_noise_loss, avg_constraint_loss
[docs] def numpy_to_gwf(strain, sample_times, channel, output_filename): """ Write a strain time series to a GWF (frame) file. Parameters ---------- strain : array-like The time-domain strain data. sample_times : array-like The corresponding time array. channel : str Channel name, e.g. ``'L1:STRAIN'``. output_filename : str Path for the output GWF file. """ frame_file = TimeSeries(strain, times=sample_times, channel=channel) frame_file.write(output_filename) return None
[docs] def gwf_to_lcf(start_time, duration, channel_name, gwf_file_location): """Write a minimal LCF (frame cache) file alongside the GWF file.""" output_string = ( f"{channel_name[0]} {channel_name} {int(start_time)} {int(duration - 2)} " f"file://localhost{gwf_file_location}" ) os.system( "echo %s > %s" % (output_string, f"{gwf_file_location.replace('gwf', 'lcf')}") ) return None
[docs] def load_tf_model(path, model_name): """ Load a TensorFlow/Keras SavedModel. Requires the ``[generative]`` optional dependencies: ``pip install deepextractor[generative]``. """ try: import tensorflow as tf except ImportError as e: raise ImportError( "TensorFlow is required for this function. " "Install it with: pip install deepextractor[generative]" ) from e return tf.keras.models.load_model(os.path.join(path, model_name))