diff --git a/torchTextClassifiers/model/components/text_embedder.py b/torchTextClassifiers/model/components/text_embedder.py index 6875112..fb34cdf 100644 --- a/torchTextClassifiers/model/components/text_embedder.py +++ b/torchTextClassifiers/model/components/text_embedder.py @@ -287,7 +287,21 @@ def __init__(self, config: TextEmbedderConfig): self.enable_gqa = ( self.n_head != self.n_kv_head ) # Group Query Attention (GQA): duplicate key/value heads to match query heads if desired + + # Validate head configuration self.head_dim = self.embedding_dim // self.n_head + + if self.head_dim * self.n_head != self.embedding_dim: + raise ValueError( + f"embedding_dim ({self.embedding_dim}) must be divisible by n_head ({self.n_head}). " + f"Got head_dim = {self.head_dim} with remainder {self.embedding_dim % self.n_head}" + ) + + if self.n_head % self.n_kv_head != 0: + raise ValueError( + f"n_head ({self.n_head}) must be divisible by n_kv_head ({self.n_kv_head}) for Group Query Attention. " + f"Got remainder {self.n_head % self.n_kv_head}" + ) self.label_embeds = nn.Embedding(self.num_classes, self.embedding_dim)