Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 13 additions & 4 deletions torchTextClassifiers/model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
"""

import logging
from typing import Annotated, Optional
from typing import Annotated, Optional, Union

import torch
from torch import nn
Expand Down Expand Up @@ -120,23 +120,32 @@ def forward(
categorical_vars: Annotated[torch.Tensor, "batch num_cats"],
return_label_attention_matrix: bool = False,
**kwargs,
) -> torch.Tensor:
) -> Union[torch.Tensor, dict[str, torch.Tensor]]:
"""
Memory-efficient forward pass implementation.

Args: output from dataset collate_fn
input_ids (torch.Tensor[Long]), shape (batch_size, seq_len): Tokenized + padded text
attention_mask (torch.Tensor[int]), shape (batch_size, seq_len): Attention mask indicating non-pad tokens
categorical_vars (torch.Tensor[Long]): Additional categorical features, (batch_size, num_categorical_features)
return_label_attention_matrix (bool): If True, returns a dict with logits and label_attention_matrix

Returns:
torch.Tensor: Model output scores for each class - shape (batch_size, num_classes)
Raw, not softmaxed.
Union[torch.Tensor, dict[str, torch.Tensor]]:
- If return_label_attention_matrix is False: torch.Tensor of shape (batch_size, num_classes)
containing raw logits (not softmaxed)
- If return_label_attention_matrix is True: dict with keys:
- "logits": torch.Tensor of shape (batch_size, num_classes)
- "label_attention_matrix": torch.Tensor of shape (batch_size, num_classes, seq_len)
"""
encoded_text = input_ids # clearer name
label_attention_matrix = None
if self.text_embedder is None:
x_text = encoded_text.float()
if return_label_attention_matrix:
raise ValueError(
"return_label_attention_matrix=True requires a text_embedder with label attention enabled"
)
else:
text_embed_output = self.text_embedder(
input_ids=encoded_text,
Expand Down