From 3ef803ea8a037006a7e7f532282fc3c25ac07ccb Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Tue, 27 Jan 2026 10:13:20 +0000 Subject: [PATCH 1/3] Initial plan From ed56ec55407c0efe859c21b23513a2cd2b2ebcd5 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Tue, 27 Jan 2026 10:19:28 +0000 Subject: [PATCH 2/3] Fix: return consistent dict structure from _get_sentence_embedding for all aggregation methods Co-authored-by: meilame-tayebjee <114609737+meilame-tayebjee@users.noreply.github.com> --- .../model/components/text_embedder.py | 20 ++++++++++++------- 1 file changed, 13 insertions(+), 7 deletions(-) diff --git a/torchTextClassifiers/model/components/text_embedder.py b/torchTextClassifiers/model/components/text_embedder.py index 627bbf0..94b445c 100644 --- a/torchTextClassifiers/model/components/text_embedder.py +++ b/torchTextClassifiers/model/components/text_embedder.py @@ -217,14 +217,20 @@ def _get_sentence_embedding( if self.attention_config is not None: if self.attention_config.aggregation_method is not None: # default is "mean" if self.attention_config.aggregation_method == "first": - return token_embeddings[:, 0, :] + return { + "sentence_embedding": token_embeddings[:, 0, :], + "label_attention_matrix": None, + } elif self.attention_config.aggregation_method == "last": - lengths = attention_mask.sum(dim=1).clamp(min=1) # last non-pad token index + 1 - return token_embeddings[ - torch.arange(token_embeddings.size(0)), - lengths - 1, - :, - ] + lengths = attention_mask.sum(dim=1).clamp(min=1).long() # last non-pad token index + 1 + return { + "sentence_embedding": token_embeddings[ + torch.arange(token_embeddings.size(0)), + lengths - 1, + :, + ], + "label_attention_matrix": None, + } else: if self.attention_config.aggregation_method != "mean": raise ValueError( From 51c6d6beca14cf4ef70d4314eeb90fac42785e5f Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Tue, 27 Jan 2026 10:20:20 +0000 Subject: [PATCH 3/3] Update return type annotation for _get_sentence_embedding to reflect dict return Co-authored-by: meilame-tayebjee <114609737+meilame-tayebjee@users.noreply.github.com> --- torchTextClassifiers/model/components/text_embedder.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torchTextClassifiers/model/components/text_embedder.py b/torchTextClassifiers/model/components/text_embedder.py index 94b445c..91788fd 100644 --- a/torchTextClassifiers/model/components/text_embedder.py +++ b/torchTextClassifiers/model/components/text_embedder.py @@ -198,7 +198,7 @@ def _get_sentence_embedding( token_embeddings: torch.Tensor, attention_mask: torch.Tensor, return_label_attention_matrix: bool = False, - ) -> torch.Tensor: + ) -> dict[str, Optional[torch.Tensor]]: """ Compute sentence embedding from embedded tokens - "remove" second dimension. @@ -206,7 +206,7 @@ def _get_sentence_embedding( token_embeddings (torch.Tensor[Long]), shape (batch_size, seq_len, embedding_dim): Tokenized + padded text attention_mask (torch.Tensor[Long]), shape (batch_size, seq_len): Attention mask indicating non-pad tokens Returns: - torch.Tensor: Sentence embeddings, shape (batch_size, embedding_dim) + dict: Dictionary with keys 'sentence_embedding' (torch.Tensor) and 'label_attention_matrix' (Optional[torch.Tensor]) """ # average over non-pad token embeddings