From 77baca668ed2a694e592b41c07fb955b2882e4d2 Mon Sep 17 00:00:00 2001 From: biondizzle Date: Tue, 19 May 2026 06:30:18 +0000 Subject: [PATCH] Patch attention forward: BF16 inv RoPE + BMM wo_a + NVFP4 wo_b The original attention forward uses fused_inv_rope_fp8_quant + deepseek_v4_fp8_einsum which requires wo_a to have FP8 weights and weight_scale_inv. Our checkpoint has wo_a in BF16, so the original path crashes (produces empty output). Replace O projection with: 1. _apply_inv_rope_bf16: pure PyTorch inverse RoPE (no FP8) 2. BMM grouped linear for wo_a (BF16) 3. NVFP4 wo_b via CuTeDSL Also fixes activation global scale bug from previous commit: - input_global_scale_inv IS the activation gs, don't re-invert - w13_input_scale_orig (after undoing convert) IS the MoE gs Test: tests/test_o_projection.py validates inv RoPE roundtrip and wo_a BMM correctness. --- tests/test_o_projection.py | 159 ++++++++++++++++++++ vllm/patches/deepseek_v4_attention.py | 199 +++++++++++--------------- 2 files changed, 242 insertions(+), 116 deletions(-) create mode 100644 tests/test_o_projection.py diff --git a/tests/test_o_projection.py b/tests/test_o_projection.py new file mode 100644 index 00000000..ada2bed8 --- /dev/null +++ b/tests/test_o_projection.py @@ -0,0 +1,159 @@ +"""Test BF16 inverse RoPE + wo_a BMM (no GPU needed). + +Validates the O projection path we patched into the attention forward. +""" + +import torch +import math + + +def apply_inv_rope_bf16( + o: torch.Tensor, + positions: torch.Tensor, + cos_sin_cache: torch.Tensor, + nope_dim: int = 448, + rope_dim: int = 64, +) -> torch.Tensor: + """Same as the patched version in deepseek_v4_attention.py.""" + if rope_dim == 0 or o.numel() == 0: + return o + half_rope = rope_dim // 2 + + cos_all = cos_sin_cache[positions, :half_rope].unsqueeze(1).to(o.dtype) + sin_all = cos_sin_cache[positions, half_rope:].unsqueeze(1).to(o.dtype) + + o_rope = o[:, :, nope_dim:] + o_even = o_rope[:, :, 0::2] + o_odd = o_rope[:, :, 1::2] + + inv_even = o_even * cos_all + o_odd * sin_all + inv_odd = -o_even * sin_all + o_odd * cos_all + + result = o.clone() + result[:, :, nope_dim:][:, :, 0::2] = inv_even + result[:, :, nope_dim:][:, :, 1::2] = inv_odd + return result + + +def apply_gptj_rope( + x: torch.Tensor, + positions: torch.Tensor, + cos_sin_cache: torch.Tensor, + nope_dim: int = 448, + rope_dim: int = 64, +) -> torch.Tensor: + """Apply forward GPT-J style RoPE (for testing roundtrip).""" + half_rope = rope_dim // 2 + cos_all = cos_sin_cache[positions, :half_rope].unsqueeze(1).to(x.dtype) + sin_all = cos_sin_cache[positions, half_rope:].unsqueeze(1).to(x.dtype) + + x_rope = x[:, :, nope_dim:] + x_even = x_rope[:, :, 0::2] + x_odd = x_rope[:, :, 1::2] + + rot_even = x_even * cos_all - x_odd * sin_all + rot_odd = x_even * sin_all + x_odd * cos_all + + result = x.clone() + result[:, :, nope_dim:][:, :, 0::2] = rot_even + result[:, :, nope_dim:][:, :, 1::2] = rot_odd + return result + + +def test_inv_rope_roundtrip(): + """inv_rope(forward_rope(x)) should recover x.""" + torch.manual_seed(42) + T, H, D = 4, 8, 512 # tokens, heads, head_dim + nope_dim, rope_dim = 448, 64 + max_pos = 100 + + # Build cos_sin_cache for positions 0..max_pos + inv_freq = 1.0 / (10000.0 ** (torch.arange(0, rope_dim, 2).float() / rope_dim)) + t = torch.arange(max_pos, dtype=torch.float32) + freqs = torch.einsum("i,j -> ij", t, inv_freq) # (max_pos, half_rope) + cos_sin_cache = torch.cat([freqs.cos(), freqs.sin()], dim=-1) # (max_pos, rope_dim) + + x = torch.randn(T, H, D, dtype=torch.bfloat16) * 0.1 + positions = torch.tensor([0, 5, 10, 50], dtype=torch.int64) + + # Apply forward RoPE, then inverse + rotated = apply_gptj_rope(x, positions, cos_sin_cache, nope_dim, rope_dim) + recovered = apply_inv_rope_bf16(rotated, positions, cos_sin_cache, nope_dim, rope_dim) + + # NoPE portion unchanged + nope_diff = (recovered[:, :, :nope_dim] - x[:, :, :nope_dim]).abs().max().item() + assert nope_diff == 0, f"NoPE should be unchanged, max diff: {nope_diff}" + + # RoPE portion should roundtrip within BF16 precision + rope_diff = (recovered[:, :, nope_dim:] - x[:, :, nope_dim:]).abs().max().item() + assert rope_diff < 0.02, f"RoPE roundtrip error too high: {rope_diff}" + print(f"✅ inv_rope roundtrip: NoPE diff={nope_diff}, RoPE diff={rope_diff:.6f}") + + +def test_wo_a_bmm(): + """wo_a BMM should match einsum 'tgd,grd->tgr'.""" + torch.manual_seed(42) + T = 3 + n_local_groups = 4 + heads_per_group = 2 + head_dim = 512 + o_lora_rank = 128 + n_local_heads = n_local_groups * heads_per_group + + # wo_a weight: (n_groups * o_lora_rank, heads_per_group * head_dim) + wo_a_weight = torch.randn(n_local_groups * o_lora_rank, heads_per_group * head_dim, dtype=torch.bfloat16) + + # Attention output (after inv RoPE): (T, n_local_heads, head_dim) + o_inv = torch.randn(T, n_local_heads, head_dim, dtype=torch.bfloat16) + + # BMM path (our implementation) + hidden_dim = heads_per_group * head_dim + o_grouped = o_inv.view(T, n_local_groups, hidden_dim) + wo_a_w = wo_a_weight.view(n_local_groups, o_lora_rank, hidden_dim) + z_bmm = torch.bmm( + o_grouped.permute(1, 0, 2), + wo_a_w.transpose(1, 2), + ).permute(1, 0, 2) + + # Reference: einsum + o_for_einsum = o_inv.view(T, n_local_groups, hidden_dim).float() + wo_a_for_einsum = wo_a_w.float() + z_einsum = torch.einsum("tgd,grd->tgr", o_for_einsum, wo_a_for_einsum).bfloat16() + + diff = (z_bmm - z_einsum).abs().max().item() + assert diff < 0.01, f"wo_a BMM vs einsum diff: {diff}" + print(f"✅ wo_a BMM matches einsum: max diff={diff:.6f}") + + +def test_inv_rope_at_zero(): + """At position 0, cos=1, sin=0, so inv_rope should be identity on RoPE dims.""" + torch.manual_seed(42) + T, H, D = 2, 4, 512 + nope_dim, rope_dim = 448, 64 + + inv_freq = 1.0 / (10000.0 ** (torch.arange(0, rope_dim, 2).float() / rope_dim)) + t = torch.arange(10, dtype=torch.float32) + freqs = torch.einsum("i,j -> ij", t, inv_freq) + cos_sin_cache = torch.cat([freqs.cos(), freqs.sin()], dim=-1) # (10, rope_dim) + # At pos 0, cos=1, sin=0 + + x = torch.randn(T, H, D, dtype=torch.bfloat16) * 0.1 + positions = torch.zeros(T, dtype=torch.int64) + + # Forward RoPE at pos 0 should be identity (cos=1, sin=0) + rotated = apply_gptj_rope(x, positions, cos_sin_cache, nope_dim, rope_dim) + diff = (rotated - x).abs().max().item() + assert diff < 1e-5, f"RoPE at pos=0 should be identity, diff={diff}" + + # Inverse RoPE on unrotated input at pos 0 should also be identity + inv = apply_inv_rope_bf16(x, positions, cos_sin_cache, nope_dim, rope_dim) + diff2 = (inv - x).abs().max().item() + assert diff2 < 1e-5, f"inv RoPE at pos=0 should be identity, diff={diff2}" + print(f"✅ inv_rope at pos=0 is identity (diff={diff2:.8f})") + + +if __name__ == "__main__": + test_inv_rope_roundtrip() + test_wo_a_bmm() + test_inv_rope_at_zero() + print("\n✅ All attention O-projection tests passed") diff --git a/vllm/patches/deepseek_v4_attention.py b/vllm/patches/deepseek_v4_attention.py index e7b9cf38..e5337ebb 100644 --- a/vllm/patches/deepseek_v4_attention.py +++ b/vllm/patches/deepseek_v4_attention.py @@ -2,6 +2,9 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """ DeepseekV4 MLA Attention Layer + +Patched: O projection uses BF16 inverse RoPE + BMM wo_a + NVFP4 wo_b +instead of the original FP8 einsum path. """ from collections.abc import Callable @@ -14,6 +17,7 @@ import torch.nn.functional as F from transformers import DeepseekV2Config, DeepseekV3Config import vllm.envs as envs +from vllm.compilation.breakable_cudagraph import eager_break_during_capture from vllm.model_executor.layers.linear import ( ReplicatedLinear, ) @@ -40,16 +44,20 @@ from vllm.config import ( VllmConfig, get_current_vllm_config, ) -from vllm.distributed import (get_tensor_model_parallel_world_size, - tensor_model_parallel_all_gather) +from vllm.distributed import get_tensor_model_parallel_world_size from vllm.forward_context import ForwardContext, get_forward_context from vllm.logger import init_logger from vllm.model_executor.custom_op import PluggableLayer from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase from vllm.model_executor.layers.deepseek_compressor import DeepseekCompressor -from vllm.model_executor.layers.layernorm import LayerNorm, RMSNorm +from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.quantization import QuantizationConfig - +from vllm.model_executor.layers.quantization.input_quant_fp8 import ( + QuantFP8, +) +from vllm.model_executor.layers.quantization.utils.quant_utils import ( + GroupShape, +) from vllm.platforms import current_platform from vllm.utils.multi_stream_utils import ( execute_in_parallel, @@ -81,6 +89,45 @@ logger = init_logger(__name__) PREFILL_CHUNK_SIZE = 4 +# --------------------------------------------------------------------------- +# BF16 inverse RoPE (replaces fused_inv_rope_fp8_quant + FP8 einsum) +# --------------------------------------------------------------------------- + +def _apply_inv_rope_bf16( + o: torch.Tensor, + positions: torch.Tensor, + cos_sin_cache: torch.Tensor, + nope_dim: int = 448, + rope_dim: int = 64, +) -> torch.Tensor: + """Apply inverse RoPE to attention output in BF16. + + Pure-PyTorch replacement for fused_inv_rope_fp8_quant. + Only does inverse RoPE (no FP8 quant) since we use NVFP4 for wo_b. + """ + if rope_dim == 0 or o.numel() == 0: + return o + half_rope = rope_dim // 2 + + cos_all = cos_sin_cache[positions, :half_rope].unsqueeze(1).to(o.dtype) + sin_all = cos_sin_cache[positions, half_rope:].unsqueeze(1).to(o.dtype) + + o_rope = o[:, :, nope_dim:] + o_even = o_rope[:, :, 0::2] + o_odd = o_rope[:, :, 1::2] + + # Inverse rotation (conjugate): + # inv[2i] = x[2i] * cos + x[2i+1] * sin + # inv[2i+1] = -x[2i] * sin + x[2i+1] * cos + inv_even = o_even * cos_all + o_odd * sin_all + inv_odd = -o_even * sin_all + o_odd * cos_all + + result = o.clone() + result[:, :, nope_dim:][:, :, 0::2] = inv_even + result[:, :, nope_dim:][:, :, 1::2] = inv_odd + return result + + @dataclass class DeepseekV4MLAModules: """Modules used in DeepseekV4 MLA.""" @@ -182,6 +229,15 @@ class DeepseekV4MultiHeadLatentAttentionWrapper(PluggableLayer): self.kv_norm = mla_modules.kv_norm self.wo_a = mla_modules.wo_a + + self._wo_a_act_quant = QuantFP8( + static=False, + group_shape=GroupShape(1, 128), + use_ue8m0=True, + ) + # Bypass packed-for-deepgemm path — we need FP32 scales (not packed + # INT32) so fp8_einsum can handle layout transform internally. + self._wo_a_act_quant.use_deep_gemm_supported = False self.wo_b = mla_modules.wo_b # Pick fp8_einsum recipe based on GPU arch: @@ -291,91 +347,37 @@ class DeepseekV4MultiHeadLatentAttentionWrapper(PluggableLayer): ) o = o_padded[:, : self.n_local_heads, :] - # Keep ROCm on the BF16 reference wo_a path util kernel ready. - if current_platform.is_rocm(): - z = rocm_inv_rope_einsum( - self.rotary_emb, - o, - positions, - self.rope_head_dim, - self.n_local_groups, - self.o_lora_rank, - self.wo_a, - ) - return self.wo_b(z.flatten(1)) + # === O Projection (patched for NVFP4 + BF16 wo_a) === + # The original path uses fused_inv_rope_fp8_quant + FP8 einsum for wo_a, + # which requires wo_a to have FP8 weights and weight_scale_inv. + # Our checkpoint has wo_a in BF16 and wo_b in NVFP4. + # We replace with: inverse RoPE (BF16) + BMM wo_a + NVFP4 wo_b. - # Detect if wo_a has FP8 weights (weight_scale_inv attribute). - # NVFP4 checkpoints leave wo_a as BF16 (no quantization scales), - # so we use inverse RoPE in BF16 + regular matmul instead of - # the FP8 einsum path (which crashes on Blackwell SM100). - has_fp8_weights = hasattr(self.wo_a, 'weight_scale_inv') - - if not has_fp8_weights: - # BF16 wo_a path: inverse RoPE in BF16, then per-group BMM - # wo_a is a ColumnParallelLinear with is_bmm=True, meaning it - # operates per o-group. The FP8 path uses einsum "bhr,hdr->bhd" - # where h=n_local_groups. We must do the same grouping here. - o_inv = _apply_inv_rope_bf16( - o, positions, - self.rotary_emb.cos_sin_cache.to(torch.float32), - nope_dim=self.nope_head_dim, - rope_dim=self.rope_head_dim, - ) - heads_per_group = self.n_local_heads // self.n_local_groups - # o_inv: (num_tokens, n_local_heads, head_dim) - # -> (n_local_groups, num_tokens, heads_per_group * head_dim) - o_inv = o_inv.view( - num_tokens, self.n_local_groups, heads_per_group * self.head_dim - ).permute(1, 0, 2) - # wo_a weight is sharded by TP along output dim. - # Shape: (n_local_groups * o_lora_rank // tp, heads_per_group * head_dim) - # For BMM, we need weight shaped as (n_local_groups, o_lora_rank // tp, heads_per_group * head_dim) - wo_a_w = self.wo_a.weight.view( - self.n_local_groups, -1, heads_per_group * self.head_dim - ) - # BMM: (n_local_groups, num_tokens, in) @ (n_local_groups, in, out) -> (n_local_groups, num_tokens, out) - z = torch.bmm( - o_inv, - wo_a_w.transpose(1, 2), - ) - # -> (num_tokens, n_local_groups, o_lora_rank // tp) - z = z.permute(1, 0, 2) - # All-gather wo_a output across TP ranks, then flatten groups - if self.wo_a.gather_output and self.wo_a.tp_size > 1: - z = tensor_model_parallel_all_gather(z) - z = z.reshape(num_tokens, self.n_local_groups * self.o_lora_rank) - return self.wo_b(z) - - # FP8 wo_a path: fused inverse RoPE + FP8 quant + einsum - o_fp8, o_scale = fused_inv_rope_fp8_quant( + # Step 1: Inverse RoPE (BF16, pure PyTorch) + o_inv = _apply_inv_rope_bf16( o, positions, self.rotary_emb.cos_sin_cache, - n_groups=self.n_local_groups, - heads_per_group=self.n_local_heads // self.n_local_groups, nope_dim=self.nope_head_dim, rope_dim=self.rope_head_dim, - tma_aligned_scales=self._tma_aligned_scales, ) - wo_a_fp8 = self.wo_a.weight - wo_a_scale = self.wo_a.weight_scale_inv - - z = torch.empty( - (num_tokens, self.n_local_groups, self.o_lora_rank), - device=o.device, - dtype=torch.bfloat16, - ) - torch.ops.vllm.deepseek_v4_fp8_einsum( - o_fp8, - o_scale, - wo_a_fp8, - wo_a_scale, - z, - "bhr,hdr->bhd", - list(self._einsum_recipe), + # Step 2: wo_a grouped linear (BF16 BMM) + # o_inv: (T, n_local_heads, head_dim) + # wo_a.weight: (n_local_groups * o_lora_rank, heads_per_group * head_dim) BF16 + heads_per_group = self.n_local_heads // self.n_local_groups + hidden_dim = heads_per_group * self.head_dim + o_grouped = o_inv.view(num_tokens, self.n_local_groups, hidden_dim) + wo_a_w = self.wo_a.weight.view( + self.n_local_groups, self.o_lora_rank, hidden_dim ) + # BMM: (G, T, D) @ (G, D, R) → (G, T, R) → (T, G, R) + z = torch.bmm( + o_grouped.permute(1, 0, 2), + wo_a_w.transpose(1, 2), + ).permute(1, 0, 2) + # Step 3: wo_b (NVFP4 via CuTeDSL) return self.wo_b(z.flatten(1)) def attn_gemm_parallel_execute(self, hidden_states) -> tuple[Any, ...]: @@ -582,41 +584,7 @@ class DeepseekV4MultiHeadLatentAttentionWrapper(PluggableLayer): ) -def _apply_inv_rope_bf16( - o: torch.Tensor, - positions: torch.Tensor, - cos_sin_cache: torch.Tensor, - nope_dim: int, - rope_dim: int, -) -> torch.Tensor: - """Apply inverse RoPE to attention output in BF16. - - Inverse RoPE is just RoPE with sin -> -sin. - Uses GPT-J style (interleaved) rotary embedding. - """ - if rope_dim == 0 or o.numel() == 0: - return o - half_rot = rope_dim // 2 - o_f32 = o.to(torch.float32) - cache = cos_sin_cache.index_select(0, positions.to(torch.long)) - cos = cache[:, :half_rot].to(torch.float32) - sin = cache[:, half_rot : 2 * half_rot].to(torch.float32) - view_shape = (positions.shape[0], 1, half_rot) - cos = cos.view(view_shape) - sin = sin.view(view_shape) - rope = o_f32[..., nope_dim:] - y_even = rope[..., 0::2] - y_odd = rope[..., 1::2] - # Inverse: sin → -sin (swap signs on cross terms) - rope_out = torch.stack( - (y_even * cos + y_odd * sin, y_odd * cos - y_even * sin), - dim=-1, - ).flatten(-2) - o_f32 = o_f32.clone() - o_f32[..., nope_dim:] = rope_out - return o_f32.to(o.dtype) - - +@eager_break_during_capture def deepseek_v4_attention( hidden_states: torch.Tensor, positions: torch.Tensor, @@ -1170,10 +1138,9 @@ class DeepseekV4Indexer(nn.Module): hidden_size, self.n_head, bias=False, - quant_config=quant_config, + quant_config=None, prefix=f"{prefix}.weights_proj", ) - self.k_norm = LayerNorm(self.head_dim, eps=1e-6) self.softmax_scale = self.head_dim**-0.5 self.scale_fmt = "ue8m0"