Fix attention patch: source from v0.21.0 stable, not local clone

The local vllm clone has different imports (breakable_cudagraph) that
don't exist in the Docker image. Now sourced from v0.21.0 tag.
This commit is contained in:
2026-05-19 06:44:59 +00:00
parent 284b6a5d57
commit 4c2effa2be

View File

@@ -31,7 +31,11 @@ from vllm.v1.attention.ops.deepseek_v4_ops import (
fused_inv_rope_fp8_quant,
fused_q_kv_rmsnorm,
)
from vllm.v1.attention.ops.rocm_aiter_mla_sparse import rocm_inv_rope_einsum
from vllm.v1.attention.ops.rocm_aiter_mla_sparse import (
rocm_forward_decode_fallback,
rocm_inv_rope_einsum,
rocm_sparse_attn_prefill,
)
if TYPE_CHECKING:
from vllm.v1.attention.backends.mla.sparse_swa import (
@@ -318,7 +322,6 @@ class DeepseekV4MultiHeadLatentAttentionWrapper(PluggableLayer):
llama_4_scaling: torch.Tensor | None = None,
) -> torch.Tensor:
# Pre-allocate attention output with FlashMLA-padded head count.
# The op writes into `o_padded`; we slice to n_local_heads after.
num_tokens = hidden_states.shape[0]
o_padded = torch.empty(
(num_tokens, self.padded_heads, self.head_dim),
@@ -341,31 +344,22 @@ class DeepseekV4MultiHeadLatentAttentionWrapper(PluggableLayer):
# Step 1: Inverse RoPE (BF16, pure PyTorch)
o_inv = _apply_inv_rope_bf16(
o,
positions,
self.rotary_emb.cos_sin_cache,
nope_dim=self.nope_head_dim,
rope_dim=self.rope_head_dim,
o, positions, self.rotary_emb.cos_sin_cache,
nope_dim=self.nope_head_dim, rope_dim=self.rope_head_dim,
)
# Step 2: wo_a grouped linear (BF16 BMM)
# o_inv: (T, n_local_heads, head_dim)
# wo_a.weight: (n_local_groups * o_lora_rank, heads_per_group * head_dim) BF16
hidden_dim = self.wo_a.weight.shape[1]
o_grouped = o_inv.view(num_tokens, self.n_local_groups, hidden_dim)
wo_a_w = self.wo_a.weight.view(
self.n_local_groups, self.o_lora_rank, hidden_dim
)
z = torch.bmm(
o_grouped.permute(1, 0, 2),
wo_a_w.transpose(1, 2),
o_grouped.permute(1, 0, 2), wo_a_w.transpose(1, 2),
).permute(1, 0, 2)
# Step 3: wo_b (NVFP4 via CuTeDSL)
return self.wo_b(z.flatten(1))
)
return self.wo_b(z.flatten(1))
def attn_gemm_parallel_execute(self, hidden_states) -> tuple[Any, ...]:
aux_streams = self.aux_stream_list
@@ -739,12 +733,6 @@ class DeepseekV4MLAAttention(nn.Module, AttentionLayerBase):
self.kv_cache = torch.tensor([])
def get_attn_backend(self) -> type[AttentionBackend]:
if current_platform.is_rocm():
from vllm.v1.attention.backends.mla.rocm_aiter_mla_sparse_dsv4 import (
DeepseekV4ROCMAiterMLASparseBackend,
)
return DeepseekV4ROCMAiterMLASparseBackend
return DeepseekV4FlashMLASparseBackend
def get_kv_cache_spec(self, vllm_config: VllmConfig) -> KVCacheSpec | None:
@@ -777,14 +765,6 @@ class DeepseekV4MLAAttention(nn.Module, AttentionLayerBase):
f"output buffer dtype {output.dtype} must match q dtype {q.dtype}"
)
if current_platform.is_rocm():
from vllm.v1.attention.backends.mla.rocm_aiter_mla_sparse_dsv4 import (
DeepseekV4ROCMAiterMLASparseImpl,
)
DeepseekV4ROCMAiterMLASparseImpl.forward(self, q, kv, positions, output)
return
# Get SWA and indexer metadata from forward context
forward_context = get_forward_context()
attn_metadata = forward_context.attn_metadata
@@ -867,6 +847,25 @@ class DeepseekV4MLAAttention(nn.Module, AttentionLayerBase):
swa_indices = swa_metadata.decode_swa_indices
swa_lens = swa_metadata.decode_swa_lens
if current_platform.is_rocm():
rocm_forward_decode_fallback(
q=q,
kv_cache=kv_cache,
swa_k_cache=self.swa_cache_layer.kv_cache,
swa_only=swa_only,
topk_indices=topk_indices,
topk_lens=topk_lens,
swa_indices=swa_indices,
swa_lens=swa_lens,
attn_sink=self.attn_sink,
scale=self.scale,
head_dim=self.head_dim,
nope_head_dim=self.nope_head_dim,
rope_head_dim=self.rope_head_dim,
output=output,
)
return
# We treat queries in the same seq as different queries
# and later we only attend by generated indices.
# q arrives pre-padded to self.padded_heads by the outer wrapper.
@@ -1030,15 +1029,28 @@ class DeepseekV4MLAAttention(nn.Module, AttentionLayerBase):
M,
N,
)
flash_mla_sparse_fwd(
q=q[query_start:query_end],
kv=kv.view(-1, 1, q.shape[-1]),
indices=combined_indices.unsqueeze(1),
sm_scale=self.scale,
attn_sink=self.attn_sink,
topk_length=combined_lens,
out=output[query_start:query_end],
)
if current_platform.is_rocm():
rocm_sparse_attn_prefill(
q=q[query_start:query_end],
kv=kv.view(-1, 1, q.shape[-1]),
indices=combined_indices.unsqueeze(1),
topk_length=combined_lens,
scale=self.scale,
head_dim=self.head_dim,
attn_sink=self.attn_sink,
output=output[query_start:query_end],
)
else:
output_chunk, _, _ = flash_mla_sparse_fwd(
q=q[query_start:query_end],
kv=kv.view(-1, 1, q.shape[-1]),
indices=combined_indices.unsqueeze(1),
sm_scale=self.scale,
attn_sink=self.attn_sink,
topk_length=combined_lens,
out=output[query_start:query_end],
)
class DeepseekV4IndexerCache(torch.nn.Module, AttentionLayerBase):