diff --git a/tests/kernels/attention/test_attention.py b/tests/kernels/attention/test_attention.py index 2d381a99b..7269d1918 100644 --- a/tests/kernels/attention/test_attention.py +++ b/tests/kernels/attention/test_attention.py @@ -10,6 +10,7 @@ import torch from tests.kernels.allclose_default import get_default_atol, get_default_rtol from tests.kernels.utils import opcheck from vllm import _custom_ops as ops +from vllm.attention.layer import Attention, MultiHeadAttention from vllm.platforms import current_platform from vllm.utils import get_max_shared_memory_bytes @@ -506,3 +507,18 @@ def test_multi_query_kv_attention_with_alibi( device, use_alibi=True, ) + + +@pytest.mark.parametrize("attention_cls", [Attention, MultiHeadAttention]) +def test_num_heads_not_divisble_by_num_kv_heads(attention_cls: type) -> None: + head_size = 64 + scale = float(1.0 / (head_size**0.5)) + num_heads = 16 + num_kv_heads = 5 + with pytest.raises(AssertionError): + _ = attention_cls( + num_heads=num_heads, + head_size=head_size, + scale=scale, + num_kv_heads=num_kv_heads, + ) diff --git a/vllm/attention/backends/blocksparse_attn.py b/vllm/attention/backends/blocksparse_attn.py index 71415f493..fe9738d80 100644 --- a/vllm/attention/backends/blocksparse_attn.py +++ b/vllm/attention/backends/blocksparse_attn.py @@ -65,7 +65,6 @@ class BlocksparseParams: assert self.block_size > 0 assert self.local_blocks >= 0 assert self.vert_stride >= 1 - assert self.num_heads % self.num_kv_heads == 0 tp_size = get_tensor_model_parallel_world_size() tp_rank = get_tensor_model_parallel_rank() @@ -329,9 +328,8 @@ class BlocksparseFlashAttentionImpl(AttentionImpl): self.head_size = head_size self.scale = float(scale) self.alibi_slopes = alibi_slopes - self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads + self.num_kv_heads = num_kv_heads - assert self.num_heads % self.num_kv_heads == 0 self.num_queries_per_kv = self.num_heads // self.num_kv_heads self.local_blocks = self.blocksparse_params.local_blocks diff --git a/vllm/attention/backends/dual_chunk_flash_attn.py b/vllm/attention/backends/dual_chunk_flash_attn.py index 55f57f37b..f62a43b44 100644 --- a/vllm/attention/backends/dual_chunk_flash_attn.py +++ b/vllm/attention/backends/dual_chunk_flash_attn.py @@ -307,7 +307,6 @@ class DualChunkFlashAttentionImpl(FlashAttentionImpl): if sliding_window is not None else (-1, -1)) self.kv_cache_dtype = kv_cache_dtype - assert self.num_heads % self.num_kv_heads == 0 self.num_queries_per_kv = self.num_heads // self.num_kv_heads if sliding_window is not None: # NOTE(woosuk): flash-attn's sliding window does not work with diff --git a/vllm/attention/backends/flash_attn.py b/vllm/attention/backends/flash_attn.py index 47c25d136..bf8e37380 100755 --- a/vllm/attention/backends/flash_attn.py +++ b/vllm/attention/backends/flash_attn.py @@ -654,7 +654,6 @@ class FlashAttentionImpl(AttentionImpl): logits_soft_cap = 0 self.logits_soft_cap = logits_soft_cap - assert self.num_heads % self.num_kv_heads == 0 self.num_queries_per_kv = self.num_heads // self.num_kv_heads support_head_sizes = FlashAttentionBackend.get_supported_head_sizes() diff --git a/vllm/attention/backends/flashinfer.py b/vllm/attention/backends/flashinfer.py index ff7310478..a987dc538 100644 --- a/vllm/attention/backends/flashinfer.py +++ b/vllm/attention/backends/flashinfer.py @@ -957,7 +957,6 @@ class FlashInferImpl(AttentionImpl): self.kv_cache_dtype = kv_cache_dtype self.logits_soft_cap = logits_soft_cap - assert self.num_heads % self.num_kv_heads == 0 self.num_queries_per_kv = self.num_heads // self.num_kv_heads if attn_type != AttentionType.DECODER: diff --git a/vllm/attention/backends/hpu_attn.py b/vllm/attention/backends/hpu_attn.py index 115e5ba1a..bf778a1e5 100644 --- a/vllm/attention/backends/hpu_attn.py +++ b/vllm/attention/backends/hpu_attn.py @@ -148,7 +148,6 @@ class HPUAttentionImpl(AttentionImpl, torch.nn.Module): alibi_slopes_tensor = torch.tensor(alibi_slopes, dtype=torch.bfloat16) self.alibi_slopes = alibi_slopes_tensor - assert self.num_heads % self.num_kv_heads == 0 self.num_queries_per_kv = self.num_heads // self.num_kv_heads if self.prefill_impl == 'fsdpa': diff --git a/vllm/attention/backends/ipex_attn.py b/vllm/attention/backends/ipex_attn.py index 21f61cf70..410ada3b0 100644 --- a/vllm/attention/backends/ipex_attn.py +++ b/vllm/attention/backends/ipex_attn.py @@ -145,7 +145,6 @@ class IpexAttnBackendImpl(AttentionImpl[IpexAttnMetadata]): self.sliding_window = sliding_window self.kv_cache_dtype = kv_cache_dtype - assert self.num_heads % self.num_kv_heads == 0 self.num_queries_per_kv = self.num_heads // self.num_kv_heads self.need_mask = (self.sliding_window is not None) if logits_soft_cap is None: diff --git a/vllm/attention/backends/pallas.py b/vllm/attention/backends/pallas.py index c5c080297..c90066695 100644 --- a/vllm/attention/backends/pallas.py +++ b/vllm/attention/backends/pallas.py @@ -121,9 +121,8 @@ class PallasAttentionBackendImpl(AttentionImpl): self.num_heads = num_heads self.head_size = head_size self.scale = float(scale) - self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads + self.num_kv_heads = num_kv_heads - assert self.num_heads % self.num_kv_heads == 0 self.num_queries_per_kv = self.num_heads // self.num_kv_heads self.logits_soft_cap = logits_soft_cap if head_size % 128 != 0: diff --git a/vllm/attention/backends/rocm_flash_attn.py b/vllm/attention/backends/rocm_flash_attn.py index 8f1da84cd..1e2c21f4e 100644 --- a/vllm/attention/backends/rocm_flash_attn.py +++ b/vllm/attention/backends/rocm_flash_attn.py @@ -528,7 +528,6 @@ class ROCmFlashAttentionImpl(AttentionImpl): if sliding_window is not None else (-1, -1)) self.kv_cache_dtype = kv_cache_dtype - assert self.num_heads % self.num_kv_heads == 0 self.num_queries_per_kv = self.num_heads // self.num_kv_heads self.paged_attn_module = _get_paged_attn_module() diff --git a/vllm/attention/backends/torch_sdpa.py b/vllm/attention/backends/torch_sdpa.py index 9d7e735dd..3e1336a5a 100644 --- a/vllm/attention/backends/torch_sdpa.py +++ b/vllm/attention/backends/torch_sdpa.py @@ -433,7 +433,6 @@ class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]): self.sliding_window = sliding_window self.kv_cache_dtype = kv_cache_dtype - assert self.num_heads % self.num_kv_heads == 0 self.num_queries_per_kv = self.num_heads // self.num_kv_heads self.need_mask = (self.alibi_slopes is not None or self.sliding_window is not None) diff --git a/vllm/attention/backends/xformers.py b/vllm/attention/backends/xformers.py index dfdc8ee64..b583240c7 100644 --- a/vllm/attention/backends/xformers.py +++ b/vllm/attention/backends/xformers.py @@ -415,7 +415,6 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]): self.sliding_window = sliding_window self.kv_cache_dtype = kv_cache_dtype - assert self.num_heads % self.num_kv_heads == 0 self.num_queries_per_kv = self.num_heads // self.num_kv_heads supported_head_sizes = PagedAttention.get_supported_head_sizes() diff --git a/vllm/attention/layer.py b/vllm/attention/layer.py index 3bbe276e0..6d9c6f51b 100644 --- a/vllm/attention/layer.py +++ b/vllm/attention/layer.py @@ -80,6 +80,9 @@ class Attention(nn.Module): calculate_kv_scales = False if num_kv_heads is None: num_kv_heads = num_heads + assert num_heads % num_kv_heads == 0, \ + f"num_heads ({num_heads}) is not " \ + f"divisible by num_kv_heads ({num_kv_heads})" # The default k/v_scale is set to 1.0. This is ignored # when kv-cache is not fp8, and should be used with @@ -291,7 +294,9 @@ class MultiHeadAttention(nn.Module): self.scale = scale self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads - assert self.num_heads % self.num_kv_heads == 0 + assert self.num_heads % self.num_kv_heads == 0, \ + f"num_heads ({self.num_heads}) is not " \ + f"divisible by num_kv_heads ({self.num_kv_heads})" self.num_queries_per_kv = self.num_heads // self.num_kv_heads dtype = torch.get_default_dtype() diff --git a/vllm/v1/attention/backends/flash_attn.py b/vllm/v1/attention/backends/flash_attn.py index 630ac1322..8b7745ced 100755 --- a/vllm/v1/attention/backends/flash_attn.py +++ b/vllm/v1/attention/backends/flash_attn.py @@ -545,7 +545,6 @@ class FlashAttentionImpl(AttentionImpl): self.logits_soft_cap = logits_soft_cap self.kv_sharing_target_layer_name = kv_sharing_target_layer_name - assert self.num_heads % self.num_kv_heads == 0 self.num_queries_per_kv = self.num_heads // self.num_kv_heads support_head_sizes = FlashAttentionBackend.get_supported_head_sizes() diff --git a/vllm/v1/attention/backends/flashinfer.py b/vllm/v1/attention/backends/flashinfer.py index 12547b99e..b2f54f37a 100755 --- a/vllm/v1/attention/backends/flashinfer.py +++ b/vllm/v1/attention/backends/flashinfer.py @@ -532,7 +532,6 @@ class FlashInferImpl(AttentionImpl): self.logits_soft_cap = logits_soft_cap self.kv_sharing_target_layer_name = kv_sharing_target_layer_name - assert self.num_heads % self.num_kv_heads == 0 self.num_queries_per_kv = self.num_heads // self.num_kv_heads if attn_type != AttentionType.DECODER: diff --git a/vllm/v1/attention/backends/flex_attention.py b/vllm/v1/attention/backends/flex_attention.py index c8cb1481c..a572b8947 100644 --- a/vllm/v1/attention/backends/flex_attention.py +++ b/vllm/v1/attention/backends/flex_attention.py @@ -376,7 +376,6 @@ class FlexAttentionImpl(AttentionImpl): raise NotImplementedError( "FlexAttention does not support logits soft cap yet.") - assert self.num_heads % self.num_kv_heads == 0 self.num_queries_per_kv = self.num_heads // self.num_kv_heads if kv_sharing_target_layer_name is not None: diff --git a/vllm/v1/attention/backends/pallas.py b/vllm/v1/attention/backends/pallas.py index 62c72f43f..7a6d8c0f8 100644 --- a/vllm/v1/attention/backends/pallas.py +++ b/vllm/v1/attention/backends/pallas.py @@ -131,7 +131,6 @@ class PallasAttentionBackendImpl(AttentionImpl): self.logits_soft_cap = logits_soft_cap self.kv_sharing_target_layer_name = kv_sharing_target_layer_name - assert self.num_heads % self.num_kv_heads == 0 self.num_queries_per_kv = self.num_heads // self.num_kv_heads if head_size % 128 != 0: raise NotImplementedError("Head size must be a multiple of 128.") diff --git a/vllm/v1/attention/backends/triton_attn.py b/vllm/v1/attention/backends/triton_attn.py index 6b67d9932..9782ec087 100644 --- a/vllm/v1/attention/backends/triton_attn.py +++ b/vllm/v1/attention/backends/triton_attn.py @@ -114,7 +114,6 @@ class TritonAttentionImpl(AttentionImpl): self.use_irope = use_irope - assert self.num_heads % self.num_kv_heads == 0 self.num_queries_per_kv = self.num_heads // self.num_kv_heads support_head_sizes = TritonAttentionBackend.get_supported_head_sizes()