Fix attention patch: source from v0.21.0 stable, not local clone
The local vllm clone has different imports (breakable_cudagraph) that don't exist in the Docker image. Now sourced from v0.21.0 tag.
This commit is contained in:
@@ -31,7 +31,11 @@ from vllm.v1.attention.ops.deepseek_v4_ops import (
|
||||
fused_inv_rope_fp8_quant,
|
||||
fused_q_kv_rmsnorm,
|
||||
)
|
||||
from vllm.v1.attention.ops.rocm_aiter_mla_sparse import rocm_inv_rope_einsum
|
||||
from vllm.v1.attention.ops.rocm_aiter_mla_sparse import (
|
||||
rocm_forward_decode_fallback,
|
||||
rocm_inv_rope_einsum,
|
||||
rocm_sparse_attn_prefill,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.v1.attention.backends.mla.sparse_swa import (
|
||||
@@ -318,7 +322,6 @@ class DeepseekV4MultiHeadLatentAttentionWrapper(PluggableLayer):
|
||||
llama_4_scaling: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
# Pre-allocate attention output with FlashMLA-padded head count.
|
||||
# The op writes into `o_padded`; we slice to n_local_heads after.
|
||||
num_tokens = hidden_states.shape[0]
|
||||
o_padded = torch.empty(
|
||||
(num_tokens, self.padded_heads, self.head_dim),
|
||||
@@ -341,31 +344,22 @@ class DeepseekV4MultiHeadLatentAttentionWrapper(PluggableLayer):
|
||||
|
||||
# Step 1: Inverse RoPE (BF16, pure PyTorch)
|
||||
o_inv = _apply_inv_rope_bf16(
|
||||
o,
|
||||
positions,
|
||||
self.rotary_emb.cos_sin_cache,
|
||||
nope_dim=self.nope_head_dim,
|
||||
rope_dim=self.rope_head_dim,
|
||||
o, positions, self.rotary_emb.cos_sin_cache,
|
||||
nope_dim=self.nope_head_dim, rope_dim=self.rope_head_dim,
|
||||
)
|
||||
|
||||
# Step 2: wo_a grouped linear (BF16 BMM)
|
||||
# o_inv: (T, n_local_heads, head_dim)
|
||||
# wo_a.weight: (n_local_groups * o_lora_rank, heads_per_group * head_dim) BF16
|
||||
hidden_dim = self.wo_a.weight.shape[1]
|
||||
o_grouped = o_inv.view(num_tokens, self.n_local_groups, hidden_dim)
|
||||
wo_a_w = self.wo_a.weight.view(
|
||||
self.n_local_groups, self.o_lora_rank, hidden_dim
|
||||
)
|
||||
z = torch.bmm(
|
||||
o_grouped.permute(1, 0, 2),
|
||||
wo_a_w.transpose(1, 2),
|
||||
o_grouped.permute(1, 0, 2), wo_a_w.transpose(1, 2),
|
||||
).permute(1, 0, 2)
|
||||
|
||||
# Step 3: wo_b (NVFP4 via CuTeDSL)
|
||||
return self.wo_b(z.flatten(1))
|
||||
)
|
||||
|
||||
return self.wo_b(z.flatten(1))
|
||||
|
||||
def attn_gemm_parallel_execute(self, hidden_states) -> tuple[Any, ...]:
|
||||
aux_streams = self.aux_stream_list
|
||||
@@ -739,12 +733,6 @@ class DeepseekV4MLAAttention(nn.Module, AttentionLayerBase):
|
||||
self.kv_cache = torch.tensor([])
|
||||
|
||||
def get_attn_backend(self) -> type[AttentionBackend]:
|
||||
if current_platform.is_rocm():
|
||||
from vllm.v1.attention.backends.mla.rocm_aiter_mla_sparse_dsv4 import (
|
||||
DeepseekV4ROCMAiterMLASparseBackend,
|
||||
)
|
||||
|
||||
return DeepseekV4ROCMAiterMLASparseBackend
|
||||
return DeepseekV4FlashMLASparseBackend
|
||||
|
||||
def get_kv_cache_spec(self, vllm_config: VllmConfig) -> KVCacheSpec | None:
|
||||
@@ -777,14 +765,6 @@ class DeepseekV4MLAAttention(nn.Module, AttentionLayerBase):
|
||||
f"output buffer dtype {output.dtype} must match q dtype {q.dtype}"
|
||||
)
|
||||
|
||||
if current_platform.is_rocm():
|
||||
from vllm.v1.attention.backends.mla.rocm_aiter_mla_sparse_dsv4 import (
|
||||
DeepseekV4ROCMAiterMLASparseImpl,
|
||||
)
|
||||
|
||||
DeepseekV4ROCMAiterMLASparseImpl.forward(self, q, kv, positions, output)
|
||||
return
|
||||
|
||||
# Get SWA and indexer metadata from forward context
|
||||
forward_context = get_forward_context()
|
||||
attn_metadata = forward_context.attn_metadata
|
||||
@@ -867,6 +847,25 @@ class DeepseekV4MLAAttention(nn.Module, AttentionLayerBase):
|
||||
swa_indices = swa_metadata.decode_swa_indices
|
||||
swa_lens = swa_metadata.decode_swa_lens
|
||||
|
||||
if current_platform.is_rocm():
|
||||
rocm_forward_decode_fallback(
|
||||
q=q,
|
||||
kv_cache=kv_cache,
|
||||
swa_k_cache=self.swa_cache_layer.kv_cache,
|
||||
swa_only=swa_only,
|
||||
topk_indices=topk_indices,
|
||||
topk_lens=topk_lens,
|
||||
swa_indices=swa_indices,
|
||||
swa_lens=swa_lens,
|
||||
attn_sink=self.attn_sink,
|
||||
scale=self.scale,
|
||||
head_dim=self.head_dim,
|
||||
nope_head_dim=self.nope_head_dim,
|
||||
rope_head_dim=self.rope_head_dim,
|
||||
output=output,
|
||||
)
|
||||
return
|
||||
|
||||
# We treat queries in the same seq as different queries
|
||||
# and later we only attend by generated indices.
|
||||
# q arrives pre-padded to self.padded_heads by the outer wrapper.
|
||||
@@ -1030,15 +1029,28 @@ class DeepseekV4MLAAttention(nn.Module, AttentionLayerBase):
|
||||
M,
|
||||
N,
|
||||
)
|
||||
flash_mla_sparse_fwd(
|
||||
q=q[query_start:query_end],
|
||||
kv=kv.view(-1, 1, q.shape[-1]),
|
||||
indices=combined_indices.unsqueeze(1),
|
||||
sm_scale=self.scale,
|
||||
attn_sink=self.attn_sink,
|
||||
topk_length=combined_lens,
|
||||
out=output[query_start:query_end],
|
||||
)
|
||||
|
||||
if current_platform.is_rocm():
|
||||
rocm_sparse_attn_prefill(
|
||||
q=q[query_start:query_end],
|
||||
kv=kv.view(-1, 1, q.shape[-1]),
|
||||
indices=combined_indices.unsqueeze(1),
|
||||
topk_length=combined_lens,
|
||||
scale=self.scale,
|
||||
head_dim=self.head_dim,
|
||||
attn_sink=self.attn_sink,
|
||||
output=output[query_start:query_end],
|
||||
)
|
||||
else:
|
||||
output_chunk, _, _ = flash_mla_sparse_fwd(
|
||||
q=q[query_start:query_end],
|
||||
kv=kv.view(-1, 1, q.shape[-1]),
|
||||
indices=combined_indices.unsqueeze(1),
|
||||
sm_scale=self.scale,
|
||||
attn_sink=self.attn_sink,
|
||||
topk_length=combined_lens,
|
||||
out=output[query_start:query_end],
|
||||
)
|
||||
|
||||
|
||||
class DeepseekV4IndexerCache(torch.nn.Module, AttentionLayerBase):
|
||||
|
||||
Reference in New Issue
Block a user