From fe35fc2342f859d3eec6c8f92cf0f38a050613ac Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Tue, 27 Jan 2026 10:13:27 +0000 Subject: [PATCH 1/3] Initial plan From 63cb6466daadaccaf820d1f0c85339bfd9b22ec0 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Tue, 27 Jan 2026 10:15:35 +0000 Subject: [PATCH 2/3] Update _get_sentence_embedding return type annotation and docstring Co-authored-by: meilame-tayebjee <114609737+meilame-tayebjee@users.noreply.github.com> --- torchTextClassifiers/model/components/text_embedder.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/torchTextClassifiers/model/components/text_embedder.py b/torchTextClassifiers/model/components/text_embedder.py index 627bbf0..9222030 100644 --- a/torchTextClassifiers/model/components/text_embedder.py +++ b/torchTextClassifiers/model/components/text_embedder.py @@ -1,6 +1,6 @@ import math from dataclasses import dataclass -from typing import Optional +from typing import Dict, Optional import torch import torch.nn as nn @@ -198,15 +198,18 @@ def _get_sentence_embedding( token_embeddings: torch.Tensor, attention_mask: torch.Tensor, return_label_attention_matrix: bool = False, - ) -> torch.Tensor: + ) -> Dict[str, Optional[torch.Tensor]]: """ Compute sentence embedding from embedded tokens - "remove" second dimension. Args (output from dataset collate_fn): token_embeddings (torch.Tensor[Long]), shape (batch_size, seq_len, embedding_dim): Tokenized + padded text attention_mask (torch.Tensor[Long]), shape (batch_size, seq_len): Attention mask indicating non-pad tokens + return_label_attention_matrix (bool): Whether to compute and return the label attention matrix Returns: - torch.Tensor: Sentence embeddings, shape (batch_size, embedding_dim) + Dict[str, Optional[torch.Tensor]]: A dictionary containing: + - 'sentence_embedding': Sentence embeddings, shape (batch_size, embedding_dim) or (batch_size, n_labels, embedding_dim) if label attention is enabled + - 'label_attention_matrix': Attention matrix if label attention is enabled and return_label_attention_matrix is True, otherwise None """ # average over non-pad token embeddings From e9635ea7baef08435bc749329411ff95b296402d Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Tue, 27 Jan 2026 10:16:35 +0000 Subject: [PATCH 3/3] Fix early returns to match dictionary return type Co-authored-by: meilame-tayebjee <114609737+meilame-tayebjee@users.noreply.github.com> --- .../model/components/text_embedder.py | 20 ++++++++++++------- 1 file changed, 13 insertions(+), 7 deletions(-) diff --git a/torchTextClassifiers/model/components/text_embedder.py b/torchTextClassifiers/model/components/text_embedder.py index 9222030..2deb2a8 100644 --- a/torchTextClassifiers/model/components/text_embedder.py +++ b/torchTextClassifiers/model/components/text_embedder.py @@ -129,7 +129,7 @@ def forward( input_ids: torch.Tensor, attention_mask: torch.Tensor, return_label_attention_matrix: bool = False, - ) -> dict[str, Optional[torch.Tensor]]: + ) -> Dict[str, Optional[torch.Tensor]]: """Converts input token IDs to their corresponding embeddings. Args: @@ -220,14 +220,20 @@ def _get_sentence_embedding( if self.attention_config is not None: if self.attention_config.aggregation_method is not None: # default is "mean" if self.attention_config.aggregation_method == "first": - return token_embeddings[:, 0, :] + return { + "sentence_embedding": token_embeddings[:, 0, :], + "label_attention_matrix": None, + } elif self.attention_config.aggregation_method == "last": lengths = attention_mask.sum(dim=1).clamp(min=1) # last non-pad token index + 1 - return token_embeddings[ - torch.arange(token_embeddings.size(0)), - lengths - 1, - :, - ] + return { + "sentence_embedding": token_embeddings[ + torch.arange(token_embeddings.size(0)), + lengths - 1, + :, + ], + "label_attention_matrix": None, + } else: if self.attention_config.aggregation_method != "mean": raise ValueError(