Files
nvfp4-megamoe-kernel/vllm/patches/layers/deepseek_compressor.py
biondizzle f74447bfd0 Proper NVFP4 integration: quantized compressor/indexer + mapper fixes
Weight mapper fixes:
- Reorder substr renames: compressor renames first, then .self_attn.compressor.
  → .attn.mla_attn.compressor., then indexer renames (so indexer keys end up
  under mla_attn after the compressor rename already fired)
- Add compressor param renames: kv_proj→wkv, gate_proj→wgate, kv_norm→norm,
  position_bias→ape (checkpoint uses NVFP4 naming, model uses internal names)
- Add indexer param renames: q_b_proj→wq_b, kv_proj→compressor.wkv,
  gate_proj→compressor.wgate, kv_norm→k_norm, position_bias→compressor.ape,
  weights_proj stays (structural: compressor.indexer → indexer.compressor)
- Remove broken suffix renames (already fixed in prior commit)

Model architecture fixes:
- Patch deepseek_compressor.py to pass quant_config (was None, but NVFP4
  checkpoint has quantized compressor weights with input_scale/weight_scale)
- Patch deepseek_v4_attention.py indexer: weights_proj now uses quant_config
  (was None, but checkpoint has quantized weights)
- Add indexer.compressor.fused_wkv_wgate stacking in load_weights

Infrastructure:
- Add deepseek_compressor.py to Dockerfile
- Force MoE backend to flashinfer_cutedsl (was auto-selecting FLASHINFER_TRTLLM)
- Update unit test to 50 cases (compressor + indexer + quantization scales)
2026-05-18 23:20:13 +00:00

437 lines
15 KiB
Python

# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from dataclasses import dataclass
from typing import Any, ClassVar, cast
import torch
from torch import nn
from vllm.config import VllmConfig, get_current_vllm_config
from vllm.forward_context import get_forward_context
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (
MergedColumnParallelLinear,
)
from vllm.platforms import current_platform
from vllm.triton_utils import tl, triton
from vllm.v1.attention.backend import (
AttentionBackend,
AttentionCGSupport,
AttentionMetadataBuilder,
CommonAttentionMetadata,
MultipleOf,
)
from vllm.v1.attention.ops.deepseek_v4_ops.fused_compress_quant_cache import (
_fused_kv_compress_norm_rope_insert_indexer_attn,
_fused_kv_compress_norm_rope_insert_indexer_mxfp4_attn,
_fused_kv_compress_norm_rope_insert_sparse_attn,
)
from vllm.v1.attention.ops.deepseek_v4_ops.fused_indexer_q import (
MXFP4_BLOCK_SIZE,
)
from vllm.v1.kv_cache_interface import (
KVCacheSpec,
MLAAttentionSpec,
SlidingWindowMLASpec,
)
class CompressorBackend(AttentionBackend):
def __init__(self):
super().__init__()
@staticmethod
def get_name() -> str:
return "CompressorBackend"
@staticmethod
def get_supported_kernel_block_sizes() -> list[int | MultipleOf]:
return [MultipleOf(1)]
@classmethod
def get_supported_head_sizes(cls) -> list[int]:
return [512, 1024]
@staticmethod
def get_builder_cls() -> type["CompressorMetadataBuilder"]:
return CompressorMetadataBuilder
@staticmethod
def get_kv_cache_shape(
num_blocks: int,
block_size: int,
num_kv_heads: int,
head_size: int,
cache_dtype_str: str = "auto",
) -> tuple[int, ...]:
assert num_kv_heads == 1
return (num_blocks, block_size, head_size)
@staticmethod
def get_kv_cache_stride_order(
include_num_layers_dimension: bool = False,
) -> tuple[int, ...]:
if include_num_layers_dimension:
return (0, 1, 2, 3)
return (0, 1, 2)
@dataclass
class CompressorMetadata:
block_table: torch.Tensor
slot_mapping: torch.Tensor
block_size: int
token_to_req_indices: torch.Tensor | None = None # [num_tokens]
class CompressorMetadataBuilder(AttentionMetadataBuilder):
_cudagraph_support: ClassVar[AttentionCGSupport] = AttentionCGSupport.ALWAYS
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
assert isinstance(self.kv_cache_spec, SlidingWindowMLASpec | MLAAttentionSpec)
mla_spec = cast(SlidingWindowMLASpec | MLAAttentionSpec, self.kv_cache_spec)
self.block_size = mla_spec.block_size
self.token_to_req_indices = torch.zeros(
self.vllm_config.scheduler_config.max_num_batched_tokens,
dtype=torch.int32,
device=self.device,
)
def build(
self,
common_prefix_len: int,
common_attn_metadata: CommonAttentionMetadata,
fast_build: bool = False,
) -> CompressorMetadata:
query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu
num_reqs = common_attn_metadata.num_reqs
query_lens = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1]
x = torch.repeat_interleave(torch.arange(num_reqs), query_lens).pin_memory()
token_to_req_indices = self.token_to_req_indices[: x.shape[0]]
token_to_req_indices.copy_(x, non_blocking=True)
return CompressorMetadata(
block_table=common_attn_metadata.block_table_tensor.clamp_(min=0),
slot_mapping=common_attn_metadata.slot_mapping,
block_size=self.block_size,
token_to_req_indices=token_to_req_indices,
)
class CompressorStateCache(torch.nn.Module, AttentionLayerBase):
def __init__(
self,
state_dim: int,
dtype: torch.dtype,
compress_ratio: int,
prefix: str,
):
super().__init__()
self.state_dim = state_dim
self.dtype = dtype
self.prefix = prefix
self.kv_cache = torch.tensor([])
compilation_config = get_current_vllm_config().compilation_config
if prefix in compilation_config.static_forward_context:
raise ValueError(f"Duplicate layer name: {prefix}")
compilation_config.static_forward_context[prefix] = self
assert self.dtype == torch.float32
assert compress_ratio in [4, 128]
coff = 1 + (compress_ratio == 4)
self.sliding_window = coff * compress_ratio
# Block size is constrained by tensor sharing between compressor states
# and KV blocks. Since compressor states share the same physical tensor
# as KV blocks, they must use the same page size.
# The KV block shape [256//4, head_dim] = [64, 584] determines:
# - C4 compressor block shape [4, 2*512*2*4] -> block_size = 4
# - C128 compressor block shape [8, 512*2*4] -> block_size = 8
# TODO(yifan): make block size automatically determined and configurable.
if compress_ratio == 4:
self.block_size = 4
elif compress_ratio == 128:
self.block_size = 8
else:
raise ValueError(f"Invalid compress ratio: {compress_ratio}")
def get_kv_cache_spec(self, vllm_config: VllmConfig) -> KVCacheSpec:
return SlidingWindowMLASpec( # only has one vector instead of K + V
block_size=self.block_size,
num_kv_heads=1,
head_size=self.state_dim,
dtype=self.dtype,
sliding_window=self.sliding_window,
alignment=576, # NOTE: FlashMLA requires 576B alignment
)
def forward(self): ...
def get_attn_backend(self) -> type[AttentionBackend]:
return CompressorBackend
class DeepseekCompressor(nn.Module):
def __init__(
self,
vllm_config: VllmConfig,
compress_ratio: int,
hidden_size: int,
head_dim: int,
rotate: bool = False,
prefix: str = "",
k_cache_prefix="",
use_fp4_cache: bool = False,
):
super().__init__()
self.compress_ratio = compress_ratio
self.hidden_size = hidden_size
self.head_dim = head_dim
self.rotate = rotate
self.prefix = prefix
self.k_cache_prefix = k_cache_prefix
self.use_fp4_cache = use_fp4_cache
config = vllm_config.model_config.hf_config
self.rope_head_dim = config.qk_rope_head_dim
self.nope_head_dim = self.head_dim - self.rope_head_dim
self.rms_norm_eps = config.rms_norm_eps
self.device = current_platform.device_type
self.max_num_reqs = vllm_config.scheduler_config.max_num_seqs
self.max_model_len = vllm_config.model_config.max_model_len
self.overlap = compress_ratio == 4
self.coff = 1 + self.overlap
state_dtype = torch.float32
self.ape = nn.Parameter(
torch.empty(
(compress_ratio, self.coff * self.head_dim),
dtype=state_dtype,
device=self.device,
),
requires_grad=False,
)
quant_config = vllm_config.quant_config
self.fused_wkv_wgate = MergedColumnParallelLinear(
self.hidden_size,
[self.coff * self.head_dim, self.coff * self.head_dim],
bias=False,
return_bias=False,
quant_config=quant_config,
disable_tp=True,
prefix=f"{prefix}.fused_wkv_wgate",
)
self.norm = RMSNorm(self.head_dim, self.rms_norm_eps)
self.state_cache = CompressorStateCache(
state_dim=2 * self.coff * self.head_dim, # kv_state + score_state
dtype=state_dtype,
compress_ratio=compress_ratio,
prefix=f"{prefix}.state_cache",
)
# Save reference to static_forward_context for forward-time KV cache lookup.
# get_current_vllm_config() is only available during __init__, not forward.
self._static_forward_context = (
vllm_config.compilation_config.static_forward_context
)
if self.head_dim == 512:
assert not use_fp4_cache, (
"MXFP4 cache is only supported for indexer (head=128)"
)
self._fused_kernel = _fused_kv_compress_norm_rope_insert_sparse_attn
self._quant_block = 64
self._token_stride = self.nope_head_dim + self.rope_head_dim * 2
self._scale_dim = self.nope_head_dim // 64 + 1 # 7 real + 1 pad
self._num_warps = 4
elif self.head_dim == 128:
if use_fp4_cache:
self._fused_kernel = (
_fused_kv_compress_norm_rope_insert_indexer_mxfp4_attn
)
self._quant_block = MXFP4_BLOCK_SIZE
self._token_stride = self.head_dim // 2
self._scale_dim = self.head_dim // MXFP4_BLOCK_SIZE
else:
self._fused_kernel = _fused_kv_compress_norm_rope_insert_indexer_attn
self._quant_block = 128
self._token_stride = self.head_dim
self._scale_dim = 4 # single float32 scale
self._num_warps = 1
else:
raise ValueError(
f"Unsupported head_dim for fused quant+cache: {self.head_dim}"
)
def forward(
self,
# [num_tokens, 2 * self.coff * self.head_dim]
kv_score: torch.Tensor,
# [num_tokens]
positions: torch.Tensor,
rotary_emb,
) -> None:
# Each of shape [num_tokens, coff * self.head_dim]
# input bf16, output are fp32
kv, score = kv_score.split(
[self.coff * self.head_dim, self.coff * self.head_dim], dim=-1
)
# Get the metadata and handle dummy profiling run.
attn_metadata = get_forward_context().attn_metadata
if not isinstance(attn_metadata, dict):
return
state_metadata = cast(
CompressorMetadata, attn_metadata[self.state_cache.prefix]
)
token_to_req_indices = state_metadata.token_to_req_indices
slot_mapping = state_metadata.slot_mapping
num_actual = slot_mapping.shape[0]
block_table = state_metadata.block_table
block_size = state_metadata.block_size
# [num_blocks, block_size, kv_dim+score_dim], where kv_dim == score_dim
state_cache = self.state_cache.kv_cache
# kv_state stored in first half, score_state stored in second half
state_width = state_cache.shape[-1] // 2
pdl_kwargs = {} if current_platform.is_rocm() else {"launch_pdl": False}
# Store the KV and score (with fused APE addition) in the state.
# NOTE: PDL is disabled — both this kernel and _fused_kernel below
# depend on preceding kernel outputs (kv/score from the cublas GEMM;
# state_cache from this kernel) but neither emits/waits on PDL grid
# dependency primitives, so launch_pdl=True caused a read-after-write
# race and non-deterministic output.
_save_partial_states_kernel[(num_actual,)](
kv,
kv.stride(0),
score,
score.stride(0),
self.ape,
self.ape.stride(0),
positions,
state_cache,
state_cache.stride(0),
state_cache.stride(1),
slot_mapping,
block_size,
HEAD_SIZE=kv.shape[-1],
TRITON_BLOCK_SIZE=triton.next_power_of_2(kv.shape[-1]),
STATE_WIDTH=state_width,
COMPRESS_RATIO=self.compress_ratio,
**pdl_kwargs,
)
# Fused: compress → RMSNorm → RoPE → FP8 quant → KV cache write.
# RoPE requirements (kernel applies forward GPT-J style rotation):
# - is_neox_style=False (interleaved pairs, NOT split-half)
# - cos_sin_cache layout: [max_pos, rope_head_dim] with first half cos,
# second half sin (per-pair, length rope_head_dim // 2 each)
# - applied to LAST rope_head_dim elements of head_dim
# - position used: (positions // compress_ratio) * compress_ratio
cos_sin_cache = rotary_emb.cos_sin_cache
k_cache_metadata = cast(Any, attn_metadata[self.k_cache_prefix])
kv_cache = self._static_forward_context[self.k_cache_prefix].kv_cache
self._fused_kernel[(num_actual,)](
# state cache
state_cache,
state_cache.stride(0),
state_cache.stride(1),
# metadata
token_to_req_indices,
positions,
slot_mapping,
block_table,
block_table.stride(0),
block_size,
# RMSNorm
self.norm.weight,
self.rms_norm_eps,
# RoPE
cos_sin_cache,
cos_sin_cache.stride(0),
# KV cache
kv_cache,
k_cache_metadata.slot_mapping,
kv_cache.shape[1], # paged KV cache block size (tokens per block)
# constexprs
HEAD_SIZE=self.head_dim,
TRITON_BLOCK_SIZE=triton.next_power_of_2(self.head_dim),
STATE_WIDTH=state_width,
COMPRESS_RATIO=self.compress_ratio,
OVERLAP=self.overlap,
ROPE_HEAD_DIM=self.rope_head_dim,
FP8_MAX=448.0,
QUANT_BLOCK=self._quant_block,
TOKEN_STRIDE=self._token_stride,
SCALE_DIM=self._scale_dim,
KV_BLOCK_STRIDE=kv_cache.stride(0),
num_warps=self._num_warps,
**pdl_kwargs,
)
@triton.jit
def _save_partial_states_kernel(
kv_ptr,
kv_stride,
score_ptr,
score_stride,
ape_ptr,
ape_stride,
positions_ptr,
state_cache_ptr,
state_cache_stride0,
state_cache_stride1,
slot_mapping_ptr,
block_size,
HEAD_SIZE: tl.constexpr,
TRITON_BLOCK_SIZE: tl.constexpr,
# state_cache last dim packs [kv_state, score_state], each STATE_WIDTH wide.
STATE_WIDTH: tl.constexpr,
COMPRESS_RATIO: tl.constexpr,
):
token_idx = tl.program_id(0)
slot_id = tl.load(slot_mapping_ptr + token_idx)
# Skip padded / invalid tokens (slot_id == -1 is the PAD sentinel used
# by vLLM). During CUDA graph replay the batch may contain padding
# tokens whose slot_mapping is -1; writing to kv_state[-1] would be an
# illegal memory access.
if slot_id < 0:
return
block_idx = slot_id // block_size
pos_in_block = slot_id % block_size
base_ptr = (
state_cache_ptr
+ block_idx * state_cache_stride0
+ pos_in_block * state_cache_stride1
)
block = tl.arange(0, TRITON_BLOCK_SIZE)
mask = block < HEAD_SIZE
kv = tl.load(kv_ptr + token_idx * kv_stride + block, mask=mask)
tl.store(base_ptr + block, kv, mask=mask)
# Fused: score += ape[position % compress_ratio]
position = tl.load(positions_ptr + token_idx)
ape_row = position % COMPRESS_RATIO
ape = tl.load(ape_ptr + ape_row * ape_stride + block, mask=mask)
score = tl.load(score_ptr + token_idx * score_stride + block, mask=mask)
tl.store(
base_ptr + STATE_WIDTH + block,
score + ape,
mask=mask,
)