diff --git a/.gitignore b/.gitignore index e65366e..586ae23 100644 --- a/.gitignore +++ b/.gitignore @@ -183,3 +183,4 @@ example_files/ _site/ .quarto/ **/*.quarto_ipynb +my_ttc/ diff --git a/torchTextClassifiers/model/components/text_embedder.py b/torchTextClassifiers/model/components/text_embedder.py index 7bd030c..b5ce92e 100644 --- a/torchTextClassifiers/model/components/text_embedder.py +++ b/torchTextClassifiers/model/components/text_embedder.py @@ -324,7 +324,9 @@ def forward(self, token_embeddings, compute_attention_matrix: Optional[bool] = F compute_attention_matrix = bool(compute_attention_matrix) # 1. Create label indices [0, 1, ..., C-1] for the whole batch - label_indices = torch.arange(self.num_classes).expand(B, -1) + label_indices = torch.arange( + self.num_classes, dtype=torch.long, device=token_embeddings.device + ).expand(B, -1) all_label_embeddings = self.label_embeds( label_indices