From f94066deaa39d08b794f9e64072c9de7ad672d63 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Mon, 26 Jan 2026 17:30:11 +0000 Subject: [PATCH 1/3] Initial plan From 21285a7c7429f6d910dd29dcecbb687afd20a322 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Mon, 26 Jan 2026 17:34:49 +0000 Subject: [PATCH 2/3] Add validation for LabelAttentionClassifier head configuration Co-authored-by: meilame-tayebjee <114609737+meilame-tayebjee@users.noreply.github.com> --- .../model/components/text_embedder.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/torchTextClassifiers/model/components/text_embedder.py b/torchTextClassifiers/model/components/text_embedder.py index 6875112..c116610 100644 --- a/torchTextClassifiers/model/components/text_embedder.py +++ b/torchTextClassifiers/model/components/text_embedder.py @@ -287,6 +287,20 @@ def __init__(self, config: TextEmbedderConfig): self.enable_gqa = ( self.n_head != self.n_kv_head ) # Group Query Attention (GQA): duplicate key/value heads to match query heads if desired + + # Validate head configuration + if self.embedding_dim % self.n_head != 0: + raise ValueError( + f"embedding_dim ({self.embedding_dim}) must be divisible by n_head ({self.n_head}). " + f"Got head_dim = {self.embedding_dim / self.n_head}" + ) + + if self.n_head % self.n_kv_head != 0: + raise ValueError( + f"n_head ({self.n_head}) must be divisible by n_kv_head ({self.n_kv_head}) for Group Query Attention. " + f"Got n_head / n_kv_head = {self.n_head / self.n_kv_head}" + ) + self.head_dim = self.embedding_dim // self.n_head self.label_embeds = nn.Embedding(self.num_classes, self.embedding_dim) From 50683b45ca8a1e3d4dab1ad407e5e8f6b7235597 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Mon, 26 Jan 2026 17:36:05 +0000 Subject: [PATCH 3/3] Improve validation to follow TextEmbedder pattern and clarify error messages Co-authored-by: meilame-tayebjee <114609737+meilame-tayebjee@users.noreply.github.com> --- torchTextClassifiers/model/components/text_embedder.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/torchTextClassifiers/model/components/text_embedder.py b/torchTextClassifiers/model/components/text_embedder.py index c116610..fb34cdf 100644 --- a/torchTextClassifiers/model/components/text_embedder.py +++ b/torchTextClassifiers/model/components/text_embedder.py @@ -289,19 +289,19 @@ def __init__(self, config: TextEmbedderConfig): ) # Group Query Attention (GQA): duplicate key/value heads to match query heads if desired # Validate head configuration - if self.embedding_dim % self.n_head != 0: + self.head_dim = self.embedding_dim // self.n_head + + if self.head_dim * self.n_head != self.embedding_dim: raise ValueError( f"embedding_dim ({self.embedding_dim}) must be divisible by n_head ({self.n_head}). " - f"Got head_dim = {self.embedding_dim / self.n_head}" + f"Got head_dim = {self.head_dim} with remainder {self.embedding_dim % self.n_head}" ) if self.n_head % self.n_kv_head != 0: raise ValueError( f"n_head ({self.n_head}) must be divisible by n_kv_head ({self.n_kv_head}) for Group Query Attention. " - f"Got n_head / n_kv_head = {self.n_head / self.n_kv_head}" + f"Got remainder {self.n_head % self.n_kv_head}" ) - - self.head_dim = self.embedding_dim // self.n_head self.label_embeds = nn.Embedding(self.num_classes, self.embedding_dim)