"""
Convert time-domain .npy arrays to STFT spectrograms (magnitude + phase).
Also provides a utility to concatenate chunked spectrogram files.
Usage::
deepextractor-specgen --input-dir data/pycbc_noise/time_domain/ --output-dir data/pycbc_noise/spectrogram_domain/
"""
import argparse
import os
import numpy as np
import torch
# Default STFT parameters (257x257 output shape)
[docs]
DEFAULT_N_FFT = 256 * 2
[docs]
DEFAULT_WIN_LENGTH = DEFAULT_N_FFT // 8
[docs]
DEFAULT_HOP_LENGTH = DEFAULT_WIN_LENGTH // 2
[docs]
def apply_stft_and_save(
array_path, save_path, n_fft, hop_length, win_length, window, chunk_size=5000
):
"""Apply STFT to a .npy array in chunks and save the result."""
array = np.load(array_path)
print(f"Loaded {array_path}, shape: {array.shape}")
total_chunks = array.shape[0] // chunk_size
stft_list = []
for i in range(0, array.shape[0], chunk_size):
chunk = array[i : i + chunk_size]
tensor = torch.tensor(chunk, dtype=torch.float32)
stft_result = torch.stft(
tensor,
n_fft=n_fft,
hop_length=hop_length,
win_length=win_length,
window=window,
return_complex=True,
)
magnitude = torch.abs(stft_result)
phase = torch.angle(stft_result)
stft_mag_phase = torch.stack([magnitude, phase], dim=1)
stft_list.append(stft_mag_phase)
del tensor, stft_result, magnitude, phase
torch.cuda.empty_cache()
print(f"Processed chunk {i // chunk_size + 1}/{max(total_chunks, 1)}")
stft_final = torch.cat(stft_list, dim=0)
stft_numpy = stft_final.cpu().numpy()
np.save(save_path, stft_numpy)
print(f"STFT saved to {save_path}.npy, final shape: {stft_numpy.shape}")
del array, stft_list, stft_final, stft_numpy
torch.cuda.empty_cache()
[docs]
def load_and_concatenate_chunks(data_dir, base_filename, total_chunks):
"""Load and concatenate chunked numpy arrays saved as ``{base}_chunk_{i}.npy``."""
stft_list = []
for i in range(total_chunks):
chunk_filename = f"{base_filename}_chunk_{i}.npy"
chunk_path = os.path.join(data_dir, chunk_filename)
if os.path.exists(chunk_path):
print(f"Loading {chunk_filename}...")
stft_list.append(np.load(chunk_path))
else:
print(f"Chunk {chunk_filename} not found. Skipping.")
print("Concatenating chunks...")
return np.concatenate(stft_list, axis=0)
[docs]
def main():
parser = argparse.ArgumentParser(
description="Convert time-domain .npy arrays to STFT spectrogram arrays",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument(
"--input-dir",
type=str,
required=True,
help="Directory containing the time-domain .npy files.",
)
parser.add_argument(
"--output-dir",
type=str,
required=True,
help="Directory to save the spectrogram .npy files.",
)
parser.add_argument("--n-fft", type=int, default=DEFAULT_N_FFT)
parser.add_argument("--win-length", type=int, default=DEFAULT_WIN_LENGTH)
parser.add_argument("--hop-length", type=int, default=DEFAULT_HOP_LENGTH)
parser.add_argument("--chunk-size", type=int, default=5000)
parser.add_argument(
"--combine-chunks",
action="store_true",
help="Combine pre-existing chunk files instead of generating new spectrograms.",
)
parser.add_argument(
"--chunks-glitch-train", type=int, default=16,
help="Number of chunks for glitch_train (used with --combine-chunks).",
)
parser.add_argument(
"--chunks-background-train", type=int, default=16,
help="Number of chunks for background_train (used with --combine-chunks).",
)
parser.add_argument(
"--chunks-glitch-val", type=int, default=2,
help="Number of chunks for glitch_val (used with --combine-chunks).",
)
parser.add_argument(
"--chunks-background-val", type=int, default=2,
help="Number of chunks for background_val (used with --combine-chunks).",
)
args = parser.parse_args()
os.makedirs(args.output_dir, exist_ok=True)
window = torch.hann_window(args.win_length)
if args.combine_chunks:
for base, n_chunks in [
("glitch_train_scaled_mag_phase", args.chunks_glitch_train),
("background_train_scaled_mag_phase", args.chunks_background_train),
("glitch_val_scaled_mag_phase", args.chunks_glitch_val),
("background_val_scaled_mag_phase", args.chunks_background_val),
]:
combined = load_and_concatenate_chunks(args.output_dir, base, n_chunks)
out_path = os.path.join(args.output_dir, f"{base}_combined.npy")
np.save(out_path, combined)
print(f"Saved combined {base} to {out_path}")
print("All combined datasets saved.")
else:
datasets = [
("glitch_train_scaled.npy", "glitch_train_scaled_mag_phase"),
("background_train_scaled.npy", "background_train_scaled_mag_phase"),
("glitch_val_scaled.npy", "glitch_val_scaled_mag_phase"),
("background_val_scaled.npy", "background_val_scaled_mag_phase"),
]
for in_name, out_name in datasets:
in_path = os.path.join(args.input_dir, in_name)
out_path = os.path.join(args.output_dir, out_name)
apply_stft_and_save(
in_path, out_path,
args.n_fft, args.hop_length, args.win_length, window,
args.chunk_size,
)
print("All STFT results saved.")
if __name__ == "__main__":
main()