diff --git a/vllm/model_executor/layers/attention/mm_encoder_attention.py b/vllm/model_executor/layers/attention/mm_encoder_attention.py index 35c10ec0b..f26d89f40 100644 --- a/vllm/model_executor/layers/attention/mm_encoder_attention.py +++ b/vllm/model_executor/layers/attention/mm_encoder_attention.py @@ -80,7 +80,7 @@ class MMEncoderAttention(CustomOp): def enabled(cls) -> bool: return True - def maybe_reshape_qkv_to_4d( + def view_qkv_to_4d( self, query: torch.Tensor, key: torch.Tensor, @@ -97,11 +97,6 @@ class MMEncoderAttention(CustomOp): key = key.view(bsz, kv_len, self.num_kv_heads, self.head_size) value = value.view(bsz, kv_len, self.num_kv_heads, self.head_size) - if (num_repeat := self.num_queries_per_kv) > 1: - # Handle MQA and GQA - key = torch.repeat_interleave(key, num_repeat, dim=2) - value = torch.repeat_interleave(value, num_repeat, dim=2) - return query, key, value def _forward_sdpa( @@ -119,9 +114,7 @@ class MMEncoderAttention(CustomOp): kv_len = key.size(1) is_reshaped = query.dim() != 4 - query, key, value = self.maybe_reshape_qkv_to_4d( - query, key, value, bsz, q_len, kv_len - ) + query, key, value = self.view_qkv_to_4d(query, key, value, bsz, q_len, kv_len) output = vit_torch_sdpa_wrapper( q=query, @@ -129,6 +122,7 @@ class MMEncoderAttention(CustomOp): v=value, scale=self.scale, cu_seqlens=cu_seqlens, + enable_gqa=self.num_heads > self.num_kv_heads, ) if is_reshaped: output = output.reshape(bsz, q_len, -1) @@ -154,9 +148,7 @@ class MMEncoderAttention(CustomOp): kv_len = key.size(1) is_reshaped = query.dim() != 4 - query, key, value = self.maybe_reshape_qkv_to_4d( - query, key, value, bsz, q_len, kv_len - ) + query, key, value = self.view_qkv_to_4d(query, key, value, bsz, q_len, kv_len) output = vit_flash_attn_wrapper( q=query, diff --git a/vllm/model_executor/models/molmo2.py b/vllm/model_executor/models/molmo2.py index 9d996a93b..30f639c8b 100644 --- a/vllm/model_executor/models/molmo2.py +++ b/vllm/model_executor/models/molmo2.py @@ -628,18 +628,6 @@ class ImagePoolingAttention(nn.Module): key = key.view(bsz, kv_len, self.num_kv_heads, self.head_dim) value = value.view(bsz, kv_len, self.num_kv_heads, self.head_dim) - if self.num_heads != self.num_kv_heads: - key = torch.repeat_interleave( - key, - self.num_heads // self.num_kv_heads, - dim=2, - ) - value = torch.repeat_interleave( - value, - self.num_heads // self.num_kv_heads, - dim=2, - ) - query, key, value = (x.transpose(1, 2) for x in (query, key, value)) out = F.scaled_dot_product_attention( @@ -648,6 +636,7 @@ class ImagePoolingAttention(nn.Module): value, attn_mask=attn_mask, is_causal=False, + enable_gqa=self.num_heads > self.num_kv_heads, ).transpose(1, 2) return out.reshape(bsz, q_len, -1) diff --git a/vllm/v1/attention/backends/cpu_attn.py b/vllm/v1/attention/backends/cpu_attn.py index 3eb9b4782..a2f2c6aeb 100644 --- a/vllm/v1/attention/backends/cpu_attn.py +++ b/vllm/v1/attention/backends/cpu_attn.py @@ -398,10 +398,6 @@ class CPUAttentionBackendImpl(AttentionImpl): key = key.movedim(0, key.dim() - 2) value = value.movedim(0, value.dim() - 2) - if self.num_kv_heads != self.num_heads: - key = key.repeat_interleave(self.num_queries_per_kv, dim=-3) - value = value.repeat_interleave(self.num_queries_per_kv, dim=-3) - causal_attn = attn_type == AttentionType.DECODER sdpa_start_loc = attn_metadata.sdpa_start_loc.numpy() # type: ignore @@ -418,6 +414,7 @@ class CPUAttentionBackendImpl(AttentionImpl): dropout_p=0.0, is_causal=causal_attn and mask is None, scale=self.scale, + enable_gqa=self.num_heads > self.num_kv_heads, ) .squeeze(0) .movedim(query.dim() - 2, 0) diff --git a/vllm/v1/attention/ops/vit_attn_wrappers.py b/vllm/v1/attention/ops/vit_attn_wrappers.py index f077a61c9..32fcb3511 100644 --- a/vllm/v1/attention/ops/vit_attn_wrappers.py +++ b/vllm/v1/attention/ops/vit_attn_wrappers.py @@ -115,13 +115,16 @@ def apply_sdpa( k: torch.Tensor, v: torch.Tensor, scale: float | None = None, + enable_gqa: bool = False, ) -> torch.Tensor: """ Input shape: (batch_size x seq_len x num_heads x head_size) """ q, k, v = (einops.rearrange(x, "b s h d -> b h s d") for x in [q, k, v]) - output = F.scaled_dot_product_attention(q, k, v, dropout_p=0.0, scale=scale) + output = F.scaled_dot_product_attention( + q, k, v, dropout_p=0.0, scale=scale, enable_gqa=enable_gqa + ) output = einops.rearrange(output, "b h s d -> b s h d ") return output @@ -134,6 +137,7 @@ def torch_sdpa_wrapper( v: torch.Tensor, scale: float | None = None, cu_seqlens: torch.Tensor | None = None, + enable_gqa: bool = False, ) -> torch.Tensor: # Never remove the contiguous logic for ROCm # Without it, hallucinations occur with the backend @@ -143,7 +147,7 @@ def torch_sdpa_wrapper( v = v.contiguous() if cu_seqlens is None: - return apply_sdpa(q, k, v, scale=scale) + return apply_sdpa(q, k, v, scale=scale, enable_gqa=enable_gqa) outputs = [] @@ -152,7 +156,7 @@ def torch_sdpa_wrapper( k_chunks = torch.split(k, lens, dim=1) v_chunks = torch.split(v, lens, dim=1) for q_i, k_i, v_i in zip(q_chunks, k_chunks, v_chunks): - output_i = apply_sdpa(q_i, k_i, v_i, scale=scale) + output_i = apply_sdpa(q_i, k_i, v_i, scale=scale, enable_gqa=enable_gqa) outputs.append(output_i) context_layer = torch.cat(outputs, dim=1) return context_layer @@ -164,6 +168,7 @@ def torch_sdpa_wrapper_fake( v: torch.Tensor, scale: float | None, cu_seqlens: torch.Tensor | None, + enable_gqa: bool = False, ) -> torch.Tensor: return torch.empty_like(q) @@ -181,5 +186,8 @@ def vit_torch_sdpa_wrapper( v: torch.Tensor, scale: float | None = None, cu_seqlens: torch.Tensor | None = None, + enable_gqa: bool = False, ) -> torch.Tensor: - return torch.ops.vllm.torch_sdpa_wrapper(q, k, v, scale, cu_seqlens) + return torch.ops.vllm.torch_sdpa_wrapper( + q, k, v, scale, cu_seqlens, enable_gqa=enable_gqa + )