diff --git a/vllm/patches/deepseek_v4_attention.py b/vllm/patches/deepseek_v4_attention.py index a2bf25f1..e7b9cf38 100644 --- a/vllm/patches/deepseek_v4_attention.py +++ b/vllm/patches/deepseek_v4_attention.py @@ -2,9 +2,6 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """ DeepseekV4 MLA Attention Layer - -Patched: O projection uses BF16 inverse RoPE + BMM wo_a + NVFP4 wo_b -instead of the original FP8 einsum path. """ from collections.abc import Callable @@ -31,11 +28,7 @@ from vllm.v1.attention.ops.deepseek_v4_ops import ( fused_inv_rope_fp8_quant, fused_q_kv_rmsnorm, ) -from vllm.v1.attention.ops.rocm_aiter_mla_sparse import ( - rocm_forward_decode_fallback, - rocm_inv_rope_einsum, - rocm_sparse_attn_prefill, -) +from vllm.v1.attention.ops.rocm_aiter_mla_sparse import rocm_inv_rope_einsum if TYPE_CHECKING: from vllm.v1.attention.backends.mla.sparse_swa import ( @@ -47,7 +40,8 @@ from vllm.config import ( VllmConfig, get_current_vllm_config, ) -from vllm.distributed import get_tensor_model_parallel_world_size +from vllm.distributed import (get_tensor_model_parallel_world_size, + tensor_model_parallel_all_gather) from vllm.forward_context import ForwardContext, get_forward_context from vllm.logger import init_logger from vllm.model_executor.custom_op import PluggableLayer @@ -55,12 +49,7 @@ from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase from vllm.model_executor.layers.deepseek_compressor import DeepseekCompressor from vllm.model_executor.layers.layernorm import LayerNorm, RMSNorm from vllm.model_executor.layers.quantization import QuantizationConfig -from vllm.model_executor.layers.quantization.input_quant_fp8 import ( - QuantFP8, -) -from vllm.model_executor.layers.quantization.utils.quant_utils import ( - GroupShape, -) + from vllm.platforms import current_platform from vllm.utils.multi_stream_utils import ( execute_in_parallel, @@ -93,34 +82,6 @@ PREFILL_CHUNK_SIZE = 4 @dataclass -# --------------------------------------------------------------------------- -# BF16 inverse RoPE (replaces fused_inv_rope_fp8_quant for the O projection) -# --------------------------------------------------------------------------- - -def _apply_inv_rope_bf16( - o: torch.Tensor, - positions: torch.Tensor, - cos_sin_cache: torch.Tensor, - nope_dim: int = 448, - rope_dim: int = 64, -) -> torch.Tensor: - """Apply inverse RoPE to attention output in BF16.""" - if rope_dim == 0 or o.numel() == 0: - return o - half_rope = rope_dim // 2 - cos_all = cos_sin_cache[positions, :half_rope].unsqueeze(1).to(o.dtype) - sin_all = cos_sin_cache[positions, half_rope:].unsqueeze(1).to(o.dtype) - o_rope = o[:, :, nope_dim:] - o_even = o_rope[:, :, 0::2] - o_odd = o_rope[:, :, 1::2] - inv_even = o_even * cos_all + o_odd * sin_all - inv_odd = -o_even * sin_all + o_odd * cos_all - result = o.clone() - result[:, :, nope_dim:][:, :, 0::2] = inv_even - result[:, :, nope_dim:][:, :, 1::2] = inv_odd - return result - - class DeepseekV4MLAModules: """Modules used in DeepseekV4 MLA.""" @@ -221,15 +182,6 @@ class DeepseekV4MultiHeadLatentAttentionWrapper(PluggableLayer): self.kv_norm = mla_modules.kv_norm self.wo_a = mla_modules.wo_a - - self._wo_a_act_quant = QuantFP8( - static=False, - group_shape=GroupShape(1, 128), - use_ue8m0=True, - ) - # Bypass packed-for-deepgemm path — we need FP32 scales (not packed - # INT32) so fp8_einsum can handle layout transform internally. - self._wo_a_act_quant.use_deep_gemm_supported = False self.wo_b = mla_modules.wo_b # Pick fp8_einsum recipe based on GPU arch: @@ -322,6 +274,7 @@ class DeepseekV4MultiHeadLatentAttentionWrapper(PluggableLayer): llama_4_scaling: torch.Tensor | None = None, ) -> torch.Tensor: # Pre-allocate attention output with FlashMLA-padded head count. + # The op writes into `o_padded`; we slice to n_local_heads after. num_tokens = hidden_states.shape[0] o_padded = torch.empty( (num_tokens, self.padded_heads, self.head_dim), @@ -338,27 +291,91 @@ class DeepseekV4MultiHeadLatentAttentionWrapper(PluggableLayer): ) o = o_padded[:, : self.n_local_heads, :] - # === O Projection (patched for BF16 wo_a + NVFP4 wo_b) === - # Original path uses fused_inv_rope_fp8_quant + FP8 einsum, which - # requires wo_a.weight_scale_inv (doesn't exist for BF16 wo_a). + # Keep ROCm on the BF16 reference wo_a path util kernel ready. + if current_platform.is_rocm(): + z = rocm_inv_rope_einsum( + self.rotary_emb, + o, + positions, + self.rope_head_dim, + self.n_local_groups, + self.o_lora_rank, + self.wo_a, + ) + return self.wo_b(z.flatten(1)) - # Step 1: Inverse RoPE (BF16, pure PyTorch) - o_inv = _apply_inv_rope_bf16( - o, positions, self.rotary_emb.cos_sin_cache, - nope_dim=self.nope_head_dim, rope_dim=self.rope_head_dim, + # Detect if wo_a has FP8 weights (weight_scale_inv attribute). + # NVFP4 checkpoints leave wo_a as BF16 (no quantization scales), + # so we use inverse RoPE in BF16 + regular matmul instead of + # the FP8 einsum path (which crashes on Blackwell SM100). + has_fp8_weights = hasattr(self.wo_a, 'weight_scale_inv') + + if not has_fp8_weights: + # BF16 wo_a path: inverse RoPE in BF16, then per-group BMM + # wo_a is a ColumnParallelLinear with is_bmm=True, meaning it + # operates per o-group. The FP8 path uses einsum "bhr,hdr->bhd" + # where h=n_local_groups. We must do the same grouping here. + o_inv = _apply_inv_rope_bf16( + o, positions, + self.rotary_emb.cos_sin_cache.to(torch.float32), + nope_dim=self.nope_head_dim, + rope_dim=self.rope_head_dim, + ) + heads_per_group = self.n_local_heads // self.n_local_groups + # o_inv: (num_tokens, n_local_heads, head_dim) + # -> (n_local_groups, num_tokens, heads_per_group * head_dim) + o_inv = o_inv.view( + num_tokens, self.n_local_groups, heads_per_group * self.head_dim + ).permute(1, 0, 2) + # wo_a weight is sharded by TP along output dim. + # Shape: (n_local_groups * o_lora_rank // tp, heads_per_group * head_dim) + # For BMM, we need weight shaped as (n_local_groups, o_lora_rank // tp, heads_per_group * head_dim) + wo_a_w = self.wo_a.weight.view( + self.n_local_groups, -1, heads_per_group * self.head_dim + ) + # BMM: (n_local_groups, num_tokens, in) @ (n_local_groups, in, out) -> (n_local_groups, num_tokens, out) + z = torch.bmm( + o_inv, + wo_a_w.transpose(1, 2), + ) + # -> (num_tokens, n_local_groups, o_lora_rank // tp) + z = z.permute(1, 0, 2) + # All-gather wo_a output across TP ranks, then flatten groups + if self.wo_a.gather_output and self.wo_a.tp_size > 1: + z = tensor_model_parallel_all_gather(z) + z = z.reshape(num_tokens, self.n_local_groups * self.o_lora_rank) + return self.wo_b(z) + + # FP8 wo_a path: fused inverse RoPE + FP8 quant + einsum + o_fp8, o_scale = fused_inv_rope_fp8_quant( + o, + positions, + self.rotary_emb.cos_sin_cache, + n_groups=self.n_local_groups, + heads_per_group=self.n_local_heads // self.n_local_groups, + nope_dim=self.nope_head_dim, + rope_dim=self.rope_head_dim, + tma_aligned_scales=self._tma_aligned_scales, ) - # Step 2: wo_a grouped linear (BF16 BMM) - hidden_dim = self.wo_a.weight.shape[1] - o_grouped = o_inv.view(num_tokens, self.n_local_groups, hidden_dim) - wo_a_w = self.wo_a.weight.view( - self.n_local_groups, self.o_lora_rank, hidden_dim - ) - z = torch.bmm( - o_grouped.permute(1, 0, 2), wo_a_w.transpose(1, 2), - ).permute(1, 0, 2) + wo_a_fp8 = self.wo_a.weight + wo_a_scale = self.wo_a.weight_scale_inv + + z = torch.empty( + (num_tokens, self.n_local_groups, self.o_lora_rank), + device=o.device, + dtype=torch.bfloat16, + ) + torch.ops.vllm.deepseek_v4_fp8_einsum( + o_fp8, + o_scale, + wo_a_fp8, + wo_a_scale, + z, + "bhr,hdr->bhd", + list(self._einsum_recipe), + ) - # Step 3: wo_b (NVFP4 via CuTeDSL) return self.wo_b(z.flatten(1)) def attn_gemm_parallel_execute(self, hidden_states) -> tuple[Any, ...]: @@ -565,6 +582,41 @@ class DeepseekV4MultiHeadLatentAttentionWrapper(PluggableLayer): ) +def _apply_inv_rope_bf16( + o: torch.Tensor, + positions: torch.Tensor, + cos_sin_cache: torch.Tensor, + nope_dim: int, + rope_dim: int, +) -> torch.Tensor: + """Apply inverse RoPE to attention output in BF16. + + Inverse RoPE is just RoPE with sin -> -sin. + Uses GPT-J style (interleaved) rotary embedding. + """ + if rope_dim == 0 or o.numel() == 0: + return o + half_rot = rope_dim // 2 + o_f32 = o.to(torch.float32) + cache = cos_sin_cache.index_select(0, positions.to(torch.long)) + cos = cache[:, :half_rot].to(torch.float32) + sin = cache[:, half_rot : 2 * half_rot].to(torch.float32) + view_shape = (positions.shape[0], 1, half_rot) + cos = cos.view(view_shape) + sin = sin.view(view_shape) + rope = o_f32[..., nope_dim:] + y_even = rope[..., 0::2] + y_odd = rope[..., 1::2] + # Inverse: sin → -sin (swap signs on cross terms) + rope_out = torch.stack( + (y_even * cos + y_odd * sin, y_odd * cos - y_even * sin), + dim=-1, + ).flatten(-2) + o_f32 = o_f32.clone() + o_f32[..., nope_dim:] = rope_out + return o_f32.to(o.dtype) + + def deepseek_v4_attention( hidden_states: torch.Tensor, positions: torch.Tensor, @@ -733,6 +785,12 @@ class DeepseekV4MLAAttention(nn.Module, AttentionLayerBase): self.kv_cache = torch.tensor([]) def get_attn_backend(self) -> type[AttentionBackend]: + if current_platform.is_rocm(): + from vllm.v1.attention.backends.mla.rocm_aiter_mla_sparse_dsv4 import ( + DeepseekV4ROCMAiterMLASparseBackend, + ) + + return DeepseekV4ROCMAiterMLASparseBackend return DeepseekV4FlashMLASparseBackend def get_kv_cache_spec(self, vllm_config: VllmConfig) -> KVCacheSpec | None: @@ -765,6 +823,14 @@ class DeepseekV4MLAAttention(nn.Module, AttentionLayerBase): f"output buffer dtype {output.dtype} must match q dtype {q.dtype}" ) + if current_platform.is_rocm(): + from vllm.v1.attention.backends.mla.rocm_aiter_mla_sparse_dsv4 import ( + DeepseekV4ROCMAiterMLASparseImpl, + ) + + DeepseekV4ROCMAiterMLASparseImpl.forward(self, q, kv, positions, output) + return + # Get SWA and indexer metadata from forward context forward_context = get_forward_context() attn_metadata = forward_context.attn_metadata @@ -847,25 +913,6 @@ class DeepseekV4MLAAttention(nn.Module, AttentionLayerBase): swa_indices = swa_metadata.decode_swa_indices swa_lens = swa_metadata.decode_swa_lens - if current_platform.is_rocm(): - rocm_forward_decode_fallback( - q=q, - kv_cache=kv_cache, - swa_k_cache=self.swa_cache_layer.kv_cache, - swa_only=swa_only, - topk_indices=topk_indices, - topk_lens=topk_lens, - swa_indices=swa_indices, - swa_lens=swa_lens, - attn_sink=self.attn_sink, - scale=self.scale, - head_dim=self.head_dim, - nope_head_dim=self.nope_head_dim, - rope_head_dim=self.rope_head_dim, - output=output, - ) - return - # We treat queries in the same seq as different queries # and later we only attend by generated indices. # q arrives pre-padded to self.padded_heads by the outer wrapper. @@ -1029,28 +1076,15 @@ class DeepseekV4MLAAttention(nn.Module, AttentionLayerBase): M, N, ) - - if current_platform.is_rocm(): - rocm_sparse_attn_prefill( - q=q[query_start:query_end], - kv=kv.view(-1, 1, q.shape[-1]), - indices=combined_indices.unsqueeze(1), - topk_length=combined_lens, - scale=self.scale, - head_dim=self.head_dim, - attn_sink=self.attn_sink, - output=output[query_start:query_end], - ) - else: - output_chunk, _, _ = flash_mla_sparse_fwd( - q=q[query_start:query_end], - kv=kv.view(-1, 1, q.shape[-1]), - indices=combined_indices.unsqueeze(1), - sm_scale=self.scale, - attn_sink=self.attn_sink, - topk_length=combined_lens, - out=output[query_start:query_end], - ) + flash_mla_sparse_fwd( + q=q[query_start:query_end], + kv=kv.view(-1, 1, q.shape[-1]), + indices=combined_indices.unsqueeze(1), + sm_scale=self.scale, + attn_sink=self.attn_sink, + topk_length=combined_lens, + out=output[query_start:query_end], + ) class DeepseekV4IndexerCache(torch.nn.Module, AttentionLayerBase): @@ -1136,7 +1170,7 @@ class DeepseekV4Indexer(nn.Module): hidden_size, self.n_head, bias=False, - quant_config=None, + quant_config=quant_config, prefix=f"{prefix}.weights_proj", ) self.k_norm = LayerNorm(self.head_dim, eps=1e-6)