Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
601fa46
feat!(label attention): enable label attention
meilame-tayebjee Jan 26, 2026
30ef8af
fix(load): restore LabelAttentionConfig object from loaded dict
meilame-tayebjee Jan 26, 2026
0a1880b
test(label attention): add a test_pipeline with label attention activ…
meilame-tayebjee Jan 26, 2026
a2fe33e
feat(explainability): add new expl. pipe. with label attention
meilame-tayebjee Jan 26, 2026
3d034f4
fix typo
meilame-tayebjee Jan 26, 2026
ec6742c
fix docstring in TextEmbedder forward
meilame-tayebjee Jan 26, 2026
f991b6b
fix: convert to LabelAttentionConfig object when dict
meilame-tayebjee Jan 26, 2026
525b482
chore: replace type checking with isinstance
meilame-tayebjee Jan 26, 2026
2374df8
chore: better dict handling in TextEmbedder forward output
meilame-tayebjee Jan 26, 2026
d266572
fix: ensure label_indices uses correct device and dtype
Copilot Jan 26, 2026
4aada37
Add validation for LabelAttentionClassifier head configuration
Copilot Jan 26, 2026
d516e6b
Improve validation to follow TextEmbedder pattern and clarify error m…
Copilot Jan 26, 2026
87d672f
Apply attention mask in LabelAttentionClassifier cross-attention
Copilot Jan 27, 2026
7a988c3
Fix trailing whitespace in attention matrix computation
Copilot Jan 27, 2026
70b79c9
doc: fix docstring on shape of label attention matrix
meilame-tayebjee Jan 27, 2026
0558f97
Add assertions for label attention attributions in tests
Copilot Jan 27, 2026
44d9345
Update _get_sentence_embedding return type annotation and docstring
Copilot Jan 27, 2026
86a0715
Fix early returns to match dictionary return type
Copilot Jan 27, 2026
a721397
Initialize label_attention_matrix to None before text_embedder branch
Copilot Jan 27, 2026
798ff8c
Fix return type annotation for TextClassificationModel.forward
Copilot Jan 27, 2026
1bbff15
Address code review feedback: fix trailing whitespace and NameError
Copilot Jan 27, 2026
308fddc
chore: disable gqa for label attention for simpl.
meilame-tayebjee Jan 27, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -183,3 +183,4 @@ example_files/
_site/
.quarto/
**/*.quarto_ipynb
my_ttc/
83 changes: 79 additions & 4 deletions tests/test_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
AttentionConfig,
CategoricalVariableNet,
ClassificationHead,
LabelAttentionConfig,
TextEmbedder,
TextEmbedderConfig,
)
Expand Down Expand Up @@ -51,7 +52,14 @@ def model_params():
}


def run_full_pipeline(tokenizer, sample_text_data, categorical_data, labels, model_params):
def run_full_pipeline(
tokenizer,
sample_text_data,
categorical_data,
labels,
model_params,
label_attention_enabled: bool = False,
):
"""Helper function to run the complete pipeline for a given tokenizer."""
# Create dataset
dataset = TextClassificationDataset(
Expand Down Expand Up @@ -83,6 +91,14 @@ def run_full_pipeline(tokenizer, sample_text_data, categorical_data, labels, mod
embedding_dim=model_params["embedding_dim"],
padding_idx=padding_idx,
attention_config=attention_config,
label_attention_config=(
LabelAttentionConfig(
n_head=attention_config.n_head,
num_classes=model_params["num_classes"],
)
if label_attention_enabled
else None
),
)

text_embedder = TextEmbedder(text_embedder_config=text_embedder_config)
Expand All @@ -98,7 +114,7 @@ def run_full_pipeline(tokenizer, sample_text_data, categorical_data, labels, mod
expected_input_dim = model_params["embedding_dim"] + categorical_var_net.output_dim
classification_head = ClassificationHead(
input_dim=expected_input_dim,
num_classes=model_params["num_classes"],
num_classes=model_params["num_classes"] if not label_attention_enabled else 1,
)

# Create model
Expand Down Expand Up @@ -136,6 +152,14 @@ def run_full_pipeline(tokenizer, sample_text_data, categorical_data, labels, mod
categorical_embedding_dims=model_params["categorical_embedding_dims"],
num_classes=model_params["num_classes"],
attention_config=attention_config,
label_attention_config=(
LabelAttentionConfig(
n_head=attention_config.n_head,
num_classes=model_params["num_classes"],
)
if label_attention_enabled
else None
),
)

# Create training config
Expand Down Expand Up @@ -163,13 +187,41 @@ def run_full_pipeline(tokenizer, sample_text_data, categorical_data, labels, mod

# Predict with explanations
top_k = 5
predictions = ttc.predict(X, top_k=top_k, explain=True)

predictions = ttc.predict(
X,
top_k=top_k,
explain_with_label_attention=label_attention_enabled,
explain_with_captum=True,
)

# Test label attention assertions
if label_attention_enabled:
assert (
predictions["label_attention_attributions"] is not None
), "Label attention attributions should not be None when label_attention_enabled is True"
label_attention_attributions = predictions["label_attention_attributions"]
expected_shape = (
len(sample_text_data), # batch_size
model_params["n_head"], # n_head
model_params["num_classes"], # num_classes
tokenizer.output_dim, # seq_len
)
assert label_attention_attributions.shape == expected_shape, (
f"Label attention attributions shape mismatch. "
f"Expected {expected_shape}, got {label_attention_attributions.shape}"
)
else:
# When label attention is not enabled, the attributions should be None
assert (
predictions.get("label_attention_attributions") is None
), "Label attention attributions should be None when label_attention_enabled is False"

# Test explainability functions
text_idx = 0
text = sample_text_data[text_idx]
offsets = predictions["offset_mapping"][text_idx]
attributions = predictions["attributions"][text_idx]
attributions = predictions["captum_attributions"][text_idx]
word_ids = predictions["word_ids"][text_idx]

words, word_attributions = map_attributions_to_word(attributions, text, word_ids, offsets)
Expand Down Expand Up @@ -239,3 +291,26 @@ def test_ngram_tokenizer(sample_data, model_params):

# Run full pipeline
run_full_pipeline(tokenizer, sample_text_data, categorical_data, labels, model_params)


def test_label_attention_enabled(sample_data, model_params):
"""Test the full pipeline with label attention enabled (using WordPieceTokenizer)."""
sample_text_data, categorical_data, labels = sample_data

vocab_size = 100
tokenizer = WordPieceTokenizer(vocab_size, output_dim=50)
tokenizer.train(sample_text_data)

# Check tokenizer works
result = tokenizer.tokenize(sample_text_data)
assert result.input_ids.shape[0] == len(sample_text_data)

# Run full pipeline with label attention enabled
run_full_pipeline(
tokenizer,
sample_text_data,
categorical_data,
labels,
model_params,
label_attention_enabled=True,
)
1 change: 1 addition & 0 deletions torchTextClassifiers/model/components/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,5 +8,6 @@
CategoricalVariableNet as CategoricalVariableNet,
)
from .classification_head import ClassificationHead as ClassificationHead
from .text_embedder import LabelAttentionConfig as LabelAttentionConfig
from .text_embedder import TextEmbedder as TextEmbedder
from .text_embedder import TextEmbedderConfig as TextEmbedderConfig
Loading