diff --git a/torchTextClassifiers/model/model.py b/torchTextClassifiers/model/model.py index d253a66..5e4cc66 100644 --- a/torchTextClassifiers/model/model.py +++ b/torchTextClassifiers/model/model.py @@ -134,6 +134,7 @@ def forward( Raw, not softmaxed. """ encoded_text = input_ids # clearer name + label_attention_matrix = None if self.text_embedder is None: x_text = encoded_text.float() else: