From 601fa4658f884c1f49ad54d021ece5be45f07d84 Mon Sep 17 00:00:00 2001 From: meilame-tayebjee Date: Mon, 26 Jan 2026 16:05:20 +0000 Subject: [PATCH 01/22] feat!(label attention): enable label attention - module and config created to do that - mainly attached the TextEmbedder (it aggregates the token embedding to produce a sentence embedding - instead of naive averaging) - rest of the code has been adapted, especially categorical var handling in TextClassificationModel --- .../model/components/__init__.py | 1 + .../model/components/text_embedder.py | 175 +++++++++++++++--- torchTextClassifiers/model/lightning.py | 1 + torchTextClassifiers/model/model.py | 30 ++- torchTextClassifiers/torchTextClassifiers.py | 8 +- 5 files changed, 179 insertions(+), 36 deletions(-) diff --git a/torchTextClassifiers/model/components/__init__.py b/torchTextClassifiers/model/components/__init__.py index b14af0e..5cad342 100644 --- a/torchTextClassifiers/model/components/__init__.py +++ b/torchTextClassifiers/model/components/__init__.py @@ -8,5 +8,6 @@ CategoricalVariableNet as CategoricalVariableNet, ) from .classification_head import ClassificationHead as ClassificationHead +from .text_embedder import LabelAttentionConfig as LabelAttentionConfig from .text_embedder import TextEmbedder as TextEmbedder from .text_embedder import TextEmbedderConfig as TextEmbedderConfig diff --git a/torchTextClassifiers/model/components/text_embedder.py b/torchTextClassifiers/model/components/text_embedder.py index b317c91..de7aff7 100644 --- a/torchTextClassifiers/model/components/text_embedder.py +++ b/torchTextClassifiers/model/components/text_embedder.py @@ -3,17 +3,26 @@ from typing import Optional import torch -from torch import nn +import torch.nn as nn +from torch.nn import functional as F from torchTextClassifiers.model.components.attention import AttentionConfig, Block, norm +@dataclass +class LabelAttentionConfig: + n_head: int + n_kv_head: int + num_classes: int + + @dataclass class TextEmbedderConfig: vocab_size: int embedding_dim: int padding_idx: int attention_config: Optional[AttentionConfig] = None + label_attention_config: Optional[LabelAttentionConfig] = None class TextEmbedder(nn.Module): @@ -26,8 +35,9 @@ def __init__(self, text_embedder_config: TextEmbedderConfig): if isinstance(self.attention_config, dict): self.attention_config = AttentionConfig(**self.attention_config) - if self.attention_config is not None: - self.attention_config.n_embd = text_embedder_config.embedding_dim + self.enable_label_attention = text_embedder_config.label_attention_config is not None + if self.enable_label_attention: + self.label_attention_module = LabelAttentionClassifier(self.config) self.vocab_size = text_embedder_config.vocab_size self.embedding_dim = text_embedder_config.embedding_dim @@ -40,6 +50,7 @@ def __init__(self, text_embedder_config: TextEmbedderConfig): ) if self.attention_config is not None: + self.attention_config.n_embd = text_embedder_config.embedding_dim self.transformer = nn.ModuleDict( { "h": nn.ModuleList( @@ -105,8 +116,23 @@ def _init_weights(self, module): elif isinstance(module, nn.Embedding): torch.nn.init.normal_(module.weight, mean=0.0, std=1.0) - def forward(self, input_ids: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor: - """Converts input token IDs to their corresponding embeddings.""" + def forward( + self, + input_ids: torch.Tensor, + attention_mask: torch.Tensor, + return_label_attention_matrix: bool = False, + ) -> torch.Tensor: + """Converts input token IDs to their corresponding embeddings. + + Args: + input_ids (torch.Tensor[Long]), shape (batch_size, seq_len): Tokenized + attention_mask (torch.Tensor[Long]), shape (batch_size, seq_len): Attention mask indicating non-pad tokens + return_label_attention_matrix (bool): Whether to return the label attention matrix + Returns: + torch.Tensor: Text embeddings, shape (batch_size, embedding_dim) if self.enable_label_attention is False, else (batch_size, num_labels, embedding_dim) + torch.Tensor: Label attention matrix, shape (batch_size, num_labels, seq_len) if return_label_attention_matrix is True, else None. + Also None if label attention is disabled (even if return_label_attention_matrix is True) + """ encoded_text = input_ids # clearer name if encoded_text.dtype != torch.long: @@ -138,14 +164,25 @@ def forward(self, input_ids: torch.Tensor, attention_mask: torch.Tensor) -> torc token_embeddings = norm(token_embeddings) - text_embedding = self._get_sentence_embedding( - token_embeddings=token_embeddings, attention_mask=attention_mask - ) + text_embedding, label_attention_matrix = self._get_sentence_embedding( + token_embeddings=token_embeddings, + attention_mask=attention_mask, + return_label_attention_matrix=return_label_attention_matrix, + ).values() - return text_embedding + if return_label_attention_matrix: + return ( + text_embedding, + label_attention_matrix, + ) # label_attention_matrix is None if label attention is disabled + else: + return text_embedding def _get_sentence_embedding( - self, token_embeddings: torch.Tensor, attention_mask: torch.Tensor + self, + token_embeddings: torch.Tensor, + attention_mask: torch.Tensor, + return_label_attention_matrix: bool = False, ) -> torch.Tensor: """ Compute sentence embedding from embedded tokens - "remove" second dimension. @@ -163,7 +200,7 @@ def _get_sentence_embedding( # mask pad-tokens if self.attention_config is not None: - if self.attention_config.aggregation_method is not None: + if self.attention_config.aggregation_method is not None: # default is "mean" if self.attention_config.aggregation_method == "first": return token_embeddings[:, 0, :] elif self.attention_config.aggregation_method == "last": @@ -181,25 +218,29 @@ def _get_sentence_embedding( assert self.attention_config is None or self.attention_config.aggregation_method == "mean" - mask = attention_mask.unsqueeze(-1).float() # (batch_size, seq_len, 1) - masked_embeddings = token_embeddings * mask # (batch_size, seq_len, embedding_dim) - - sentence_embedding = masked_embeddings.sum(dim=1) / mask.sum(dim=1).clamp( - min=1.0 - ) # avoid division by zero - - sentence_embedding = torch.nan_to_num(sentence_embedding, 0.0) - - return sentence_embedding - - def __call__(self, *args, **kwargs): - out = super().__call__(*args, **kwargs) - if out.dim() != 2: - raise ValueError( - f"Output of {self.__class__.__name__}.forward must be 2D " - f"(got shape {tuple(out.shape)})" + if self.enable_label_attention: + label_attention_result = self.label_attention_module( + token_embeddings, compute_attention_matrix=return_label_attention_matrix ) - return out + sentence_embedding = label_attention_result[ + "sentence_embedding" + ] # (bs, n_labels, d_embed), so classifier needs to be a (d_embed, 1) matrix + label_attention_matrix = label_attention_result["attention_matrix"] + + else: # sentence embedding = mean of (non-pad) token embeddings + mask = attention_mask.unsqueeze(-1).float() # (batch_size, seq_len, 1) + masked_embeddings = token_embeddings * mask # (batch_size, seq_len, embedding_dim) + sentence_embedding = masked_embeddings.sum(dim=1) / mask.sum(dim=1).clamp( + min=1.0 + ) # avoid division by zero + + sentence_embedding = torch.nan_to_num(sentence_embedding, 0.0) + label_attention_matrix = None + + return { + "sentence_embedding": sentence_embedding, + "label_attention_matrix": label_attention_matrix, + } def _precompute_rotary_embeddings(self, seq_len, head_dim, base=10000, device=None): # autodetect the device from model embeddings @@ -221,3 +262,79 @@ def _precompute_rotary_embeddings(self, seq_len, head_dim, base=10000, device=No ) # add batch and head dims for later broadcasting return cos, sin + + +class LabelAttentionClassifier(nn.Module): + """ + A head for aggregating token embeddings into label-specific sentence embeddings using cross-attention mechanism. + Labels are queries that attend over token embeddings (keys and values) to produce label-specific embeddings. + + """ + + def __init__(self, config: TextEmbedderConfig): + super().__init__() + + label_attention_config = config.label_attention_config + self.embedding_dim = config.embedding_dim + self.num_classes = label_attention_config.num_classes + self.n_head = label_attention_config.n_head + self.n_kv_head = label_attention_config.n_kv_head + self.enable_gqa = ( + self.n_head != self.n_kv_head + ) # Group Query Attention (GQA): duplicate key/value heads to match query heads if desired + self.head_dim = self.embedding_dim // self.n_head + + self.label_embeds = nn.Embedding(self.num_classes, self.embedding_dim) + + self.c_q = nn.Linear(self.embedding_dim, self.n_head * self.head_dim, bias=False) + self.c_k = nn.Linear(self.embedding_dim, self.n_kv_head * self.head_dim, bias=False) + self.c_v = nn.Linear(self.embedding_dim, self.n_kv_head * self.head_dim, bias=False) + self.c_proj = nn.Linear(self.embedding_dim, self.embedding_dim, bias=False) + + def forward(self, token_embeddings, compute_attention_matrix: Optional[bool] = False): + """ + Args: + token_embeddings (torch.Tensor), shape (batch, seq_len, d_model): Embedded tokens from the text input. + compute_attention_matrix (bool): Whether to compute and return the attention matrix. + Returns: + dict: { + "sentence_embedding": torch.Tensor, shape (batch, num_classes, d_model): Label-specific sentence embeddings. + "attention_matrix": Optional[torch.Tensor], shape (batch, n_head, num_classes, seq_len): Attention weights if compute_attention_matrix is True, else None. + } + + """ + B, T, C = token_embeddings.size() + + # 1. Create label indices [0, 1, ..., C-1] for the whole batch + label_indices = torch.arange(self.num_classes).expand(B, -1) + + all_label_embeddings = self.label_embeds( + label_indices + ) # Shape: [batch, num_classes, d_model] + all_label_embeddings = norm(all_label_embeddings) + + q = self.c_q(all_label_embeddings).view(B, self.num_classes, self.n_head, self.head_dim) + k = self.c_k(token_embeddings).view(B, T, self.n_kv_head, self.head_dim) + v = self.c_v(token_embeddings).view(B, T, self.n_kv_head, self.head_dim) + + q, k = norm(q), norm(k) # QK norm + q, k, v = ( + q.transpose(1, 2), + k.transpose(1, 2), + v.transpose(1, 2), + ) # make head be batch dim, i.e. (B, T, H, D) -> (B, H, T, D) + + y = F.scaled_dot_product_attention(q, k, v, is_causal=False, enable_gqa=self.enable_gqa) + + # Re-assemble the heads side by side and project back to residual stream + y = y.transpose(1, 2).contiguous().view(B, self.num_classes, -1) # (bs, n_labels, d_model) + y = self.c_proj(y) + + attention_matrix = None + if compute_attention_matrix: + # size (B, n_head, n_labels, seq_len) - we let the user handle aggregation over heads if desired + attention_matrix = torch.softmax( + torch.matmul(q, k.transpose(-2, -1)) / (self.head_dim**0.5), dim=-1 + ) + + return {"sentence_embedding": y, "attention_matrix": attention_matrix} diff --git a/torchTextClassifiers/model/lightning.py b/torchTextClassifiers/model/lightning.py index 1ebc697..8726f20 100644 --- a/torchTextClassifiers/model/lightning.py +++ b/torchTextClassifiers/model/lightning.py @@ -102,6 +102,7 @@ def validation_step(self, batch, batch_idx: int): targets = batch["labels"] outputs = self.forward(batch) + loss = self.loss(outputs, targets) self.log("val_loss", loss, on_epoch=True, on_step=False, prog_bar=True, sync_dist=True) diff --git a/torchTextClassifiers/model/model.py b/torchTextClassifiers/model/model.py index d9cffbf..8bd4b82 100644 --- a/torchTextClassifiers/model/model.py +++ b/torchTextClassifiers/model/model.py @@ -1,7 +1,7 @@ -"""FastText model components. +"""TextClassification model components. This module contains the PyTorch model, Lightning module, and dataset classes -for FastText classification. Consolidates what was previously in pytorch_model.py, +for TextClassification classification. Consolidates what was previously in pytorch_model.py, lightning_module.py, and dataset.py. """ @@ -17,6 +17,7 @@ ClassificationHead, TextEmbedder, ) +from torchTextClassifiers.model.components.attention import norm logger = logging.getLogger(__name__) @@ -67,8 +68,6 @@ def __init__( self._validate_component_connections() - self.num_classes = self.classification_head.num_classes - torch.nn.init.zeros_(self.classification_head.net.weight) if self.text_embedder is not None: self.text_embedder.init_weights() @@ -98,6 +97,17 @@ def _check_text_categorical_connection(self, text_embedder, cat_var_net): raise ValueError( "Classification head input dimension does not match expected dimension from text embedder and categorical variable net." ) + if self.text_embedder.enable_label_attention: + self.enable_label_attention = True + if self.classification_head.num_classes != 1: + raise ValueError( + "Label attention is enabled. TextEmbedder outputs a (num_classes, embedding_dim) tensor, so the ClassificationHead should have an output dimension of 1." + ) + # if enable_label_attention is True, label_attention_config exists - and contains num_classes necessarily + self.num_classes = self.text_embedder.config.label_attention_config.num_classes + else: + self.enable_label_attention = False + self.num_classes = self.classification_head.num_classes else: logger.warning( "⚠️ No text embedder provided; assuming input text is already embedded or vectorized. Take care that the classification head input dimension matches the input text dimension." @@ -131,21 +141,29 @@ def forward( if self.categorical_variable_net: x_cat = self.categorical_variable_net(categorical_vars) + if self.enable_label_attention: + # x_text is (batch_size, num_classes, embedding_dim) + # x_cat is (batch_size, cat_embedding_dim) + # We need to expand x_cat to (batch_size, num_classes, cat_embedding_dim) + # x_cat will be appended to x_text along the last dimension for each class + x_cat = x_cat.unsqueeze(1).expand(-1, self.num_classes, -1) + if ( self.categorical_variable_net.forward_type == CategoricalForwardType.AVERAGE_AND_CONCAT or self.categorical_variable_net.forward_type == CategoricalForwardType.CONCATENATE_ALL ): - x_combined = torch.cat((x_text, x_cat), dim=1) + x_combined = torch.cat((x_text, x_cat), dim=-1) else: assert ( self.categorical_variable_net.forward_type == CategoricalForwardType.SUM_TO_TEXT ) + x_combined = x_text + x_cat else: x_combined = x_text - logits = self.classification_head(x_combined) + logits = self.classification_head(norm(x_combined)).squeeze(-1) return logits diff --git a/torchTextClassifiers/torchTextClassifiers.py b/torchTextClassifiers/torchTextClassifiers.py index 79ce301..4955cf1 100644 --- a/torchTextClassifiers/torchTextClassifiers.py +++ b/torchTextClassifiers/torchTextClassifiers.py @@ -29,6 +29,7 @@ CategoricalForwardType, CategoricalVariableNet, ClassificationHead, + LabelAttentionConfig, TextEmbedder, TextEmbedderConfig, ) @@ -53,6 +54,7 @@ class ModelConfig: categorical_embedding_dims: Optional[Union[List[int], int]] = None num_classes: Optional[int] = None attention_config: Optional[AttentionConfig] = None + label_attention_config: Optional[LabelAttentionConfig] = None def to_dict(self) -> Dict[str, Any]: return asdict(self) @@ -140,6 +142,7 @@ def __init__( self.embedding_dim = model_config.embedding_dim self.categorical_vocabulary_sizes = model_config.categorical_vocabulary_sizes self.num_classes = model_config.num_classes + self.enable_label_attention = model_config.label_attention_config is not None if self.tokenizer.output_vectorized: self.text_embedder = None @@ -153,6 +156,7 @@ def __init__( embedding_dim=self.embedding_dim, padding_idx=tokenizer.padding_idx, attention_config=model_config.attention_config, + label_attention_config=model_config.label_attention_config, ) self.text_embedder = TextEmbedder( text_embedder_config=text_embedder_config, @@ -174,7 +178,9 @@ def __init__( self.classification_head = ClassificationHead( input_dim=classif_head_input_dim, - num_classes=model_config.num_classes, + num_classes=1 + if self.enable_label_attention + else model_config.num_classes, # output dim is 1 when using label attention, because embeddings are (num_classes, embedding_dim) ) self.pytorch_model = TextClassificationModel( From 30ef8af101bd6bb6b7f255c77ea541c06ef45deb Mon Sep 17 00:00:00 2001 From: meilame-tayebjee Date: Mon, 26 Jan 2026 16:24:33 +0000 Subject: [PATCH 02/22] fix(load): restore LabelAttentionConfig object from loaded dict used as a namespace after, so no converting it throws a bug --- torchTextClassifiers/torchTextClassifiers.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/torchTextClassifiers/torchTextClassifiers.py b/torchTextClassifiers/torchTextClassifiers.py index 4955cf1..e1c6b29 100644 --- a/torchTextClassifiers/torchTextClassifiers.py +++ b/torchTextClassifiers/torchTextClassifiers.py @@ -671,6 +671,10 @@ def load(cls, path: Union[str, Path], device: str = "auto") -> "torchTextClassif # Reconstruct model_config model_config = ModelConfig.from_dict(metadata["model_config"]) + if type(model_config.label_attention_config) is dict: + model_config.label_attention_config = LabelAttentionConfig( + **model_config.label_attention_config + ) # Create instance instance = cls( From 0a1880b5bf61e09d5551672fcb596b11be72bfda Mon Sep 17 00:00:00 2001 From: meilame-tayebjee Date: Mon, 26 Jan 2026 16:25:06 +0000 Subject: [PATCH 03/22] test(label attention): add a test_pipeline with label attention activated --- tests/test_pipeline.py | 53 ++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 51 insertions(+), 2 deletions(-) diff --git a/tests/test_pipeline.py b/tests/test_pipeline.py index 56dff6c..744731e 100644 --- a/tests/test_pipeline.py +++ b/tests/test_pipeline.py @@ -9,6 +9,7 @@ AttentionConfig, CategoricalVariableNet, ClassificationHead, + LabelAttentionConfig, TextEmbedder, TextEmbedderConfig, ) @@ -51,7 +52,14 @@ def model_params(): } -def run_full_pipeline(tokenizer, sample_text_data, categorical_data, labels, model_params): +def run_full_pipeline( + tokenizer, + sample_text_data, + categorical_data, + labels, + model_params, + label_attention_enabled: bool = False, +): """Helper function to run the complete pipeline for a given tokenizer.""" # Create dataset dataset = TextClassificationDataset( @@ -83,6 +91,15 @@ def run_full_pipeline(tokenizer, sample_text_data, categorical_data, labels, mod embedding_dim=model_params["embedding_dim"], padding_idx=padding_idx, attention_config=attention_config, + label_attention_config=( + LabelAttentionConfig( + n_head=attention_config.n_head, + n_kv_head=attention_config.n_kv_head, + num_classes=model_params["num_classes"], + ) + if label_attention_enabled + else None + ), ) text_embedder = TextEmbedder(text_embedder_config=text_embedder_config) @@ -98,7 +115,7 @@ def run_full_pipeline(tokenizer, sample_text_data, categorical_data, labels, mod expected_input_dim = model_params["embedding_dim"] + categorical_var_net.output_dim classification_head = ClassificationHead( input_dim=expected_input_dim, - num_classes=model_params["num_classes"], + num_classes=model_params["num_classes"] if not label_attention_enabled else 1, ) # Create model @@ -136,6 +153,15 @@ def run_full_pipeline(tokenizer, sample_text_data, categorical_data, labels, mod categorical_embedding_dims=model_params["categorical_embedding_dims"], num_classes=model_params["num_classes"], attention_config=attention_config, + label_attention_config=( + LabelAttentionConfig( + n_head=attention_config.n_head, + n_kv_head=attention_config.n_kv_head, + num_classes=model_params["num_classes"], + ) + if label_attention_enabled + else None + ), ) # Create training config @@ -239,3 +265,26 @@ def test_ngram_tokenizer(sample_data, model_params): # Run full pipeline run_full_pipeline(tokenizer, sample_text_data, categorical_data, labels, model_params) + + +def test_label_attention_enabled(sample_data, model_params): + """Test the full pipeline with label attention enabled (using WordPieceTokenizer).""" + sample_text_data, categorical_data, labels = sample_data + + vocab_size = 100 + tokenizer = WordPieceTokenizer(vocab_size, output_dim=50) + tokenizer.train(sample_text_data) + + # Check tokenizer works + result = tokenizer.tokenize(sample_text_data) + assert result.input_ids.shape[0] == len(sample_text_data) + + # Run full pipeline with label attention enabled + run_full_pipeline( + tokenizer, + sample_text_data, + categorical_data, + labels, + model_params, + label_attention_enabled=True, + ) From a2fe33e1a8be02d736d773dfcace34ca0c0ae03b Mon Sep 17 00:00:00 2001 From: meilame-tayebjee Date: Mon, 26 Jan 2026 17:15:47 +0000 Subject: [PATCH 04/22] feat(explainability): add new expl. pipe. with label attention - given a parameter, retrieve the attention matrix - compatible with captum attributions - update tests accordingly --- tests/test_pipeline.py | 10 ++- .../model/components/text_embedder.py | 14 ++-- torchTextClassifiers/model/model.py | 15 +++- torchTextClassifiers/torchTextClassifiers.py | 70 +++++++++++++------ 4 files changed, 77 insertions(+), 32 deletions(-) diff --git a/tests/test_pipeline.py b/tests/test_pipeline.py index 744731e..765e289 100644 --- a/tests/test_pipeline.py +++ b/tests/test_pipeline.py @@ -189,13 +189,19 @@ def run_full_pipeline( # Predict with explanations top_k = 5 - predictions = ttc.predict(X, top_k=top_k, explain=True) + + predictions = ttc.predict( + X, + top_k=top_k, + explain_with_label_attention=label_attention_enabled, + explain_with_captum=True, + ) # Test explainability functions text_idx = 0 text = sample_text_data[text_idx] offsets = predictions["offset_mapping"][text_idx] - attributions = predictions["attributions"][text_idx] + attributions = predictions["captum_attributions"][text_idx] word_ids = predictions["word_ids"][text_idx] words, word_attributions = map_attributions_to_word(attributions, text, word_ids, offsets) diff --git a/torchTextClassifiers/model/components/text_embedder.py b/torchTextClassifiers/model/components/text_embedder.py index de7aff7..e406192 100644 --- a/torchTextClassifiers/model/components/text_embedder.py +++ b/torchTextClassifiers/model/components/text_embedder.py @@ -170,13 +170,10 @@ def forward( return_label_attention_matrix=return_label_attention_matrix, ).values() - if return_label_attention_matrix: - return ( - text_embedding, - label_attention_matrix, - ) # label_attention_matrix is None if label attention is disabled - else: - return text_embedding + return { + "sentence_embedding": text_embedding, + "label_attention_matrix": label_attention_matrix, + } def _get_sentence_embedding( self, @@ -304,6 +301,9 @@ def forward(self, token_embeddings, compute_attention_matrix: Optional[bool] = F """ B, T, C = token_embeddings.size() + if isinstance(compute_attention_matrix, torch.Tensor): + compute_attention_matrix = compute_attention_matrix[0].item() + compute_attention_matrix = bool(compute_attention_matrix) # 1. Create label indices [0, 1, ..., C-1] for the whole batch label_indices = torch.arange(self.num_classes).expand(B, -1) diff --git a/torchTextClassifiers/model/model.py b/torchTextClassifiers/model/model.py index 8bd4b82..e2a2880 100644 --- a/torchTextClassifiers/model/model.py +++ b/torchTextClassifiers/model/model.py @@ -118,6 +118,7 @@ def forward( input_ids: Annotated[torch.Tensor, "batch seq_len"], attention_mask: Annotated[torch.Tensor, "batch seq_len"], categorical_vars: Annotated[torch.Tensor, "batch num_cats"], + return_label_attention_matrix: bool = False, **kwargs, ) -> torch.Tensor: """ @@ -136,7 +137,16 @@ def forward( if self.text_embedder is None: x_text = encoded_text.float() else: - x_text = self.text_embedder(input_ids=encoded_text, attention_mask=attention_mask) + text_embed_output = self.text_embedder( + input_ids=encoded_text, + attention_mask=attention_mask, + return_label_attention_matrix=return_label_attention_matrix, + ) + x_text = text_embed_output["sentence_embedding"] + if isinstance(return_label_attention_matrix, torch.Tensor): + return_label_attention_matrix = return_label_attention_matrix[0].item() + if return_label_attention_matrix: + label_attention_matrix = text_embed_output["label_attention_matrix"] if self.categorical_variable_net: x_cat = self.categorical_variable_net(categorical_vars) @@ -166,4 +176,7 @@ def forward( logits = self.classification_head(norm(x_combined)).squeeze(-1) + if return_label_attention_matrix: + return {"logits": logits, "label_attention_matrix": label_attention_matrix} + return logits diff --git a/torchTextClassifiers/torchTextClassifiers.py b/torchTextClassifiers/torchTextClassifiers.py index e1c6b29..8b3d0cd 100644 --- a/torchTextClassifiers/torchTextClassifiers.py +++ b/torchTextClassifiers/torchTextClassifiers.py @@ -492,13 +492,15 @@ def predict( self, X_test: np.ndarray, top_k=1, - explain=False, + explain_with_label_attention: bool = False, + explain_with_captum=False, ): """ Args: X_test (np.ndarray): input data to predict on, shape (N,d) where the first column is text and the rest are categorical variables top_k (int): for each sentence, return the top_k most likely predictions (default: 1) - explain (bool): launch gradient integration to have an explanation of the prediction (default: False) + explain_with_label_attention (bool): if enabled, use attention matrix labels x tokens to have an explanation of the prediction (default: False) + explain_with_captum (bool): launch gradient integration with Captum for explanation (default: False) Returns: A dictionary containing the following fields: - predictions (torch.Tensor, shape (len(text), top_k)): A tensor containing the top_k most likely codes to the query. @@ -507,6 +509,7 @@ def predict( - attributions (torch.Tensor, shape (len(text), top_k, seq_len)): A tensor containing the attributions for each token in the text. """ + explain = explain_with_label_attention or explain_with_captum if explain: return_offsets_mapping = True # to be passed to the tokenizer return_word_ids = True @@ -515,13 +518,19 @@ def predict( "Explainability is not supported when the tokenizer outputs vectorized text directly. Please use a tokenizer that outputs token IDs." ) else: - if not HAS_CAPTUM: - raise ImportError( - "Captum is not installed and is required for explainability. Run 'pip install/uv add torchFastText[explainability]'." - ) - lig = LayerIntegratedGradients( - self.pytorch_model, self.pytorch_model.text_embedder.embedding_layer - ) # initialize a Captum layer gradient integrator + if explain_with_captum: + if not HAS_CAPTUM: + raise ImportError( + "Captum is not installed and is required for explainability. Run 'pip install/uv add torchFastText[explainability]'." + ) + lig = LayerIntegratedGradients( + self.pytorch_model, self.pytorch_model.text_embedder.embedding_layer + ) # initialize a Captum layer gradient integrator + if explain_with_label_attention: + if not self.enable_label_attention: + raise RuntimeError( + "Label attention explainability is enabled, but the model was not configured with label attention. Please enable label attention in the model configuration during initialization and retrain." + ) else: return_offsets_mapping = False return_word_ids = False @@ -553,9 +562,19 @@ def predict( else: categorical_vars = torch.empty((encoded_text.shape[0], 0), dtype=torch.float32) - pred = self.pytorch_model( - encoded_text, attention_mask, categorical_vars + model_output = self.pytorch_model( + encoded_text, + attention_mask, + categorical_vars, + return_label_attention_matrix=explain_with_label_attention, ) # forward pass, contains the prediction scores (len(text), num_classes) + pred = ( + model_output["logits"] if explain_with_label_attention else model_output + ) # (batch_size, num_classes) + + label_attention_matrix = ( + model_output["label_attention_matrix"] if explain_with_label_attention else None + ) label_scores = pred.detach().cpu().softmax(dim=1) # convert to probabilities @@ -565,21 +584,28 @@ def predict( confidence = torch.round(label_scores_topk.values, decimals=2) # and their scores if explain: - all_attributions = [] - for k in range(top_k): - attributions = lig.attribute( - (encoded_text, attention_mask, categorical_vars), - target=torch.Tensor(predictions[:, k]).long(), - ) # (batch_size, seq_len) - attributions = attributions.sum(dim=-1) - all_attributions.append(attributions.detach().cpu()) - - all_attributions = torch.stack(all_attributions, dim=1) # (batch_size, top_k, seq_len) + if explain_with_captum: + # Captum explanations + captum_attributions = [] + for k in range(top_k): + attributions = lig.attribute( + (encoded_text, attention_mask, categorical_vars), + target=torch.Tensor(predictions[:, k]).long(), + ) # (batch_size, seq_len) + attributions = attributions.sum(dim=-1) + captum_attributions.append(attributions.detach().cpu()) + + captum_attributions = torch.stack( + captum_attributions, dim=1 + ) # (batch_size, top_k, seq_len) + else: + captum_attributions = None return { "prediction": predictions, "confidence": confidence, - "attributions": all_attributions, + "captum_attributions": captum_attributions, + "label_attention_attributions": label_attention_matrix, "offset_mapping": tokenize_output.offset_mapping, "word_ids": tokenize_output.word_ids, } From 3d034f43445ba15228a7755b3ee92609471d11ee Mon Sep 17 00:00:00 2001 From: Meilame Tayebjee <114609737+meilame-tayebjee@users.noreply.github.com> Date: Mon, 26 Jan 2026 18:24:57 +0100 Subject: [PATCH 05/22] fix typo Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- torchTextClassifiers/model/model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchTextClassifiers/model/model.py b/torchTextClassifiers/model/model.py index e2a2880..d253a66 100644 --- a/torchTextClassifiers/model/model.py +++ b/torchTextClassifiers/model/model.py @@ -1,7 +1,7 @@ """TextClassification model components. This module contains the PyTorch model, Lightning module, and dataset classes -for TextClassification classification. Consolidates what was previously in pytorch_model.py, +for text classification. Consolidates what was previously in pytorch_model.py, lightning_module.py, and dataset.py. """ From ec6742cb4bf4300ab5d76233f8ea3f35f81adfd8 Mon Sep 17 00:00:00 2001 From: Meilame Tayebjee <114609737+meilame-tayebjee@users.noreply.github.com> Date: Mon, 26 Jan 2026 18:27:25 +0100 Subject: [PATCH 06/22] fix docstring in TextEmbedder forward Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- .../model/components/text_embedder.py | 18 +++++++++++++----- 1 file changed, 13 insertions(+), 5 deletions(-) diff --git a/torchTextClassifiers/model/components/text_embedder.py b/torchTextClassifiers/model/components/text_embedder.py index e406192..6875112 100644 --- a/torchTextClassifiers/model/components/text_embedder.py +++ b/torchTextClassifiers/model/components/text_embedder.py @@ -121,17 +121,25 @@ def forward( input_ids: torch.Tensor, attention_mask: torch.Tensor, return_label_attention_matrix: bool = False, - ) -> torch.Tensor: + ) -> dict[str, Optional[torch.Tensor]]: """Converts input token IDs to their corresponding embeddings. Args: input_ids (torch.Tensor[Long]), shape (batch_size, seq_len): Tokenized attention_mask (torch.Tensor[Long]), shape (batch_size, seq_len): Attention mask indicating non-pad tokens - return_label_attention_matrix (bool): Whether to return the label attention matrix + return_label_attention_matrix (bool): Whether to return the label attention matrix. + Returns: - torch.Tensor: Text embeddings, shape (batch_size, embedding_dim) if self.enable_label_attention is False, else (batch_size, num_labels, embedding_dim) - torch.Tensor: Label attention matrix, shape (batch_size, num_labels, seq_len) if return_label_attention_matrix is True, else None. - Also None if label attention is disabled (even if return_label_attention_matrix is True) + dict: A dictionary with the following keys: + + - "sentence_embedding" (torch.Tensor): Text embeddings of shape + (batch_size, embedding_dim) if ``self.enable_label_attention`` is False, + else (batch_size, num_labels, embedding_dim). + + - "label_attention_matrix" (Optional[torch.Tensor]): Label attention + matrix of shape (batch_size, num_labels, seq_len) if + ``return_label_attention_matrix`` is True and label attention is + enabled, otherwise ``None``. """ encoded_text = input_ids # clearer name From f991b6b6bf393fe598bab6d0ac8ccc00f2a62622 Mon Sep 17 00:00:00 2001 From: Meilame Tayebjee <114609737+meilame-tayebjee@users.noreply.github.com> Date: Mon, 26 Jan 2026 18:31:04 +0100 Subject: [PATCH 07/22] fix: convert to LabelAttentionConfig object when dict Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- torchTextClassifiers/model/components/text_embedder.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/torchTextClassifiers/model/components/text_embedder.py b/torchTextClassifiers/model/components/text_embedder.py index 6875112..3184cc9 100644 --- a/torchTextClassifiers/model/components/text_embedder.py +++ b/torchTextClassifiers/model/components/text_embedder.py @@ -35,7 +35,15 @@ def __init__(self, text_embedder_config: TextEmbedderConfig): if isinstance(self.attention_config, dict): self.attention_config = AttentionConfig(**self.attention_config) - self.enable_label_attention = text_embedder_config.label_attention_config is not None + # Normalize label_attention_config: allow dicts and convert them to LabelAttentionConfig + self.label_attention_config = text_embedder_config.label_attention_config + if isinstance(self.label_attention_config, dict): + self.label_attention_config = LabelAttentionConfig(**self.label_attention_config) + # Keep self.config in sync so downstream components (e.g., LabelAttentionClassifier) + # always see a LabelAttentionConfig instance rather than a raw dict. + self.config.label_attention_config = self.label_attention_config + + self.enable_label_attention = self.label_attention_config is not None if self.enable_label_attention: self.label_attention_module = LabelAttentionClassifier(self.config) From 525b48276f73357740fbd902c6554ecb755ee804 Mon Sep 17 00:00:00 2001 From: Meilame Tayebjee <114609737+meilame-tayebjee@users.noreply.github.com> Date: Mon, 26 Jan 2026 18:32:14 +0100 Subject: [PATCH 08/22] chore: replace type checking with isinstance Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- torchTextClassifiers/torchTextClassifiers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchTextClassifiers/torchTextClassifiers.py b/torchTextClassifiers/torchTextClassifiers.py index 8b3d0cd..a4f2c55 100644 --- a/torchTextClassifiers/torchTextClassifiers.py +++ b/torchTextClassifiers/torchTextClassifiers.py @@ -697,7 +697,7 @@ def load(cls, path: Union[str, Path], device: str = "auto") -> "torchTextClassif # Reconstruct model_config model_config = ModelConfig.from_dict(metadata["model_config"]) - if type(model_config.label_attention_config) is dict: + if isinstance(model_config.label_attention_config, dict): model_config.label_attention_config = LabelAttentionConfig( **model_config.label_attention_config ) From 2374df8c47d054218acdc8496b6a4f7830776e8c Mon Sep 17 00:00:00 2001 From: Meilame Tayebjee <114609737+meilame-tayebjee@users.noreply.github.com> Date: Mon, 26 Jan 2026 18:33:28 +0100 Subject: [PATCH 09/22] chore: better dict handling in TextEmbedder forward output Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- torchTextClassifiers/model/components/text_embedder.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/torchTextClassifiers/model/components/text_embedder.py b/torchTextClassifiers/model/components/text_embedder.py index 3184cc9..7bd030c 100644 --- a/torchTextClassifiers/model/components/text_embedder.py +++ b/torchTextClassifiers/model/components/text_embedder.py @@ -180,12 +180,14 @@ def forward( token_embeddings = norm(token_embeddings) - text_embedding, label_attention_matrix = self._get_sentence_embedding( + out = self._get_sentence_embedding( token_embeddings=token_embeddings, attention_mask=attention_mask, return_label_attention_matrix=return_label_attention_matrix, - ).values() + ) + text_embedding = out["sentence_embedding"] + label_attention_matrix = out["label_attention_matrix"] return { "sentence_embedding": text_embedding, "label_attention_matrix": label_attention_matrix, From d266572192599127947fa5d7fb86486f6860e268 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Mon, 26 Jan 2026 17:40:04 +0000 Subject: [PATCH 10/22] fix: ensure label_indices uses correct device and dtype Co-authored-by: meilame-tayebjee <114609737+meilame-tayebjee@users.noreply.github.com> --- .gitignore | 1 + torchTextClassifiers/model/components/text_embedder.py | 4 +++- 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/.gitignore b/.gitignore index e65366e..586ae23 100644 --- a/.gitignore +++ b/.gitignore @@ -183,3 +183,4 @@ example_files/ _site/ .quarto/ **/*.quarto_ipynb +my_ttc/ diff --git a/torchTextClassifiers/model/components/text_embedder.py b/torchTextClassifiers/model/components/text_embedder.py index 7bd030c..b5ce92e 100644 --- a/torchTextClassifiers/model/components/text_embedder.py +++ b/torchTextClassifiers/model/components/text_embedder.py @@ -324,7 +324,9 @@ def forward(self, token_embeddings, compute_attention_matrix: Optional[bool] = F compute_attention_matrix = bool(compute_attention_matrix) # 1. Create label indices [0, 1, ..., C-1] for the whole batch - label_indices = torch.arange(self.num_classes).expand(B, -1) + label_indices = torch.arange( + self.num_classes, dtype=torch.long, device=token_embeddings.device + ).expand(B, -1) all_label_embeddings = self.label_embeds( label_indices From 4aada37a01ce16f6bfbc6abdfef36f28515f4591 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Mon, 26 Jan 2026 17:34:49 +0000 Subject: [PATCH 11/22] Add validation for LabelAttentionClassifier head configuration Co-authored-by: meilame-tayebjee <114609737+meilame-tayebjee@users.noreply.github.com> --- .../model/components/text_embedder.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/torchTextClassifiers/model/components/text_embedder.py b/torchTextClassifiers/model/components/text_embedder.py index b5ce92e..125e6a6 100644 --- a/torchTextClassifiers/model/components/text_embedder.py +++ b/torchTextClassifiers/model/components/text_embedder.py @@ -297,6 +297,20 @@ def __init__(self, config: TextEmbedderConfig): self.enable_gqa = ( self.n_head != self.n_kv_head ) # Group Query Attention (GQA): duplicate key/value heads to match query heads if desired + + # Validate head configuration + if self.embedding_dim % self.n_head != 0: + raise ValueError( + f"embedding_dim ({self.embedding_dim}) must be divisible by n_head ({self.n_head}). " + f"Got head_dim = {self.embedding_dim / self.n_head}" + ) + + if self.n_head % self.n_kv_head != 0: + raise ValueError( + f"n_head ({self.n_head}) must be divisible by n_kv_head ({self.n_kv_head}) for Group Query Attention. " + f"Got n_head / n_kv_head = {self.n_head / self.n_kv_head}" + ) + self.head_dim = self.embedding_dim // self.n_head self.label_embeds = nn.Embedding(self.num_classes, self.embedding_dim) From d516e6b24fec433fb360c0a25ad7af3ad928f57f Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Mon, 26 Jan 2026 17:36:05 +0000 Subject: [PATCH 12/22] Improve validation to follow TextEmbedder pattern and clarify error messages Co-authored-by: meilame-tayebjee <114609737+meilame-tayebjee@users.noreply.github.com> --- torchTextClassifiers/model/components/text_embedder.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/torchTextClassifiers/model/components/text_embedder.py b/torchTextClassifiers/model/components/text_embedder.py index 125e6a6..ddad80a 100644 --- a/torchTextClassifiers/model/components/text_embedder.py +++ b/torchTextClassifiers/model/components/text_embedder.py @@ -299,19 +299,19 @@ def __init__(self, config: TextEmbedderConfig): ) # Group Query Attention (GQA): duplicate key/value heads to match query heads if desired # Validate head configuration - if self.embedding_dim % self.n_head != 0: + self.head_dim = self.embedding_dim // self.n_head + + if self.head_dim * self.n_head != self.embedding_dim: raise ValueError( f"embedding_dim ({self.embedding_dim}) must be divisible by n_head ({self.n_head}). " - f"Got head_dim = {self.embedding_dim / self.n_head}" + f"Got head_dim = {self.head_dim} with remainder {self.embedding_dim % self.n_head}" ) if self.n_head % self.n_kv_head != 0: raise ValueError( f"n_head ({self.n_head}) must be divisible by n_kv_head ({self.n_kv_head}) for Group Query Attention. " - f"Got n_head / n_kv_head = {self.n_head / self.n_kv_head}" + f"Got remainder {self.n_head % self.n_kv_head}" ) - - self.head_dim = self.embedding_dim // self.n_head self.label_embeds = nn.Embedding(self.num_classes, self.embedding_dim) From 87d672ff27ec4a180c6389958e8374a57a5a2760 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Tue, 27 Jan 2026 09:50:02 +0000 Subject: [PATCH 13/22] Apply attention mask in LabelAttentionClassifier cross-attention Co-authored-by: meilame-tayebjee <114609737+meilame-tayebjee@users.noreply.github.com> --- .../model/components/text_embedder.py | 36 +++++++++++++++---- 1 file changed, 29 insertions(+), 7 deletions(-) diff --git a/torchTextClassifiers/model/components/text_embedder.py b/torchTextClassifiers/model/components/text_embedder.py index ddad80a..da21c30 100644 --- a/torchTextClassifiers/model/components/text_embedder.py +++ b/torchTextClassifiers/model/components/text_embedder.py @@ -235,7 +235,9 @@ def _get_sentence_embedding( if self.enable_label_attention: label_attention_result = self.label_attention_module( - token_embeddings, compute_attention_matrix=return_label_attention_matrix + token_embeddings, + attention_mask=attention_mask, + compute_attention_matrix=return_label_attention_matrix, ) sentence_embedding = label_attention_result[ "sentence_embedding" @@ -320,10 +322,11 @@ def __init__(self, config: TextEmbedderConfig): self.c_v = nn.Linear(self.embedding_dim, self.n_kv_head * self.head_dim, bias=False) self.c_proj = nn.Linear(self.embedding_dim, self.embedding_dim, bias=False) - def forward(self, token_embeddings, compute_attention_matrix: Optional[bool] = False): + def forward(self, token_embeddings, attention_mask: Optional[torch.Tensor] = None, compute_attention_matrix: Optional[bool] = False): """ Args: token_embeddings (torch.Tensor), shape (batch, seq_len, d_model): Embedded tokens from the text input. + attention_mask (torch.Tensor, optional), shape (batch, seq_len): Attention mask indicating non-pad tokens (1 for real tokens, 0 for padding). compute_attention_matrix (bool): Whether to compute and return the attention matrix. Returns: dict: { @@ -358,7 +361,18 @@ def forward(self, token_embeddings, compute_attention_matrix: Optional[bool] = F v.transpose(1, 2), ) # make head be batch dim, i.e. (B, T, H, D) -> (B, H, T, D) - y = F.scaled_dot_product_attention(q, k, v, is_causal=False, enable_gqa=self.enable_gqa) + # Prepare attention mask for scaled_dot_product_attention + # attention_mask: (B, T) with 1 for real tokens, 0 for padding + # scaled_dot_product_attention expects attn_mask: (B, H, Q, K) or broadcastable shape + # where True means "mask out" (ignore), False means "attend to" + attn_mask = None + if attention_mask is not None: + # Convert: 0 (padding) -> True (mask out), 1 (real) -> False (attend to) + attn_mask = (attention_mask == 0) # (B, T) + # Expand to (B, 1, 1, T) for broadcasting across heads and queries + attn_mask = attn_mask.unsqueeze(1).unsqueeze(2) # (B, 1, 1, T) + + y = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask, is_causal=False, enable_gqa=self.enable_gqa) # Re-assemble the heads side by side and project back to residual stream y = y.transpose(1, 2).contiguous().view(B, self.num_classes, -1) # (bs, n_labels, d_model) @@ -366,9 +380,17 @@ def forward(self, token_embeddings, compute_attention_matrix: Optional[bool] = F attention_matrix = None if compute_attention_matrix: - # size (B, n_head, n_labels, seq_len) - we let the user handle aggregation over heads if desired - attention_matrix = torch.softmax( - torch.matmul(q, k.transpose(-2, -1)) / (self.head_dim**0.5), dim=-1 - ) + # Compute attention scores + # size (B, n_head, n_labels, seq_len) + attention_scores = torch.matmul(q, k.transpose(-2, -1)) / (self.head_dim**0.5) + + # Apply mask to attention scores before softmax + if attention_mask is not None: + # attn_mask is already in the right shape: (B, 1, 1, T) + # We need to apply it to scores of shape (B, n_head, n_labels, T) + # Set masked positions to -inf so they become 0 after softmax + attention_scores = attention_scores.masked_fill(attn_mask, float('-inf')) + + attention_matrix = torch.softmax(attention_scores, dim=-1) return {"sentence_embedding": y, "attention_matrix": attention_matrix} From 7a988c3d4f56730d5a9d8bdccf1e60b1d63fdc2e Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Tue, 27 Jan 2026 09:51:20 +0000 Subject: [PATCH 14/22] Fix trailing whitespace in attention matrix computation Co-authored-by: meilame-tayebjee <114609737+meilame-tayebjee@users.noreply.github.com> --- torchTextClassifiers/model/components/text_embedder.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torchTextClassifiers/model/components/text_embedder.py b/torchTextClassifiers/model/components/text_embedder.py index da21c30..627bbf0 100644 --- a/torchTextClassifiers/model/components/text_embedder.py +++ b/torchTextClassifiers/model/components/text_embedder.py @@ -383,14 +383,14 @@ def forward(self, token_embeddings, attention_mask: Optional[torch.Tensor] = Non # Compute attention scores # size (B, n_head, n_labels, seq_len) attention_scores = torch.matmul(q, k.transpose(-2, -1)) / (self.head_dim**0.5) - + # Apply mask to attention scores before softmax if attention_mask is not None: # attn_mask is already in the right shape: (B, 1, 1, T) # We need to apply it to scores of shape (B, n_head, n_labels, T) # Set masked positions to -inf so they become 0 after softmax attention_scores = attention_scores.masked_fill(attn_mask, float('-inf')) - + attention_matrix = torch.softmax(attention_scores, dim=-1) return {"sentence_embedding": y, "attention_matrix": attention_matrix} From 70b79c91c884b3a3807ec1edcf49b0effe870458 Mon Sep 17 00:00:00 2001 From: Meilame Tayebjee <114609737+meilame-tayebjee@users.noreply.github.com> Date: Tue, 27 Jan 2026 11:14:11 +0100 Subject: [PATCH 15/22] doc: fix docstring on shape of label attention matrix Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- torchTextClassifiers/model/components/text_embedder.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/torchTextClassifiers/model/components/text_embedder.py b/torchTextClassifiers/model/components/text_embedder.py index 627bbf0..4caadea 100644 --- a/torchTextClassifiers/model/components/text_embedder.py +++ b/torchTextClassifiers/model/components/text_embedder.py @@ -142,12 +142,14 @@ def forward( - "sentence_embedding" (torch.Tensor): Text embeddings of shape (batch_size, embedding_dim) if ``self.enable_label_attention`` is False, - else (batch_size, num_labels, embedding_dim). + else (batch_size, num_classes, embedding_dim), where ``num_classes`` + is the number of label classes. - "label_attention_matrix" (Optional[torch.Tensor]): Label attention - matrix of shape (batch_size, num_labels, seq_len) if + matrix of shape (batch_size, n_head, num_classes, seq_len) if ``return_label_attention_matrix`` is True and label attention is - enabled, otherwise ``None``. + enabled, otherwise ``None``. The dimensions correspond to + (batch_size, attention heads, label classes, sequence length). """ encoded_text = input_ids # clearer name From 0558f9716eb2995afa7a8833458a7334e7648c3d Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Tue, 27 Jan 2026 10:23:40 +0000 Subject: [PATCH 16/22] Add assertions for label attention attributions in tests Co-authored-by: meilame-tayebjee <114609737+meilame-tayebjee@users.noreply.github.com> --- tests/test_pipeline.py | 22 ++++++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/tests/test_pipeline.py b/tests/test_pipeline.py index 765e289..dabe76e 100644 --- a/tests/test_pipeline.py +++ b/tests/test_pipeline.py @@ -197,6 +197,28 @@ def run_full_pipeline( explain_with_captum=True, ) + # Test label attention assertions + if label_attention_enabled: + assert predictions["label_attention_attributions"] is not None, ( + "Label attention attributions should not be None when label_attention_enabled is True" + ) + label_attention_attributions = predictions["label_attention_attributions"] + expected_shape = ( + len(sample_text_data), # batch_size + model_params["n_head"], # n_head + model_params["num_classes"], # num_classes + tokenizer.output_dim, # seq_len + ) + assert label_attention_attributions.shape == expected_shape, ( + f"Label attention attributions shape mismatch. " + f"Expected {expected_shape}, got {label_attention_attributions.shape}" + ) + else: + # When label attention is not enabled, the attributions should be None + assert predictions.get("label_attention_attributions") is None, ( + "Label attention attributions should be None when label_attention_enabled is False" + ) + # Test explainability functions text_idx = 0 text = sample_text_data[text_idx] From 44d9345ed09403b5d7335f55176cdbc4d5707ae6 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Tue, 27 Jan 2026 10:15:35 +0000 Subject: [PATCH 17/22] Update _get_sentence_embedding return type annotation and docstring Co-authored-by: meilame-tayebjee <114609737+meilame-tayebjee@users.noreply.github.com> --- torchTextClassifiers/model/components/text_embedder.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/torchTextClassifiers/model/components/text_embedder.py b/torchTextClassifiers/model/components/text_embedder.py index 4caadea..7fd6b5b 100644 --- a/torchTextClassifiers/model/components/text_embedder.py +++ b/torchTextClassifiers/model/components/text_embedder.py @@ -1,6 +1,6 @@ import math from dataclasses import dataclass -from typing import Optional +from typing import Dict, Optional import torch import torch.nn as nn @@ -200,15 +200,18 @@ def _get_sentence_embedding( token_embeddings: torch.Tensor, attention_mask: torch.Tensor, return_label_attention_matrix: bool = False, - ) -> torch.Tensor: + ) -> Dict[str, Optional[torch.Tensor]]: """ Compute sentence embedding from embedded tokens - "remove" second dimension. Args (output from dataset collate_fn): token_embeddings (torch.Tensor[Long]), shape (batch_size, seq_len, embedding_dim): Tokenized + padded text attention_mask (torch.Tensor[Long]), shape (batch_size, seq_len): Attention mask indicating non-pad tokens + return_label_attention_matrix (bool): Whether to compute and return the label attention matrix Returns: - torch.Tensor: Sentence embeddings, shape (batch_size, embedding_dim) + Dict[str, Optional[torch.Tensor]]: A dictionary containing: + - 'sentence_embedding': Sentence embeddings, shape (batch_size, embedding_dim) or (batch_size, n_labels, embedding_dim) if label attention is enabled + - 'label_attention_matrix': Attention matrix if label attention is enabled and return_label_attention_matrix is True, otherwise None """ # average over non-pad token embeddings From 86a0715b835e8211bdd7333252e4ef7f6510ba16 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Tue, 27 Jan 2026 10:16:35 +0000 Subject: [PATCH 18/22] Fix early returns to match dictionary return type Co-authored-by: meilame-tayebjee <114609737+meilame-tayebjee@users.noreply.github.com> --- .../model/components/text_embedder.py | 20 ++++++++++++------- 1 file changed, 13 insertions(+), 7 deletions(-) diff --git a/torchTextClassifiers/model/components/text_embedder.py b/torchTextClassifiers/model/components/text_embedder.py index 7fd6b5b..1dfce87 100644 --- a/torchTextClassifiers/model/components/text_embedder.py +++ b/torchTextClassifiers/model/components/text_embedder.py @@ -129,7 +129,7 @@ def forward( input_ids: torch.Tensor, attention_mask: torch.Tensor, return_label_attention_matrix: bool = False, - ) -> dict[str, Optional[torch.Tensor]]: + ) -> Dict[str, Optional[torch.Tensor]]: """Converts input token IDs to their corresponding embeddings. Args: @@ -222,14 +222,20 @@ def _get_sentence_embedding( if self.attention_config is not None: if self.attention_config.aggregation_method is not None: # default is "mean" if self.attention_config.aggregation_method == "first": - return token_embeddings[:, 0, :] + return { + "sentence_embedding": token_embeddings[:, 0, :], + "label_attention_matrix": None, + } elif self.attention_config.aggregation_method == "last": lengths = attention_mask.sum(dim=1).clamp(min=1) # last non-pad token index + 1 - return token_embeddings[ - torch.arange(token_embeddings.size(0)), - lengths - 1, - :, - ] + return { + "sentence_embedding": token_embeddings[ + torch.arange(token_embeddings.size(0)), + lengths - 1, + :, + ], + "label_attention_matrix": None, + } else: if self.attention_config.aggregation_method != "mean": raise ValueError( From a721397ad1d654e4e0a043f047c9a9f9c44322ed Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Tue, 27 Jan 2026 10:16:07 +0000 Subject: [PATCH 19/22] Initialize label_attention_matrix to None before text_embedder branch Co-authored-by: meilame-tayebjee <114609737+meilame-tayebjee@users.noreply.github.com> --- torchTextClassifiers/model/model.py | 1 + 1 file changed, 1 insertion(+) diff --git a/torchTextClassifiers/model/model.py b/torchTextClassifiers/model/model.py index d253a66..5e4cc66 100644 --- a/torchTextClassifiers/model/model.py +++ b/torchTextClassifiers/model/model.py @@ -134,6 +134,7 @@ def forward( Raw, not softmaxed. """ encoded_text = input_ids # clearer name + label_attention_matrix = None if self.text_embedder is None: x_text = encoded_text.float() else: From 798ff8c1dbea51614e9dc3f5cd30b63261233c93 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Tue, 27 Jan 2026 10:12:54 +0000 Subject: [PATCH 20/22] Fix return type annotation for TextClassificationModel.forward Co-authored-by: meilame-tayebjee <114609737+meilame-tayebjee@users.noreply.github.com> --- torchTextClassifiers/model/model.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/torchTextClassifiers/model/model.py b/torchTextClassifiers/model/model.py index 5e4cc66..599d5f5 100644 --- a/torchTextClassifiers/model/model.py +++ b/torchTextClassifiers/model/model.py @@ -6,7 +6,7 @@ """ import logging -from typing import Annotated, Optional +from typing import Annotated, Optional, Union import torch from torch import nn @@ -120,7 +120,7 @@ def forward( categorical_vars: Annotated[torch.Tensor, "batch num_cats"], return_label_attention_matrix: bool = False, **kwargs, - ) -> torch.Tensor: + ) -> Union[torch.Tensor, dict[str, torch.Tensor]]: """ Memory-efficient forward pass implementation. @@ -128,10 +128,15 @@ def forward( input_ids (torch.Tensor[Long]), shape (batch_size, seq_len): Tokenized + padded text attention_mask (torch.Tensor[int]), shape (batch_size, seq_len): Attention mask indicating non-pad tokens categorical_vars (torch.Tensor[Long]): Additional categorical features, (batch_size, num_categorical_features) + return_label_attention_matrix (bool): If True, returns a dict with logits and label_attention_matrix Returns: - torch.Tensor: Model output scores for each class - shape (batch_size, num_classes) - Raw, not softmaxed. + Union[torch.Tensor, dict[str, torch.Tensor]]: + - If return_label_attention_matrix is False: torch.Tensor of shape (batch_size, num_classes) + containing raw logits (not softmaxed) + - If return_label_attention_matrix is True: dict with keys: + - "logits": torch.Tensor of shape (batch_size, num_classes) + - "label_attention_matrix": torch.Tensor of shape (batch_size, num_classes, seq_len) """ encoded_text = input_ids # clearer name label_attention_matrix = None From 1bbff1593852a540d8e925fd88c9be81f9393a4d Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Tue, 27 Jan 2026 10:14:36 +0000 Subject: [PATCH 21/22] Address code review feedback: fix trailing whitespace and NameError Co-authored-by: meilame-tayebjee <114609737+meilame-tayebjee@users.noreply.github.com> --- torchTextClassifiers/model/model.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/torchTextClassifiers/model/model.py b/torchTextClassifiers/model/model.py index 599d5f5..3630e62 100644 --- a/torchTextClassifiers/model/model.py +++ b/torchTextClassifiers/model/model.py @@ -132,7 +132,7 @@ def forward( Returns: Union[torch.Tensor, dict[str, torch.Tensor]]: - - If return_label_attention_matrix is False: torch.Tensor of shape (batch_size, num_classes) + - If return_label_attention_matrix is False: torch.Tensor of shape (batch_size, num_classes) containing raw logits (not softmaxed) - If return_label_attention_matrix is True: dict with keys: - "logits": torch.Tensor of shape (batch_size, num_classes) @@ -142,6 +142,10 @@ def forward( label_attention_matrix = None if self.text_embedder is None: x_text = encoded_text.float() + if return_label_attention_matrix: + raise ValueError( + "return_label_attention_matrix=True requires a text_embedder with label attention enabled" + ) else: text_embed_output = self.text_embedder( input_ids=encoded_text, From 308fddcc0867ac1538c1e9f3b281247e6de1d5ba Mon Sep 17 00:00:00 2001 From: meilame-tayebjee Date: Tue, 27 Jan 2026 10:53:21 +0000 Subject: [PATCH 22/22] chore: disable gqa for label attention for simpl. --- tests/test_pipeline.py | 14 ++++---- .../model/components/text_embedder.py | 36 ++++++++----------- 2 files changed, 21 insertions(+), 29 deletions(-) diff --git a/tests/test_pipeline.py b/tests/test_pipeline.py index dabe76e..dfeecd2 100644 --- a/tests/test_pipeline.py +++ b/tests/test_pipeline.py @@ -94,7 +94,6 @@ def run_full_pipeline( label_attention_config=( LabelAttentionConfig( n_head=attention_config.n_head, - n_kv_head=attention_config.n_kv_head, num_classes=model_params["num_classes"], ) if label_attention_enabled @@ -156,7 +155,6 @@ def run_full_pipeline( label_attention_config=( LabelAttentionConfig( n_head=attention_config.n_head, - n_kv_head=attention_config.n_kv_head, num_classes=model_params["num_classes"], ) if label_attention_enabled @@ -199,9 +197,9 @@ def run_full_pipeline( # Test label attention assertions if label_attention_enabled: - assert predictions["label_attention_attributions"] is not None, ( - "Label attention attributions should not be None when label_attention_enabled is True" - ) + assert ( + predictions["label_attention_attributions"] is not None + ), "Label attention attributions should not be None when label_attention_enabled is True" label_attention_attributions = predictions["label_attention_attributions"] expected_shape = ( len(sample_text_data), # batch_size @@ -215,9 +213,9 @@ def run_full_pipeline( ) else: # When label attention is not enabled, the attributions should be None - assert predictions.get("label_attention_attributions") is None, ( - "Label attention attributions should be None when label_attention_enabled is False" - ) + assert ( + predictions.get("label_attention_attributions") is None + ), "Label attention attributions should be None when label_attention_enabled is False" # Test explainability functions text_idx = 0 diff --git a/torchTextClassifiers/model/components/text_embedder.py b/torchTextClassifiers/model/components/text_embedder.py index 1dfce87..9d5aaa5 100644 --- a/torchTextClassifiers/model/components/text_embedder.py +++ b/torchTextClassifiers/model/components/text_embedder.py @@ -12,7 +12,6 @@ @dataclass class LabelAttentionConfig: n_head: int - n_kv_head: int num_classes: int @@ -306,34 +305,29 @@ def __init__(self, config: TextEmbedderConfig): self.embedding_dim = config.embedding_dim self.num_classes = label_attention_config.num_classes self.n_head = label_attention_config.n_head - self.n_kv_head = label_attention_config.n_kv_head - self.enable_gqa = ( - self.n_head != self.n_kv_head - ) # Group Query Attention (GQA): duplicate key/value heads to match query heads if desired - + # Validate head configuration self.head_dim = self.embedding_dim // self.n_head - + if self.head_dim * self.n_head != self.embedding_dim: raise ValueError( f"embedding_dim ({self.embedding_dim}) must be divisible by n_head ({self.n_head}). " f"Got head_dim = {self.head_dim} with remainder {self.embedding_dim % self.n_head}" ) - - if self.n_head % self.n_kv_head != 0: - raise ValueError( - f"n_head ({self.n_head}) must be divisible by n_kv_head ({self.n_kv_head}) for Group Query Attention. " - f"Got remainder {self.n_head % self.n_kv_head}" - ) self.label_embeds = nn.Embedding(self.num_classes, self.embedding_dim) self.c_q = nn.Linear(self.embedding_dim, self.n_head * self.head_dim, bias=False) - self.c_k = nn.Linear(self.embedding_dim, self.n_kv_head * self.head_dim, bias=False) - self.c_v = nn.Linear(self.embedding_dim, self.n_kv_head * self.head_dim, bias=False) + self.c_k = nn.Linear(self.embedding_dim, self.n_head * self.head_dim, bias=False) + self.c_v = nn.Linear(self.embedding_dim, self.n_head * self.head_dim, bias=False) self.c_proj = nn.Linear(self.embedding_dim, self.embedding_dim, bias=False) - def forward(self, token_embeddings, attention_mask: Optional[torch.Tensor] = None, compute_attention_matrix: Optional[bool] = False): + def forward( + self, + token_embeddings, + attention_mask: Optional[torch.Tensor] = None, + compute_attention_matrix: Optional[bool] = False, + ): """ Args: token_embeddings (torch.Tensor), shape (batch, seq_len, d_model): Embedded tokens from the text input. @@ -362,8 +356,8 @@ def forward(self, token_embeddings, attention_mask: Optional[torch.Tensor] = Non all_label_embeddings = norm(all_label_embeddings) q = self.c_q(all_label_embeddings).view(B, self.num_classes, self.n_head, self.head_dim) - k = self.c_k(token_embeddings).view(B, T, self.n_kv_head, self.head_dim) - v = self.c_v(token_embeddings).view(B, T, self.n_kv_head, self.head_dim) + k = self.c_k(token_embeddings).view(B, T, self.n_head, self.head_dim) + v = self.c_v(token_embeddings).view(B, T, self.n_head, self.head_dim) q, k = norm(q), norm(k) # QK norm q, k, v = ( @@ -379,11 +373,11 @@ def forward(self, token_embeddings, attention_mask: Optional[torch.Tensor] = Non attn_mask = None if attention_mask is not None: # Convert: 0 (padding) -> True (mask out), 1 (real) -> False (attend to) - attn_mask = (attention_mask == 0) # (B, T) + attn_mask = attention_mask == 0 # (B, T) # Expand to (B, 1, 1, T) for broadcasting across heads and queries attn_mask = attn_mask.unsqueeze(1).unsqueeze(2) # (B, 1, 1, T) - y = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask, is_causal=False, enable_gqa=self.enable_gqa) + y = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask, is_causal=False) # Re-assemble the heads side by side and project back to residual stream y = y.transpose(1, 2).contiguous().view(B, self.num_classes, -1) # (bs, n_labels, d_model) @@ -400,7 +394,7 @@ def forward(self, token_embeddings, attention_mask: Optional[torch.Tensor] = Non # attn_mask is already in the right shape: (B, 1, 1, T) # We need to apply it to scores of shape (B, n_head, n_labels, T) # Set masked positions to -inf so they become 0 after softmax - attention_scores = attention_scores.masked_fill(attn_mask, float('-inf')) + attention_scores = attention_scores.masked_fill(attn_mask, float("-inf")) attention_matrix = torch.softmax(attention_scores, dim=-1)