Source code for deepextractor.models.architectures

import logging
from typing import Optional

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.init as init

[docs] logger = logging.getLogger(__name__)
# --------------------------------------------------------------------------- # TD building blocks (used by UNET1D_LSTM_ATT) # ---------------------------------------------------------------------------
[docs] class SE1D(nn.Module): """Squeeze-and-Excitation channel attention for 1D features (B, C, T).""" def __init__(self, channels, r=8): super().__init__() hidden = max(1, channels // r)
[docs] self.net = nn.Sequential( nn.AdaptiveAvgPool1d(1), nn.Conv1d(channels, hidden, 1), nn.ReLU(inplace=True), nn.Conv1d(hidden, channels, 1), nn.Sigmoid(), )
[docs] def forward(self, x): return x * self.net(x)
[docs] class MHSA1D(nn.Module): """Temporal multi-head self-attention on (B, C, T).""" def __init__(self, channels, num_heads=4, dropout=0.0): super().__init__()
[docs] self.proj_in = nn.Conv1d(channels, channels, 1)
[docs] self.attn = nn.MultiheadAttention( embed_dim=channels, num_heads=num_heads, dropout=dropout, batch_first=True, )
[docs] self.proj_out = nn.Conv1d(channels, channels, 1)
[docs] self.norm = nn.LayerNorm(channels)
[docs] def forward(self, x): z = self.proj_in(x).transpose(1, 2) # (B, T, C) z2, _ = self.attn(z, z, z) z = self.norm(z + z2).transpose(1, 2) # (B, C, T) return self.proj_out(z)
[docs] class LSTMBottleneck1D(nn.Module): """Bidirectional LSTM bottleneck for (B, C, T) features.""" def __init__( self, channels, hidden: Optional[int] = None, num_layers: int = 1, bidirectional: bool = True, dropout: float = 0.0, ): super().__init__() hidden = hidden or channels // 2
[docs] self.lstm = nn.LSTM( input_size=channels, hidden_size=hidden, num_layers=num_layers, batch_first=True, bidirectional=bidirectional, dropout=dropout if num_layers > 1 else 0.0, )
out_ch = hidden * (2 if bidirectional else 1)
[docs] self.proj = nn.Linear(out_ch, channels)
[docs] self.norm = nn.LayerNorm(channels)
[docs] def forward(self, x): z = x.transpose(1, 2) # (B, T, C) z, _ = self.lstm(z) z = self.norm(self.proj(z)) return z.transpose(1, 2) # (B, C, T)
class _DoubleConv1D_norm(nn.Module): """DoubleConv1D with configurable normalisation (used by UNET1D_LSTM_ATT).""" def __init__(self, in_ch, out_ch, norm='bn'): super().__init__() def _norm(ch): return nn.BatchNorm1d(ch) if norm == 'bn' else nn.GroupNorm(ch, ch) self.conv = nn.Sequential( nn.Conv1d(in_ch, out_ch, kernel_size=3, padding=1, bias=False), _norm(out_ch), nn.ReLU(inplace=True), nn.Conv1d(out_ch, out_ch, kernel_size=3, padding=1, bias=False), _norm(out_ch), nn.ReLU(inplace=True), ) def forward(self, x): return self.conv(x)
[docs] class UNET1D_LSTM_ATT(nn.Module): """ Time-domain U-Net for two-detector signal/glitch separation. Input : (B, in_channels, T) — default 2 (H1 + L1 strain stacked) Output: (B, out_channels, T) — default 4 (h1_signal, h1_noise, l1_signal, l1_noise) Optional bottleneck: bidirectional LSTM, multi-head self-attention. Optional skip connections: Squeeze-and-Excitation (SE) gating. """ def __init__( self, in_channels: int = 2, out_channels: int = 4, features=(64, 128, 256, 512), use_lstm_bottleneck: bool = True, lstm_layers: int = 1, lstm_bidirectional: bool = True, lstm_dropout: float = 0.0, use_mhsa_bottleneck: bool = False, use_mhsa_encoder: bool = False, use_mhsa_decoder: bool = False, attn_heads: int = 4, use_se_on_skips: bool = True, norm: str = 'bn', ): super().__init__()
[docs] self.pool = nn.MaxPool1d(2, 2)
[docs] self.downs = nn.ModuleList()
[docs] self.ups = nn.ModuleList()
[docs] self.skip_se = nn.ModuleList() if use_se_on_skips else None
[docs] self.use_lstm = use_lstm_bottleneck
[docs] self.use_mhsa_bottleneck = use_mhsa_bottleneck
[docs] self.use_mhsa_encoder = use_mhsa_encoder
[docs] self.use_mhsa_decoder = use_mhsa_decoder
# Encoder ch_in = in_channels for f in features: self.downs.append(_DoubleConv1D_norm(ch_in, f, norm=norm)) if use_se_on_skips: self.skip_se.append(SE1D(f)) ch_in = f # Bottleneck bottleneck_ch = features[-1]
[docs] self.bottleneck_conv = _DoubleConv1D_norm(bottleneck_ch, bottleneck_ch * 2, norm=norm)
ch_bn = bottleneck_ch * 2 if use_lstm_bottleneck: self.bottleneck_lstm = LSTMBottleneck1D( channels=ch_bn, num_layers=lstm_layers, bidirectional=lstm_bidirectional, dropout=lstm_dropout, ) if use_mhsa_bottleneck: self.bottleneck_attn = MHSA1D(channels=ch_bn, num_heads=attn_heads) if use_mhsa_encoder: self.encoder_attn = MHSA1D(channels=bottleneck_ch, num_heads=attn_heads) # Decoder for f in reversed(features): self.ups.append(nn.ConvTranspose1d(ch_bn, f, kernel_size=2, stride=2)) self.ups.append(_DoubleConv1D_norm(f * 2, f, norm=norm)) ch_bn = f if use_mhsa_decoder: self.decoder_attn = MHSA1D(channels=features[-1], num_heads=attn_heads)
[docs] self.final_conv = nn.Conv1d(features[0], out_channels, kernel_size=1)
[docs] def forward(self, x): skips = [] z = x for i, down in enumerate(self.downs): z = down(z) skips.append(self.skip_se[i](z) if self.skip_se is not None else z) z = self.pool(z) if self.use_mhsa_encoder: z = z + self.encoder_attn(z) z = self.bottleneck_conv(z) if self.use_lstm: z = z + self.bottleneck_lstm(z) if self.use_mhsa_bottleneck: z = z + self.bottleneck_attn(z) skips = skips[::-1] for i in range(0, len(self.ups), 2): z = self.ups[i](z) skip = skips[i // 2] if z.shape[-1] != skip.shape[-1]: z = F.interpolate(z, size=skip.shape[-1]) z = torch.cat([skip, z], dim=1) z = self.ups[i + 1](z) if self.use_mhsa_decoder and i == 0: z = z + self.decoder_attn(z) return self.final_conv(z)
# 1D models
[docs] class DoubleConv1D(nn.Module): def __init__(self, in_channels, out_channels): super(DoubleConv1D, self).__init__()
[docs] self.conv = nn.Sequential( nn.Conv1d(in_channels, out_channels, kernel_size=3, padding=1, bias=False), nn.BatchNorm1d(out_channels), nn.ReLU(inplace=True), nn.Conv1d(out_channels, out_channels, kernel_size=3, padding=1, bias=False), nn.BatchNorm1d(out_channels), nn.ReLU(inplace=True), )
[docs] def forward(self, x): return self.conv(x)
[docs] class UNET1D(nn.Module): def __init__(self, in_channels=1, out_channels=1, features=[64, 128, 256, 512]): super(UNET1D, self).__init__()
[docs] self.ups = nn.ModuleList()
[docs] self.downs = nn.ModuleList()
[docs] self.pool = nn.MaxPool1d(kernel_size=2, stride=2)
for feature in features: self.downs.append(DoubleConv1D(in_channels, feature)) in_channels = feature for feature in reversed(features): self.ups.append( nn.ConvTranspose1d(feature * 2, feature, kernel_size=2, stride=2) ) self.ups.append(DoubleConv1D(feature * 2, feature))
[docs] self.bottleneck = DoubleConv1D(features[-1], features[-1] * 2)
[docs] self.final_conv = nn.Conv1d(features[0], out_channels, kernel_size=1)
[docs] def forward(self, x): skip_connections = [] for down in self.downs: x = down(x) skip_connections.append(x) x = self.pool(x) x = self.bottleneck(x) skip_connections = skip_connections[::-1] for idx in range(0, len(self.ups), 2): x = self.ups[idx](x) skip_connection = skip_connections[idx // 2] if x.shape != skip_connection.shape: x = F.interpolate(x, size=skip_connection.shape[2:]) concat_skip = torch.cat((skip_connection, x), dim=1) x = self.ups[idx + 1](concat_skip) return self.final_conv(x)
[docs] class DnCNN1D(nn.Module): def __init__(self, depth=12, n_channels=64, image_channels=1, use_bnorm=True, kernel_size=3): super(DnCNN1D, self).__init__() padding = 1 layers = [] layers.append( nn.Conv1d( in_channels=image_channels, out_channels=n_channels, kernel_size=kernel_size, padding=padding, bias=True, ) ) layers.append(nn.ReLU(inplace=True)) for _ in range(depth - 2): layers.append( nn.Conv1d( in_channels=n_channels, out_channels=n_channels, kernel_size=kernel_size, padding=padding, bias=False, ) ) layers.append(nn.BatchNorm1d(n_channels, eps=0.0001, momentum=0.95)) layers.append(nn.ReLU(inplace=True)) layers.append( nn.Conv1d( in_channels=n_channels, out_channels=image_channels, kernel_size=kernel_size, padding=padding, bias=False, ) )
[docs] self.dncnn = nn.Sequential(*layers)
self._initialize_weights()
[docs] def forward(self, x): return self.dncnn(x)
def _initialize_weights(self): for m in self.modules(): if isinstance(m, nn.Conv1d): init.orthogonal_(m.weight) logger.debug("init weight") if m.bias is not None: init.constant_(m.bias, 0) elif isinstance(m, nn.BatchNorm2d): init.constant_(m.weight, 1) init.constant_(m.bias, 0)
[docs] class Autoencoder1D(nn.Module): def __init__(self, in_channels=1, out_channels=1, features=[64, 128, 256, 512]): super(Autoencoder1D, self).__init__()
[docs] self.ups = nn.ModuleList()
[docs] self.downs = nn.ModuleList()
[docs] self.pool = nn.MaxPool1d(kernel_size=2, stride=2)
for feature in features: self.downs.append(DoubleConv1D(in_channels, feature)) in_channels = feature for feature in reversed(features): self.ups.append( nn.ConvTranspose1d(feature * 2, feature, kernel_size=2, stride=2) ) self.ups.append(DoubleConv1D(feature, feature))
[docs] self.bottleneck = DoubleConv1D(features[-1], features[-1] * 2)
[docs] self.final_conv = nn.Conv1d(features[0], out_channels, kernel_size=1)
[docs] def forward(self, x): for down in self.downs: x = down(x) x = self.pool(x) x = self.bottleneck(x) for idx in range(0, len(self.ups), 2): x = self.ups[idx](x) x = self.ups[idx + 1](x) return self.final_conv(x)
# 2D models
[docs] class DoubleConv2D(nn.Module): def __init__(self, in_channels, out_channels): super(DoubleConv2D, self).__init__()
[docs] self.conv = nn.Sequential( nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, bias=False), nn.BatchNorm2d(out_channels), nn.ReLU(inplace=True), nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1, bias=False), nn.BatchNorm2d(out_channels), nn.ReLU(inplace=True), )
[docs] def forward(self, x): return self.conv(x)
[docs] class UNET2D(nn.Module): def __init__(self, in_channels=2, out_channels=2, features=[64, 128, 256, 512]): super(UNET2D, self).__init__()
[docs] self.ups = nn.ModuleList()
[docs] self.downs = nn.ModuleList()
[docs] self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
for feature in features: self.downs.append(DoubleConv2D(in_channels, feature)) in_channels = feature for feature in reversed(features): self.ups.append( nn.ConvTranspose2d(feature * 2, feature, kernel_size=2, stride=2) ) self.ups.append(DoubleConv2D(feature * 2, feature))
[docs] self.bottleneck = DoubleConv2D(features[-1], features[-1] * 2)
[docs] self.final_conv = nn.Conv2d(features[0], out_channels, kernel_size=1)
[docs] def forward(self, x): skip_connections = [] for down in self.downs: x = down(x) skip_connections.append(x) x = self.pool(x) x = self.bottleneck(x) skip_connections = skip_connections[::-1] for idx in range(0, len(self.ups), 2): x = self.ups[idx](x) skip_connection = skip_connections[idx // 2] if x.shape != skip_connection.shape: x = F.interpolate(x, size=skip_connection.shape[2:]) concat_skip = torch.cat((skip_connection, x), dim=1) x = self.ups[idx + 1](concat_skip) return self.final_conv(x)
[docs] class Autoencoder2D(nn.Module): def __init__(self, in_channels=2, out_channels=2, features=[64, 128, 256, 512]): super(Autoencoder2D, self).__init__()
[docs] self.ups = nn.ModuleList()
[docs] self.downs = nn.ModuleList()
[docs] self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
for feature in features: self.downs.append(DoubleConv2D(in_channels, feature)) in_channels = feature for feature in reversed(features): self.ups.append( nn.ConvTranspose2d(feature * 2, feature, kernel_size=2, stride=2) ) self.ups.append(DoubleConv2D(feature, feature))
[docs] self.bottleneck = DoubleConv2D(features[-1], features[-1] * 2)
[docs] self.final_conv = nn.Conv2d(features[0], out_channels, kernel_size=1)
[docs] def forward(self, x): original_size = x.size()[-2:] downs_outputs = [] for down in self.downs: x = down(x) downs_outputs.append(x) x = self.pool(x) x = self.bottleneck(x) downs_outputs = downs_outputs[::-1] for idx in range(0, len(self.ups), 2): x = self.ups[idx](x) target_size = downs_outputs[idx // 2].size()[-2:] if x.size()[-2:] != target_size: x = F.interpolate(x, size=target_size) x = self.ups[idx + 1](x) if x.size()[-2:] != original_size: x = F.interpolate(x, size=original_size) return self.final_conv(x)
[docs] class ModifiedAutoencoder2D(nn.Module): def __init__(self, in_channels=2, out_channels=2, features=[64, 128, 256]): super(ModifiedAutoencoder2D, self).__init__()
[docs] self.ups = nn.ModuleList()
[docs] self.downs = nn.ModuleList()
[docs] self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
for feature in features: self.downs.append(DoubleConv2D(in_channels, feature)) in_channels = feature for feature in reversed(features): self.ups.append( nn.Sequential( nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True), nn.Conv2d(feature * 2, feature, kernel_size=3, stride=1, padding=1), ) )
[docs] self.bottleneck = nn.Conv2d( features[-1], features[-1] * 2, kernel_size=3, stride=1, padding=1 )
[docs] self.final_conv = nn.Conv2d(features[0], out_channels, kernel_size=1)
[docs] def forward(self, x): original_size = x.size()[-2:] downs_outputs = [] for down in self.downs: x = down(x) downs_outputs.append(x) x = self.pool(x) x = self.bottleneck(x) downs_outputs = downs_outputs[::-1] for idx, up in enumerate(self.ups): x = up(x) target_size = downs_outputs[idx].size()[-2:] if x.size()[-2:] != target_size: x = F.interpolate(x, size=target_size, mode="bilinear", align_corners=False) if x.size()[-2:] != original_size: x = F.interpolate(x, size=original_size) return self.final_conv(x)