import numpy as np
import h5py
[docs]
class ChannelStandardScaler:
"""Per-channel standard scaler for (N, C, T) time-series data.
Fits one mean and std per channel across all samples and time points.
The mean_ and scale_ attributes (shape (C,)) are compatible with the
HDF5Dataset input_scaler interface.
Compatible with joblib.dump / pickle for serialisation.
"""
def __init__(self):
[docs]
self.n_channels_ = None
[docs]
def fit(self, X: np.ndarray) -> "ChannelStandardScaler":
"""Fit on array X of shape (N, C, T).
For large datasets use fit_from_hdf5 instead to avoid loading
everything into memory.
"""
if X.ndim != 3:
raise ValueError(f"Expected 3-D array (N, C, T), got shape {X.shape}")
# Mean and std over axes 0 (samples) and 2 (time), per channel
self.mean_ = X.mean(axis=(0, 2)) # (C,)
self.scale_ = X.std(axis=(0, 2)) # (C,)
self.scale_[self.scale_ == 0] = 1.0 # guard against zero-variance channels
self.n_channels_ = X.shape[1]
return self
[docs]
def fit_from_hdf5(
self, hdf5_path: str, key: str, chunk_size: int = 2048
) -> "ChannelStandardScaler":
"""Fit on a dataset too large to load at once.
Computes per-channel mean and variance in two online passes over the
HDF5 dataset — first pass for the mean, second for the variance.
Memory usage is O(chunk_size * C * T) rather than O(N * C * T).
Args:
hdf5_path: Path to the HDF5 file.
key: Dataset key with shape (N, C, T).
chunk_size: Number of samples to process at a time.
"""
with h5py.File(hdf5_path, "r") as f:
ds = f[key]
n, c, t = ds.shape
# Pass 1: mean per channel
sum_ = np.zeros(c, dtype=np.float64)
count = np.float64(n * t)
for start in range(0, n, chunk_size):
end = min(start + chunk_size, n)
chunk = ds[start:end].astype(np.float64) # (batch, C, T)
sum_ += chunk.sum(axis=(0, 2))
mean_ = sum_ / count
# Pass 2: variance per channel
sq_sum = np.zeros(c, dtype=np.float64)
for start in range(0, n, chunk_size):
end = min(start + chunk_size, n)
chunk = ds[start:end].astype(np.float64)
diff = chunk - mean_[np.newaxis, :, np.newaxis]
sq_sum += (diff ** 2).sum(axis=(0, 2))
var_ = sq_sum / count
self.mean_ = mean_.astype(np.float32)
self.scale_ = np.sqrt(var_).astype(np.float32)
self.scale_[self.scale_ == 0] = 1.0
self.n_channels_ = c
return self
def _check_fitted(self):
if self.mean_ is None:
raise RuntimeError("Scaler is not fitted. Call fit() or fit_from_hdf5() first.")