[Kernel] Raise verbose error and consolidate num_heads/num_kv_heads divisibility check (#19339)

Signed-off-by: 22quinn <33176974+22quinn@users.noreply.github.com>
This commit is contained in:
22quinn
2025-06-14 22:43:48 -07:00
committed by GitHub
parent ee1531bc38
commit 0b73736a0d
17 changed files with 24 additions and 19 deletions

View File

@@ -545,7 +545,6 @@ class FlashAttentionImpl(AttentionImpl):
self.logits_soft_cap = logits_soft_cap
self.kv_sharing_target_layer_name = kv_sharing_target_layer_name
assert self.num_heads % self.num_kv_heads == 0
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
support_head_sizes = FlashAttentionBackend.get_supported_head_sizes()

View File

@@ -532,7 +532,6 @@ class FlashInferImpl(AttentionImpl):
self.logits_soft_cap = logits_soft_cap
self.kv_sharing_target_layer_name = kv_sharing_target_layer_name
assert self.num_heads % self.num_kv_heads == 0
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
if attn_type != AttentionType.DECODER:

View File

@@ -376,7 +376,6 @@ class FlexAttentionImpl(AttentionImpl):
raise NotImplementedError(
"FlexAttention does not support logits soft cap yet.")
assert self.num_heads % self.num_kv_heads == 0
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
if kv_sharing_target_layer_name is not None:

View File

@@ -131,7 +131,6 @@ class PallasAttentionBackendImpl(AttentionImpl):
self.logits_soft_cap = logits_soft_cap
self.kv_sharing_target_layer_name = kv_sharing_target_layer_name
assert self.num_heads % self.num_kv_heads == 0
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
if head_size % 128 != 0:
raise NotImplementedError("Head size must be a multiple of 128.")

View File

@@ -114,7 +114,6 @@ class TritonAttentionImpl(AttentionImpl):
self.use_irope = use_irope
assert self.num_heads % self.num_kv_heads == 0
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
support_head_sizes = TritonAttentionBackend.get_supported_head_sizes()