Revert deepseek_v4_attention.py to ffc2264 — don't nuke existing patches

The file at ffc2264 already had our BF16 wo_a path (_apply_inv_rope_bf16 +
BMM + all-gather) with FP8 fallback. I was replacing it from the wrong
vllm source, losing all prior work. Restored to the known-good version.
This commit is contained in:
2026-05-19 06:52:40 +00:00
parent 4c2effa2be
commit 62abf41b03

View File

@@ -2,9 +2,6 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
DeepseekV4 MLA Attention Layer
Patched: O projection uses BF16 inverse RoPE + BMM wo_a + NVFP4 wo_b
instead of the original FP8 einsum path.
"""
from collections.abc import Callable
@@ -31,11 +28,7 @@ from vllm.v1.attention.ops.deepseek_v4_ops import (
fused_inv_rope_fp8_quant,
fused_q_kv_rmsnorm,
)
from vllm.v1.attention.ops.rocm_aiter_mla_sparse import (
rocm_forward_decode_fallback,
rocm_inv_rope_einsum,
rocm_sparse_attn_prefill,
)
from vllm.v1.attention.ops.rocm_aiter_mla_sparse import rocm_inv_rope_einsum
if TYPE_CHECKING:
from vllm.v1.attention.backends.mla.sparse_swa import (
@@ -47,7 +40,8 @@ from vllm.config import (
VllmConfig,
get_current_vllm_config,
)
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.distributed import (get_tensor_model_parallel_world_size,
tensor_model_parallel_all_gather)
from vllm.forward_context import ForwardContext, get_forward_context
from vllm.logger import init_logger
from vllm.model_executor.custom_op import PluggableLayer
@@ -55,12 +49,7 @@ from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
from vllm.model_executor.layers.deepseek_compressor import DeepseekCompressor
from vllm.model_executor.layers.layernorm import LayerNorm, RMSNorm
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.quantization.input_quant_fp8 import (
QuantFP8,
)
from vllm.model_executor.layers.quantization.utils.quant_utils import (
GroupShape,
)
from vllm.platforms import current_platform
from vllm.utils.multi_stream_utils import (
execute_in_parallel,
@@ -93,34 +82,6 @@ PREFILL_CHUNK_SIZE = 4
@dataclass
# ---------------------------------------------------------------------------
# BF16 inverse RoPE (replaces fused_inv_rope_fp8_quant for the O projection)
# ---------------------------------------------------------------------------
def _apply_inv_rope_bf16(
o: torch.Tensor,
positions: torch.Tensor,
cos_sin_cache: torch.Tensor,
nope_dim: int = 448,
rope_dim: int = 64,
) -> torch.Tensor:
"""Apply inverse RoPE to attention output in BF16."""
if rope_dim == 0 or o.numel() == 0:
return o
half_rope = rope_dim // 2
cos_all = cos_sin_cache[positions, :half_rope].unsqueeze(1).to(o.dtype)
sin_all = cos_sin_cache[positions, half_rope:].unsqueeze(1).to(o.dtype)
o_rope = o[:, :, nope_dim:]
o_even = o_rope[:, :, 0::2]
o_odd = o_rope[:, :, 1::2]
inv_even = o_even * cos_all + o_odd * sin_all
inv_odd = -o_even * sin_all + o_odd * cos_all
result = o.clone()
result[:, :, nope_dim:][:, :, 0::2] = inv_even
result[:, :, nope_dim:][:, :, 1::2] = inv_odd
return result
class DeepseekV4MLAModules:
"""Modules used in DeepseekV4 MLA."""
@@ -221,15 +182,6 @@ class DeepseekV4MultiHeadLatentAttentionWrapper(PluggableLayer):
self.kv_norm = mla_modules.kv_norm
self.wo_a = mla_modules.wo_a
self._wo_a_act_quant = QuantFP8(
static=False,
group_shape=GroupShape(1, 128),
use_ue8m0=True,
)
# Bypass packed-for-deepgemm path — we need FP32 scales (not packed
# INT32) so fp8_einsum can handle layout transform internally.
self._wo_a_act_quant.use_deep_gemm_supported = False
self.wo_b = mla_modules.wo_b
# Pick fp8_einsum recipe based on GPU arch:
@@ -322,6 +274,7 @@ class DeepseekV4MultiHeadLatentAttentionWrapper(PluggableLayer):
llama_4_scaling: torch.Tensor | None = None,
) -> torch.Tensor:
# Pre-allocate attention output with FlashMLA-padded head count.
# The op writes into `o_padded`; we slice to n_local_heads after.
num_tokens = hidden_states.shape[0]
o_padded = torch.empty(
(num_tokens, self.padded_heads, self.head_dim),
@@ -338,27 +291,91 @@ class DeepseekV4MultiHeadLatentAttentionWrapper(PluggableLayer):
)
o = o_padded[:, : self.n_local_heads, :]
# === O Projection (patched for BF16 wo_a + NVFP4 wo_b) ===
# Original path uses fused_inv_rope_fp8_quant + FP8 einsum, which
# requires wo_a.weight_scale_inv (doesn't exist for BF16 wo_a).
# Keep ROCm on the BF16 reference wo_a path util kernel ready.
if current_platform.is_rocm():
z = rocm_inv_rope_einsum(
self.rotary_emb,
o,
positions,
self.rope_head_dim,
self.n_local_groups,
self.o_lora_rank,
self.wo_a,
)
return self.wo_b(z.flatten(1))
# Step 1: Inverse RoPE (BF16, pure PyTorch)
o_inv = _apply_inv_rope_bf16(
o, positions, self.rotary_emb.cos_sin_cache,
nope_dim=self.nope_head_dim, rope_dim=self.rope_head_dim,
# Detect if wo_a has FP8 weights (weight_scale_inv attribute).
# NVFP4 checkpoints leave wo_a as BF16 (no quantization scales),
# so we use inverse RoPE in BF16 + regular matmul instead of
# the FP8 einsum path (which crashes on Blackwell SM100).
has_fp8_weights = hasattr(self.wo_a, 'weight_scale_inv')
if not has_fp8_weights:
# BF16 wo_a path: inverse RoPE in BF16, then per-group BMM
# wo_a is a ColumnParallelLinear with is_bmm=True, meaning it
# operates per o-group. The FP8 path uses einsum "bhr,hdr->bhd"
# where h=n_local_groups. We must do the same grouping here.
o_inv = _apply_inv_rope_bf16(
o, positions,
self.rotary_emb.cos_sin_cache.to(torch.float32),
nope_dim=self.nope_head_dim,
rope_dim=self.rope_head_dim,
)
heads_per_group = self.n_local_heads // self.n_local_groups
# o_inv: (num_tokens, n_local_heads, head_dim)
# -> (n_local_groups, num_tokens, heads_per_group * head_dim)
o_inv = o_inv.view(
num_tokens, self.n_local_groups, heads_per_group * self.head_dim
).permute(1, 0, 2)
# wo_a weight is sharded by TP along output dim.
# Shape: (n_local_groups * o_lora_rank // tp, heads_per_group * head_dim)
# For BMM, we need weight shaped as (n_local_groups, o_lora_rank // tp, heads_per_group * head_dim)
wo_a_w = self.wo_a.weight.view(
self.n_local_groups, -1, heads_per_group * self.head_dim
)
# BMM: (n_local_groups, num_tokens, in) @ (n_local_groups, in, out) -> (n_local_groups, num_tokens, out)
z = torch.bmm(
o_inv,
wo_a_w.transpose(1, 2),
)
# -> (num_tokens, n_local_groups, o_lora_rank // tp)
z = z.permute(1, 0, 2)
# All-gather wo_a output across TP ranks, then flatten groups
if self.wo_a.gather_output and self.wo_a.tp_size > 1:
z = tensor_model_parallel_all_gather(z)
z = z.reshape(num_tokens, self.n_local_groups * self.o_lora_rank)
return self.wo_b(z)
# FP8 wo_a path: fused inverse RoPE + FP8 quant + einsum
o_fp8, o_scale = fused_inv_rope_fp8_quant(
o,
positions,
self.rotary_emb.cos_sin_cache,
n_groups=self.n_local_groups,
heads_per_group=self.n_local_heads // self.n_local_groups,
nope_dim=self.nope_head_dim,
rope_dim=self.rope_head_dim,
tma_aligned_scales=self._tma_aligned_scales,
)
# Step 2: wo_a grouped linear (BF16 BMM)
hidden_dim = self.wo_a.weight.shape[1]
o_grouped = o_inv.view(num_tokens, self.n_local_groups, hidden_dim)
wo_a_w = self.wo_a.weight.view(
self.n_local_groups, self.o_lora_rank, hidden_dim
)
z = torch.bmm(
o_grouped.permute(1, 0, 2), wo_a_w.transpose(1, 2),
).permute(1, 0, 2)
wo_a_fp8 = self.wo_a.weight
wo_a_scale = self.wo_a.weight_scale_inv
z = torch.empty(
(num_tokens, self.n_local_groups, self.o_lora_rank),
device=o.device,
dtype=torch.bfloat16,
)
torch.ops.vllm.deepseek_v4_fp8_einsum(
o_fp8,
o_scale,
wo_a_fp8,
wo_a_scale,
z,
"bhr,hdr->bhd",
list(self._einsum_recipe),
)
# Step 3: wo_b (NVFP4 via CuTeDSL)
return self.wo_b(z.flatten(1))
def attn_gemm_parallel_execute(self, hidden_states) -> tuple[Any, ...]:
@@ -565,6 +582,41 @@ class DeepseekV4MultiHeadLatentAttentionWrapper(PluggableLayer):
)
def _apply_inv_rope_bf16(
o: torch.Tensor,
positions: torch.Tensor,
cos_sin_cache: torch.Tensor,
nope_dim: int,
rope_dim: int,
) -> torch.Tensor:
"""Apply inverse RoPE to attention output in BF16.
Inverse RoPE is just RoPE with sin -> -sin.
Uses GPT-J style (interleaved) rotary embedding.
"""
if rope_dim == 0 or o.numel() == 0:
return o
half_rot = rope_dim // 2
o_f32 = o.to(torch.float32)
cache = cos_sin_cache.index_select(0, positions.to(torch.long))
cos = cache[:, :half_rot].to(torch.float32)
sin = cache[:, half_rot : 2 * half_rot].to(torch.float32)
view_shape = (positions.shape[0], 1, half_rot)
cos = cos.view(view_shape)
sin = sin.view(view_shape)
rope = o_f32[..., nope_dim:]
y_even = rope[..., 0::2]
y_odd = rope[..., 1::2]
# Inverse: sin → -sin (swap signs on cross terms)
rope_out = torch.stack(
(y_even * cos + y_odd * sin, y_odd * cos - y_even * sin),
dim=-1,
).flatten(-2)
o_f32 = o_f32.clone()
o_f32[..., nope_dim:] = rope_out
return o_f32.to(o.dtype)
def deepseek_v4_attention(
hidden_states: torch.Tensor,
positions: torch.Tensor,
@@ -733,6 +785,12 @@ class DeepseekV4MLAAttention(nn.Module, AttentionLayerBase):
self.kv_cache = torch.tensor([])
def get_attn_backend(self) -> type[AttentionBackend]:
if current_platform.is_rocm():
from vllm.v1.attention.backends.mla.rocm_aiter_mla_sparse_dsv4 import (
DeepseekV4ROCMAiterMLASparseBackend,
)
return DeepseekV4ROCMAiterMLASparseBackend
return DeepseekV4FlashMLASparseBackend
def get_kv_cache_spec(self, vllm_config: VllmConfig) -> KVCacheSpec | None:
@@ -765,6 +823,14 @@ class DeepseekV4MLAAttention(nn.Module, AttentionLayerBase):
f"output buffer dtype {output.dtype} must match q dtype {q.dtype}"
)
if current_platform.is_rocm():
from vllm.v1.attention.backends.mla.rocm_aiter_mla_sparse_dsv4 import (
DeepseekV4ROCMAiterMLASparseImpl,
)
DeepseekV4ROCMAiterMLASparseImpl.forward(self, q, kv, positions, output)
return
# Get SWA and indexer metadata from forward context
forward_context = get_forward_context()
attn_metadata = forward_context.attn_metadata
@@ -847,25 +913,6 @@ class DeepseekV4MLAAttention(nn.Module, AttentionLayerBase):
swa_indices = swa_metadata.decode_swa_indices
swa_lens = swa_metadata.decode_swa_lens
if current_platform.is_rocm():
rocm_forward_decode_fallback(
q=q,
kv_cache=kv_cache,
swa_k_cache=self.swa_cache_layer.kv_cache,
swa_only=swa_only,
topk_indices=topk_indices,
topk_lens=topk_lens,
swa_indices=swa_indices,
swa_lens=swa_lens,
attn_sink=self.attn_sink,
scale=self.scale,
head_dim=self.head_dim,
nope_head_dim=self.nope_head_dim,
rope_head_dim=self.rope_head_dim,
output=output,
)
return
# We treat queries in the same seq as different queries
# and later we only attend by generated indices.
# q arrives pre-padded to self.padded_heads by the outer wrapper.
@@ -1029,28 +1076,15 @@ class DeepseekV4MLAAttention(nn.Module, AttentionLayerBase):
M,
N,
)
if current_platform.is_rocm():
rocm_sparse_attn_prefill(
q=q[query_start:query_end],
kv=kv.view(-1, 1, q.shape[-1]),
indices=combined_indices.unsqueeze(1),
topk_length=combined_lens,
scale=self.scale,
head_dim=self.head_dim,
attn_sink=self.attn_sink,
output=output[query_start:query_end],
)
else:
output_chunk, _, _ = flash_mla_sparse_fwd(
q=q[query_start:query_end],
kv=kv.view(-1, 1, q.shape[-1]),
indices=combined_indices.unsqueeze(1),
sm_scale=self.scale,
attn_sink=self.attn_sink,
topk_length=combined_lens,
out=output[query_start:query_end],
)
flash_mla_sparse_fwd(
q=q[query_start:query_end],
kv=kv.view(-1, 1, q.shape[-1]),
indices=combined_indices.unsqueeze(1),
sm_scale=self.scale,
attn_sink=self.attn_sink,
topk_length=combined_lens,
out=output[query_start:query_end],
)
class DeepseekV4IndexerCache(torch.nn.Module, AttentionLayerBase):
@@ -1136,7 +1170,7 @@ class DeepseekV4Indexer(nn.Module):
hidden_size,
self.n_head,
bias=False,
quant_config=None,
quant_config=quant_config,
prefix=f"{prefix}.weights_proj",
)
self.k_norm = LayerNorm(self.head_dim, eps=1e-6)