From 4c2effa2be1e9ac4fd4edb7e72d48aa6a471da89 Mon Sep 17 00:00:00 2001 From: biondizzle Date: Tue, 19 May 2026 06:44:59 +0000 Subject: [PATCH] Fix attention patch: source from v0.21.0 stable, not local clone The local vllm clone has different imports (breakable_cudagraph) that don't exist in the Docker image. Now sourced from v0.21.0 tag. --- vllm/patches/deepseek_v4_attention.py | 86 +++++++++++++++------------ 1 file changed, 49 insertions(+), 37 deletions(-) diff --git a/vllm/patches/deepseek_v4_attention.py b/vllm/patches/deepseek_v4_attention.py index e4846da9..a2bf25f1 100644 --- a/vllm/patches/deepseek_v4_attention.py +++ b/vllm/patches/deepseek_v4_attention.py @@ -31,7 +31,11 @@ from vllm.v1.attention.ops.deepseek_v4_ops import ( fused_inv_rope_fp8_quant, fused_q_kv_rmsnorm, ) -from vllm.v1.attention.ops.rocm_aiter_mla_sparse import rocm_inv_rope_einsum +from vllm.v1.attention.ops.rocm_aiter_mla_sparse import ( + rocm_forward_decode_fallback, + rocm_inv_rope_einsum, + rocm_sparse_attn_prefill, +) if TYPE_CHECKING: from vllm.v1.attention.backends.mla.sparse_swa import ( @@ -318,7 +322,6 @@ class DeepseekV4MultiHeadLatentAttentionWrapper(PluggableLayer): llama_4_scaling: torch.Tensor | None = None, ) -> torch.Tensor: # Pre-allocate attention output with FlashMLA-padded head count. - # The op writes into `o_padded`; we slice to n_local_heads after. num_tokens = hidden_states.shape[0] o_padded = torch.empty( (num_tokens, self.padded_heads, self.head_dim), @@ -341,31 +344,22 @@ class DeepseekV4MultiHeadLatentAttentionWrapper(PluggableLayer): # Step 1: Inverse RoPE (BF16, pure PyTorch) o_inv = _apply_inv_rope_bf16( - o, - positions, - self.rotary_emb.cos_sin_cache, - nope_dim=self.nope_head_dim, - rope_dim=self.rope_head_dim, + o, positions, self.rotary_emb.cos_sin_cache, + nope_dim=self.nope_head_dim, rope_dim=self.rope_head_dim, ) # Step 2: wo_a grouped linear (BF16 BMM) - # o_inv: (T, n_local_heads, head_dim) - # wo_a.weight: (n_local_groups * o_lora_rank, heads_per_group * head_dim) BF16 hidden_dim = self.wo_a.weight.shape[1] o_grouped = o_inv.view(num_tokens, self.n_local_groups, hidden_dim) wo_a_w = self.wo_a.weight.view( self.n_local_groups, self.o_lora_rank, hidden_dim ) z = torch.bmm( - o_grouped.permute(1, 0, 2), - wo_a_w.transpose(1, 2), + o_grouped.permute(1, 0, 2), wo_a_w.transpose(1, 2), ).permute(1, 0, 2) # Step 3: wo_b (NVFP4 via CuTeDSL) return self.wo_b(z.flatten(1)) - ) - - return self.wo_b(z.flatten(1)) def attn_gemm_parallel_execute(self, hidden_states) -> tuple[Any, ...]: aux_streams = self.aux_stream_list @@ -739,12 +733,6 @@ class DeepseekV4MLAAttention(nn.Module, AttentionLayerBase): self.kv_cache = torch.tensor([]) def get_attn_backend(self) -> type[AttentionBackend]: - if current_platform.is_rocm(): - from vllm.v1.attention.backends.mla.rocm_aiter_mla_sparse_dsv4 import ( - DeepseekV4ROCMAiterMLASparseBackend, - ) - - return DeepseekV4ROCMAiterMLASparseBackend return DeepseekV4FlashMLASparseBackend def get_kv_cache_spec(self, vllm_config: VllmConfig) -> KVCacheSpec | None: @@ -777,14 +765,6 @@ class DeepseekV4MLAAttention(nn.Module, AttentionLayerBase): f"output buffer dtype {output.dtype} must match q dtype {q.dtype}" ) - if current_platform.is_rocm(): - from vllm.v1.attention.backends.mla.rocm_aiter_mla_sparse_dsv4 import ( - DeepseekV4ROCMAiterMLASparseImpl, - ) - - DeepseekV4ROCMAiterMLASparseImpl.forward(self, q, kv, positions, output) - return - # Get SWA and indexer metadata from forward context forward_context = get_forward_context() attn_metadata = forward_context.attn_metadata @@ -867,6 +847,25 @@ class DeepseekV4MLAAttention(nn.Module, AttentionLayerBase): swa_indices = swa_metadata.decode_swa_indices swa_lens = swa_metadata.decode_swa_lens + if current_platform.is_rocm(): + rocm_forward_decode_fallback( + q=q, + kv_cache=kv_cache, + swa_k_cache=self.swa_cache_layer.kv_cache, + swa_only=swa_only, + topk_indices=topk_indices, + topk_lens=topk_lens, + swa_indices=swa_indices, + swa_lens=swa_lens, + attn_sink=self.attn_sink, + scale=self.scale, + head_dim=self.head_dim, + nope_head_dim=self.nope_head_dim, + rope_head_dim=self.rope_head_dim, + output=output, + ) + return + # We treat queries in the same seq as different queries # and later we only attend by generated indices. # q arrives pre-padded to self.padded_heads by the outer wrapper. @@ -1030,15 +1029,28 @@ class DeepseekV4MLAAttention(nn.Module, AttentionLayerBase): M, N, ) - flash_mla_sparse_fwd( - q=q[query_start:query_end], - kv=kv.view(-1, 1, q.shape[-1]), - indices=combined_indices.unsqueeze(1), - sm_scale=self.scale, - attn_sink=self.attn_sink, - topk_length=combined_lens, - out=output[query_start:query_end], - ) + + if current_platform.is_rocm(): + rocm_sparse_attn_prefill( + q=q[query_start:query_end], + kv=kv.view(-1, 1, q.shape[-1]), + indices=combined_indices.unsqueeze(1), + topk_length=combined_lens, + scale=self.scale, + head_dim=self.head_dim, + attn_sink=self.attn_sink, + output=output[query_start:query_end], + ) + else: + output_chunk, _, _ = flash_mla_sparse_fwd( + q=q[query_start:query_end], + kv=kv.view(-1, 1, q.shape[-1]), + indices=combined_indices.unsqueeze(1), + sm_scale=self.scale, + attn_sink=self.attn_sink, + topk_length=combined_lens, + out=output[query_start:query_end], + ) class DeepseekV4IndexerCache(torch.nn.Module, AttentionLayerBase):