diff --git a/torchTextClassifiers/model/components/text_embedder.py b/torchTextClassifiers/model/components/text_embedder.py index 627bbf0..91788fd 100644 --- a/torchTextClassifiers/model/components/text_embedder.py +++ b/torchTextClassifiers/model/components/text_embedder.py @@ -198,7 +198,7 @@ def _get_sentence_embedding( token_embeddings: torch.Tensor, attention_mask: torch.Tensor, return_label_attention_matrix: bool = False, - ) -> torch.Tensor: + ) -> dict[str, Optional[torch.Tensor]]: """ Compute sentence embedding from embedded tokens - "remove" second dimension. @@ -206,7 +206,7 @@ def _get_sentence_embedding( token_embeddings (torch.Tensor[Long]), shape (batch_size, seq_len, embedding_dim): Tokenized + padded text attention_mask (torch.Tensor[Long]), shape (batch_size, seq_len): Attention mask indicating non-pad tokens Returns: - torch.Tensor: Sentence embeddings, shape (batch_size, embedding_dim) + dict: Dictionary with keys 'sentence_embedding' (torch.Tensor) and 'label_attention_matrix' (Optional[torch.Tensor]) """ # average over non-pad token embeddings @@ -217,14 +217,20 @@ def _get_sentence_embedding( if self.attention_config is not None: if self.attention_config.aggregation_method is not None: # default is "mean" if self.attention_config.aggregation_method == "first": - return token_embeddings[:, 0, :] + return { + "sentence_embedding": token_embeddings[:, 0, :], + "label_attention_matrix": None, + } elif self.attention_config.aggregation_method == "last": - lengths = attention_mask.sum(dim=1).clamp(min=1) # last non-pad token index + 1 - return token_embeddings[ - torch.arange(token_embeddings.size(0)), - lengths - 1, - :, - ] + lengths = attention_mask.sum(dim=1).clamp(min=1).long() # last non-pad token index + 1 + return { + "sentence_embedding": token_embeddings[ + torch.arange(token_embeddings.size(0)), + lengths - 1, + :, + ], + "label_attention_matrix": None, + } else: if self.attention_config.aggregation_method != "mean": raise ValueError(