From 2c3f17a832123e41186e0e791182bd2e9f54c434 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Tue, 27 Jan 2026 10:10:34 +0000 Subject: [PATCH 1/3] Initial plan From 9d5d7dafececeff99488f610520d2722037cd20b Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Tue, 27 Jan 2026 10:12:54 +0000 Subject: [PATCH 2/3] Fix return type annotation for TextClassificationModel.forward Co-authored-by: meilame-tayebjee <114609737+meilame-tayebjee@users.noreply.github.com> --- torchTextClassifiers/model/model.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/torchTextClassifiers/model/model.py b/torchTextClassifiers/model/model.py index d253a66..eb23cca 100644 --- a/torchTextClassifiers/model/model.py +++ b/torchTextClassifiers/model/model.py @@ -6,7 +6,7 @@ """ import logging -from typing import Annotated, Optional +from typing import Annotated, Optional, Union import torch from torch import nn @@ -120,7 +120,7 @@ def forward( categorical_vars: Annotated[torch.Tensor, "batch num_cats"], return_label_attention_matrix: bool = False, **kwargs, - ) -> torch.Tensor: + ) -> Union[torch.Tensor, dict[str, torch.Tensor]]: """ Memory-efficient forward pass implementation. @@ -128,10 +128,15 @@ def forward( input_ids (torch.Tensor[Long]), shape (batch_size, seq_len): Tokenized + padded text attention_mask (torch.Tensor[int]), shape (batch_size, seq_len): Attention mask indicating non-pad tokens categorical_vars (torch.Tensor[Long]): Additional categorical features, (batch_size, num_categorical_features) + return_label_attention_matrix (bool): If True, returns a dict with logits and label_attention_matrix Returns: - torch.Tensor: Model output scores for each class - shape (batch_size, num_classes) - Raw, not softmaxed. + Union[torch.Tensor, dict[str, torch.Tensor]]: + - If return_label_attention_matrix is False: torch.Tensor of shape (batch_size, num_classes) + containing raw logits (not softmaxed) + - If return_label_attention_matrix is True: dict with keys: + - "logits": torch.Tensor of shape (batch_size, num_classes) + - "label_attention_matrix": torch.Tensor of shape (batch_size, num_classes, seq_len) """ encoded_text = input_ids # clearer name if self.text_embedder is None: From f2e42796721f8d8d899fd05215d7a3fd46d827ad Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Tue, 27 Jan 2026 10:14:36 +0000 Subject: [PATCH 3/3] Address code review feedback: fix trailing whitespace and NameError Co-authored-by: meilame-tayebjee <114609737+meilame-tayebjee@users.noreply.github.com> --- torchTextClassifiers/model/model.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/torchTextClassifiers/model/model.py b/torchTextClassifiers/model/model.py index eb23cca..3630e62 100644 --- a/torchTextClassifiers/model/model.py +++ b/torchTextClassifiers/model/model.py @@ -132,15 +132,20 @@ def forward( Returns: Union[torch.Tensor, dict[str, torch.Tensor]]: - - If return_label_attention_matrix is False: torch.Tensor of shape (batch_size, num_classes) + - If return_label_attention_matrix is False: torch.Tensor of shape (batch_size, num_classes) containing raw logits (not softmaxed) - If return_label_attention_matrix is True: dict with keys: - "logits": torch.Tensor of shape (batch_size, num_classes) - "label_attention_matrix": torch.Tensor of shape (batch_size, num_classes, seq_len) """ encoded_text = input_ids # clearer name + label_attention_matrix = None if self.text_embedder is None: x_text = encoded_text.float() + if return_label_attention_matrix: + raise ValueError( + "return_label_attention_matrix=True requires a text_embedder with label attention enabled" + ) else: text_embed_output = self.text_embedder( input_ids=encoded_text,