diff --git a/.gitignore b/.gitignore index e65366e..586ae23 100644 --- a/.gitignore +++ b/.gitignore @@ -183,3 +183,4 @@ example_files/ _site/ .quarto/ **/*.quarto_ipynb +my_ttc/ diff --git a/tests/test_pipeline.py b/tests/test_pipeline.py index 56dff6c..dfeecd2 100644 --- a/tests/test_pipeline.py +++ b/tests/test_pipeline.py @@ -9,6 +9,7 @@ AttentionConfig, CategoricalVariableNet, ClassificationHead, + LabelAttentionConfig, TextEmbedder, TextEmbedderConfig, ) @@ -51,7 +52,14 @@ def model_params(): } -def run_full_pipeline(tokenizer, sample_text_data, categorical_data, labels, model_params): +def run_full_pipeline( + tokenizer, + sample_text_data, + categorical_data, + labels, + model_params, + label_attention_enabled: bool = False, +): """Helper function to run the complete pipeline for a given tokenizer.""" # Create dataset dataset = TextClassificationDataset( @@ -83,6 +91,14 @@ def run_full_pipeline(tokenizer, sample_text_data, categorical_data, labels, mod embedding_dim=model_params["embedding_dim"], padding_idx=padding_idx, attention_config=attention_config, + label_attention_config=( + LabelAttentionConfig( + n_head=attention_config.n_head, + num_classes=model_params["num_classes"], + ) + if label_attention_enabled + else None + ), ) text_embedder = TextEmbedder(text_embedder_config=text_embedder_config) @@ -98,7 +114,7 @@ def run_full_pipeline(tokenizer, sample_text_data, categorical_data, labels, mod expected_input_dim = model_params["embedding_dim"] + categorical_var_net.output_dim classification_head = ClassificationHead( input_dim=expected_input_dim, - num_classes=model_params["num_classes"], + num_classes=model_params["num_classes"] if not label_attention_enabled else 1, ) # Create model @@ -136,6 +152,14 @@ def run_full_pipeline(tokenizer, sample_text_data, categorical_data, labels, mod categorical_embedding_dims=model_params["categorical_embedding_dims"], num_classes=model_params["num_classes"], attention_config=attention_config, + label_attention_config=( + LabelAttentionConfig( + n_head=attention_config.n_head, + num_classes=model_params["num_classes"], + ) + if label_attention_enabled + else None + ), ) # Create training config @@ -163,13 +187,41 @@ def run_full_pipeline(tokenizer, sample_text_data, categorical_data, labels, mod # Predict with explanations top_k = 5 - predictions = ttc.predict(X, top_k=top_k, explain=True) + + predictions = ttc.predict( + X, + top_k=top_k, + explain_with_label_attention=label_attention_enabled, + explain_with_captum=True, + ) + + # Test label attention assertions + if label_attention_enabled: + assert ( + predictions["label_attention_attributions"] is not None + ), "Label attention attributions should not be None when label_attention_enabled is True" + label_attention_attributions = predictions["label_attention_attributions"] + expected_shape = ( + len(sample_text_data), # batch_size + model_params["n_head"], # n_head + model_params["num_classes"], # num_classes + tokenizer.output_dim, # seq_len + ) + assert label_attention_attributions.shape == expected_shape, ( + f"Label attention attributions shape mismatch. " + f"Expected {expected_shape}, got {label_attention_attributions.shape}" + ) + else: + # When label attention is not enabled, the attributions should be None + assert ( + predictions.get("label_attention_attributions") is None + ), "Label attention attributions should be None when label_attention_enabled is False" # Test explainability functions text_idx = 0 text = sample_text_data[text_idx] offsets = predictions["offset_mapping"][text_idx] - attributions = predictions["attributions"][text_idx] + attributions = predictions["captum_attributions"][text_idx] word_ids = predictions["word_ids"][text_idx] words, word_attributions = map_attributions_to_word(attributions, text, word_ids, offsets) @@ -239,3 +291,26 @@ def test_ngram_tokenizer(sample_data, model_params): # Run full pipeline run_full_pipeline(tokenizer, sample_text_data, categorical_data, labels, model_params) + + +def test_label_attention_enabled(sample_data, model_params): + """Test the full pipeline with label attention enabled (using WordPieceTokenizer).""" + sample_text_data, categorical_data, labels = sample_data + + vocab_size = 100 + tokenizer = WordPieceTokenizer(vocab_size, output_dim=50) + tokenizer.train(sample_text_data) + + # Check tokenizer works + result = tokenizer.tokenize(sample_text_data) + assert result.input_ids.shape[0] == len(sample_text_data) + + # Run full pipeline with label attention enabled + run_full_pipeline( + tokenizer, + sample_text_data, + categorical_data, + labels, + model_params, + label_attention_enabled=True, + ) diff --git a/torchTextClassifiers/model/components/__init__.py b/torchTextClassifiers/model/components/__init__.py index b14af0e..5cad342 100644 --- a/torchTextClassifiers/model/components/__init__.py +++ b/torchTextClassifiers/model/components/__init__.py @@ -8,5 +8,6 @@ CategoricalVariableNet as CategoricalVariableNet, ) from .classification_head import ClassificationHead as ClassificationHead +from .text_embedder import LabelAttentionConfig as LabelAttentionConfig from .text_embedder import TextEmbedder as TextEmbedder from .text_embedder import TextEmbedderConfig as TextEmbedderConfig diff --git a/torchTextClassifiers/model/components/text_embedder.py b/torchTextClassifiers/model/components/text_embedder.py index b317c91..9d5aaa5 100644 --- a/torchTextClassifiers/model/components/text_embedder.py +++ b/torchTextClassifiers/model/components/text_embedder.py @@ -1,19 +1,27 @@ import math from dataclasses import dataclass -from typing import Optional +from typing import Dict, Optional import torch -from torch import nn +import torch.nn as nn +from torch.nn import functional as F from torchTextClassifiers.model.components.attention import AttentionConfig, Block, norm +@dataclass +class LabelAttentionConfig: + n_head: int + num_classes: int + + @dataclass class TextEmbedderConfig: vocab_size: int embedding_dim: int padding_idx: int attention_config: Optional[AttentionConfig] = None + label_attention_config: Optional[LabelAttentionConfig] = None class TextEmbedder(nn.Module): @@ -26,8 +34,17 @@ def __init__(self, text_embedder_config: TextEmbedderConfig): if isinstance(self.attention_config, dict): self.attention_config = AttentionConfig(**self.attention_config) - if self.attention_config is not None: - self.attention_config.n_embd = text_embedder_config.embedding_dim + # Normalize label_attention_config: allow dicts and convert them to LabelAttentionConfig + self.label_attention_config = text_embedder_config.label_attention_config + if isinstance(self.label_attention_config, dict): + self.label_attention_config = LabelAttentionConfig(**self.label_attention_config) + # Keep self.config in sync so downstream components (e.g., LabelAttentionClassifier) + # always see a LabelAttentionConfig instance rather than a raw dict. + self.config.label_attention_config = self.label_attention_config + + self.enable_label_attention = self.label_attention_config is not None + if self.enable_label_attention: + self.label_attention_module = LabelAttentionClassifier(self.config) self.vocab_size = text_embedder_config.vocab_size self.embedding_dim = text_embedder_config.embedding_dim @@ -40,6 +57,7 @@ def __init__(self, text_embedder_config: TextEmbedderConfig): ) if self.attention_config is not None: + self.attention_config.n_embd = text_embedder_config.embedding_dim self.transformer = nn.ModuleDict( { "h": nn.ModuleList( @@ -105,8 +123,33 @@ def _init_weights(self, module): elif isinstance(module, nn.Embedding): torch.nn.init.normal_(module.weight, mean=0.0, std=1.0) - def forward(self, input_ids: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor: - """Converts input token IDs to their corresponding embeddings.""" + def forward( + self, + input_ids: torch.Tensor, + attention_mask: torch.Tensor, + return_label_attention_matrix: bool = False, + ) -> Dict[str, Optional[torch.Tensor]]: + """Converts input token IDs to their corresponding embeddings. + + Args: + input_ids (torch.Tensor[Long]), shape (batch_size, seq_len): Tokenized + attention_mask (torch.Tensor[Long]), shape (batch_size, seq_len): Attention mask indicating non-pad tokens + return_label_attention_matrix (bool): Whether to return the label attention matrix. + + Returns: + dict: A dictionary with the following keys: + + - "sentence_embedding" (torch.Tensor): Text embeddings of shape + (batch_size, embedding_dim) if ``self.enable_label_attention`` is False, + else (batch_size, num_classes, embedding_dim), where ``num_classes`` + is the number of label classes. + + - "label_attention_matrix" (Optional[torch.Tensor]): Label attention + matrix of shape (batch_size, n_head, num_classes, seq_len) if + ``return_label_attention_matrix`` is True and label attention is + enabled, otherwise ``None``. The dimensions correspond to + (batch_size, attention heads, label classes, sequence length). + """ encoded_text = input_ids # clearer name if encoded_text.dtype != torch.long: @@ -138,23 +181,36 @@ def forward(self, input_ids: torch.Tensor, attention_mask: torch.Tensor) -> torc token_embeddings = norm(token_embeddings) - text_embedding = self._get_sentence_embedding( - token_embeddings=token_embeddings, attention_mask=attention_mask + out = self._get_sentence_embedding( + token_embeddings=token_embeddings, + attention_mask=attention_mask, + return_label_attention_matrix=return_label_attention_matrix, ) - return text_embedding + text_embedding = out["sentence_embedding"] + label_attention_matrix = out["label_attention_matrix"] + return { + "sentence_embedding": text_embedding, + "label_attention_matrix": label_attention_matrix, + } def _get_sentence_embedding( - self, token_embeddings: torch.Tensor, attention_mask: torch.Tensor - ) -> torch.Tensor: + self, + token_embeddings: torch.Tensor, + attention_mask: torch.Tensor, + return_label_attention_matrix: bool = False, + ) -> Dict[str, Optional[torch.Tensor]]: """ Compute sentence embedding from embedded tokens - "remove" second dimension. Args (output from dataset collate_fn): token_embeddings (torch.Tensor[Long]), shape (batch_size, seq_len, embedding_dim): Tokenized + padded text attention_mask (torch.Tensor[Long]), shape (batch_size, seq_len): Attention mask indicating non-pad tokens + return_label_attention_matrix (bool): Whether to compute and return the label attention matrix Returns: - torch.Tensor: Sentence embeddings, shape (batch_size, embedding_dim) + Dict[str, Optional[torch.Tensor]]: A dictionary containing: + - 'sentence_embedding': Sentence embeddings, shape (batch_size, embedding_dim) or (batch_size, n_labels, embedding_dim) if label attention is enabled + - 'label_attention_matrix': Attention matrix if label attention is enabled and return_label_attention_matrix is True, otherwise None """ # average over non-pad token embeddings @@ -163,16 +219,22 @@ def _get_sentence_embedding( # mask pad-tokens if self.attention_config is not None: - if self.attention_config.aggregation_method is not None: + if self.attention_config.aggregation_method is not None: # default is "mean" if self.attention_config.aggregation_method == "first": - return token_embeddings[:, 0, :] + return { + "sentence_embedding": token_embeddings[:, 0, :], + "label_attention_matrix": None, + } elif self.attention_config.aggregation_method == "last": lengths = attention_mask.sum(dim=1).clamp(min=1) # last non-pad token index + 1 - return token_embeddings[ - torch.arange(token_embeddings.size(0)), - lengths - 1, - :, - ] + return { + "sentence_embedding": token_embeddings[ + torch.arange(token_embeddings.size(0)), + lengths - 1, + :, + ], + "label_attention_matrix": None, + } else: if self.attention_config.aggregation_method != "mean": raise ValueError( @@ -181,25 +243,31 @@ def _get_sentence_embedding( assert self.attention_config is None or self.attention_config.aggregation_method == "mean" - mask = attention_mask.unsqueeze(-1).float() # (batch_size, seq_len, 1) - masked_embeddings = token_embeddings * mask # (batch_size, seq_len, embedding_dim) - - sentence_embedding = masked_embeddings.sum(dim=1) / mask.sum(dim=1).clamp( - min=1.0 - ) # avoid division by zero - - sentence_embedding = torch.nan_to_num(sentence_embedding, 0.0) - - return sentence_embedding - - def __call__(self, *args, **kwargs): - out = super().__call__(*args, **kwargs) - if out.dim() != 2: - raise ValueError( - f"Output of {self.__class__.__name__}.forward must be 2D " - f"(got shape {tuple(out.shape)})" + if self.enable_label_attention: + label_attention_result = self.label_attention_module( + token_embeddings, + attention_mask=attention_mask, + compute_attention_matrix=return_label_attention_matrix, ) - return out + sentence_embedding = label_attention_result[ + "sentence_embedding" + ] # (bs, n_labels, d_embed), so classifier needs to be a (d_embed, 1) matrix + label_attention_matrix = label_attention_result["attention_matrix"] + + else: # sentence embedding = mean of (non-pad) token embeddings + mask = attention_mask.unsqueeze(-1).float() # (batch_size, seq_len, 1) + masked_embeddings = token_embeddings * mask # (batch_size, seq_len, embedding_dim) + sentence_embedding = masked_embeddings.sum(dim=1) / mask.sum(dim=1).clamp( + min=1.0 + ) # avoid division by zero + + sentence_embedding = torch.nan_to_num(sentence_embedding, 0.0) + label_attention_matrix = None + + return { + "sentence_embedding": sentence_embedding, + "label_attention_matrix": label_attention_matrix, + } def _precompute_rotary_embeddings(self, seq_len, head_dim, base=10000, device=None): # autodetect the device from model embeddings @@ -221,3 +289,113 @@ def _precompute_rotary_embeddings(self, seq_len, head_dim, base=10000, device=No ) # add batch and head dims for later broadcasting return cos, sin + + +class LabelAttentionClassifier(nn.Module): + """ + A head for aggregating token embeddings into label-specific sentence embeddings using cross-attention mechanism. + Labels are queries that attend over token embeddings (keys and values) to produce label-specific embeddings. + + """ + + def __init__(self, config: TextEmbedderConfig): + super().__init__() + + label_attention_config = config.label_attention_config + self.embedding_dim = config.embedding_dim + self.num_classes = label_attention_config.num_classes + self.n_head = label_attention_config.n_head + + # Validate head configuration + self.head_dim = self.embedding_dim // self.n_head + + if self.head_dim * self.n_head != self.embedding_dim: + raise ValueError( + f"embedding_dim ({self.embedding_dim}) must be divisible by n_head ({self.n_head}). " + f"Got head_dim = {self.head_dim} with remainder {self.embedding_dim % self.n_head}" + ) + + self.label_embeds = nn.Embedding(self.num_classes, self.embedding_dim) + + self.c_q = nn.Linear(self.embedding_dim, self.n_head * self.head_dim, bias=False) + self.c_k = nn.Linear(self.embedding_dim, self.n_head * self.head_dim, bias=False) + self.c_v = nn.Linear(self.embedding_dim, self.n_head * self.head_dim, bias=False) + self.c_proj = nn.Linear(self.embedding_dim, self.embedding_dim, bias=False) + + def forward( + self, + token_embeddings, + attention_mask: Optional[torch.Tensor] = None, + compute_attention_matrix: Optional[bool] = False, + ): + """ + Args: + token_embeddings (torch.Tensor), shape (batch, seq_len, d_model): Embedded tokens from the text input. + attention_mask (torch.Tensor, optional), shape (batch, seq_len): Attention mask indicating non-pad tokens (1 for real tokens, 0 for padding). + compute_attention_matrix (bool): Whether to compute and return the attention matrix. + Returns: + dict: { + "sentence_embedding": torch.Tensor, shape (batch, num_classes, d_model): Label-specific sentence embeddings. + "attention_matrix": Optional[torch.Tensor], shape (batch, n_head, num_classes, seq_len): Attention weights if compute_attention_matrix is True, else None. + } + + """ + B, T, C = token_embeddings.size() + if isinstance(compute_attention_matrix, torch.Tensor): + compute_attention_matrix = compute_attention_matrix[0].item() + compute_attention_matrix = bool(compute_attention_matrix) + + # 1. Create label indices [0, 1, ..., C-1] for the whole batch + label_indices = torch.arange( + self.num_classes, dtype=torch.long, device=token_embeddings.device + ).expand(B, -1) + + all_label_embeddings = self.label_embeds( + label_indices + ) # Shape: [batch, num_classes, d_model] + all_label_embeddings = norm(all_label_embeddings) + + q = self.c_q(all_label_embeddings).view(B, self.num_classes, self.n_head, self.head_dim) + k = self.c_k(token_embeddings).view(B, T, self.n_head, self.head_dim) + v = self.c_v(token_embeddings).view(B, T, self.n_head, self.head_dim) + + q, k = norm(q), norm(k) # QK norm + q, k, v = ( + q.transpose(1, 2), + k.transpose(1, 2), + v.transpose(1, 2), + ) # make head be batch dim, i.e. (B, T, H, D) -> (B, H, T, D) + + # Prepare attention mask for scaled_dot_product_attention + # attention_mask: (B, T) with 1 for real tokens, 0 for padding + # scaled_dot_product_attention expects attn_mask: (B, H, Q, K) or broadcastable shape + # where True means "mask out" (ignore), False means "attend to" + attn_mask = None + if attention_mask is not None: + # Convert: 0 (padding) -> True (mask out), 1 (real) -> False (attend to) + attn_mask = attention_mask == 0 # (B, T) + # Expand to (B, 1, 1, T) for broadcasting across heads and queries + attn_mask = attn_mask.unsqueeze(1).unsqueeze(2) # (B, 1, 1, T) + + y = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask, is_causal=False) + + # Re-assemble the heads side by side and project back to residual stream + y = y.transpose(1, 2).contiguous().view(B, self.num_classes, -1) # (bs, n_labels, d_model) + y = self.c_proj(y) + + attention_matrix = None + if compute_attention_matrix: + # Compute attention scores + # size (B, n_head, n_labels, seq_len) + attention_scores = torch.matmul(q, k.transpose(-2, -1)) / (self.head_dim**0.5) + + # Apply mask to attention scores before softmax + if attention_mask is not None: + # attn_mask is already in the right shape: (B, 1, 1, T) + # We need to apply it to scores of shape (B, n_head, n_labels, T) + # Set masked positions to -inf so they become 0 after softmax + attention_scores = attention_scores.masked_fill(attn_mask, float("-inf")) + + attention_matrix = torch.softmax(attention_scores, dim=-1) + + return {"sentence_embedding": y, "attention_matrix": attention_matrix} diff --git a/torchTextClassifiers/model/lightning.py b/torchTextClassifiers/model/lightning.py index 1ebc697..8726f20 100644 --- a/torchTextClassifiers/model/lightning.py +++ b/torchTextClassifiers/model/lightning.py @@ -102,6 +102,7 @@ def validation_step(self, batch, batch_idx: int): targets = batch["labels"] outputs = self.forward(batch) + loss = self.loss(outputs, targets) self.log("val_loss", loss, on_epoch=True, on_step=False, prog_bar=True, sync_dist=True) diff --git a/torchTextClassifiers/model/model.py b/torchTextClassifiers/model/model.py index d9cffbf..3630e62 100644 --- a/torchTextClassifiers/model/model.py +++ b/torchTextClassifiers/model/model.py @@ -1,12 +1,12 @@ -"""FastText model components. +"""TextClassification model components. This module contains the PyTorch model, Lightning module, and dataset classes -for FastText classification. Consolidates what was previously in pytorch_model.py, +for text classification. Consolidates what was previously in pytorch_model.py, lightning_module.py, and dataset.py. """ import logging -from typing import Annotated, Optional +from typing import Annotated, Optional, Union import torch from torch import nn @@ -17,6 +17,7 @@ ClassificationHead, TextEmbedder, ) +from torchTextClassifiers.model.components.attention import norm logger = logging.getLogger(__name__) @@ -67,8 +68,6 @@ def __init__( self._validate_component_connections() - self.num_classes = self.classification_head.num_classes - torch.nn.init.zeros_(self.classification_head.net.weight) if self.text_embedder is not None: self.text_embedder.init_weights() @@ -98,6 +97,17 @@ def _check_text_categorical_connection(self, text_embedder, cat_var_net): raise ValueError( "Classification head input dimension does not match expected dimension from text embedder and categorical variable net." ) + if self.text_embedder.enable_label_attention: + self.enable_label_attention = True + if self.classification_head.num_classes != 1: + raise ValueError( + "Label attention is enabled. TextEmbedder outputs a (num_classes, embedding_dim) tensor, so the ClassificationHead should have an output dimension of 1." + ) + # if enable_label_attention is True, label_attention_config exists - and contains num_classes necessarily + self.num_classes = self.text_embedder.config.label_attention_config.num_classes + else: + self.enable_label_attention = False + self.num_classes = self.classification_head.num_classes else: logger.warning( "⚠️ No text embedder provided; assuming input text is already embedded or vectorized. Take care that the classification head input dimension matches the input text dimension." @@ -108,8 +118,9 @@ def forward( input_ids: Annotated[torch.Tensor, "batch seq_len"], attention_mask: Annotated[torch.Tensor, "batch seq_len"], categorical_vars: Annotated[torch.Tensor, "batch num_cats"], + return_label_attention_matrix: bool = False, **kwargs, - ) -> torch.Tensor: + ) -> Union[torch.Tensor, dict[str, torch.Tensor]]: """ Memory-efficient forward pass implementation. @@ -117,35 +128,65 @@ def forward( input_ids (torch.Tensor[Long]), shape (batch_size, seq_len): Tokenized + padded text attention_mask (torch.Tensor[int]), shape (batch_size, seq_len): Attention mask indicating non-pad tokens categorical_vars (torch.Tensor[Long]): Additional categorical features, (batch_size, num_categorical_features) + return_label_attention_matrix (bool): If True, returns a dict with logits and label_attention_matrix Returns: - torch.Tensor: Model output scores for each class - shape (batch_size, num_classes) - Raw, not softmaxed. + Union[torch.Tensor, dict[str, torch.Tensor]]: + - If return_label_attention_matrix is False: torch.Tensor of shape (batch_size, num_classes) + containing raw logits (not softmaxed) + - If return_label_attention_matrix is True: dict with keys: + - "logits": torch.Tensor of shape (batch_size, num_classes) + - "label_attention_matrix": torch.Tensor of shape (batch_size, num_classes, seq_len) """ encoded_text = input_ids # clearer name + label_attention_matrix = None if self.text_embedder is None: x_text = encoded_text.float() + if return_label_attention_matrix: + raise ValueError( + "return_label_attention_matrix=True requires a text_embedder with label attention enabled" + ) else: - x_text = self.text_embedder(input_ids=encoded_text, attention_mask=attention_mask) + text_embed_output = self.text_embedder( + input_ids=encoded_text, + attention_mask=attention_mask, + return_label_attention_matrix=return_label_attention_matrix, + ) + x_text = text_embed_output["sentence_embedding"] + if isinstance(return_label_attention_matrix, torch.Tensor): + return_label_attention_matrix = return_label_attention_matrix[0].item() + if return_label_attention_matrix: + label_attention_matrix = text_embed_output["label_attention_matrix"] if self.categorical_variable_net: x_cat = self.categorical_variable_net(categorical_vars) + if self.enable_label_attention: + # x_text is (batch_size, num_classes, embedding_dim) + # x_cat is (batch_size, cat_embedding_dim) + # We need to expand x_cat to (batch_size, num_classes, cat_embedding_dim) + # x_cat will be appended to x_text along the last dimension for each class + x_cat = x_cat.unsqueeze(1).expand(-1, self.num_classes, -1) + if ( self.categorical_variable_net.forward_type == CategoricalForwardType.AVERAGE_AND_CONCAT or self.categorical_variable_net.forward_type == CategoricalForwardType.CONCATENATE_ALL ): - x_combined = torch.cat((x_text, x_cat), dim=1) + x_combined = torch.cat((x_text, x_cat), dim=-1) else: assert ( self.categorical_variable_net.forward_type == CategoricalForwardType.SUM_TO_TEXT ) + x_combined = x_text + x_cat else: x_combined = x_text - logits = self.classification_head(x_combined) + logits = self.classification_head(norm(x_combined)).squeeze(-1) + + if return_label_attention_matrix: + return {"logits": logits, "label_attention_matrix": label_attention_matrix} return logits diff --git a/torchTextClassifiers/torchTextClassifiers.py b/torchTextClassifiers/torchTextClassifiers.py index 79ce301..a4f2c55 100644 --- a/torchTextClassifiers/torchTextClassifiers.py +++ b/torchTextClassifiers/torchTextClassifiers.py @@ -29,6 +29,7 @@ CategoricalForwardType, CategoricalVariableNet, ClassificationHead, + LabelAttentionConfig, TextEmbedder, TextEmbedderConfig, ) @@ -53,6 +54,7 @@ class ModelConfig: categorical_embedding_dims: Optional[Union[List[int], int]] = None num_classes: Optional[int] = None attention_config: Optional[AttentionConfig] = None + label_attention_config: Optional[LabelAttentionConfig] = None def to_dict(self) -> Dict[str, Any]: return asdict(self) @@ -140,6 +142,7 @@ def __init__( self.embedding_dim = model_config.embedding_dim self.categorical_vocabulary_sizes = model_config.categorical_vocabulary_sizes self.num_classes = model_config.num_classes + self.enable_label_attention = model_config.label_attention_config is not None if self.tokenizer.output_vectorized: self.text_embedder = None @@ -153,6 +156,7 @@ def __init__( embedding_dim=self.embedding_dim, padding_idx=tokenizer.padding_idx, attention_config=model_config.attention_config, + label_attention_config=model_config.label_attention_config, ) self.text_embedder = TextEmbedder( text_embedder_config=text_embedder_config, @@ -174,7 +178,9 @@ def __init__( self.classification_head = ClassificationHead( input_dim=classif_head_input_dim, - num_classes=model_config.num_classes, + num_classes=1 + if self.enable_label_attention + else model_config.num_classes, # output dim is 1 when using label attention, because embeddings are (num_classes, embedding_dim) ) self.pytorch_model = TextClassificationModel( @@ -486,13 +492,15 @@ def predict( self, X_test: np.ndarray, top_k=1, - explain=False, + explain_with_label_attention: bool = False, + explain_with_captum=False, ): """ Args: X_test (np.ndarray): input data to predict on, shape (N,d) where the first column is text and the rest are categorical variables top_k (int): for each sentence, return the top_k most likely predictions (default: 1) - explain (bool): launch gradient integration to have an explanation of the prediction (default: False) + explain_with_label_attention (bool): if enabled, use attention matrix labels x tokens to have an explanation of the prediction (default: False) + explain_with_captum (bool): launch gradient integration with Captum for explanation (default: False) Returns: A dictionary containing the following fields: - predictions (torch.Tensor, shape (len(text), top_k)): A tensor containing the top_k most likely codes to the query. @@ -501,6 +509,7 @@ def predict( - attributions (torch.Tensor, shape (len(text), top_k, seq_len)): A tensor containing the attributions for each token in the text. """ + explain = explain_with_label_attention or explain_with_captum if explain: return_offsets_mapping = True # to be passed to the tokenizer return_word_ids = True @@ -509,13 +518,19 @@ def predict( "Explainability is not supported when the tokenizer outputs vectorized text directly. Please use a tokenizer that outputs token IDs." ) else: - if not HAS_CAPTUM: - raise ImportError( - "Captum is not installed and is required for explainability. Run 'pip install/uv add torchFastText[explainability]'." - ) - lig = LayerIntegratedGradients( - self.pytorch_model, self.pytorch_model.text_embedder.embedding_layer - ) # initialize a Captum layer gradient integrator + if explain_with_captum: + if not HAS_CAPTUM: + raise ImportError( + "Captum is not installed and is required for explainability. Run 'pip install/uv add torchFastText[explainability]'." + ) + lig = LayerIntegratedGradients( + self.pytorch_model, self.pytorch_model.text_embedder.embedding_layer + ) # initialize a Captum layer gradient integrator + if explain_with_label_attention: + if not self.enable_label_attention: + raise RuntimeError( + "Label attention explainability is enabled, but the model was not configured with label attention. Please enable label attention in the model configuration during initialization and retrain." + ) else: return_offsets_mapping = False return_word_ids = False @@ -547,9 +562,19 @@ def predict( else: categorical_vars = torch.empty((encoded_text.shape[0], 0), dtype=torch.float32) - pred = self.pytorch_model( - encoded_text, attention_mask, categorical_vars + model_output = self.pytorch_model( + encoded_text, + attention_mask, + categorical_vars, + return_label_attention_matrix=explain_with_label_attention, ) # forward pass, contains the prediction scores (len(text), num_classes) + pred = ( + model_output["logits"] if explain_with_label_attention else model_output + ) # (batch_size, num_classes) + + label_attention_matrix = ( + model_output["label_attention_matrix"] if explain_with_label_attention else None + ) label_scores = pred.detach().cpu().softmax(dim=1) # convert to probabilities @@ -559,21 +584,28 @@ def predict( confidence = torch.round(label_scores_topk.values, decimals=2) # and their scores if explain: - all_attributions = [] - for k in range(top_k): - attributions = lig.attribute( - (encoded_text, attention_mask, categorical_vars), - target=torch.Tensor(predictions[:, k]).long(), - ) # (batch_size, seq_len) - attributions = attributions.sum(dim=-1) - all_attributions.append(attributions.detach().cpu()) - - all_attributions = torch.stack(all_attributions, dim=1) # (batch_size, top_k, seq_len) + if explain_with_captum: + # Captum explanations + captum_attributions = [] + for k in range(top_k): + attributions = lig.attribute( + (encoded_text, attention_mask, categorical_vars), + target=torch.Tensor(predictions[:, k]).long(), + ) # (batch_size, seq_len) + attributions = attributions.sum(dim=-1) + captum_attributions.append(attributions.detach().cpu()) + + captum_attributions = torch.stack( + captum_attributions, dim=1 + ) # (batch_size, top_k, seq_len) + else: + captum_attributions = None return { "prediction": predictions, "confidence": confidence, - "attributions": all_attributions, + "captum_attributions": captum_attributions, + "label_attention_attributions": label_attention_matrix, "offset_mapping": tokenize_output.offset_mapping, "word_ids": tokenize_output.word_ids, } @@ -665,6 +697,10 @@ def load(cls, path: Union[str, Path], device: str = "auto") -> "torchTextClassif # Reconstruct model_config model_config = ModelConfig.from_dict(metadata["model_config"]) + if isinstance(model_config.label_attention_config, dict): + model_config.label_attention_config = LabelAttentionConfig( + **model_config.label_attention_config + ) # Create instance instance = cls(