Patch attention forward: BF16 inv RoPE + BMM wo_a + NVFP4 wo_b
The original attention forward uses fused_inv_rope_fp8_quant + deepseek_v4_fp8_einsum which requires wo_a to have FP8 weights and weight_scale_inv. Our checkpoint has wo_a in BF16, so the original path crashes (produces empty output). Replace O projection with: 1. _apply_inv_rope_bf16: pure PyTorch inverse RoPE (no FP8) 2. BMM grouped linear for wo_a (BF16) 3. NVFP4 wo_b via CuTeDSL Also fixes activation global scale bug from previous commit: - input_global_scale_inv IS the activation gs, don't re-invert - w13_input_scale_orig (after undoing convert) IS the MoE gs Test: tests/test_o_projection.py validates inv RoPE roundtrip and wo_a BMM correctness.
This commit is contained in:
159
tests/test_o_projection.py
Normal file
159
tests/test_o_projection.py
Normal file
@@ -0,0 +1,159 @@
|
||||
"""Test BF16 inverse RoPE + wo_a BMM (no GPU needed).
|
||||
|
||||
Validates the O projection path we patched into the attention forward.
|
||||
"""
|
||||
|
||||
import torch
|
||||
import math
|
||||
|
||||
|
||||
def apply_inv_rope_bf16(
|
||||
o: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
cos_sin_cache: torch.Tensor,
|
||||
nope_dim: int = 448,
|
||||
rope_dim: int = 64,
|
||||
) -> torch.Tensor:
|
||||
"""Same as the patched version in deepseek_v4_attention.py."""
|
||||
if rope_dim == 0 or o.numel() == 0:
|
||||
return o
|
||||
half_rope = rope_dim // 2
|
||||
|
||||
cos_all = cos_sin_cache[positions, :half_rope].unsqueeze(1).to(o.dtype)
|
||||
sin_all = cos_sin_cache[positions, half_rope:].unsqueeze(1).to(o.dtype)
|
||||
|
||||
o_rope = o[:, :, nope_dim:]
|
||||
o_even = o_rope[:, :, 0::2]
|
||||
o_odd = o_rope[:, :, 1::2]
|
||||
|
||||
inv_even = o_even * cos_all + o_odd * sin_all
|
||||
inv_odd = -o_even * sin_all + o_odd * cos_all
|
||||
|
||||
result = o.clone()
|
||||
result[:, :, nope_dim:][:, :, 0::2] = inv_even
|
||||
result[:, :, nope_dim:][:, :, 1::2] = inv_odd
|
||||
return result
|
||||
|
||||
|
||||
def apply_gptj_rope(
|
||||
x: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
cos_sin_cache: torch.Tensor,
|
||||
nope_dim: int = 448,
|
||||
rope_dim: int = 64,
|
||||
) -> torch.Tensor:
|
||||
"""Apply forward GPT-J style RoPE (for testing roundtrip)."""
|
||||
half_rope = rope_dim // 2
|
||||
cos_all = cos_sin_cache[positions, :half_rope].unsqueeze(1).to(x.dtype)
|
||||
sin_all = cos_sin_cache[positions, half_rope:].unsqueeze(1).to(x.dtype)
|
||||
|
||||
x_rope = x[:, :, nope_dim:]
|
||||
x_even = x_rope[:, :, 0::2]
|
||||
x_odd = x_rope[:, :, 1::2]
|
||||
|
||||
rot_even = x_even * cos_all - x_odd * sin_all
|
||||
rot_odd = x_even * sin_all + x_odd * cos_all
|
||||
|
||||
result = x.clone()
|
||||
result[:, :, nope_dim:][:, :, 0::2] = rot_even
|
||||
result[:, :, nope_dim:][:, :, 1::2] = rot_odd
|
||||
return result
|
||||
|
||||
|
||||
def test_inv_rope_roundtrip():
|
||||
"""inv_rope(forward_rope(x)) should recover x."""
|
||||
torch.manual_seed(42)
|
||||
T, H, D = 4, 8, 512 # tokens, heads, head_dim
|
||||
nope_dim, rope_dim = 448, 64
|
||||
max_pos = 100
|
||||
|
||||
# Build cos_sin_cache for positions 0..max_pos
|
||||
inv_freq = 1.0 / (10000.0 ** (torch.arange(0, rope_dim, 2).float() / rope_dim))
|
||||
t = torch.arange(max_pos, dtype=torch.float32)
|
||||
freqs = torch.einsum("i,j -> ij", t, inv_freq) # (max_pos, half_rope)
|
||||
cos_sin_cache = torch.cat([freqs.cos(), freqs.sin()], dim=-1) # (max_pos, rope_dim)
|
||||
|
||||
x = torch.randn(T, H, D, dtype=torch.bfloat16) * 0.1
|
||||
positions = torch.tensor([0, 5, 10, 50], dtype=torch.int64)
|
||||
|
||||
# Apply forward RoPE, then inverse
|
||||
rotated = apply_gptj_rope(x, positions, cos_sin_cache, nope_dim, rope_dim)
|
||||
recovered = apply_inv_rope_bf16(rotated, positions, cos_sin_cache, nope_dim, rope_dim)
|
||||
|
||||
# NoPE portion unchanged
|
||||
nope_diff = (recovered[:, :, :nope_dim] - x[:, :, :nope_dim]).abs().max().item()
|
||||
assert nope_diff == 0, f"NoPE should be unchanged, max diff: {nope_diff}"
|
||||
|
||||
# RoPE portion should roundtrip within BF16 precision
|
||||
rope_diff = (recovered[:, :, nope_dim:] - x[:, :, nope_dim:]).abs().max().item()
|
||||
assert rope_diff < 0.02, f"RoPE roundtrip error too high: {rope_diff}"
|
||||
print(f"✅ inv_rope roundtrip: NoPE diff={nope_diff}, RoPE diff={rope_diff:.6f}")
|
||||
|
||||
|
||||
def test_wo_a_bmm():
|
||||
"""wo_a BMM should match einsum 'tgd,grd->tgr'."""
|
||||
torch.manual_seed(42)
|
||||
T = 3
|
||||
n_local_groups = 4
|
||||
heads_per_group = 2
|
||||
head_dim = 512
|
||||
o_lora_rank = 128
|
||||
n_local_heads = n_local_groups * heads_per_group
|
||||
|
||||
# wo_a weight: (n_groups * o_lora_rank, heads_per_group * head_dim)
|
||||
wo_a_weight = torch.randn(n_local_groups * o_lora_rank, heads_per_group * head_dim, dtype=torch.bfloat16)
|
||||
|
||||
# Attention output (after inv RoPE): (T, n_local_heads, head_dim)
|
||||
o_inv = torch.randn(T, n_local_heads, head_dim, dtype=torch.bfloat16)
|
||||
|
||||
# BMM path (our implementation)
|
||||
hidden_dim = heads_per_group * head_dim
|
||||
o_grouped = o_inv.view(T, n_local_groups, hidden_dim)
|
||||
wo_a_w = wo_a_weight.view(n_local_groups, o_lora_rank, hidden_dim)
|
||||
z_bmm = torch.bmm(
|
||||
o_grouped.permute(1, 0, 2),
|
||||
wo_a_w.transpose(1, 2),
|
||||
).permute(1, 0, 2)
|
||||
|
||||
# Reference: einsum
|
||||
o_for_einsum = o_inv.view(T, n_local_groups, hidden_dim).float()
|
||||
wo_a_for_einsum = wo_a_w.float()
|
||||
z_einsum = torch.einsum("tgd,grd->tgr", o_for_einsum, wo_a_for_einsum).bfloat16()
|
||||
|
||||
diff = (z_bmm - z_einsum).abs().max().item()
|
||||
assert diff < 0.01, f"wo_a BMM vs einsum diff: {diff}"
|
||||
print(f"✅ wo_a BMM matches einsum: max diff={diff:.6f}")
|
||||
|
||||
|
||||
def test_inv_rope_at_zero():
|
||||
"""At position 0, cos=1, sin=0, so inv_rope should be identity on RoPE dims."""
|
||||
torch.manual_seed(42)
|
||||
T, H, D = 2, 4, 512
|
||||
nope_dim, rope_dim = 448, 64
|
||||
|
||||
inv_freq = 1.0 / (10000.0 ** (torch.arange(0, rope_dim, 2).float() / rope_dim))
|
||||
t = torch.arange(10, dtype=torch.float32)
|
||||
freqs = torch.einsum("i,j -> ij", t, inv_freq)
|
||||
cos_sin_cache = torch.cat([freqs.cos(), freqs.sin()], dim=-1) # (10, rope_dim)
|
||||
# At pos 0, cos=1, sin=0
|
||||
|
||||
x = torch.randn(T, H, D, dtype=torch.bfloat16) * 0.1
|
||||
positions = torch.zeros(T, dtype=torch.int64)
|
||||
|
||||
# Forward RoPE at pos 0 should be identity (cos=1, sin=0)
|
||||
rotated = apply_gptj_rope(x, positions, cos_sin_cache, nope_dim, rope_dim)
|
||||
diff = (rotated - x).abs().max().item()
|
||||
assert diff < 1e-5, f"RoPE at pos=0 should be identity, diff={diff}"
|
||||
|
||||
# Inverse RoPE on unrotated input at pos 0 should also be identity
|
||||
inv = apply_inv_rope_bf16(x, positions, cos_sin_cache, nope_dim, rope_dim)
|
||||
diff2 = (inv - x).abs().max().item()
|
||||
assert diff2 < 1e-5, f"inv RoPE at pos=0 should be identity, diff={diff2}"
|
||||
print(f"✅ inv_rope at pos=0 is identity (diff={diff2:.8f})")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_inv_rope_roundtrip()
|
||||
test_wo_a_bmm()
|
||||
test_inv_rope_at_zero()
|
||||
print("\n✅ All attention O-projection tests passed")
|
||||
@@ -2,6 +2,9 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""
|
||||
DeepseekV4 MLA Attention Layer
|
||||
|
||||
Patched: O projection uses BF16 inverse RoPE + BMM wo_a + NVFP4 wo_b
|
||||
instead of the original FP8 einsum path.
|
||||
"""
|
||||
|
||||
from collections.abc import Callable
|
||||
@@ -14,6 +17,7 @@ import torch.nn.functional as F
|
||||
from transformers import DeepseekV2Config, DeepseekV3Config
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.compilation.breakable_cudagraph import eager_break_during_capture
|
||||
from vllm.model_executor.layers.linear import (
|
||||
ReplicatedLinear,
|
||||
)
|
||||
@@ -40,16 +44,20 @@ from vllm.config import (
|
||||
VllmConfig,
|
||||
get_current_vllm_config,
|
||||
)
|
||||
from vllm.distributed import (get_tensor_model_parallel_world_size,
|
||||
tensor_model_parallel_all_gather)
|
||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||
from vllm.forward_context import ForwardContext, get_forward_context
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.custom_op import PluggableLayer
|
||||
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
|
||||
from vllm.model_executor.layers.deepseek_compressor import DeepseekCompressor
|
||||
from vllm.model_executor.layers.layernorm import LayerNorm, RMSNorm
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
|
||||
from vllm.model_executor.layers.quantization.input_quant_fp8 import (
|
||||
QuantFP8,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
GroupShape,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.multi_stream_utils import (
|
||||
execute_in_parallel,
|
||||
@@ -81,6 +89,45 @@ logger = init_logger(__name__)
|
||||
PREFILL_CHUNK_SIZE = 4
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# BF16 inverse RoPE (replaces fused_inv_rope_fp8_quant + FP8 einsum)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _apply_inv_rope_bf16(
|
||||
o: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
cos_sin_cache: torch.Tensor,
|
||||
nope_dim: int = 448,
|
||||
rope_dim: int = 64,
|
||||
) -> torch.Tensor:
|
||||
"""Apply inverse RoPE to attention output in BF16.
|
||||
|
||||
Pure-PyTorch replacement for fused_inv_rope_fp8_quant.
|
||||
Only does inverse RoPE (no FP8 quant) since we use NVFP4 for wo_b.
|
||||
"""
|
||||
if rope_dim == 0 or o.numel() == 0:
|
||||
return o
|
||||
half_rope = rope_dim // 2
|
||||
|
||||
cos_all = cos_sin_cache[positions, :half_rope].unsqueeze(1).to(o.dtype)
|
||||
sin_all = cos_sin_cache[positions, half_rope:].unsqueeze(1).to(o.dtype)
|
||||
|
||||
o_rope = o[:, :, nope_dim:]
|
||||
o_even = o_rope[:, :, 0::2]
|
||||
o_odd = o_rope[:, :, 1::2]
|
||||
|
||||
# Inverse rotation (conjugate):
|
||||
# inv[2i] = x[2i] * cos + x[2i+1] * sin
|
||||
# inv[2i+1] = -x[2i] * sin + x[2i+1] * cos
|
||||
inv_even = o_even * cos_all + o_odd * sin_all
|
||||
inv_odd = -o_even * sin_all + o_odd * cos_all
|
||||
|
||||
result = o.clone()
|
||||
result[:, :, nope_dim:][:, :, 0::2] = inv_even
|
||||
result[:, :, nope_dim:][:, :, 1::2] = inv_odd
|
||||
return result
|
||||
|
||||
|
||||
@dataclass
|
||||
class DeepseekV4MLAModules:
|
||||
"""Modules used in DeepseekV4 MLA."""
|
||||
@@ -182,6 +229,15 @@ class DeepseekV4MultiHeadLatentAttentionWrapper(PluggableLayer):
|
||||
|
||||
self.kv_norm = mla_modules.kv_norm
|
||||
self.wo_a = mla_modules.wo_a
|
||||
|
||||
self._wo_a_act_quant = QuantFP8(
|
||||
static=False,
|
||||
group_shape=GroupShape(1, 128),
|
||||
use_ue8m0=True,
|
||||
)
|
||||
# Bypass packed-for-deepgemm path — we need FP32 scales (not packed
|
||||
# INT32) so fp8_einsum can handle layout transform internally.
|
||||
self._wo_a_act_quant.use_deep_gemm_supported = False
|
||||
self.wo_b = mla_modules.wo_b
|
||||
|
||||
# Pick fp8_einsum recipe based on GPU arch:
|
||||
@@ -291,91 +347,37 @@ class DeepseekV4MultiHeadLatentAttentionWrapper(PluggableLayer):
|
||||
)
|
||||
o = o_padded[:, : self.n_local_heads, :]
|
||||
|
||||
# Keep ROCm on the BF16 reference wo_a path util kernel ready.
|
||||
if current_platform.is_rocm():
|
||||
z = rocm_inv_rope_einsum(
|
||||
self.rotary_emb,
|
||||
o,
|
||||
positions,
|
||||
self.rope_head_dim,
|
||||
self.n_local_groups,
|
||||
self.o_lora_rank,
|
||||
self.wo_a,
|
||||
)
|
||||
return self.wo_b(z.flatten(1))
|
||||
# === O Projection (patched for NVFP4 + BF16 wo_a) ===
|
||||
# The original path uses fused_inv_rope_fp8_quant + FP8 einsum for wo_a,
|
||||
# which requires wo_a to have FP8 weights and weight_scale_inv.
|
||||
# Our checkpoint has wo_a in BF16 and wo_b in NVFP4.
|
||||
# We replace with: inverse RoPE (BF16) + BMM wo_a + NVFP4 wo_b.
|
||||
|
||||
# Detect if wo_a has FP8 weights (weight_scale_inv attribute).
|
||||
# NVFP4 checkpoints leave wo_a as BF16 (no quantization scales),
|
||||
# so we use inverse RoPE in BF16 + regular matmul instead of
|
||||
# the FP8 einsum path (which crashes on Blackwell SM100).
|
||||
has_fp8_weights = hasattr(self.wo_a, 'weight_scale_inv')
|
||||
|
||||
if not has_fp8_weights:
|
||||
# BF16 wo_a path: inverse RoPE in BF16, then per-group BMM
|
||||
# wo_a is a ColumnParallelLinear with is_bmm=True, meaning it
|
||||
# operates per o-group. The FP8 path uses einsum "bhr,hdr->bhd"
|
||||
# where h=n_local_groups. We must do the same grouping here.
|
||||
o_inv = _apply_inv_rope_bf16(
|
||||
o, positions,
|
||||
self.rotary_emb.cos_sin_cache.to(torch.float32),
|
||||
nope_dim=self.nope_head_dim,
|
||||
rope_dim=self.rope_head_dim,
|
||||
)
|
||||
heads_per_group = self.n_local_heads // self.n_local_groups
|
||||
# o_inv: (num_tokens, n_local_heads, head_dim)
|
||||
# -> (n_local_groups, num_tokens, heads_per_group * head_dim)
|
||||
o_inv = o_inv.view(
|
||||
num_tokens, self.n_local_groups, heads_per_group * self.head_dim
|
||||
).permute(1, 0, 2)
|
||||
# wo_a weight is sharded by TP along output dim.
|
||||
# Shape: (n_local_groups * o_lora_rank // tp, heads_per_group * head_dim)
|
||||
# For BMM, we need weight shaped as (n_local_groups, o_lora_rank // tp, heads_per_group * head_dim)
|
||||
wo_a_w = self.wo_a.weight.view(
|
||||
self.n_local_groups, -1, heads_per_group * self.head_dim
|
||||
)
|
||||
# BMM: (n_local_groups, num_tokens, in) @ (n_local_groups, in, out) -> (n_local_groups, num_tokens, out)
|
||||
z = torch.bmm(
|
||||
o_inv,
|
||||
wo_a_w.transpose(1, 2),
|
||||
)
|
||||
# -> (num_tokens, n_local_groups, o_lora_rank // tp)
|
||||
z = z.permute(1, 0, 2)
|
||||
# All-gather wo_a output across TP ranks, then flatten groups
|
||||
if self.wo_a.gather_output and self.wo_a.tp_size > 1:
|
||||
z = tensor_model_parallel_all_gather(z)
|
||||
z = z.reshape(num_tokens, self.n_local_groups * self.o_lora_rank)
|
||||
return self.wo_b(z)
|
||||
|
||||
# FP8 wo_a path: fused inverse RoPE + FP8 quant + einsum
|
||||
o_fp8, o_scale = fused_inv_rope_fp8_quant(
|
||||
# Step 1: Inverse RoPE (BF16, pure PyTorch)
|
||||
o_inv = _apply_inv_rope_bf16(
|
||||
o,
|
||||
positions,
|
||||
self.rotary_emb.cos_sin_cache,
|
||||
n_groups=self.n_local_groups,
|
||||
heads_per_group=self.n_local_heads // self.n_local_groups,
|
||||
nope_dim=self.nope_head_dim,
|
||||
rope_dim=self.rope_head_dim,
|
||||
tma_aligned_scales=self._tma_aligned_scales,
|
||||
)
|
||||
|
||||
wo_a_fp8 = self.wo_a.weight
|
||||
wo_a_scale = self.wo_a.weight_scale_inv
|
||||
|
||||
z = torch.empty(
|
||||
(num_tokens, self.n_local_groups, self.o_lora_rank),
|
||||
device=o.device,
|
||||
dtype=torch.bfloat16,
|
||||
)
|
||||
torch.ops.vllm.deepseek_v4_fp8_einsum(
|
||||
o_fp8,
|
||||
o_scale,
|
||||
wo_a_fp8,
|
||||
wo_a_scale,
|
||||
z,
|
||||
"bhr,hdr->bhd",
|
||||
list(self._einsum_recipe),
|
||||
# Step 2: wo_a grouped linear (BF16 BMM)
|
||||
# o_inv: (T, n_local_heads, head_dim)
|
||||
# wo_a.weight: (n_local_groups * o_lora_rank, heads_per_group * head_dim) BF16
|
||||
heads_per_group = self.n_local_heads // self.n_local_groups
|
||||
hidden_dim = heads_per_group * self.head_dim
|
||||
o_grouped = o_inv.view(num_tokens, self.n_local_groups, hidden_dim)
|
||||
wo_a_w = self.wo_a.weight.view(
|
||||
self.n_local_groups, self.o_lora_rank, hidden_dim
|
||||
)
|
||||
# BMM: (G, T, D) @ (G, D, R) → (G, T, R) → (T, G, R)
|
||||
z = torch.bmm(
|
||||
o_grouped.permute(1, 0, 2),
|
||||
wo_a_w.transpose(1, 2),
|
||||
).permute(1, 0, 2)
|
||||
|
||||
# Step 3: wo_b (NVFP4 via CuTeDSL)
|
||||
return self.wo_b(z.flatten(1))
|
||||
|
||||
def attn_gemm_parallel_execute(self, hidden_states) -> tuple[Any, ...]:
|
||||
@@ -582,41 +584,7 @@ class DeepseekV4MultiHeadLatentAttentionWrapper(PluggableLayer):
|
||||
)
|
||||
|
||||
|
||||
def _apply_inv_rope_bf16(
|
||||
o: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
cos_sin_cache: torch.Tensor,
|
||||
nope_dim: int,
|
||||
rope_dim: int,
|
||||
) -> torch.Tensor:
|
||||
"""Apply inverse RoPE to attention output in BF16.
|
||||
|
||||
Inverse RoPE is just RoPE with sin -> -sin.
|
||||
Uses GPT-J style (interleaved) rotary embedding.
|
||||
"""
|
||||
if rope_dim == 0 or o.numel() == 0:
|
||||
return o
|
||||
half_rot = rope_dim // 2
|
||||
o_f32 = o.to(torch.float32)
|
||||
cache = cos_sin_cache.index_select(0, positions.to(torch.long))
|
||||
cos = cache[:, :half_rot].to(torch.float32)
|
||||
sin = cache[:, half_rot : 2 * half_rot].to(torch.float32)
|
||||
view_shape = (positions.shape[0], 1, half_rot)
|
||||
cos = cos.view(view_shape)
|
||||
sin = sin.view(view_shape)
|
||||
rope = o_f32[..., nope_dim:]
|
||||
y_even = rope[..., 0::2]
|
||||
y_odd = rope[..., 1::2]
|
||||
# Inverse: sin → -sin (swap signs on cross terms)
|
||||
rope_out = torch.stack(
|
||||
(y_even * cos + y_odd * sin, y_odd * cos - y_even * sin),
|
||||
dim=-1,
|
||||
).flatten(-2)
|
||||
o_f32 = o_f32.clone()
|
||||
o_f32[..., nope_dim:] = rope_out
|
||||
return o_f32.to(o.dtype)
|
||||
|
||||
|
||||
@eager_break_during_capture
|
||||
def deepseek_v4_attention(
|
||||
hidden_states: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
@@ -1170,10 +1138,9 @@ class DeepseekV4Indexer(nn.Module):
|
||||
hidden_size,
|
||||
self.n_head,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
quant_config=None,
|
||||
prefix=f"{prefix}.weights_proj",
|
||||
)
|
||||
self.k_norm = LayerNorm(self.head_dim, eps=1e-6)
|
||||
self.softmax_scale = self.head_dim**-0.5
|
||||
|
||||
self.scale_fmt = "ue8m0"
|
||||
|
||||
Reference in New Issue
Block a user