Source code for lisbet.modeling.backbones.tcn

"""TCN Backbone for LISBET.

Implements a Temporal Convolutional Network (TCN) with dilated convolutions and
residual connections.

Reference:
    Bai, S., Kolter, J. Z., & Koltun, V. (2018). An Empirical Evaluation of Generic
    Convolutional and Recurrent Networks for Sequence Modeling. arXiv:1803.01271.
"""

from typing import Any

import torch
from torch import nn

from lisbet.modeling.backbones.base import BackboneInterface


[docs] class Chomp1d(nn.Module): """Removes padding from the end of the sequence to ensure causality."""
[docs] def __init__(self, chomp_size: int): super().__init__() self.chomp_size = chomp_size
[docs] def forward(self, x): return x[..., : -self.chomp_size] if self.chomp_size > 0 else x
[docs] class TemporalBlock(nn.Module): """A single TCN block with dilated convolutions and residual connection."""
[docs] def __init__( self, in_channels: int, out_channels: int, kernel_size: int, stride: int, dilation: int, padding: int, dropout: float = 0.0, ): super().__init__() self.conv1 = nn.Conv1d( in_channels, out_channels, kernel_size, stride=stride, padding=padding, dilation=dilation, ) self.chomp1 = Chomp1d(padding) self.relu1 = nn.ReLU() self.dropout1 = nn.Dropout(dropout) self.conv2 = nn.Conv1d( out_channels, out_channels, kernel_size, stride=stride, padding=padding, dilation=dilation, ) self.chomp2 = Chomp1d(padding) self.relu2 = nn.ReLU() self.dropout2 = nn.Dropout(dropout) self.downsample = ( nn.Conv1d(in_channels, out_channels, 1) if in_channels != out_channels else nn.Identity() ) self.final_relu = nn.ReLU()
[docs] def forward(self, x): out = self.conv1(x) out = self.chomp1(out) out = self.relu1(out) out = self.dropout1(out) out = self.conv2(out) out = self.chomp2(out) out = self.relu2(out) out = self.dropout2(out) res = self.downsample(x) return self.final_relu(out + res)
[docs] class TCNBackbone(BackboneInterface): """ Temporal Convolutional Network (TCN) backbone for sequence modeling. Parameters ---------- feature_dim : int Dimension of the input features. embedding_dim : int Dimension of the output embeddings. hidden_dim : int Number of channels in the hidden layers. num_layers : int Number of temporal blocks (layers). kernel_size : int, optional Size of the convolutional kernel. Default: 3 dilation_base : int, optional Base for the dilation factor. Default: 2 dropout : float, optional Dropout rate. Default: 0.0 """
[docs] def __init__( self, feature_dim: int, embedding_dim: int, hidden_dim: int, num_layers: int, kernel_size: int = 3, dilation_base: int = 2, dropout: float = 0.0, ): super().__init__() self.feature_dim = feature_dim self.embedding_dim = embedding_dim self.hidden_dim = hidden_dim self.num_layers = num_layers self.kernel_size = kernel_size self.dilation_base = dilation_base self.dropout = dropout layers = [] in_channels = feature_dim for i in range(num_layers): out_channels = embedding_dim if i == num_layers - 1 else hidden_dim dilation = dilation_base**i padding = (kernel_size - 1) * dilation layers.append( TemporalBlock( in_channels, out_channels, kernel_size, stride=1, dilation=dilation, padding=padding, dropout=dropout, ) ) in_channels = out_channels self.network = nn.Sequential(*layers)
[docs] def forward(self, x: torch.Tensor) -> torch.Tensor: """ Forward pass through the TCN backbone. Parameters ---------- x : torch.Tensor Input tensor of shape (batch_size, sequence_length, feature_dim). Returns ------- torch.Tensor Output tensor of shape (batch_size, sequence_length, embedding_dim). """ # Convert to (batch, features, seq) x = x.transpose(1, 2) out = self.network(x) # Convert back to (batch, seq, features) out = out.transpose(1, 2) return out
[docs] def get_config(self) -> dict[str, Any]: """Get the configuration dictionary for this backbone.""" return { "feature_dim": self.feature_dim, "embedding_dim": self.embedding_dim, "hidden_dim": self.hidden_dim, "num_layers": self.num_layers, "kernel_size": self.kernel_size, "dilation_base": self.dilation_base, "dropout": self.dropout, }
[docs] @classmethod def from_config(cls, config: dict[str, Any]) -> "TCNBackbone": """Create a TCNBackbone instance from a configuration dictionary.""" return cls( feature_dim=config["feature_dim"], embedding_dim=config["embedding_dim"], hidden_dim=config["hidden_dim"], num_layers=config["num_layers"], kernel_size=config.get("kernel_size", 3), dilation_base=config.get("dilation_base", 2), dropout=config.get("dropout", 0.0), )