Weight mapper fixes: - Reorder substr renames: compressor renames first, then .self_attn.compressor. → .attn.mla_attn.compressor., then indexer renames (so indexer keys end up under mla_attn after the compressor rename already fired) - Add compressor param renames: kv_proj→wkv, gate_proj→wgate, kv_norm→norm, position_bias→ape (checkpoint uses NVFP4 naming, model uses internal names) - Add indexer param renames: q_b_proj→wq_b, kv_proj→compressor.wkv, gate_proj→compressor.wgate, kv_norm→k_norm, position_bias→compressor.ape, weights_proj stays (structural: compressor.indexer → indexer.compressor) - Remove broken suffix renames (already fixed in prior commit) Model architecture fixes: - Patch deepseek_compressor.py to pass quant_config (was None, but NVFP4 checkpoint has quantized compressor weights with input_scale/weight_scale) - Patch deepseek_v4_attention.py indexer: weights_proj now uses quant_config (was None, but checkpoint has quantized weights) - Add indexer.compressor.fused_wkv_wgate stacking in load_weights Infrastructure: - Add deepseek_compressor.py to Dockerfile - Force MoE backend to flashinfer_cutedsl (was auto-selecting FLASHINFER_TRTLLM) - Update unit test to 50 cases (compressor + indexer + quantization scales)
437 lines
15 KiB
Python
437 lines
15 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
|
|
from dataclasses import dataclass
|
|
from typing import Any, ClassVar, cast
|
|
|
|
import torch
|
|
from torch import nn
|
|
|
|
from vllm.config import VllmConfig, get_current_vllm_config
|
|
from vllm.forward_context import get_forward_context
|
|
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
|
|
from vllm.model_executor.layers.layernorm import RMSNorm
|
|
from vllm.model_executor.layers.linear import (
|
|
MergedColumnParallelLinear,
|
|
)
|
|
from vllm.platforms import current_platform
|
|
from vllm.triton_utils import tl, triton
|
|
from vllm.v1.attention.backend import (
|
|
AttentionBackend,
|
|
AttentionCGSupport,
|
|
AttentionMetadataBuilder,
|
|
CommonAttentionMetadata,
|
|
MultipleOf,
|
|
)
|
|
from vllm.v1.attention.ops.deepseek_v4_ops.fused_compress_quant_cache import (
|
|
_fused_kv_compress_norm_rope_insert_indexer_attn,
|
|
_fused_kv_compress_norm_rope_insert_indexer_mxfp4_attn,
|
|
_fused_kv_compress_norm_rope_insert_sparse_attn,
|
|
)
|
|
from vllm.v1.attention.ops.deepseek_v4_ops.fused_indexer_q import (
|
|
MXFP4_BLOCK_SIZE,
|
|
)
|
|
from vllm.v1.kv_cache_interface import (
|
|
KVCacheSpec,
|
|
MLAAttentionSpec,
|
|
SlidingWindowMLASpec,
|
|
)
|
|
|
|
|
|
class CompressorBackend(AttentionBackend):
|
|
def __init__(self):
|
|
super().__init__()
|
|
|
|
@staticmethod
|
|
def get_name() -> str:
|
|
return "CompressorBackend"
|
|
|
|
@staticmethod
|
|
def get_supported_kernel_block_sizes() -> list[int | MultipleOf]:
|
|
return [MultipleOf(1)]
|
|
|
|
@classmethod
|
|
def get_supported_head_sizes(cls) -> list[int]:
|
|
return [512, 1024]
|
|
|
|
@staticmethod
|
|
def get_builder_cls() -> type["CompressorMetadataBuilder"]:
|
|
return CompressorMetadataBuilder
|
|
|
|
@staticmethod
|
|
def get_kv_cache_shape(
|
|
num_blocks: int,
|
|
block_size: int,
|
|
num_kv_heads: int,
|
|
head_size: int,
|
|
cache_dtype_str: str = "auto",
|
|
) -> tuple[int, ...]:
|
|
assert num_kv_heads == 1
|
|
return (num_blocks, block_size, head_size)
|
|
|
|
@staticmethod
|
|
def get_kv_cache_stride_order(
|
|
include_num_layers_dimension: bool = False,
|
|
) -> tuple[int, ...]:
|
|
if include_num_layers_dimension:
|
|
return (0, 1, 2, 3)
|
|
return (0, 1, 2)
|
|
|
|
|
|
@dataclass
|
|
class CompressorMetadata:
|
|
block_table: torch.Tensor
|
|
slot_mapping: torch.Tensor
|
|
block_size: int
|
|
|
|
token_to_req_indices: torch.Tensor | None = None # [num_tokens]
|
|
|
|
|
|
class CompressorMetadataBuilder(AttentionMetadataBuilder):
|
|
_cudagraph_support: ClassVar[AttentionCGSupport] = AttentionCGSupport.ALWAYS
|
|
|
|
def __init__(self, *args, **kwargs):
|
|
super().__init__(*args, **kwargs)
|
|
assert isinstance(self.kv_cache_spec, SlidingWindowMLASpec | MLAAttentionSpec)
|
|
mla_spec = cast(SlidingWindowMLASpec | MLAAttentionSpec, self.kv_cache_spec)
|
|
self.block_size = mla_spec.block_size
|
|
|
|
self.token_to_req_indices = torch.zeros(
|
|
self.vllm_config.scheduler_config.max_num_batched_tokens,
|
|
dtype=torch.int32,
|
|
device=self.device,
|
|
)
|
|
|
|
def build(
|
|
self,
|
|
common_prefix_len: int,
|
|
common_attn_metadata: CommonAttentionMetadata,
|
|
fast_build: bool = False,
|
|
) -> CompressorMetadata:
|
|
query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu
|
|
num_reqs = common_attn_metadata.num_reqs
|
|
query_lens = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1]
|
|
x = torch.repeat_interleave(torch.arange(num_reqs), query_lens).pin_memory()
|
|
token_to_req_indices = self.token_to_req_indices[: x.shape[0]]
|
|
token_to_req_indices.copy_(x, non_blocking=True)
|
|
return CompressorMetadata(
|
|
block_table=common_attn_metadata.block_table_tensor.clamp_(min=0),
|
|
slot_mapping=common_attn_metadata.slot_mapping,
|
|
block_size=self.block_size,
|
|
token_to_req_indices=token_to_req_indices,
|
|
)
|
|
|
|
|
|
class CompressorStateCache(torch.nn.Module, AttentionLayerBase):
|
|
def __init__(
|
|
self,
|
|
state_dim: int,
|
|
dtype: torch.dtype,
|
|
compress_ratio: int,
|
|
prefix: str,
|
|
):
|
|
super().__init__()
|
|
self.state_dim = state_dim
|
|
self.dtype = dtype
|
|
self.prefix = prefix
|
|
self.kv_cache = torch.tensor([])
|
|
compilation_config = get_current_vllm_config().compilation_config
|
|
if prefix in compilation_config.static_forward_context:
|
|
raise ValueError(f"Duplicate layer name: {prefix}")
|
|
compilation_config.static_forward_context[prefix] = self
|
|
|
|
assert self.dtype == torch.float32
|
|
assert compress_ratio in [4, 128]
|
|
coff = 1 + (compress_ratio == 4)
|
|
self.sliding_window = coff * compress_ratio
|
|
# Block size is constrained by tensor sharing between compressor states
|
|
# and KV blocks. Since compressor states share the same physical tensor
|
|
# as KV blocks, they must use the same page size.
|
|
# The KV block shape [256//4, head_dim] = [64, 584] determines:
|
|
# - C4 compressor block shape [4, 2*512*2*4] -> block_size = 4
|
|
# - C128 compressor block shape [8, 512*2*4] -> block_size = 8
|
|
# TODO(yifan): make block size automatically determined and configurable.
|
|
if compress_ratio == 4:
|
|
self.block_size = 4
|
|
elif compress_ratio == 128:
|
|
self.block_size = 8
|
|
else:
|
|
raise ValueError(f"Invalid compress ratio: {compress_ratio}")
|
|
|
|
def get_kv_cache_spec(self, vllm_config: VllmConfig) -> KVCacheSpec:
|
|
return SlidingWindowMLASpec( # only has one vector instead of K + V
|
|
block_size=self.block_size,
|
|
num_kv_heads=1,
|
|
head_size=self.state_dim,
|
|
dtype=self.dtype,
|
|
sliding_window=self.sliding_window,
|
|
alignment=576, # NOTE: FlashMLA requires 576B alignment
|
|
)
|
|
|
|
def forward(self): ...
|
|
|
|
def get_attn_backend(self) -> type[AttentionBackend]:
|
|
return CompressorBackend
|
|
|
|
|
|
class DeepseekCompressor(nn.Module):
|
|
def __init__(
|
|
self,
|
|
vllm_config: VllmConfig,
|
|
compress_ratio: int,
|
|
hidden_size: int,
|
|
head_dim: int,
|
|
rotate: bool = False,
|
|
prefix: str = "",
|
|
k_cache_prefix="",
|
|
use_fp4_cache: bool = False,
|
|
):
|
|
super().__init__()
|
|
self.compress_ratio = compress_ratio
|
|
self.hidden_size = hidden_size
|
|
self.head_dim = head_dim
|
|
self.rotate = rotate
|
|
self.prefix = prefix
|
|
self.k_cache_prefix = k_cache_prefix
|
|
self.use_fp4_cache = use_fp4_cache
|
|
|
|
config = vllm_config.model_config.hf_config
|
|
self.rope_head_dim = config.qk_rope_head_dim
|
|
self.nope_head_dim = self.head_dim - self.rope_head_dim
|
|
self.rms_norm_eps = config.rms_norm_eps
|
|
self.device = current_platform.device_type
|
|
self.max_num_reqs = vllm_config.scheduler_config.max_num_seqs
|
|
self.max_model_len = vllm_config.model_config.max_model_len
|
|
|
|
self.overlap = compress_ratio == 4
|
|
self.coff = 1 + self.overlap
|
|
|
|
state_dtype = torch.float32
|
|
self.ape = nn.Parameter(
|
|
torch.empty(
|
|
(compress_ratio, self.coff * self.head_dim),
|
|
dtype=state_dtype,
|
|
device=self.device,
|
|
),
|
|
requires_grad=False,
|
|
)
|
|
|
|
quant_config = vllm_config.quant_config
|
|
|
|
self.fused_wkv_wgate = MergedColumnParallelLinear(
|
|
self.hidden_size,
|
|
[self.coff * self.head_dim, self.coff * self.head_dim],
|
|
bias=False,
|
|
return_bias=False,
|
|
quant_config=quant_config,
|
|
disable_tp=True,
|
|
prefix=f"{prefix}.fused_wkv_wgate",
|
|
)
|
|
self.norm = RMSNorm(self.head_dim, self.rms_norm_eps)
|
|
|
|
self.state_cache = CompressorStateCache(
|
|
state_dim=2 * self.coff * self.head_dim, # kv_state + score_state
|
|
dtype=state_dtype,
|
|
compress_ratio=compress_ratio,
|
|
prefix=f"{prefix}.state_cache",
|
|
)
|
|
|
|
# Save reference to static_forward_context for forward-time KV cache lookup.
|
|
# get_current_vllm_config() is only available during __init__, not forward.
|
|
self._static_forward_context = (
|
|
vllm_config.compilation_config.static_forward_context
|
|
)
|
|
|
|
if self.head_dim == 512:
|
|
assert not use_fp4_cache, (
|
|
"MXFP4 cache is only supported for indexer (head=128)"
|
|
)
|
|
self._fused_kernel = _fused_kv_compress_norm_rope_insert_sparse_attn
|
|
self._quant_block = 64
|
|
self._token_stride = self.nope_head_dim + self.rope_head_dim * 2
|
|
self._scale_dim = self.nope_head_dim // 64 + 1 # 7 real + 1 pad
|
|
self._num_warps = 4
|
|
elif self.head_dim == 128:
|
|
if use_fp4_cache:
|
|
self._fused_kernel = (
|
|
_fused_kv_compress_norm_rope_insert_indexer_mxfp4_attn
|
|
)
|
|
self._quant_block = MXFP4_BLOCK_SIZE
|
|
self._token_stride = self.head_dim // 2
|
|
self._scale_dim = self.head_dim // MXFP4_BLOCK_SIZE
|
|
else:
|
|
self._fused_kernel = _fused_kv_compress_norm_rope_insert_indexer_attn
|
|
self._quant_block = 128
|
|
self._token_stride = self.head_dim
|
|
self._scale_dim = 4 # single float32 scale
|
|
self._num_warps = 1
|
|
else:
|
|
raise ValueError(
|
|
f"Unsupported head_dim for fused quant+cache: {self.head_dim}"
|
|
)
|
|
|
|
def forward(
|
|
self,
|
|
# [num_tokens, 2 * self.coff * self.head_dim]
|
|
kv_score: torch.Tensor,
|
|
# [num_tokens]
|
|
positions: torch.Tensor,
|
|
rotary_emb,
|
|
) -> None:
|
|
# Each of shape [num_tokens, coff * self.head_dim]
|
|
# input bf16, output are fp32
|
|
kv, score = kv_score.split(
|
|
[self.coff * self.head_dim, self.coff * self.head_dim], dim=-1
|
|
)
|
|
|
|
# Get the metadata and handle dummy profiling run.
|
|
attn_metadata = get_forward_context().attn_metadata
|
|
if not isinstance(attn_metadata, dict):
|
|
return
|
|
|
|
state_metadata = cast(
|
|
CompressorMetadata, attn_metadata[self.state_cache.prefix]
|
|
)
|
|
token_to_req_indices = state_metadata.token_to_req_indices
|
|
slot_mapping = state_metadata.slot_mapping
|
|
num_actual = slot_mapping.shape[0]
|
|
block_table = state_metadata.block_table
|
|
block_size = state_metadata.block_size
|
|
|
|
# [num_blocks, block_size, kv_dim+score_dim], where kv_dim == score_dim
|
|
state_cache = self.state_cache.kv_cache
|
|
# kv_state stored in first half, score_state stored in second half
|
|
state_width = state_cache.shape[-1] // 2
|
|
pdl_kwargs = {} if current_platform.is_rocm() else {"launch_pdl": False}
|
|
|
|
# Store the KV and score (with fused APE addition) in the state.
|
|
# NOTE: PDL is disabled — both this kernel and _fused_kernel below
|
|
# depend on preceding kernel outputs (kv/score from the cublas GEMM;
|
|
# state_cache from this kernel) but neither emits/waits on PDL grid
|
|
# dependency primitives, so launch_pdl=True caused a read-after-write
|
|
# race and non-deterministic output.
|
|
_save_partial_states_kernel[(num_actual,)](
|
|
kv,
|
|
kv.stride(0),
|
|
score,
|
|
score.stride(0),
|
|
self.ape,
|
|
self.ape.stride(0),
|
|
positions,
|
|
state_cache,
|
|
state_cache.stride(0),
|
|
state_cache.stride(1),
|
|
slot_mapping,
|
|
block_size,
|
|
HEAD_SIZE=kv.shape[-1],
|
|
TRITON_BLOCK_SIZE=triton.next_power_of_2(kv.shape[-1]),
|
|
STATE_WIDTH=state_width,
|
|
COMPRESS_RATIO=self.compress_ratio,
|
|
**pdl_kwargs,
|
|
)
|
|
|
|
# Fused: compress → RMSNorm → RoPE → FP8 quant → KV cache write.
|
|
# RoPE requirements (kernel applies forward GPT-J style rotation):
|
|
# - is_neox_style=False (interleaved pairs, NOT split-half)
|
|
# - cos_sin_cache layout: [max_pos, rope_head_dim] with first half cos,
|
|
# second half sin (per-pair, length rope_head_dim // 2 each)
|
|
# - applied to LAST rope_head_dim elements of head_dim
|
|
# - position used: (positions // compress_ratio) * compress_ratio
|
|
cos_sin_cache = rotary_emb.cos_sin_cache
|
|
k_cache_metadata = cast(Any, attn_metadata[self.k_cache_prefix])
|
|
kv_cache = self._static_forward_context[self.k_cache_prefix].kv_cache
|
|
|
|
self._fused_kernel[(num_actual,)](
|
|
# state cache
|
|
state_cache,
|
|
state_cache.stride(0),
|
|
state_cache.stride(1),
|
|
# metadata
|
|
token_to_req_indices,
|
|
positions,
|
|
slot_mapping,
|
|
block_table,
|
|
block_table.stride(0),
|
|
block_size,
|
|
# RMSNorm
|
|
self.norm.weight,
|
|
self.rms_norm_eps,
|
|
# RoPE
|
|
cos_sin_cache,
|
|
cos_sin_cache.stride(0),
|
|
# KV cache
|
|
kv_cache,
|
|
k_cache_metadata.slot_mapping,
|
|
kv_cache.shape[1], # paged KV cache block size (tokens per block)
|
|
# constexprs
|
|
HEAD_SIZE=self.head_dim,
|
|
TRITON_BLOCK_SIZE=triton.next_power_of_2(self.head_dim),
|
|
STATE_WIDTH=state_width,
|
|
COMPRESS_RATIO=self.compress_ratio,
|
|
OVERLAP=self.overlap,
|
|
ROPE_HEAD_DIM=self.rope_head_dim,
|
|
FP8_MAX=448.0,
|
|
QUANT_BLOCK=self._quant_block,
|
|
TOKEN_STRIDE=self._token_stride,
|
|
SCALE_DIM=self._scale_dim,
|
|
KV_BLOCK_STRIDE=kv_cache.stride(0),
|
|
num_warps=self._num_warps,
|
|
**pdl_kwargs,
|
|
)
|
|
|
|
|
|
@triton.jit
|
|
def _save_partial_states_kernel(
|
|
kv_ptr,
|
|
kv_stride,
|
|
score_ptr,
|
|
score_stride,
|
|
ape_ptr,
|
|
ape_stride,
|
|
positions_ptr,
|
|
state_cache_ptr,
|
|
state_cache_stride0,
|
|
state_cache_stride1,
|
|
slot_mapping_ptr,
|
|
block_size,
|
|
HEAD_SIZE: tl.constexpr,
|
|
TRITON_BLOCK_SIZE: tl.constexpr,
|
|
# state_cache last dim packs [kv_state, score_state], each STATE_WIDTH wide.
|
|
STATE_WIDTH: tl.constexpr,
|
|
COMPRESS_RATIO: tl.constexpr,
|
|
):
|
|
token_idx = tl.program_id(0)
|
|
slot_id = tl.load(slot_mapping_ptr + token_idx)
|
|
|
|
# Skip padded / invalid tokens (slot_id == -1 is the PAD sentinel used
|
|
# by vLLM). During CUDA graph replay the batch may contain padding
|
|
# tokens whose slot_mapping is -1; writing to kv_state[-1] would be an
|
|
# illegal memory access.
|
|
if slot_id < 0:
|
|
return
|
|
|
|
block_idx = slot_id // block_size
|
|
pos_in_block = slot_id % block_size
|
|
base_ptr = (
|
|
state_cache_ptr
|
|
+ block_idx * state_cache_stride0
|
|
+ pos_in_block * state_cache_stride1
|
|
)
|
|
|
|
block = tl.arange(0, TRITON_BLOCK_SIZE)
|
|
mask = block < HEAD_SIZE
|
|
|
|
kv = tl.load(kv_ptr + token_idx * kv_stride + block, mask=mask)
|
|
tl.store(base_ptr + block, kv, mask=mask)
|
|
|
|
# Fused: score += ape[position % compress_ratio]
|
|
position = tl.load(positions_ptr + token_idx)
|
|
ape_row = position % COMPRESS_RATIO
|
|
ape = tl.load(ape_ptr + ape_row * ape_stride + block, mask=mask)
|
|
score = tl.load(score_ptr + token_idx * score_stride + block, mask=mask)
|
|
tl.store(
|
|
base_ptr + STATE_WIDTH + block,
|
|
score + ape,
|
|
mask=mask,
|
|
)
|