Skip to content

Conversation

Copy link
Contributor

Copilot AI commented Jan 27, 2026

The forward method returns either torch.Tensor or dict[str, torch.Tensor] depending on return_label_attention_matrix, but was annotated as only returning torch.Tensor.

Changes:

  • Updated return type to Union[torch.Tensor, dict[str, torch.Tensor]]
  • Enhanced docstring to document both return paths with dict key structure
  • Added validation: raises ValueError when return_label_attention_matrix=True without a configured text_embedder

Example:

# Now correctly typed for both cases
logits = model(input_ids, attention_mask, categorical_vars)  # torch.Tensor
result = model(input_ids, attention_mask, categorical_vars, return_label_attention_matrix=True)
# result: dict with keys "logits" and "label_attention_matrix"

💡 You can make Copilot smarter by setting up custom instructions, customizing its development environment and configuring Model Context Protocol (MCP) servers. Learn more Copilot coding agent tips in the docs.

Copilot AI and others added 2 commits January 27, 2026 10:12
Co-authored-by: meilame-tayebjee <114609737+meilame-tayebjee@users.noreply.github.com>
Co-authored-by: meilame-tayebjee <114609737+meilame-tayebjee@users.noreply.github.com>
Copilot AI changed the title [WIP] Update cross attention labels based on review feedback Fix return type annotation for TextClassificationModel.forward Jan 27, 2026
@meilame-tayebjee meilame-tayebjee marked this pull request as ready for review January 27, 2026 10:42
@meilame-tayebjee meilame-tayebjee merged commit 1bbff15 into 24-add-cross-attention-labels-text Jan 27, 2026
@meilame-tayebjee meilame-tayebjee deleted the copilot/sub-pr-60 branch January 27, 2026 10:42
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants