Source code for lisbet.modeling.heads.embedding
"""Embedding head for extracting behavior embeddings."""
from typing import Any
import torch
from torch import nn
[docs]
class EmbeddingHead(nn.Module):
"""Embedding head for extracting behavior embeddings.
This head selects a specific token from the sequence (typically the last one)
and returns it as the behavior embedding without any additional transformation.
Parameters
----------
output_token_idx : int
Index of the token to use for embedding extraction (e.g., -1 for last token).
Attributes
----------
output_token_idx : int
Index of the token used for embedding extraction.
"""
[docs]
def __init__(self, output_token_idx: int) -> None:
super().__init__()
self.output_token_idx = output_token_idx
[docs]
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Forward pass through the embedding head.
Parameters
----------
x : torch.Tensor
Input tensor of shape (batch_size, sequence_length, embedding_dim).
Returns
-------
torch.Tensor
Embedding tensor of shape (batch_size, embedding_dim).
"""
x = x[:, self.output_token_idx]
return x
[docs]
def get_config(self) -> dict[str, Any]:
"""Get the configuration dictionary for this head.
Returns
-------
dict[str, Any]
Configuration dictionary containing all parameters needed
to recreate this head instance.
"""
return {
"output_token_idx": self.output_token_idx,
}
[docs]
@classmethod
def from_config(cls, config: dict[str, Any]) -> "EmbeddingHead":
"""Create an EmbeddingHead instance from a configuration dictionary.
Parameters
----------
config : dict[str, Any]
Configuration dictionary containing all parameters needed
to create the head instance.
Returns
-------
EmbeddingHead
New EmbeddingHead instance created from the configuration.
"""
return cls(**config)