From 02b9c1ac20106227f71cef12097ff30b5a79bef3 Mon Sep 17 00:00:00 2001 From: biondizzle Date: Tue, 19 May 2026 23:04:36 +0000 Subject: [PATCH] nuke vllm because this keep confusing people --- Dockerfile | 69 - vllm/cutedsl_quant_method.py | 157 -- vllm/kernels/linear/nvfp4/cutedsl.py | 133 -- vllm/nvfp4_cutedsl.py | 536 ----- vllm/patches/deepseek_v4.py | 1783 ----------------- vllm/patches/deepseek_v4_attention.py | 1537 -------------- vllm/patches/fused_moe/experts/cutedsl_moe.py | 308 --- vllm/patches/fused_moe/oracle/nvfp4.py | 535 ----- vllm/patches/kernel.py | 216 -- vllm/patches/layers/csa_attention.py | 356 ---- vllm/patches/layers/deepseek_compressor.py | 455 ----- vllm/patches/layers/mhc.py | 195 -- vllm/patches/patch_compressor_cache.py | 48 - vllm/patches/patch_debug_layers.py | 39 - vllm/patches/patch_indexer_cache.py | 58 - vllm/patches/patch_kv_cache_utils.py | 135 -- vllm/patches/patch_swa_cache.py | 71 - vllm/patches/register_cutedsl_kernel.py | 46 - 18 files changed, 6677 deletions(-) delete mode 100644 vllm/cutedsl_quant_method.py delete mode 100644 vllm/kernels/linear/nvfp4/cutedsl.py delete mode 100644 vllm/nvfp4_cutedsl.py delete mode 100644 vllm/patches/deepseek_v4.py delete mode 100644 vllm/patches/deepseek_v4_attention.py delete mode 100644 vllm/patches/fused_moe/experts/cutedsl_moe.py delete mode 100644 vllm/patches/fused_moe/oracle/nvfp4.py delete mode 100644 vllm/patches/kernel.py delete mode 100644 vllm/patches/layers/csa_attention.py delete mode 100644 vllm/patches/layers/deepseek_compressor.py delete mode 100644 vllm/patches/layers/mhc.py delete mode 100644 vllm/patches/patch_compressor_cache.py delete mode 100644 vllm/patches/patch_debug_layers.py delete mode 100644 vllm/patches/patch_indexer_cache.py delete mode 100644 vllm/patches/patch_kv_cache_utils.py delete mode 100644 vllm/patches/patch_swa_cache.py delete mode 100644 vllm/patches/register_cutedsl_kernel.py diff --git a/Dockerfile b/Dockerfile index c6f00da2..659fdfec 100644 --- a/Dockerfile +++ b/Dockerfile @@ -26,72 +26,3 @@ RUN cd /root/nvfp4-megamoe-kernel && pip install -e . COPY cutedsl/ /root/nvfp4-megamoe-kernel/cutedsl/ ENV PYTHONPATH="/root/nvfp4-megamoe-kernel:${PYTHONPATH}" - -# Patch vLLM — overwrite model files and register architecture -ARG VLLM_MODELS_DIR=/usr/local/lib/python3.12/dist-packages/vllm/model_executor/models -ARG VLLM_LAYERS_DIR=/usr/local/lib/python3.12/dist-packages/vllm/model_executor/layers -ARG VLLM_QUANT_DIR=/usr/local/lib/python3.12/dist-packages/vllm/model_executor/layers/quantization -ARG VLLM_FUSED_MOE_DIR=/usr/local/lib/python3.12/dist-packages/vllm/model_executor/layers/fused_moe -ARG VLLM_LOADER_DIR=/usr/local/lib/python3.12/dist-packages/vllm/model_executor/model_loader - -# Core model patches -COPY vllm/patches/deepseek_v4.py ${VLLM_MODELS_DIR}/deepseek_v4.py -COPY vllm/patches/deepseek_v4_attention.py ${VLLM_LAYERS_DIR}/deepseek_v4_attention.py -COPY vllm/patches/layers/deepseek_compressor.py ${VLLM_LAYERS_DIR}/deepseek_compressor.py - -# Replace MHC TileLang kernels with pure PyTorch (avoids TileLang JIT on Blackwell) -# The nightly image has all MHC in layers/mhc.py (imports tilelang at top level). -# Our replacement is pure PyTorch — no tilelang dependency at all. -COPY vllm/patches/layers/mhc.py ${VLLM_LAYERS_DIR}/mhc.py - -# CSA/HCA attention kernel (replaces FlashMLA on Blackwell) -COPY vllm/patches/layers/csa_attention.py ${VLLM_LAYERS_DIR}/csa_attention.py - -# CuTeDSL NVFP4 linear kernel (registered as NvFp4LinearKernel) -ARG VLLM_NVFP4_DIR=/usr/local/lib/python3.12/dist-packages/vllm/model_executor/kernels/linear/nvfp4 -COPY vllm/kernels/linear/nvfp4/cutedsl.py ${VLLM_NVFP4_DIR}/cutedsl.py - -# Patch KV cache utils to handle DeepseekV4 SWA page sizes > MLA page sizes -# (SWA layers have larger page sizes than compressed MLA layers on Blackwell) -ARG VLLM_CORE_DIR=/usr/local/lib/python3.12/dist-packages/vllm/v1/core -COPY vllm/patches/patch_kv_cache_utils.py /tmp/patch_kv_cache_utils.py -RUN python3 /tmp/patch_kv_cache_utils.py ${VLLM_CORE_DIR}/kv_cache_utils.py && rm /tmp/patch_kv_cache_utils.py - -# Patch SWA cache and Indexer cache for Blackwell (no FlashMLA alignment) -ARG VLLM_SPARSE_SWA_DIR=/usr/local/lib/python3.12/dist-packages/vllm/v1/attention/backends/mla -ARG VLLM_LAYERS_DIR2=/usr/local/lib/python3.12/dist-packages/vllm/model_executor/layers -COPY vllm/patches/patch_swa_cache.py /tmp/patch_swa_cache.py -RUN python3 /tmp/patch_swa_cache.py ${VLLM_SPARSE_SWA_DIR}/sparse_swa.py && rm /tmp/patch_swa_cache.py -COPY vllm/patches/patch_indexer_cache.py /tmp/patch_indexer_cache.py -RUN python3 /tmp/patch_indexer_cache.py ${VLLM_LAYERS_DIR2}/deepseek_v4_attention.py && rm /tmp/patch_indexer_cache.py -COPY vllm/patches/patch_compressor_cache.py /tmp/patch_compressor_cache.py -RUN python3 /tmp/patch_compressor_cache.py ${VLLM_LAYERS_DIR2}/deepseek_compressor.py && rm /tmp/patch_compressor_cache.py - -# Debug: print layer name mismatch -ARG VLLM_WORKER_DIR=/usr/local/lib/python3.12/dist-packages/vllm/v1/worker -COPY vllm/patches/patch_debug_layers.py /tmp/patch_debug_layers.py -RUN python3 /tmp/patch_debug_layers.py ${VLLM_WORKER_DIR}/gpu_model_runner.py && rm /tmp/patch_debug_layers.py - -# Register CuTeDSL kernel in vLLM's linear kernel selection -ARG VLLM_LINEAR_DIR=/usr/local/lib/python3.12/dist-packages/vllm/model_executor/kernels/linear -COPY vllm/patches/register_cutedsl_kernel.py /tmp/register_cutedsl_kernel.py -RUN python3 /tmp/register_cutedsl_kernel.py ${VLLM_LINEAR_DIR}/__init__.py && rm /tmp/register_cutedsl_kernel.py - -# Config patches (add cutedsl to MoEBackend) -ARG VLLM_CONFIG_DIR=/usr/local/lib/python3.12/dist-packages/vllm/config -COPY vllm/patches/kernel.py ${VLLM_CONFIG_DIR}/kernel.py - -# NVFP4 MoE backend registration -COPY vllm/patches/fused_moe/oracle/nvfp4.py ${VLLM_FUSED_MOE_DIR}/oracle/nvfp4.py -COPY vllm/patches/fused_moe/experts/cutedsl_moe.py ${VLLM_FUSED_MOE_DIR}/experts/cutedsl_moe.py - -# Register DeepseekV4ForCausalLM model architecture (if not already in upstream) -RUN grep -q '"DeepseekV4ForCausalLM"' ${VLLM_MODELS_DIR}/registry.py || \ - sed -i 's/"DeepseekV32ForCausalLM": ("deepseek_v2", "DeepseekV3ForCausalLM"),/"DeepseekV32ForCausalLM": ("deepseek_v2", "DeepseekV3ForCausalLM"),\n "DeepseekV4ForCausalLM": ("deepseek_v4", "DeepseekV4ForCausalLM"),/' \ - ${VLLM_MODELS_DIR}/registry.py - -# Verify -RUN python3 -c "import torch; print(f'PyTorch {torch.__version__} OK')" && \ - python3 -c "import vllm; print('vLLM OK')" && \ - python3 -c "import nvfp4_megamoe_kernel; print('NVFP4 kernel OK')" && \ - python3 -c "import cutlass; print('CuTeDSL OK')" diff --git a/vllm/cutedsl_quant_method.py b/vllm/cutedsl_quant_method.py deleted file mode 100644 index 2035afc9..00000000 --- a/vllm/cutedsl_quant_method.py +++ /dev/null @@ -1,157 +0,0 @@ -"""CuTeDSL NVFP4 Quantization Method for vLLM - -Replaces the broken FlashInferCutlassNvFp4LinearKernel with CuTeDSL GEMM. -After process_weights_after_loading, the module's quant_method is swapped -to CuTeDSLNvfp4LinearMethod which routes forward() through CuTeDSL -via torch.library.custom_op (opaque to torch.compile). -""" - -import torch - -from vllm.model_executor.layers.linear import LinearMethodBase -from cutedsl.custom_ops import register_runner, nvfp4_linear_gemm - - -class CuTeDSLNvfp4Method(LinearMethodBase): - """Pre-processing quant method that sets up CuTeDSL runners. - - Installed on NVFP4 linear layers before process_weights_after_loading. - When vLLM calls process_weights_after_loading, this method: - 1. Reads NVFP4 weights (uint8, float8 block scales, float32 global scales) - 2. Converts to CuTeDSL format - 3. Creates CuTeDSLNvfp4Linear runner - 4. Stores runner on the module - 5. Frees original weight/scale params - 6. Replaces quant_method with CuTeDSLNvfp4LinearMethod - """ - - def __init__(self, is_fused: bool = False): - """ - Args: - is_fused: True for MergedColumnParallelLinear with two sub-projections - (e.g., fused_wqa_wkv with q_a + kv, or gate_up_proj with gate + up). - Handles dual weight_scale_2 the same way as MoE L1: - normalize to max(gs1, gs2), fold ratio into block scales. - """ - self.is_fused = is_fused - - def create_weights(self, layer, input_size_per_partition, - output_partition_sizes, input_size, output_size, - params_dtype, **extra_weight_attrs): - # We don't create weights — ModelOptNvFp4LinearMethod already did that. - # This method is only installed after weight loading. - pass - - def process_weights_after_loading(self, layer: torch.nn.Module) -> None: - from cutedsl.nvfp4_linear import CuTeDSLNvfp4Linear - - w_uint8 = layer.weight.data # (out, in//2) uint8 packed E2M1 - device = w_uint8.device - out_features = w_uint8.shape[0] - in_features = w_uint8.shape[1] * 2 # 2 FP4 values per uint8 - - # Convert uint8 → float4_e2m1fn_x2, then permute to (K_packed, N) - w_fp4 = w_uint8.view(torch.float4_e2m1fn_x2).permute(1, 0).contiguous() - - # Block scales: (N, K_sf) → (K_sf, N) - sf = layer.weight_scale.data - if sf.dtype != torch.float8_e4m3fn: - sf = sf.to(torch.float8_e4m3fn) - sf = sf.permute(1, 0).contiguous() - - # Global scale - weight_scale_2 = layer.weight_scale_2.data - if self.is_fused and weight_scale_2.numel() == 2: - # Dual global scales (fused_wqa_wkv: q_a + kv, gate_up: gate + up) - gs1 = weight_scale_2[0].item() - gs2 = weight_scale_2[1].item() - gs = max(gs1, gs2) - - # Fold ratio into block scales via float32 round-trip - if gs1 != gs2: - sf_f32 = sf.float() - # After permute to (K_sf, N): first sub-projection's output - # columns, then second sub-projection's output columns - logical_widths = getattr(layer, 'logical_widths', None) - if logical_widths is not None and len(logical_widths) == 2: - split_point = logical_widths[0] - else: - split_point = out_features // 2 - sf_f32[:, :split_point] *= (gs1 / gs) - sf_f32[:, split_point:] *= (gs2 / gs) - sf = sf_f32.to(torch.float8_e4m3fn) - else: - gs = weight_scale_2.max().item() - - # Create CuTeDSL runner - runner = CuTeDSLNvfp4Linear( - in_features=in_features, - out_features=out_features, - device=device, - ) - runner.fp4 = [w_fp4] - runner.sf = [sf] - runner.gs = [gs] - runner.finalize_weights() - - # Register runner in global registry (for torch.library.custom_op) - layer._cutedsl_runner_id = register_runner(runner) - layer._cutedsl_out_features = out_features - - # Warmup: compute activation global scale from sample data. - # The checkpoint's input_scale is a calibration-time value that does NOT - # match what quantize_activation_nvfp4 expects at runtime. Using it - # produces garbage output (empty EOS tokens). The correct approach is - # a warmup forward pass that measures the actual activation distribution. - # Use only 1 token to minimize GPU memory overhead during weight loading. - with torch.no_grad(): - sample = torch.randn(1, in_features, - dtype=torch.bfloat16, device=device) * 2.0 - runner.compute_activation_global_scale(sample) - del sample - torch.cuda.empty_cache() - - # Replace weight with dummy BF16 (needed by vLLM module introspection) - # Replace weight with a GPU dummy (some vLLM code paths like - # torch.mm(compressor.weight.T) expect weight on GPU). - layer.weight = torch.nn.Parameter( - torch.zeros(out_features, in_features, dtype=torch.bfloat16, - device=device), - requires_grad=False, - ) - - # Free original NVFP4 params - for attr in ("weight_scale", "weight_scale_2", "input_scale", - "input_global_scale", "input_global_scale_inv", - "weight_global_scale", "alpha", "weight_scale_inv"): - if hasattr(layer, attr): - try: - delattr(layer, attr) - except Exception: - pass - - # Swap quant method to the forward-only one - layer.quant_method = CuTeDSLNvfp4LinearMethod() - - def apply(self, layer, x, bias=None): - raise NotImplementedError( - "CuTeDSLNvfp4Method should be replaced by " - "CuTeDSLNvfp4LinearMethod after process_weights_after_loading" - ) - - -class CuTeDSLNvfp4LinearMethod(LinearMethodBase): - """Forward path: BF16 input → CuTeDSL NVFP4 GEMM → BF16 output.""" - - def create_weights(self, layer, input_size_per_partition, - output_partition_sizes, input_size, output_size, - params_dtype, **extra_weight_attrs): - pass - - def apply(self, layer, x: torch.Tensor, bias=None) -> torch.Tensor: - result = nvfp4_linear_gemm( - x, layer._cutedsl_runner_id, layer._cutedsl_out_features, - ) - if bias is not None: - result = result + bias - return result diff --git a/vllm/kernels/linear/nvfp4/cutedsl.py b/vllm/kernels/linear/nvfp4/cutedsl.py deleted file mode 100644 index 7a9828e0..00000000 --- a/vllm/kernels/linear/nvfp4/cutedsl.py +++ /dev/null @@ -1,133 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -"""CuTeDSL NVFP4 Linear Kernel for vLLM. - -Registers as an NvFp4LinearKernel so that vLLM kernel selection -(init_nvfp4_linear_kernel) picks it up on Blackwell GPUs. -Routes NVFP4 GEMM through CuTeDSL's MLIR-compiled grouped GEMM. - -Uses torch.library.custom_op to make Dynamo (torch.compile) treat the -GEMM as opaque. The runner's _run_impl is already cudagraph-safe. -""" - -import torch - -from vllm.logger import init_logger -from vllm.platforms import current_platform - -from .base import NvFp4LinearKernel, NvFp4LinearLayerConfig -from cutedsl.custom_ops import register_runner, nvfp4_linear_gemm - -logger = init_logger(__name__) - - -class CuTeDSLNvFp4LinearKernel(NvFp4LinearKernel): - """NVFP4 GEMM via the CuTeDSL framework (Blackwell SM100+).""" - - @classmethod - def is_supported( - cls, compute_capability: int | None = None - ) -> tuple[bool, str | None]: - cap = compute_capability or current_platform.get_device_capability() - if cap is not None and cap.major >= 10: - return True, None - return False, "CuTeDSL NVFP4 requires SM100+ (Blackwell)" - - @classmethod - def can_implement(cls, config: NvFp4LinearLayerConfig) -> tuple[bool, str | None]: - return True, None - - def process_weights_after_loading(self, layer: torch.nn.Module) -> None: - """Convert NVFP4 weights into CuTeDSL kernel format.""" - from cutedsl.nvfp4_linear import CuTeDSLNvfp4Linear - - w_uint8 = layer.weight.data - device = w_uint8.device - out_features = w_uint8.shape[0] - in_features = w_uint8.shape[1] * 2 - - w_fp4 = w_uint8.view(torch.float4_e2m1fn_x2).permute(1, 0).contiguous() - - sf = layer.weight_scale.data - if sf.dtype != torch.float8_e4m3fn: - sf = sf.to(torch.float8_e4m3fn) - sf = sf.permute(1, 0).contiguous() - - gs = layer.weight_global_scale.data.item() - - if layer.weight_global_scale.numel() == 2: - gs0 = layer.weight_global_scale[0].item() - gs1 = layer.weight_global_scale[1].item() - gs = max(gs0, gs1) - if gs0 != gs1: - sf_f32 = sf.float() - logical_widths = getattr(layer, 'logical_widths', None) - if logical_widths is not None and len(logical_widths) == 2: - split_point = logical_widths[0] - else: - split_point = out_features // 2 - sf_f32[:, :split_point] *= (gs0 / gs) - sf_f32[:, split_point:] *= (gs1 / gs) - sf = sf_f32.to(torch.float8_e4m3fn) - - runner = CuTeDSLNvfp4Linear( - in_features=in_features, - out_features=out_features, - device=str(device), - ) - runner.fp4 = [w_fp4] - runner.sf = [sf] - runner.gs = [gs] - runner.finalize_weights() - - # Warmup: compute activation global scale from sample data. - # The checkpoint's input_scale is a calibration-time value that does NOT - # match what quantize_activation_nvfp4 expects at runtime. Using it - # produces garbage output (empty EOS tokens). The correct approach is - # a warmup forward pass that measures the actual activation distribution. - # Use only 1 token to minimize GPU memory overhead during weight loading. - with torch.no_grad(): - sample = torch.randn( - 1, in_features, - dtype=torch.bfloat16, device=str(device), - ) * 2.0 - runner.compute_activation_global_scale(sample) - del sample - torch.cuda.empty_cache() - - # Register the runner and store the ID (not the runner itself) - layer._cutedsl_runner_id = register_runner(runner) - layer._cutedsl_out_features = out_features - - # Replace weight with a GPU dummy (some vLLM code paths like - # torch.mm(compressor.weight.T) expect weight on GPU). - layer.weight = torch.nn.Parameter( - torch.zeros(out_features, in_features, dtype=torch.bfloat16, - device=device), - requires_grad=False, - ) - - for attr in ("weight_scale", "weight_global_scale", - "input_global_scale", "input_global_scale_inv", - "alpha", "weights_padding_cols", "weight_scale_2", - "input_scale"): - if hasattr(layer, attr): - try: - delattr(layer, attr) - except Exception: - pass - - def apply_weights( - self, - layer: torch.nn.Module, - x: torch.Tensor, - bias: torch.Tensor | None = None, - ) -> torch.Tensor: - result = nvfp4_linear_gemm( - x, - layer._cutedsl_runner_id, - layer._cutedsl_out_features, - ) - if bias is not None: - result = result + bias - return result diff --git a/vllm/nvfp4_cutedsl.py b/vllm/nvfp4_cutedsl.py deleted file mode 100644 index cd1bc3ca..00000000 --- a/vllm/nvfp4_cutedsl.py +++ /dev/null @@ -1,536 +0,0 @@ -""" -vLLM integration for the CuTeDSL NVFP4 MoE kernel. - -CUDA-graph-compatible design: -- All intermediate buffers pre-allocated at max_num_tokens * top_k size -- No .item(), .tolist(), .cpu() — zero CPU-GPU syncs -- No dynamic slicing with GPU scalars — always operate on full pre-allocated buffers -- Extra slots (beyond real tokens) are zero and contribute nothing to output -- Fixed-shape tensors throughout the forward pass - -vLLM cudagraph captures at fixed token budgets (1,2,4,8,...,8192). -During capture, num_tokens equals the budget — all shapes are fixed. -During replay, inputs are padded to the budget size. Our runner always -processes max_slots = budget * top_k rows; padding rows are zeros. - -Dynamo compatibility: uses torch.library.custom_op via cutedsl.custom_ops -so torch.compile (fullgraph mode) treats the GEMM as an opaque black box. -The runner's _run_impl is already cudagraph-safe. -""" -import torch - -from cutedsl.bridge import ( - quantize_activation_nvfp4, - quantize_weight_to_nvfp4, - quantize_to_nvfp4, - make_b_k_major, - assemble_scales_3d_side, - run_nvfp4_grouped_gemm, -) -from cutedsl.custom_ops import register_runner, nvfp4_moe_gemm -from cutedsl.kernel.moe.torch_scaled_grouped_mm import ( - ceil_div as cutedsl_ceil_div, - pad_and_swizzle_single, -) - - -class CuTeDSLMoERunner: - """Manages NVFP4 MoE execution via the CuTeDSL kernel. - - CUDA-graph-compatible: all buffers pre-allocated, no CPU-GPU syncs, - no dynamic shapes. Always computes at max_num_tokens * top_k capacity. - """ - - def __init__(self, num_experts, hidden_size, intermediate_size, - max_num_tokens=8192, top_k=8, device="cuda", - experts_start_idx=0): - self.num_experts = num_experts - self.hidden_size = hidden_size - self.intermediate_size = intermediate_size - self.max_num_tokens = max_num_tokens - self.top_k = top_k - self.device = device - self.experts_start_idx = experts_start_idx - self._swiglu_limit = None # Set via set_swiglu_limit() - - # Weight storage (set before _ensure_stacked) - self.l1_fp4 = None - self.l1_sf = None - self.l1_gs = None - self.l2_fp4 = None - self.l2_sf = None - self.l2_gs = None - - # Stacked weight tensors (set in _ensure_stacked) - self._l1_mat_b = None - self._l2_mat_b = None - self._l1_scale_b = None - self._l2_scale_b = None - self._l1_gsb = None - self._l2_gsb = None - - # Default: 1/2688 ≈ 0.000372 (amax=1 → gs=1/2688) - # Overridden in finalize_weights with checkpoint input_scale or warmup value - self._l1_activation_global_scale = 1.0 / (6.0 * 448.0) - self._l2_activation_global_scale = 1.0 / (6.0 * 448.0) - - # Pre-allocated cudagraph buffers (set in _allocate_buffers) - self._token_indices = None - self._expert_id_range = None - self._expert_offsets_buf = None - self._per_expert_scale_bufs_l1 = None - self._per_expert_scale_bufs_l2 = None - self._padded_x_sf_buf_l1 = None - self._padded_x_sf_buf_l2 = None - self._l1_gsa_buf = None - self._l2_gsa_buf = None - self._output_buf = None - self._row_indices_buf = None - self._padded_hidden_buf = None - self._padded_activated_buf = None # unused, using shared - self._padded_expert_offsets_buf = None - self._max_chunks_per_expert = cutedsl_ceil_div( - self.max_num_tokens * self.top_k, self.num_experts * 128 - ) - self._buffers_allocated = False - - def set_swiglu_limit(self, limit: float | None): - """Set the swiglu_limit for activation clamping.""" - self._swiglu_limit = limit - - def _fill_token_indices(self): - """Fill _token_indices with [0,0,..0, 1,1,..1, ...] (each token repeated top_k times). - - Builds on CPU first, then copies to GPU, to ensure correctness - regardless of CuTeDSL JIT GPU memory corruption. - """ - src = torch.arange(self.max_num_tokens, dtype=torch.int32) - cpu_indices = src.unsqueeze(1).expand(-1, self.top_k).contiguous().view(-1) - self._token_indices.copy_(cpu_indices) - - def _allocate_buffers(self): - """Pre-allocate scale buffers at max size for cudagraph compatibility.""" - # Per-expert scale buffers: separate L1/L2 since K_sf differs - K_sf_l1 = cutedsl_ceil_div(self.hidden_size, 16) - padded_cols_l1 = cutedsl_ceil_div(K_sf_l1, 4) * 4 - K_sf_l2 = cutedsl_ceil_div(self.intermediate_size, 16) - padded_cols_l2 = cutedsl_ceil_div(K_sf_l2, 4) * 4 - - self._per_expert_scale_bufs_l1 = [ - torch.zeros(128, padded_cols_l1, dtype=torch.float16, device=self.device).to(torch.float8_e4m3fn) - for _ in range(self.num_experts) - ] - self._per_expert_scale_bufs_l2 = [ - torch.zeros(128, padded_cols_l2, dtype=torch.float16, device=self.device).to(torch.float8_e4m3fn) - for _ in range(self.num_experts) - ] - - # Initialize shared buffers dict (if not already) - device_key = str(self.device) - if not hasattr(CuTeDSLMoERunner, '_shared_padded_bufs'): - CuTeDSLMoERunner._shared_padded_bufs = {} - if device_key not in CuTeDSLMoERunner._shared_padded_bufs: - CuTeDSLMoERunner._shared_padded_bufs[device_key] = {} - - # Padded x_sf buffers: SHARED across all runners (not per-layer) - max_sf_rows = self.num_experts * self._max_chunks_per_expert * 128 - if 'xsf_l1' not in CuTeDSLMoERunner._shared_padded_bufs[device_key]: - CuTeDSLMoERunner._shared_padded_bufs[device_key].update({ - 'xsf_l1': torch.zeros( - max_sf_rows, padded_cols_l1, dtype=torch.float16, device=self.device - ).to(torch.float8_e4m3fn), - 'xsf_l2': torch.zeros( - max_sf_rows, padded_cols_l2, dtype=torch.float16, device=self.device - ).to(torch.float8_e4m3fn), - 'output': torch.zeros( - self.max_num_tokens, self.hidden_size, dtype=torch.bfloat16, device=self.device - ), - }) - self._padded_x_sf_buf_l1 = CuTeDSLMoERunner._shared_padded_bufs[device_key]['xsf_l1'] - self._padded_x_sf_buf_l2 = CuTeDSLMoERunner._shared_padded_bufs[device_key]['xsf_l2'] - self._output_buf = CuTeDSLMoERunner._shared_padded_bufs[device_key]['output'] - - # Pre-allocated global_scale_a buffers (filled via .fill_(), no torch.full during capture) - self._l1_gsa_buf = torch.zeros(self.num_experts, dtype=torch.float32, device=self.device) - self._l2_gsa_buf = torch.zeros(self.num_experts, dtype=torch.float32, device=self.device) - - # Row indices for scale assembly (max_num_tokens * top_k slots) - self._row_indices_buf = torch.arange( - self.max_num_tokens * self.top_k, device=self.device - ) - - # Padded hidden/activated: SHARED across all runners (not per-layer) - max_rows_per_expert = self._max_chunks_per_expert * 128 - padded_max_slots = self.num_experts * max_rows_per_expert - if 'hidden' not in CuTeDSLMoERunner._shared_padded_bufs[device_key]: - CuTeDSLMoERunner._shared_padded_bufs[device_key].update({ - 'hidden': torch.zeros( - padded_max_slots, self.hidden_size, dtype=torch.bfloat16, device=self.device - ), - 'hidden_fp4': torch.zeros( - padded_max_slots, self.hidden_size // 2, dtype=torch.uint8, device=self.device - ).view(torch.float4_e2m1fn_x2), - 'activated': torch.zeros( - padded_max_slots, self.intermediate_size, dtype=torch.bfloat16, device=self.device - ), - 'activated_fp4': torch.zeros( - padded_max_slots, self.intermediate_size // 2, dtype=torch.uint8, device=self.device - ).view(torch.float4_e2m1fn_x2), - }) - self._shared_bufs = CuTeDSLMoERunner._shared_padded_bufs[device_key] - - # Padded expert offsets buffer: [0, max_rows, 2*max_rows, ...] (fixed) - self._padded_expert_offsets_buf = torch.zeros( - self.num_experts + 1, dtype=torch.int32, device=self.device - ) - max_rows_per_expert = self._max_chunks_per_expert * 128 - self._padded_expert_offsets_buf[1:] = torch.arange( - 1, self.num_experts + 1, dtype=torch.int32, device=self.device - ) * max_rows_per_expert - - self._buffers_allocated = True - - def _ensure_stacked(self): - if self._l1_mat_b is not None: - return - - # Stack and prepare weight tensors FIRST (triggers CuTeDSL JIT compilation) - self._l1_mat_b = make_b_k_major(torch.stack(self.l1_fp4)) - self._l2_mat_b = make_b_k_major(torch.stack(self.l2_fp4)) - self._l1_scale_b = assemble_scales_3d_side(self.l1_sf) - self._l2_scale_b = assemble_scales_3d_side(self.l2_sf) - self._l1_gsb = torch.tensor(self.l1_gs, dtype=torch.float32, device=self.device) - self._l2_gsb = torch.tensor(self.l2_gs, dtype=torch.float32, device=self.device) - self.l1_fp4 = None - self.l1_sf = None - self.l1_gs = None - self.l2_fp4 = None - self.l2_sf = None - self.l2_gs = None - - # Allocate buffers AFTER JIT compilation - # (CuTeDSL's cute.compile corrupts GPU memory during JIT; - # tensors allocated before/during compilation may be zeroed) - # - # _token_indices: GPU tensor for cudagraph compatibility. - # CuTeDSL JIT may corrupt GPU memory, so we fill AFTER stacking - # (which triggers the weight JIT). The GEMM JIT in run_nvfp4_grouped_gemm - # is triggered on the first run() call; we refill _token_indices after - # that first call via the _needs_token_refill flag. - self._token_indices = torch.zeros( - self.max_num_tokens * self.top_k, dtype=torch.int32, device=self.device - ) - self._fill_token_indices() - self._needs_token_refill = True # GEMM JIT may corrupt; refill after first run - - self._expert_id_range = torch.arange( - self.num_experts, dtype=torch.int32 - ).to(self.device) - self._expert_offsets_buf = torch.zeros( - self.num_experts + 1, dtype=torch.int32, device=self.device - ) - self._allocate_buffers() - - def prepare_weights_direct(self, l1_fp4, l1_sf, l1_gs, l2_fp4, l2_sf, l2_gs): - self.l1_fp4 = l1_fp4 - self.l1_sf = l1_sf - self.l1_gs = l1_gs - self.l2_fp4 = l2_fp4 - self.l2_sf = l2_sf - self.l2_gs = l2_gs - self._l1_mat_b = None - - def prepare_weights_from_dequantized(self, l1_weights_bf16, l2_weights_bf16): - self.l1_fp4, self.l1_sf, self.l1_gs = [], [], [] - self.l2_fp4, self.l2_sf, self.l2_gs = [], [], [] - for l1_w, l2_w in zip(l1_weights_bf16, l2_weights_bf16): - l1_w_t = l1_w.T - w_fp4, w_sf, w_gs = quantize_weight_to_nvfp4(l1_w_t) - self.l1_fp4.append(w_fp4) - self.l1_sf.append(w_sf) - self.l1_gs.append(w_gs) - l2_w_t = l2_w.T - w_fp4, w_sf, w_gs = quantize_weight_to_nvfp4(l2_w_t) - self.l2_fp4.append(w_fp4) - self.l2_sf.append(w_sf) - self.l2_gs.append(w_gs) - self._l1_mat_b = None - - def _assemble_scales_cudagraph_safe(self, x_sf, expert_offsets, - padded_expert_offsets, - padded_x_sf_buf, per_expert_bufs): - """Assemble 2D-side activation scales (cudagraph-safe, NO CPU syncs). - - Phase 1: Scatter x_sf into padded per-expert sections (GPU-only). - Phase 2: Apply full-buffer Blackwell 32_4_4 swizzle (no Python loops). - - The buffer is 128-row aligned per expert (from padded_expert_offsets), - so the full-buffer swizzle produces the correct layout. The GEMM reads - scale_a using padded_expert_offsets, matching the scatter layout. - """ - K_sf = x_sf.shape[1] - padded_x_sf = padded_x_sf_buf - padded_x_sf.zero_() - - # Phase 1: Scatter x_sf into padded per-expert sections (GPU-only) - total_rows = x_sf.shape[0] - row_indices = self._row_indices_buf[:total_rows] - expert_assign = torch.searchsorted( - expert_offsets[1:], row_indices, right=True - ).clamp(max=self.num_experts - 1) - local_row = row_indices - expert_offsets[expert_assign] - dst_rows = padded_expert_offsets[expert_assign] + local_row - padded_x_sf[dst_rows, :K_sf] = x_sf - - # Phase 2: Full-buffer swizzle (no CPU sync, no Python loops) - # padded_x_sf is 128-row aligned per expert and 4-col aligned. - # to_blocked: (rows, cols) → view(R, 128, C, 4) → permute(0,2,1,3) - # → reshape(-1, 4, 32, 4) → transpose(1,2) → reshape(-1, 32, 16) → flatten - rows = padded_x_sf.shape[0] - cols = padded_x_sf.shape[1] - R = rows // 128 - C = cols // 4 - blocks = padded_x_sf.view(R, 128, C, 4).permute(0, 2, 1, 3) - rearranged = blocks.reshape(-1, 4, 32, 4).transpose(1, 2).reshape(-1, 32, 16) - swizzled = rearranged.flatten().view(torch.float8_e4m3fn) - return swizzled.reshape(rows, cols) - - def compute_activation_global_scales(self, hidden_states_sample, topk_weights, topk_ids): - """Compute activation global scales from a warmup forward pass. - - Called BEFORE cudagraph capture. Uses the SAME padded GEMM path as run() - to ensure kernel JIT happens with the same layout, and L2 gs is computed - from actual L1 output (not an approximation). - """ - self._ensure_stacked() - device = hidden_states_sample.device - num_tokens = hidden_states_sample.shape[0] - top_k = topk_ids.shape[1] - - with torch.no_grad(): - # Build slot mapping (same as run()) - flat_ids = topk_ids.reshape(-1) - num_slots = num_tokens * top_k - token_indices = self._token_indices[:num_slots] - sort_idx = flat_ids.argsort(stable=True) - sorted_ids = flat_ids[sort_idx] - sorted_token_ids = token_indices[sort_idx] - slot_hidden = hidden_states_sample[sorted_token_ids] - - # L1: get exact gs from quantize_to_nvfp4 - _, _, l1_gs = quantize_to_nvfp4(slot_hidden) - - # Quantize slot_hidden for GEMM - slot_x_fp4, slot_x_sf = quantize_activation_nvfp4(slot_hidden, l1_gs) - - expert_id_range = self._expert_id_range - tokens_per_expert = (sorted_ids.unsqueeze(1) == expert_id_range.unsqueeze(0)).sum(dim=0).int() - expert_offsets = self._expert_offsets_buf - expert_offsets.zero_() - expert_offsets[1:self.num_experts + 1] = tokens_per_expert.cumsum(0) - - padded_tokens_per_expert = ((tokens_per_expert + 127) // 128) * 128 - padded_expert_offsets = self._padded_expert_offsets_buf - padded_expert_offsets.zero_() - padded_expert_offsets[1:self.num_experts + 1] = padded_tokens_per_expert.cumsum(0) - - # Compute padded_dst (same as run()) - row_indices = self._row_indices_buf[:num_slots] - expert_assign = torch.searchsorted( - expert_offsets[1:], row_indices, right=True - ).clamp(max=self.num_experts - 1) - local_row = row_indices - expert_offsets[expert_assign] - padded_dst = padded_expert_offsets[expert_assign] + local_row - - # Scatter x_fp4 into padded layout - padded_x_fp4 = self._shared_bufs['hidden_fp4'] - padded_x_fp4.view(torch.uint8).zero_() - padded_x_fp4.view(torch.uint8)[padded_dst] = slot_x_fp4.view(torch.uint8) - - l1_scale_a = self._assemble_scales_cudagraph_safe( - slot_x_sf, expert_offsets[:self.num_experts + 1], - padded_expert_offsets, - self._padded_x_sf_buf_l1, self._per_expert_scale_bufs_l1 - ) - l1_gsa = torch.full((self.num_experts,), l1_gs, dtype=torch.float32, device=device) - - l1_out = run_nvfp4_grouped_gemm( - mat_a=padded_x_fp4, mat_b=self._l1_mat_b, - scale_a=l1_scale_a, scale_b=self._l1_scale_b, - expert_offsets=padded_expert_offsets[1:self.num_experts + 1], - global_scale_a=l1_gsa, global_scale_b=self._l1_gsb, - ) - - # Extract real token outputs - l1_out_real = l1_out[padded_dst] - - # L2: get exact gs from SiLU(gate)*up - gate = l1_out_real[:, :self.intermediate_size] - up = l1_out_real[:, self.intermediate_size:] - gate_silu = torch.nn.functional.silu(gate) - if self._swiglu_limit is not None: - gate_silu = gate_silu.clamp(max=self._swiglu_limit) - up = up.clamp(min=-self._swiglu_limit, max=self._swiglu_limit) - activated = gate_silu * up - _, _, l2_gs = quantize_to_nvfp4(activated) - - self._l1_activation_global_scale = l1_gs - self._l2_activation_global_scale = l2_gs - - - - def run(self, hidden_states, topk_weights, topk_ids, expert_indices=None): - """Forward: route tokens to experts, GEMM, combine. - - Uses torch.library.custom_op (nvfp4::moe_gemm) so torch.compile - treats this as an opaque op. The custom op calls _run_impl internally. - """ - if not hasattr(self, '_runner_id'): - self._runner_id = register_runner(self) - return nvfp4_moe_gemm( - hidden_states, topk_weights, topk_ids, - self._runner_id, self.hidden_size, - ) - - def _run_impl(self, hidden_states, topk_weights, topk_ids, expert_indices=None): - """Run the NVFP4 MoE forward pass. - - Handles global→local expert ID remapping for expert parallelism. - Fully cudagraph-safe: no CPU-GPU syncs, no dynamic shapes. - - Each expert's slots are padded to multiples of 128 for the GEMM. - expert_offsets is [0, padded_e0, padded_e0+padded_e1, ...]. - scale_a is produced at those same offsets. - """ - num_tokens = hidden_states.shape[0] - top_k = topk_ids.shape[1] - device = hidden_states.device - - self._ensure_stacked() - - # -- Remap global expert IDs to local IDs -- - local_ids = topk_ids - self.experts_start_idx - local_mask = (local_ids >= 0) & (local_ids < self.num_experts) - safe_ids = local_ids.clamp(0, self.num_experts - 1) - safe_weights = topk_weights * local_mask.float() - - # -- Build slot mapping -- - flat_ids = safe_ids.reshape(-1) - flat_weights = safe_weights.reshape(-1) - num_slots = num_tokens * top_k - token_indices = self._token_indices[:num_slots] - - sort_idx = flat_ids.argsort(stable=True) - sorted_ids = flat_ids[sort_idx] - sorted_weights = flat_weights[sort_idx] - sorted_token_ids = token_indices[sort_idx] - - # Expert offsets (real token counts) - expert_id_range = self._expert_id_range - tokens_per_expert = (sorted_ids.unsqueeze(1) == expert_id_range.unsqueeze(0)).sum(dim=0).int() - expert_offsets = self._expert_offsets_buf - expert_offsets.zero_() - expert_offsets[1:self.num_experts + 1] = tokens_per_expert.cumsum(0) - - # Pad each expert to 128-row alignment (GPU-only computation) - padded_tokens_per_expert = ((tokens_per_expert + 127) // 128) * 128 - padded_expert_offsets = self._padded_expert_offsets_buf - padded_expert_offsets.zero_() - padded_expert_offsets[1:self.num_experts + 1] = padded_tokens_per_expert.cumsum(0) - total_padded_slots = padded_expert_offsets[self.num_experts] - - # -- Gather hidden states into slot order, compute padded_dst -- - slot_hidden = hidden_states[sorted_token_ids] - row_indices = self._row_indices_buf[:num_slots] - expert_assign = torch.searchsorted( - expert_offsets[1:], row_indices, right=True - ).clamp(max=self.num_experts - 1) - local_row = row_indices - expert_offsets[expert_assign] - padded_dst = padded_expert_offsets[expert_assign] + local_row - - # === L1: gate + up === - # Quantize slot_hidden (sorted tokens), NOT padded_hidden. - # padded_hidden is padded with zeros; quantizing it produces - # x_sf rows at padded positions, but x_sf[:num_slots] would - # only get scales for the first num_slots PADDED rows (expert 0), - # not the scattered token positions. Quantizing slot_hidden - # gives x_sf with num_slots rows (one per token), which the - # scale assembly correctly scatters into padded layout. - slot_x_fp4, slot_x_sf = quantize_activation_nvfp4( - slot_hidden, self._l1_activation_global_scale - ) - # Scatter x_fp4 into padded layout for the GEMM - # Must scatter as uint8 (float4_e2m1fn_x2 doesn't support index_put) - padded_x_fp4 = self._shared_bufs['hidden_fp4'] - padded_x_fp4.view(torch.uint8).zero_() - padded_x_fp4.view(torch.uint8)[padded_dst] = slot_x_fp4.view(torch.uint8) - - l1_scale_a = self._assemble_scales_cudagraph_safe( - slot_x_sf, expert_offsets[:self.num_experts + 1], - padded_expert_offsets, - self._padded_x_sf_buf_l1, self._per_expert_scale_bufs_l1 - ) - l1_gsa = self._l1_gsa_buf.fill_(self._l1_activation_global_scale) - - l1_out = run_nvfp4_grouped_gemm( - mat_a=padded_x_fp4, mat_b=self._l1_mat_b, - scale_a=l1_scale_a, scale_b=self._l1_scale_b, - expert_offsets=padded_expert_offsets[1:self.num_experts + 1], - global_scale_a=l1_gsa, global_scale_b=self._l1_gsb, - ) - - # Extract real token outputs from padded GEMM output - l1_out_real = l1_out[padded_dst] - - # === SiLU(gate) * up (with swiglu_limit clamp) === - gate = l1_out_real[:, :self.intermediate_size] - up = l1_out_real[:, self.intermediate_size:] - gate_silu = torch.nn.functional.silu(gate) - # Apply DeepSeek-V4 swiglu_limit: clamp both silu(gate) and up - if self._swiglu_limit is not None: - gate_silu = gate_silu.clamp(max=self._swiglu_limit) - up = up.clamp(min=-self._swiglu_limit, max=self._swiglu_limit) - activated = gate_silu * up - - # === L2: down === - # Quantize activated (per-token), scatter into padded FP4 buffer - slot_l2_x_fp4, slot_l2_x_sf = quantize_activation_nvfp4( - activated, self._l2_activation_global_scale - ) - padded_activated_fp4 = self._shared_bufs['activated_fp4'] - padded_activated_fp4.view(torch.uint8).zero_() - padded_activated_fp4.view(torch.uint8)[padded_dst] = slot_l2_x_fp4.view(torch.uint8) - - l2_scale_a = self._assemble_scales_cudagraph_safe( - slot_l2_x_sf, expert_offsets[:self.num_experts + 1], - padded_expert_offsets, - self._padded_x_sf_buf_l2, self._per_expert_scale_bufs_l2 - ) - l2_gsa = self._l2_gsa_buf.fill_(self._l2_activation_global_scale) - - l2_out = run_nvfp4_grouped_gemm( - mat_a=padded_activated_fp4, mat_b=self._l2_mat_b, - scale_a=l2_scale_a, scale_b=self._l2_scale_b, - expert_offsets=padded_expert_offsets[1:self.num_experts + 1], - global_scale_a=l2_gsa, global_scale_b=self._l2_gsb, - ) - - l2_out_real = l2_out[padded_dst] - - # === Scatter -> final output === - y = self._output_buf[:num_tokens] - y.zero_() - weighted_out = l2_out_real * sorted_weights.unsqueeze(1).to(l2_out_real.dtype) - y.scatter_add_( - 0, - sorted_token_ids.unsqueeze(1).expand(-1, self.hidden_size), - weighted_out, - ) - - # Refill _token_indices after GEMM JIT on first call - # (CuTeDSL's cute.compile may corrupt GPU memory during first GEMM) - if self._needs_token_refill: - self._fill_token_indices() - self._needs_token_refill = False - - return y diff --git a/vllm/patches/deepseek_v4.py b/vllm/patches/deepseek_v4.py deleted file mode 100644 index 0b1df975..00000000 --- a/vllm/patches/deepseek_v4.py +++ /dev/null @@ -1,1783 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -import typing -from collections.abc import Callable, Iterable -from itertools import islice - -import regex as re -import torch -import torch.nn as nn - -from vllm.compilation.decorators import support_torch_compile -from vllm.config import VllmConfig, get_current_vllm_config -from vllm.distributed import ( - get_ep_group, - get_pp_group, - get_tensor_model_parallel_rank, - get_tensor_model_parallel_world_size, -) -from vllm.forward_context import get_forward_context -from vllm.model_executor.layers.activation import SiluAndMul, SiluAndMulWithClamp -from vllm.model_executor.layers.deepseek_v4_attention import ( - DeepseekV4Indexer, - DeepseekV4MLAModules, - DeepseekV4MultiHeadLatentAttentionWrapper, -) -from vllm.model_executor.layers.fused_moe import FusedMoE, GateLinear -from vllm.model_executor.layers.fused_moe.layer import UnquantizedFusedMoEMethod -from vllm.model_executor.layers.fused_moe.router.fused_topk_bias_router import ( - fused_topk_bias, -) -from vllm.model_executor.layers.layernorm import RMSNorm -from vllm.model_executor.layers.linear import ( - ColumnParallelLinear, - MergedColumnParallelLinear, - RowParallelLinear, -) -from vllm.model_executor.layers.logits_processor import LogitsProcessor -from vllm.model_executor.layers.quantization import ( - QuantizationConfig, - QuantizationMethods, -) -from vllm.model_executor.layers.quantization.fp8 import Fp8Config -from vllm.model_executor.layers.quantization.mxfp4 import Mxfp4MoEMethod -from vllm.model_executor.layers.quantization.utils.quant_utils import ( - is_layer_skipped, -) -from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.layers.vocab_parallel_embedding import ( - ParallelLMHead, - VocabParallelEmbedding, -) -from vllm.model_executor.model_loader.weight_utils import default_weight_loader -from vllm.model_executor.models.interfaces import SupportsPP -from vllm.model_executor.utils import set_weight_attrs -from vllm.platforms import current_platform -from vllm.sequence import IntermediateTensors -from vllm.triton_utils import tl, triton -from vllm.utils.torch_utils import direct_register_custom_op - -from .utils import ( - AutoWeightsLoader, - PPMissingLayer, - WeightsMapper, - extract_layer_index, - is_pp_missing_parameter, - make_layers, - maybe_prefix, -) - -_DEEPSEEK_V4_EXPERT_DTYPES = ("fp4", "fp8") - - -class DeepseekV4MLP(nn.Module): - def __init__( - self, - hidden_size: int, - intermediate_size: int, - hidden_act: str, - swiglu_limit: float | None = None, - quant_config: QuantizationConfig | None = None, - reduce_results: bool = True, - is_sequence_parallel: bool = False, - prefix: str = "", - ) -> None: - super().__init__() - - # If is_sequence_parallel, the input and output tensors are sharded - # across the ranks within the tp_group. In this case the weights are - # replicated and no collective ops are needed. - # Otherwise we use standard TP with an allreduce at the end. - self.gate_up_proj = MergedColumnParallelLinear( - hidden_size, - [intermediate_size] * 2, - bias=False, - quant_config=quant_config, - disable_tp=is_sequence_parallel, - prefix=f"{prefix}.gate_up_proj", - ) - self.down_proj = RowParallelLinear( - intermediate_size, - hidden_size, - bias=False, - quant_config=quant_config, - reduce_results=reduce_results, - disable_tp=is_sequence_parallel, - prefix=f"{prefix}.down_proj", - ) - if hidden_act != "silu": - raise ValueError( - f"Unsupported activation: {hidden_act}. Only silu is supported for now." - ) - if swiglu_limit is not None: - self.act_fn = SiluAndMulWithClamp(swiglu_limit) - else: - self.act_fn = SiluAndMul() - - def forward(self, x): - gate_up, _ = self.gate_up_proj(x) - x = self.act_fn(gate_up) - x, _ = self.down_proj(x) - return x - - -class DeepseekV4FP8Config(Fp8Config): - """FP8 config for DeepSeek V4 with expert-dtype-aware MoE dispatch. - - DeepSeek V4 checkpoints always use FP8 block quantization for - linear/attention layers. The MoE expert weights vary by checkpoint: - - ``expert_dtype="fp4"`` (e.g. DeepSeek-V4-Flash): MXFP4 experts - with ue8m0 (e8m0fnu) FP8 linear scales. - - ``expert_dtype="fp8"`` (e.g. DeepSeek-V4-Flash-Base): FP8 block - experts with float32 FP8 linear scales. - - The dispatch and the linear scale dtype are both keyed off - ``expert_dtype`` from the model's hf_config; missing values default - to ``"fp4"`` so existing FP4 checkpoints stay unchanged. - - NOTE: ``expert_dtype`` is resolved lazily because this config is - constructed during VllmConfig setup, before ``set_current_vllm_config`` - is active. Reading hf_config eagerly in ``__init__`` would always see - the default ``"fp4"`` and silently misroute Flash-Base checkpoints. - """ - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self._resolved_expert_dtype: str | None = None - # ``is_scale_e8m0`` is a property that resolves on first read, - # by which time the current vllm_config has been set. - - @property - def expert_dtype(self) -> str: - if self._resolved_expert_dtype is None: - try: - hf_config = get_current_vllm_config().model_config.hf_config - except Exception: - # vllm_config not yet set; defer the decision until a - # later call lands inside set_current_vllm_config. - return "fp4" - expert_dtype = getattr(hf_config, "expert_dtype", "fp4") - if expert_dtype not in _DEEPSEEK_V4_EXPERT_DTYPES: - raise ValueError( - f"Unsupported DeepSeek V4 expert_dtype={expert_dtype!r}; " - f"expected one of {_DEEPSEEK_V4_EXPERT_DTYPES}." - ) - self._resolved_expert_dtype = expert_dtype - from vllm.logger import init_logger - - init_logger(__name__).info_once( - "DeepSeek V4 expert_dtype resolved to %r", expert_dtype - ) - return self._resolved_expert_dtype - - @property - def is_scale_e8m0(self) -> bool: - # FP4 checkpoints store FP8 linear scales as e8m0fnu; FP8 expert - # checkpoints (Flash-Base) store them as float32. - return self.expert_dtype == "fp4" - - @classmethod - def get_name(cls) -> QuantizationMethods: - return "deepseek_v4_fp8" - - @classmethod - def override_quantization_method( - cls, hf_quant_cfg, user_quant, hf_config=None - ) -> QuantizationMethods | None: - if not ( - isinstance(hf_quant_cfg, dict) - and hf_quant_cfg.get("quant_method") in ("fp8", "deepseek_v4_fp8") - ): - return None - model_type = getattr(hf_config, "model_type", None) - if model_type == "deepseek_v4" or user_quant == "deepseek_v4_fp8": - return "deepseek_v4_fp8" - return None - - def get_quant_method(self, layer, prefix): - if isinstance(layer, FusedMoE): - if is_layer_skipped( - prefix=prefix, - ignored_layers=self.ignored_layers, - fused_mapping=self.packed_modules_mapping, - ): - return UnquantizedFusedMoEMethod(layer.moe_config) - if self.expert_dtype == "fp4": - return Mxfp4MoEMethod(layer.moe_config) - # expert_dtype == "fp8": fall through to Fp8Config which - # returns Fp8MoEMethod with block-wise float32 scales. - return super().get_quant_method(layer, prefix) - - def is_mxfp4_quant(self, prefix, layer): - return isinstance(layer, FusedMoE) and self.expert_dtype == "fp4" - - -@triton.jit -def _deepseek_v4_stage_mega_moe_inputs_kernel( - hidden_states, - x_fp8, - x_sf, - topk_ids, - topk_weights, - topk_idx_out, - topk_weights_out, - hidden_stride_m: tl.constexpr, - hidden_stride_k: tl.constexpr, - x_stride_m: tl.constexpr, - x_stride_k: tl.constexpr, - x_sf_stride_m: tl.constexpr, - x_sf_stride_k: tl.constexpr, - topk_ids_stride_m: tl.constexpr, - topk_ids_stride_k: tl.constexpr, - topk_weights_stride_m: tl.constexpr, - topk_weights_stride_k: tl.constexpr, - topk_idx_stride_m: tl.constexpr, - topk_idx_stride_k: tl.constexpr, - topk_weights_out_stride_m: tl.constexpr, - topk_weights_out_stride_k: tl.constexpr, - hidden_size: tl.constexpr, - top_k: tl.constexpr, - BLOCK_K: tl.constexpr, - GROUP_K: tl.constexpr, - BLOCK_TOPK: tl.constexpr, -) -> None: - token_id = tl.program_id(0) - k_block_id = tl.program_id(1) - - k_offsets = k_block_id * BLOCK_K + tl.arange(0, BLOCK_K) - k_mask = k_offsets < hidden_size - hidden = tl.load( - hidden_states + token_id * hidden_stride_m + k_offsets * hidden_stride_k, - mask=k_mask, - other=0.0, - ).to(tl.float32) - - num_groups: tl.constexpr = BLOCK_K // GROUP_K - hidden_groups = tl.reshape(tl.abs(hidden), [num_groups, GROUP_K]) - amax = tl.max(hidden_groups, axis=1) - amax = tl.maximum(amax, 1.0e-4) - - scale = amax / 448.0 - scale_bits = scale.to(tl.uint32, bitcast=True) - scale_exp = ((scale_bits >> 23) & 0xFF) + ((scale_bits & 0x7FFFFF) != 0).to( - tl.uint32 - ) - scale_exp = tl.minimum(tl.maximum(scale_exp, 1), 254) - rounded_scale = (scale_exp << 23).to(tl.float32, bitcast=True) - - hidden_groups = tl.reshape(hidden, [num_groups, GROUP_K]) - scaled = hidden_groups * (1.0 / rounded_scale)[:, None] - scaled = tl.reshape(scaled, [BLOCK_K]) - fp8 = scaled.to(tl.float8e4nv) - tl.store( - x_fp8 + token_id * x_stride_m + k_offsets * x_stride_k, - fp8, - mask=k_mask, - ) - - scale_offsets = tl.arange(0, num_groups) - packed_scale = tl.sum(scale_exp << (scale_offsets * 8), axis=0).to(tl.int32) - tl.store( - x_sf + token_id * x_sf_stride_m + k_block_id * x_sf_stride_k, - packed_scale, - ) - - if k_block_id == 0: - topk_offsets = tl.arange(0, BLOCK_TOPK) - topk_mask = topk_offsets < top_k - - ids = tl.load( - topk_ids + token_id * topk_ids_stride_m + topk_offsets * topk_ids_stride_k, - mask=topk_mask, - other=0, - ).to(tl.int64) - tl.store( - topk_idx_out - + token_id * topk_idx_stride_m - + topk_offsets * topk_idx_stride_k, - ids, - mask=topk_mask, - ) - - weights = tl.load( - topk_weights - + token_id * topk_weights_stride_m - + topk_offsets * topk_weights_stride_k, - mask=topk_mask, - other=0.0, - ) - tl.store( - topk_weights_out - + token_id * topk_weights_out_stride_m - + topk_offsets * topk_weights_out_stride_k, - weights, - mask=topk_mask, - ) - - -def _stage_deepseek_v4_mega_moe_inputs( - hidden_states: torch.Tensor, - topk_weights: torch.Tensor, - topk_ids: torch.Tensor, - x_fp8: torch.Tensor, - x_sf: torch.Tensor, - topk_idx_out: torch.Tensor, - topk_weights_out: torch.Tensor, -) -> None: - num_tokens, hidden_size = hidden_states.shape - if num_tokens == 0: - return - if hidden_size % 128 != 0: - raise ValueError( - "DeepSeek V4 MegaMoE input staging requires hidden_size to be " - "a multiple of 128." - ) - top_k = topk_ids.shape[1] - if topk_weights.shape != topk_ids.shape: - raise ValueError( - "DeepSeek V4 MegaMoE input staging requires topk_weights and " - "topk_ids to have the same shape." - ) - - block_k = 128 - grid = (num_tokens, triton.cdiv(hidden_size, block_k)) - block_topk = triton.next_power_of_2(top_k) - _deepseek_v4_stage_mega_moe_inputs_kernel[grid]( - hidden_states, - x_fp8, - x_sf, - topk_ids, - topk_weights, - topk_idx_out, - topk_weights_out, - hidden_states.stride(0), - hidden_states.stride(1), - x_fp8.stride(0), - x_fp8.stride(1), - x_sf.stride(0), - x_sf.stride(1), - topk_ids.stride(0), - topk_ids.stride(1), - topk_weights.stride(0), - topk_weights.stride(1), - topk_idx_out.stride(0), - topk_idx_out.stride(1), - topk_weights_out.stride(0), - topk_weights_out.stride(1), - hidden_size, - top_k, - BLOCK_K=block_k, - GROUP_K=32, - BLOCK_TOPK=block_topk, - num_warps=4, - ) - - -def make_deepseek_v4_expert_params_mapping( - num_experts: int, -) -> list[tuple[str, str, int, str]]: - return [ - ( - "experts.w13_" if shard_id in ("w1", "w3") else "experts.w2_", - f"experts.{expert_id}.{weight_name}.", - expert_id, - shard_id, - ) - for expert_id in range(num_experts) - for shard_id, weight_name in [ - ("w1", "w1"), - ("w2", "w2"), - ("w3", "w3"), - ] - ] - - -class DeepseekV4MegaMoEExperts(nn.Module): - _symm_buffer_cache: dict[tuple[int, int, int, int, int, int, int], object] = {} - - def __init__( - self, - vllm_config: VllmConfig, - *, - num_experts: int, - num_local_experts: int, - experts_start_idx: int, - top_k: int, - hidden_size: int, - intermediate_size: int, - prefix: str = "", - ): - super().__init__() - self.prefix = prefix - self.num_experts = num_experts - self.num_local_experts = num_local_experts - self.experts_start_idx = experts_start_idx - self.experts_end_idx = experts_start_idx + num_local_experts - self.top_k = top_k - self.hidden_size = hidden_size - self.intermediate_size = intermediate_size - self.max_num_tokens = vllm_config.scheduler_config.max_num_batched_tokens - - weight_attrs = {"weight_loader": self.weight_loader} - self.w13_weight = nn.Parameter( - torch.zeros( - num_local_experts, - 2 * intermediate_size, - hidden_size // 2, - dtype=torch.uint8, - ), - requires_grad=False, - ) - set_weight_attrs(self.w13_weight, weight_attrs) - - self.w13_weight_scale = nn.Parameter( - torch.zeros( - num_local_experts, - 2 * intermediate_size, - hidden_size // 32, - dtype=torch.uint8, - ), - requires_grad=False, - ) - set_weight_attrs(self.w13_weight_scale, weight_attrs) - self.w13_weight_scale.quant_method = "block" - - self.w2_weight = nn.Parameter( - torch.zeros( - num_local_experts, - hidden_size, - intermediate_size // 2, - dtype=torch.uint8, - ), - requires_grad=False, - ) - set_weight_attrs(self.w2_weight, weight_attrs) - - self.w2_weight_scale = nn.Parameter( - torch.zeros( - num_local_experts, - hidden_size, - intermediate_size // 32, - dtype=torch.uint8, - ), - requires_grad=False, - ) - set_weight_attrs(self.w2_weight_scale, weight_attrs) - self.w2_weight_scale.quant_method = "block" - - self._transformed_l1_weights: tuple[torch.Tensor, torch.Tensor] | None = None - self._transformed_l2_weights: tuple[torch.Tensor, torch.Tensor] | None = None - - # Register in the static forward context so the custom-op wrapper - # can look up this module by name from within a torch.compile graph. - compilation_config = vllm_config.compilation_config - if prefix in compilation_config.static_forward_context: - raise ValueError(f"Duplicate layer name: {prefix}") - compilation_config.static_forward_context[prefix] = self - - def _map_global_expert_id(self, expert_id: int) -> int: - if expert_id < self.experts_start_idx or expert_id >= self.experts_end_idx: - return -1 - return expert_id - self.experts_start_idx - - def weight_loader( - self, - param: nn.Parameter, - loaded_weight: torch.Tensor, - weight_name: str, - shard_id: str, - expert_id: int, - return_success: bool = False, - ) -> bool | None: - local_expert_id = self._map_global_expert_id(expert_id) - if local_expert_id == -1: - return False if return_success else None - - expert_data = param.data[local_expert_id] - if shard_id in ("w1", "w3"): - if "w13_" not in weight_name: - return False if return_success else None - shard_offset = 0 if shard_id == "w1" else self.intermediate_size - expert_data = expert_data.narrow(0, shard_offset, self.intermediate_size) - elif shard_id == "w2": - if "w2_" not in weight_name: - return False if return_success else None - else: - raise ValueError(f"Unsupported expert shard id: {shard_id}") - - if expert_data.shape != loaded_weight.shape: - raise ValueError( - f"DeepSeek V4 MegaMoE expert weight shape mismatch for " - f"{weight_name}: parameter shard {tuple(expert_data.shape)} " - f"vs checkpoint {tuple(loaded_weight.shape)}" - ) - expert_data.copy_(loaded_weight) - return True if return_success else None - - @staticmethod - def _ue8m0_uint8_to_float(sf: torch.Tensor) -> torch.Tensor: - return (sf.to(torch.int32) << 23).view(torch.float32) - - def _check_runtime_supported(self) -> None: - if not torch.cuda.is_available(): - raise NotImplementedError("DeepSeek V4 MegaMoE requires CUDA.") - device = self.w13_weight.device - if device.type != "cuda": - raise NotImplementedError( - "DeepSeek V4 MegaMoE expert weights must be loaded on CUDA." - ) - if torch.cuda.get_device_capability(device)[0] != 10: - raise NotImplementedError("DeepGEMM MegaMoE requires SM100 GPUs.") - if self.hidden_size % 128 != 0 or self.intermediate_size % 128 != 0: - raise ValueError( - "DeepGEMM MegaMoE requires hidden and intermediate sizes " - "to be multiples of 128." - ) - - def finalize_weights(self) -> None: - if self._transformed_l1_weights is not None: - return - - self._check_runtime_supported() - import vllm.third_party.deep_gemm as deep_gemm - - w13_scale = deep_gemm.transform_sf_into_required_layout( - self._ue8m0_uint8_to_float(self.w13_weight_scale.data).contiguous(), - 2 * self.intermediate_size, - self.hidden_size, - (1, 32), - self.num_local_experts, - ) - w2_scale = deep_gemm.transform_sf_into_required_layout( - self._ue8m0_uint8_to_float(self.w2_weight_scale.data).contiguous(), - self.hidden_size, - self.intermediate_size, - (1, 32), - self.num_local_experts, - ) - self._transformed_l1_weights, self._transformed_l2_weights = ( - deep_gemm.transform_weights_for_mega_moe( - (self.w13_weight.data.view(torch.int8).contiguous(), w13_scale), - (self.w2_weight.data.view(torch.int8).contiguous(), w2_scale), - ) - ) - # Drop the original loader-side parameters: the MegaMoE kernels only - # consume the transformed views above. transform_weights_for_mega_moe - # allocates a fresh tensor for the L1 weight (see _interleave_l1_weights) - # and fresh SF tensors for L1/L2; the L2 weight is the only tensor that - # aliases the original storage, and _transformed_l2_weights still holds - # it, so the storage stays live after we drop the Parameter. - self.w13_weight = None - self.w13_weight_scale = None - self.w2_weight = None - self.w2_weight_scale = None - - def get_symm_buffer(self): - import vllm.third_party.deep_gemm as deep_gemm - - group = get_ep_group().device_group - device = torch.accelerator.current_device_index() - key = ( - id(group), - device, - self.num_experts, - self.max_num_tokens, - self.top_k, - self.hidden_size, - self.intermediate_size, - ) - symm_buffer = self._symm_buffer_cache.get(key) - if symm_buffer is None: - symm_buffer = deep_gemm.get_symm_buffer_for_mega_moe( - group, - self.num_experts, - self.max_num_tokens, - self.top_k, - self.hidden_size, - self.intermediate_size, - ) - self._symm_buffer_cache[key] = symm_buffer - return symm_buffer - - def forward( - self, - hidden_states: torch.Tensor, - topk_weights: torch.Tensor, - topk_ids: torch.Tensor, - *, - activation_clamp: float | None, - fast_math: bool = True, - ) -> torch.Tensor: - if hidden_states.shape[0] > self.max_num_tokens: - raise ValueError( - f"DeepSeek V4 MegaMoE got {hidden_states.shape[0]} tokens, " - f"but the symmetric buffer was sized for {self.max_num_tokens}." - ) - y = torch.empty_like(hidden_states, dtype=torch.bfloat16) - torch.ops.vllm.deepseek_v4_mega_moe_experts( - hidden_states, - topk_weights, - topk_ids, - y, - self.prefix, - activation_clamp, - fast_math, - ) - return y - - def _run_mega_moe( - self, - hidden_states: torch.Tensor, - topk_weights: torch.Tensor, - topk_ids: torch.Tensor, - y: torch.Tensor, - activation_clamp: float | None, - fast_math: bool, - ) -> None: - import vllm.third_party.deep_gemm as deep_gemm - - symm_buffer = self.get_symm_buffer() - num_tokens = hidden_states.shape[0] - _stage_deepseek_v4_mega_moe_inputs( - hidden_states, - topk_weights, - topk_ids, - symm_buffer.x[:num_tokens], - symm_buffer.x_sf[:num_tokens], - symm_buffer.topk_idx[:num_tokens], - symm_buffer.topk_weights[:num_tokens], - ) - - # This method must have been already called during the weight loading phase. - # We call it again here to cover the dummy weight loading case. - self.finalize_weights() - - assert self._transformed_l1_weights is not None - assert self._transformed_l2_weights is not None - deep_gemm.fp8_fp4_mega_moe( - y, - self._transformed_l1_weights, - self._transformed_l2_weights, - symm_buffer, - activation_clamp=activation_clamp, - fast_math=fast_math, - ) - - -DeepseekV4MegaMoEExperts.weight_loader.supports_moe_loading = True # type: ignore[attr-defined] - - -def _deepseek_v4_mega_moe_experts_op( - hidden_states: torch.Tensor, - topk_weights: torch.Tensor, - topk_ids: torch.Tensor, - out: torch.Tensor, - layer_name: str, - activation_clamp: float | None, - fast_math: bool, -) -> None: - self = get_forward_context().no_compile_layers[layer_name] - self._run_mega_moe( - hidden_states, - topk_weights, - topk_ids, - out, - activation_clamp, - fast_math, - ) - - -def _deepseek_v4_mega_moe_experts_op_fake( - hidden_states: torch.Tensor, - topk_weights: torch.Tensor, - topk_ids: torch.Tensor, - out: torch.Tensor, - layer_name: str, - activation_clamp: float | None, - fast_math: bool, -) -> None: - return None - - -direct_register_custom_op( - op_name="deepseek_v4_mega_moe_experts", - op_func=_deepseek_v4_mega_moe_experts_op, - mutates_args=["out"], - fake_impl=_deepseek_v4_mega_moe_experts_op_fake, -) - - -class DeepseekV4MoE(nn.Module): - def __init__( - self, - vllm_config: VllmConfig, - prefix: str = "", - ): - super().__init__() - - self.tp_size = get_tensor_model_parallel_world_size() - config = vllm_config.model_config.hf_config - quant_config = vllm_config.quant_config - self.prefix = prefix - self.use_mega_moe = ( - vllm_config.kernel_config.moe_backend == "deep_gemm_mega_moe" - ) - if self.use_mega_moe and not vllm_config.parallel_config.enable_expert_parallel: - raise NotImplementedError( - "DeepSeek V4 MegaMoE currently requires expert parallel. " - "Enable it with --enable-expert-parallel, or pick a different " - "moe backend." - ) - - self.routed_scaling_factor = getattr(config, "routed_scaling_factor", 1.0) - self.hidden_size = config.hidden_size - - self.n_routed_experts = config.n_routed_experts - self.n_activated_experts = config.num_experts_per_tok - self.moe_intermediate_size = config.moe_intermediate_size - self.swiglu_limit = config.swiglu_limit - self.renormalize = config.norm_topk_prob - self.scoring_func = getattr(config, "scoring_func", "sqrtsoftplus") - if self.use_mega_moe and self.scoring_func != "sqrtsoftplus": - raise NotImplementedError( - "DeepSeek V4 MegaMoE currently supports sqrtsoftplus routing only." - ) - if self.use_mega_moe and getattr(config, "expert_dtype", "fp4") != "fp4": - raise NotImplementedError( - "DeepSeek V4 MegaMoE only supports fp4 experts; got expert_dtype=" - f"{config.expert_dtype!r}. Drop --kernel-config moe_backend=" - "deep_gemm_mega_moe for this checkpoint." - ) - - self.gate = GateLinear( - config.hidden_size, - config.n_routed_experts, - out_dtype=torch.float32, - bias=False, - prefix=f"{prefix}.gate", - ) - self.gate.e_score_correction_bias = None - self.gate.tid2eid = None - is_hash_moe = extract_layer_index(prefix) < config.num_hash_layers - self.hash_indices_dtype = torch.int64 if self.use_mega_moe else torch.int32 - - if is_hash_moe: - # hash MoE doesn't use e_score_correction_bias - # Use randint instead of empty to avoid garbage values causing - # invalid memory access in dummy mode (--load-format="dummy") - self.gate.tid2eid = nn.Parameter( - torch.randint( - 0, - config.n_routed_experts, - (config.vocab_size, config.num_experts_per_tok), - dtype=self.hash_indices_dtype, - ), - requires_grad=False, - ) - elif getattr(config, "topk_method", None) == "noaux_tc": - self.gate.e_score_correction_bias = nn.Parameter( - torch.empty(config.n_routed_experts, dtype=torch.float32), - requires_grad=False, - ) - - if config.n_shared_experts is None: - self.shared_experts = None - else: - intermediate_size = config.moe_intermediate_size * config.n_shared_experts - - self.shared_experts = DeepseekV4MLP( - hidden_size=config.hidden_size, - intermediate_size=intermediate_size, - hidden_act=config.hidden_act, - swiglu_limit=self.swiglu_limit, - quant_config=quant_config, - reduce_results=self.use_mega_moe, - prefix=f"{prefix}.shared_experts", - ) - - if self.use_mega_moe: - self._init_mega_moe_experts(vllm_config, config, prefix) - else: - self._init_fused_moe_experts(config, quant_config, prefix) - - def _init_mega_moe_experts( - self, - vllm_config: VllmConfig, - config, - prefix: str, - ) -> None: - self.ep_group = get_ep_group() - self.ep_size = self.ep_group.world_size - self.ep_rank = self.ep_group.rank_in_group - assert config.n_routed_experts % self.ep_size == 0 - - self.n_local_experts = config.n_routed_experts // self.ep_size - self.experts_start_idx = self.ep_rank * self.n_local_experts - self.experts_end_idx = self.experts_start_idx + self.n_local_experts - - self.experts = DeepseekV4MegaMoEExperts( - vllm_config, - num_experts=config.n_routed_experts, - num_local_experts=self.n_local_experts, - experts_start_idx=self.experts_start_idx, - top_k=config.num_experts_per_tok, - hidden_size=config.hidden_size, - intermediate_size=config.moe_intermediate_size, - prefix=f"{prefix}.experts", - ) - - def _init_fused_moe_experts( - self, - config, - quant_config, - prefix: str, - ) -> None: - self.tp_rank = get_tensor_model_parallel_rank() - assert config.n_routed_experts % self.tp_size == 0 - - self.n_local_experts = config.n_routed_experts // self.tp_size - self.experts_start_idx = self.tp_rank * self.n_local_experts - self.experts_end_idx = self.experts_start_idx + self.n_local_experts - - self.experts = FusedMoE( - shared_experts=self.shared_experts, - gate=self.gate, - num_experts=config.n_routed_experts, - top_k=config.num_experts_per_tok, - hidden_size=config.hidden_size, - intermediate_size=config.moe_intermediate_size, - renormalize=config.norm_topk_prob, - quant_config=quant_config, - prefix=f"{prefix}.experts", - scoring_func=self.scoring_func, - routed_scaling_factor=self.routed_scaling_factor, - e_score_correction_bias=self.gate.e_score_correction_bias, - hash_indices_table=self.gate.tid2eid, - swiglu_limit=self.swiglu_limit, - router_logits_dtype=torch.float32, - ) - - def forward( - self, hidden_states: torch.Tensor, input_ids: torch.Tensor | None = None - ) -> torch.Tensor: - if self.gate.tid2eid is not None and input_ids is None: - raise ValueError("DeepSeek V4 hash MoE routing requires input_ids.") - - if not self.use_mega_moe: - return self._forward_fused_moe(hidden_states, input_ids) - - org_shape = hidden_states.shape - router_logits, _ = self.gate(hidden_states) - topk_weights, topk_ids = fused_topk_bias( - hidden_states=hidden_states, - gating_output=router_logits, - scoring_func=self.scoring_func, - e_score_correction_bias=self.gate.e_score_correction_bias.data - if self.gate.e_score_correction_bias is not None - else None, - topk=self.n_activated_experts, - renormalize=self.renormalize, - indices_type=self.hash_indices_dtype, - input_tokens=input_ids, - hash_indices_table=self.gate.tid2eid, - routed_scaling_factor=self.routed_scaling_factor, - ) - activation_clamp = ( - float(self.swiglu_limit) if self.swiglu_limit is not None else None - ) - final_hidden_states = self.experts( - hidden_states, - topk_weights, - topk_ids, - activation_clamp=activation_clamp, - ) - - if self.shared_experts is not None: - shared_output = self.shared_experts(hidden_states) - final_hidden_states += shared_output - - return final_hidden_states.view(org_shape) - - def _forward_fused_moe( - self, hidden_states: torch.Tensor, input_ids: torch.Tensor | None = None - ) -> torch.Tensor: - org_shape = hidden_states.shape - if self.experts.is_internal_router: - # In this case, the gate/router runs inside the FusedMoE class - final_hidden_states = self.experts( - hidden_states=hidden_states, - router_logits=hidden_states, - input_ids=input_ids, - ) - else: - router_logits, _ = self.gate(hidden_states) - final_hidden_states = self.experts( - hidden_states=hidden_states, - router_logits=router_logits, - input_ids=input_ids, - ) - - return final_hidden_states.view(org_shape) - - def finalize_mega_moe_weights(self) -> None: - if self.use_mega_moe: - self.experts.finalize_weights() - - -class DeepseekV4Attention(nn.Module): - def __init__( - self, - vllm_config: VllmConfig, - prefix: str, - topk_indices_buffer: torch.Tensor | None = None, - aux_stream_list: list[torch.cuda.Stream] | None = None, - ): - super().__init__() - config = vllm_config.model_config.hf_config - quant_config = vllm_config.quant_config - layer_id = extract_layer_index(prefix) - - self.layer_id = layer_id - self.hidden_size = config.hidden_size - self.n_heads = config.num_attention_heads - tp_size = get_tensor_model_parallel_world_size() - assert self.n_heads % tp_size == 0 - - self.n_local_heads = self.n_heads // tp_size - self.q_lora_rank = config.q_lora_rank - self.o_lora_rank = config.o_lora_rank - self.head_dim = config.head_dim - self.rope_head_dim = config.qk_rope_head_dim - self.nope_head_dim = self.head_dim - self.rope_head_dim - self.n_groups = config.o_groups - self.n_local_groups = self.n_groups // tp_size - self.window_size = config.sliding_window - # NOTE(zyongye) Compress ratio can't be 0 - # we do this for because MTP layer is not included - # in the compress ratio list - if layer_id < config.num_hidden_layers: - self.compress_ratio = max(1, config.compress_ratios[layer_id]) - else: - self.compress_ratio = 1 - self.eps = config.rms_norm_eps - self.max_position_embeddings = config.max_position_embeddings - - # Padded to min 64 heads for FlashMLA, initialized to -inf - # (no sink effect). Weight loading fills the first n_local_heads slots. - padded_heads = max(self.n_local_heads, 64) - self.attn_sink = nn.Parameter( - torch.full((padded_heads,), -float("inf"), dtype=torch.float32), - requires_grad=False, - ) - - self.fused_wqa_wkv = MergedColumnParallelLinear( - self.hidden_size, - [self.q_lora_rank, self.head_dim], - bias=False, - quant_config=quant_config, - prefix=f"{prefix}.fused_wqa_wkv", - disable_tp=True, # fused ReplicatedLinear - ) - self.q_norm = RMSNorm(self.q_lora_rank, self.eps) - self.wq_b = ColumnParallelLinear( - self.q_lora_rank, - self.n_heads * self.head_dim, - bias=False, - quant_config=quant_config, - return_bias=False, - prefix=f"{prefix}.wq_b", - ) - - self.kv_norm = RMSNorm(self.head_dim, self.eps) - # wo_a is BF16 in the NVFP4 checkpoint (no quantization scales). - # Pass quant_config=None so it loads as a plain BF16 linear layer. - self.wo_a = ColumnParallelLinear( - self.n_heads * self.head_dim // self.n_groups, - self.n_groups * self.o_lora_rank, - bias=False, - quant_config=None, - return_bias=False, - prefix=f"{prefix}.wo_a", - ) - self.wo_a.is_bmm = True - self.wo_a.bmm_batch_size = self.n_local_groups - self.wo_b = RowParallelLinear( - self.n_groups * self.o_lora_rank, - self.hidden_size, - bias=False, - quant_config=quant_config, - return_bias=False, - prefix=f"{prefix}.wo_b", - ) - self.softmax_scale = self.head_dim**-0.5 - self.scale_fmt = config.quantization_config["scale_fmt"] - - self.rope_parameters = config.rope_scaling - - # Initialize rotary embedding BEFORE DeepseekV4MLAModules (which needs it) - rope_parameters = config.rope_parameters - rope_parameters["rope_theta"] = ( - config.compress_rope_theta if self.compress_ratio > 1 else config.rope_theta - ) - if config.rope_parameters["rope_type"] != "default": - config.rope_parameters["rope_type"] = ( - "deepseek_yarn" - if config.rope_parameters.get("apply_yarn_scaling", True) - else "deepseek_llama_scaling" - ) - rope_parameters["mscale"] = 0 # Disable mscale - rope_parameters["mscale_all_dim"] = 0 # Disable mscale - rope_parameters["is_deepseek_v4"] = True - rope_parameters["rope_dim"] = self.rope_head_dim - self.rotary_emb = get_rope( - self.head_dim, - max_position=self.max_position_embeddings, - rope_parameters=rope_parameters, - is_neox_style=False, - ) - - self.indexer = None - if self.compress_ratio == 4: - # Only C4A uses sparse attention and hence has indexer. - self.indexer = DeepseekV4Indexer( - vllm_config, - config=config, - hidden_size=self.hidden_size, - q_lora_rank=self.q_lora_rank, - quant_config=quant_config, - cache_config=vllm_config.cache_config, - topk_indices_buffer=topk_indices_buffer, - compress_ratio=self.compress_ratio, - prefix=f"{prefix}.indexer", - ) - - mla_modules = DeepseekV4MLAModules( - vllm_config=vllm_config, - fused_wqa_wkv=self.fused_wqa_wkv, - q_norm=self.q_norm, - wq_b=self.wq_b, - kv_norm=self.kv_norm, - wo_a=self.wo_a, - wo_b=self.wo_b, - attn_sink=self.attn_sink, - rotary_emb=self.rotary_emb, - indexer=self.indexer, - indexer_rotary_emb=self.rotary_emb, - topk_indices_buffer=topk_indices_buffer, - aux_stream_list=aux_stream_list, - ) - self.mla_attn = DeepseekV4MultiHeadLatentAttentionWrapper( - hidden_size=self.hidden_size, - num_heads=self.n_local_heads, - head_dim=self.head_dim, - scale=self.softmax_scale, - qk_nope_head_dim=self.nope_head_dim, - qk_rope_head_dim=self.rope_head_dim, - v_head_dim=self.head_dim, - q_lora_rank=self.q_lora_rank, - kv_lora_rank=self.head_dim, - o_lora_rank=self.o_lora_rank, - mla_modules=mla_modules, - window_size=self.window_size, - compress_ratio=self.compress_ratio, - cache_config=vllm_config.cache_config, - quant_config=quant_config, - prefix=prefix, - ) - - def forward( - self, - positions: torch.Tensor, - hidden_states: torch.Tensor, - llama_4_scaling: torch.Tensor | None, - ): - return self.mla_attn(positions, hidden_states, llama_4_scaling) - - -class DeepseekV4DecoderLayer(nn.Module): - def __init__( - self, - vllm_config, - prefix, - topk_indices_buffer: torch.Tensor | None = None, - aux_stream_list: list[torch.cuda.Stream] | None = None, - ): - super().__init__() - - # Lazy import to avoid top-level tilelang dependency. - # Registers both torch.ops.vllm.mhc_pre and mhc_post - import vllm.model_executor.layers.mhc # noqa: F401 - - config = vllm_config.model_config.hf_config - self.hidden_size = config.hidden_size - - self.rms_norm_eps = config.rms_norm_eps - self.attn = DeepseekV4Attention( - vllm_config, - prefix=f"{prefix}.attn", - topk_indices_buffer=topk_indices_buffer, - aux_stream_list=aux_stream_list, - ) - self.ffn = DeepseekV4MoE(vllm_config, prefix=f"{prefix}.ffn") - - self.attn_norm = RMSNorm(self.hidden_size, self.rms_norm_eps) - self.ffn_norm = RMSNorm(self.hidden_size, self.rms_norm_eps) - self.hc_mult = config.hc_mult - self.hc_sinkhorn_iters = config.hc_sinkhorn_iters - self.hc_eps = config.hc_eps - self.hc_post_alpha = 2.0 - mix_hc = (2 + self.hc_mult) * self.hc_mult - hc_dim = self.hc_mult * self.hidden_size - self.hc_attn_fn = nn.Parameter( - torch.empty( - (mix_hc, hc_dim), - dtype=torch.float32, - ), - requires_grad=False, - ) - self.hc_ffn_fn = nn.Parameter( - torch.empty( - (mix_hc, hc_dim), - dtype=torch.float32, - ), - requires_grad=False, - ) - self.hc_attn_base = nn.Parameter( - torch.empty( - mix_hc, - dtype=torch.float32, - ), - requires_grad=False, - ) - self.hc_ffn_base = nn.Parameter( - torch.empty( - mix_hc, - dtype=torch.float32, - ), - requires_grad=False, - ) - self.hc_attn_scale = nn.Parameter( - torch.empty( - 3, - dtype=torch.float32, - ), - requires_grad=False, - ) - self.hc_ffn_scale = nn.Parameter( - torch.empty( - 3, - dtype=torch.float32, - ), - requires_grad=False, - ) - - def hc_pre( - self, - x: torch.Tensor, - hc_fn: torch.Tensor, - hc_scale: torch.Tensor, - hc_base: torch.Tensor, - ): - post_mix, res_mix, layer_input = torch.ops.vllm.mhc_pre( - residual=x, - fn=hc_fn, - hc_scale=hc_scale, - hc_base=hc_base, - rms_eps=self.rms_norm_eps, - hc_pre_eps=self.hc_eps, - hc_sinkhorn_eps=self.hc_eps, - hc_post_mult_value=self.hc_post_alpha, - sinkhorn_repeat=self.hc_sinkhorn_iters, - ) - return layer_input, post_mix, res_mix - - def hc_post( - self, - x: torch.Tensor, - residual: torch.Tensor, - post: torch.Tensor, - comb: torch.Tensor, - ): - return torch.ops.vllm.mhc_post(x, residual, post, comb) - - def forward( - self, - x: torch.Tensor, - positions: torch.Tensor, - input_ids: torch.Tensor | None, - post_mix: torch.Tensor | None, - res_mix: torch.Tensor | None, - residual: torch.Tensor | None, - ) -> torch.Tensor: - if residual is None: - # Run standalone hc_pre on first layer - residual = x - x, post_mix, res_mix = self.hc_pre( - x, self.hc_attn_fn, self.hc_attn_scale, self.hc_attn_base - ) - else: - residual, post_mix, res_mix, x = torch.ops.vllm.mhc_fused_post_pre( - x, - residual, - post_mix, - res_mix, - self.hc_attn_fn, - self.hc_attn_scale, - self.hc_attn_base, - self.rms_norm_eps, - self.hc_eps, - self.hc_eps, - self.hc_post_alpha, - self.hc_sinkhorn_iters, - ) - - x = self.attn_norm(x) - x = self.attn(positions, x, None) - - residual, post_mix, res_mix, x = torch.ops.vllm.mhc_fused_post_pre( - x, - residual, - post_mix, - res_mix, - self.hc_ffn_fn, - self.hc_ffn_scale, - self.hc_ffn_base, - self.rms_norm_eps, - self.hc_eps, - self.hc_eps, - self.hc_post_alpha, - self.hc_sinkhorn_iters, - ) - - x = self.ffn_norm(x) - x = self.ffn(x, input_ids) - return x, residual, post_mix, res_mix - - -@support_torch_compile -class DeepseekV4Model(nn.Module): - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): - super().__init__() - - config = vllm_config.model_config.hf_config - quant_config = vllm_config.quant_config - self.config = config - self.use_mega_moe = ( - vllm_config.kernel_config.moe_backend == "deep_gemm_mega_moe" - ) - if self.use_mega_moe and not vllm_config.parallel_config.enable_expert_parallel: - raise NotImplementedError( - "DeepSeek V4 MegaMoE currently requires expert parallel. " - "Enable it with --enable-expert-parallel, or pick a different " - "moe backend." - ) - self.vocab_size = config.vocab_size - self.hc_eps = config.hc_eps - self.hc_mult = config.hc_mult - self.hc_dim = self.hc_mult * config.hidden_size - self.rms_norm_eps = config.rms_norm_eps - - # Three aux streams: one per non-default input GEMM in - # DeepseekV4MultiHeadLatentAttentionWrapper.attn_gemm_parallel_execute - # (compressor kv_score, indexer.weights_proj, indexer.compressor - # kv_score). fused_wqa_wkv stays on the default stream. - # Disable them on ROCm because of hang issues. - aux_stream_list = ( - None - if current_platform.is_rocm() - else [torch.cuda.Stream() for _ in range(3)] - ) - - self.device = current_platform.device_type - # Reserved topk indices buffer for all Indexer layers to reuse. - self.topk_indices_buffer = torch.empty( - vllm_config.scheduler_config.max_num_batched_tokens, - config.index_topk, - dtype=torch.int32, - device=self.device, - ) - - if get_pp_group().is_first_rank: - self.embed_tokens = VocabParallelEmbedding( - config.vocab_size, - config.hidden_size, - quant_config=quant_config, - prefix=f"{prefix}.embed_tokens", - ) - else: - self.embed_tokens = PPMissingLayer() - - self.start_layer, self.end_layer, self.layers = make_layers( - config.num_hidden_layers, - lambda prefix: DeepseekV4DecoderLayer( - vllm_config, - prefix=prefix, - topk_indices_buffer=self.topk_indices_buffer, - aux_stream_list=aux_stream_list, - ), - prefix=f"{prefix}.layers", - ) - - if get_pp_group().is_last_rank: - self.norm = RMSNorm(config.hidden_size, self.rms_norm_eps) - else: - self.norm = PPMissingLayer() - - self.hc_head_fn = nn.Parameter( - torch.empty( - self.hc_mult, - self.hc_dim, - dtype=torch.float32, - ), - requires_grad=False, - ) - self.hc_head_base = nn.Parameter( - torch.empty( - self.hc_mult, - dtype=torch.float32, - ), - requires_grad=False, - ) - self.hc_head_scale = nn.Parameter( - torch.empty(1, dtype=torch.float32), - requires_grad=False, - ) - - # Pre-hc_head residual stream buffer for the MTP draft. Stable - # address (outside the cudagraph pool) so the copy_ in forward() - # refreshes it correctly across captured shapes. - # refreshes it correctly across captured shapes. Only allocated on - # the last PP rank — that's where MTP target hidden states are - # produced. - if get_pp_group().is_last_rank: - self._mtp_hidden_buffer = torch.empty( - vllm_config.scheduler_config.max_num_batched_tokens, - self.hc_dim, - dtype=vllm_config.model_config.dtype, - device=self.device, - ) - else: - self._mtp_hidden_buffer = None - - def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: - return self.embed_tokens(input_ids) - - def make_empty_intermediate_tensors( - self, - batch_size: int, - dtype: torch.dtype, - device: torch.device, - ) -> IntermediateTensors: - # PP intermediate tensors carry the multi-stream hidden_states - # of shape (num_tokens, hc_mult, hidden_size) — V4 expands the - # token embedding to hc_mult streams before the first decoder - # layer and keeps that shape until hc_head() collapses it. - return IntermediateTensors( - { - "hidden_states": torch.zeros( - (batch_size, self.hc_mult, self.config.hidden_size), - dtype=dtype, - device=device, - ), - } - ) - - def forward( - self, - input_ids: torch.Tensor, - positions: torch.Tensor, - intermediate_tensors: IntermediateTensors | None, - inputs_embeds: torch.Tensor | None = None, - ) -> torch.Tensor | IntermediateTensors: - if get_pp_group().is_first_rank: - if inputs_embeds is not None: - hidden_states = inputs_embeds - else: - hidden_states = self.embed_input_ids(input_ids) - hidden_states = hidden_states.unsqueeze(-2).repeat(1, self.hc_mult, 1) - else: - assert intermediate_tensors is not None - hidden_states = intermediate_tensors["hidden_states"] - - if self.use_mega_moe: - input_ids = input_ids.to(torch.int64) - - residual, post_mix, res_mix = None, None, None - for layer in islice(self.layers, self.start_layer, self.end_layer): - hidden_states, residual, post_mix, res_mix = layer( - hidden_states, - positions, - input_ids, - post_mix, - res_mix, - residual, - ) - else: - hidden_states = layer.hc_post(hidden_states, residual, post_mix, res_mix) - - if not get_pp_group().is_last_rank: - return IntermediateTensors({"hidden_states": hidden_states}) - - # Stash pre-hc_head residual for the MTP draft (captured copy_). - num_tokens = hidden_states.shape[0] - self._mtp_hidden_buffer[:num_tokens].copy_(hidden_states.flatten(1)) - - hidden_states = hc_head( - hidden_states, - self.hc_head_fn, - self.hc_head_scale, - self.hc_head_base, - self.rms_norm_eps, - self.hc_eps, - ) - hidden_states = self.norm(hidden_states) - return hidden_states - - def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: - stacked_params_mapping = [ - # (param_name, shard_name, shard_id) - ("gate_up_proj", "w1", 0), - ("gate_up_proj", "w3", 1), - ("attn.fused_wqa_wkv", "attn.wq_a", 0), - ("attn.fused_wqa_wkv", "attn.wkv", 1), - ("compressor.fused_wkv_wgate", "compressor.wkv", 0), - ("compressor.fused_wkv_wgate", "compressor.wgate", 1), - ] - params_dict = dict(self.named_parameters()) - loaded_params: set[str] = set() - - # TP for attention - tp_size = get_tensor_model_parallel_world_size() - tp_rank = get_tensor_model_parallel_rank() - n_head = self.config.num_attention_heads - n_local_head = n_head // tp_size - head_rank_start = n_local_head * tp_rank - head_rank_end = n_local_head * (tp_rank + 1) - - # Pre-compute expert mapping ONCE. - expert_mapping = self.get_expert_mapping() - - for name, loaded_weight in weights: - for param_name, weight_name, shard_id in stacked_params_mapping: - # Skip non-stacked layers and experts (experts handled below). - if ".experts." in name: - continue - if weight_name not in name: - continue - name = name.replace(weight_name, param_name) - - if is_pp_missing_parameter(name, self): - break - param = params_dict[name] - weight_loader = param.weight_loader - weight_loader(param, loaded_weight, shard_id) - loaded_params.add(name) - break - else: - if ".experts." in name: - # E8M0 scales are stored as float8_e8m0fnu in - # checkpoints but the MoE param is uint8. copy_() - # would do a numeric conversion (e.g. 2^-7 → 0), - # destroying the raw exponent bytes. - if ( - "weight_scale" in name - and loaded_weight.dtype == torch.float8_e8m0fnu - ): - loaded_weight = loaded_weight.view(torch.uint8) - for mapping in expert_mapping: - param_name, weight_name, expert_id, shard_id = mapping - if weight_name not in name: - continue - name_mapped = name.replace(weight_name, param_name) - if is_pp_missing_parameter(name_mapped, self): - continue - param = params_dict[name_mapped] - # We should ask the weight loader to return success or not - # here since otherwise we may skip experts with other - # available replicas. - weight_loader = typing.cast( - Callable[..., bool], param.weight_loader - ) - success = weight_loader( - param, - loaded_weight, - name_mapped, - shard_id=shard_id, - expert_id=expert_id, - return_success=True, - ) - if success: - name = name_mapped - break - loaded_params.add(name_mapped) - continue - elif "attn_sink" in name: - if is_pp_missing_parameter(name, self): - continue - narrow_weight = loaded_weight[head_rank_start:head_rank_end] - n = narrow_weight.shape[0] - params_dict[name][:n].copy_(narrow_weight) - loaded_params.add(name) - continue - else: - if is_pp_missing_parameter(name, self): - continue - param = params_dict[name] - weight_loader = getattr( - param, "weight_loader", default_weight_loader - ) - weight_loader(param, loaded_weight) - loaded_params.add(name) - continue - - return loaded_params - - def get_expert_mapping(self) -> list[tuple[str, str, int, str]]: - first_layer = next(iter(islice(self.layers, self.start_layer, self.end_layer))) - if first_layer.ffn.use_mega_moe: - return make_deepseek_v4_expert_params_mapping(self.config.n_routed_experts) - # Params for weights, fp8 weight scales, fp8 activation scales - # (param_name, weight_name, expert_id, shard_id) - return FusedMoE.make_expert_params_mapping( - self, - ckpt_gate_proj_name="w1", - ckpt_down_proj_name="w2", - ckpt_up_proj_name="w3", - num_experts=self.config.n_routed_experts, - ) - - def finalize_mega_moe_weights(self) -> None: - for layer in islice(self.layers, self.start_layer, self.end_layer): - layer.ffn.finalize_mega_moe_weights() - - -@torch.compile(backend=current_platform.simple_compile_backend) -def hc_head( - hidden_states: torch.Tensor, - hc_fn: torch.Tensor, - hc_scale: torch.Tensor, - hc_base: torch.Tensor, - rms_norm_eps: float, - hc_eps: float, -) -> torch.Tensor: - hc_mult, hidden_size = hidden_states.shape[-2:] - outer_shape = hidden_states.shape[:-2] - hs_flat = hidden_states.view(-1, hc_mult, hidden_size) - num_tokens = hs_flat.shape[0] - out = torch.empty( - num_tokens, hidden_size, dtype=torch.bfloat16, device=hidden_states.device - ) - torch.ops.vllm.hc_head_fused_kernel( - hs_flat, - hc_fn, - hc_scale, - hc_base, - out, - hidden_size, - rms_norm_eps, - hc_eps, - hc_mult, - ) - return out.view(*outer_shape, hidden_size) - - -def _make_deepseek_v4_weights_mapper(expert_dtype: str) -> WeightsMapper: - if expert_dtype == "fp4": - # MXFP4 experts use Mxfp4MoEMethod, which registers scales as - # ``w{1,2,3}_weight_scale`` (no _inv suffix). FP8 linear and - # shared experts use Fp8LinearMethod's block scales, which - # register as ``weight_scale_inv``. - scale_regex = { - re.compile(r"(\.experts\.\d+\.w[123])\.scale$"): r"\1.weight_scale", - re.compile(r"\.scale$"): ".weight_scale_inv", - } - else: - # FP8 experts use Fp8MoEMethod (block_quant=True), which registers - # scales as ``w{13,2}_weight_scale_inv``. Map all ``.scale`` keys - # there. - scale_regex = { - re.compile(r"\.scale$"): ".weight_scale_inv", - } - return WeightsMapper( - orig_to_new_prefix={ - "layers.": "model.layers.", - "embed.": "model.embed.", - "norm.": "model.norm.", - "hc_head": "model.hc_head", - "mtp.": "model.mtp.", - }, - orig_to_new_regex=scale_regex, - orig_to_new_suffix={ - "head.weight": "lm_head.weight", - "embed.weight": "embed_tokens.weight", - ".ffn.gate.bias": ".ffn.gate.e_score_correction_bias", - }, - orig_to_new_substr={ - ".attn.compressor.": ".attn.mla_attn.compressor.", - ".shared_experts.w2": ".shared_experts.down_proj", - }, - ) - - -def _make_deepseek_v4_nvfp4_weights_mapper() -> WeightsMapper: - """Weight mapper for NVFP4 (ModelOpt) DeepSeek-V4 checkpoints. - - NVFP4 checkpoints use different key naming than the default MXFP4 format: - - ``self_attn`` prefix instead of ``attn`` - - ``mlp`` prefix instead of ``ffn`` - - Expert weights: gate_proj/up_proj/down_proj (not w1/w3/w2) - - Scales already have .weight_scale / .weight_scale_2 / .input_scale suffixes - - Compressor uses kv_proj/gate_proj (not wkv/wgate) - - o_a_proj is BF16 (no quantization scales) - """ - expert_rename_regex = { - re.compile(r"(\.experts\.\d+\.)gate_proj\."): r"\1w1.", - re.compile(r"(\.experts\.\d+\.)up_proj\."): r"\1w3.", - re.compile(r"(\.experts\.\d+\.)down_proj\."): r"\1w2.", - } - return WeightsMapper( - orig_to_new_prefix={ - "layers.": "model.layers.", - "embed.": "model.embed.", - "norm.": "model.norm.", - # hc_head NOT mapped here — checkpoint already has model.hc_head.* - # and model params are flat (hc_head_fn, not hc_head.fn) - "mtp.": "model.mtp.", - }, - orig_to_new_regex=expert_rename_regex, - orig_to_new_suffix={ - # NVFP4 checkpoint already uses lm_head.weight and - # model.embed_tokens.weight — no suffix renames needed. - ".ffn.gate.bias": ".ffn.gate.e_score_correction_bias", - }, - orig_to_new_substr={ - # Indexer params (MUST come before general compressor renames) - ".self_attn.compressor.indexer.q_b_proj.": - ".attn.indexer.wq_b.", - ".self_attn.compressor.indexer.weights_proj.": - ".attn.indexer.weights_proj.", - ".self_attn.compressor.indexer.kv_norm.": - ".attn.indexer.k_norm.", - ".self_attn.compressor.indexer.kv_proj.": - ".attn.indexer.compressor.wkv.", - ".self_attn.compressor.indexer.gate_proj.": - ".attn.indexer.compressor.wgate.", - ".self_attn.compressor.indexer.position_bias": - ".attn.indexer.compressor.ape", - # Compressor renames (non-indexer) - "compressor.kv_proj.": "compressor.wkv.", - "compressor.gate_proj.": "compressor.wgate.", - "compressor.kv_norm.": "compressor.norm.", - "compressor.position_bias": "compressor.ape", - ".self_attn.compressor.": ".attn.mla_attn.compressor.", - # Attention projections - ".self_attn.q_a_proj.": ".attn.wq_a.", - ".self_attn.kv_proj.": ".attn.wkv.", - ".self_attn.q_b_proj.": ".attn.wq_b.", - ".self_attn.o_a_proj.": ".attn.wo_a.", - ".self_attn.o_b_proj.": ".attn.wo_b.", - ".self_attn.q_a_norm.": ".attn.q_norm.", - ".self_attn.kv_norm.": ".attn.kv_norm.", - ".self_attn.sinks": ".attn.attn_sink", - # Shared experts - ".mlp.shared_experts.gate_proj.": - ".ffn.shared_experts.w1.", - ".mlp.shared_experts.up_proj.": - ".ffn.shared_experts.w3.", - ".mlp.shared_experts.down_proj.": - ".ffn.shared_experts.down_proj.", - # General renames - ".mlp.": ".ffn.", - ".self_attn.": ".attn.", - "input_layernorm.": "attn_norm.", - "post_attention_layernorm.": "ffn_norm.", - # HC params - ".attn_hc.fn": ".hc_attn_fn", - ".attn_hc.base": ".hc_attn_base", - ".attn_hc.scale": ".hc_attn_scale", - ".ffn_hc.fn": ".hc_ffn_fn", - ".ffn_hc.base": ".hc_ffn_base", - ".ffn_hc.scale": ".hc_ffn_scale", - "hc_head.hc_fn": "hc_head_fn", - "hc_head.hc_base": "hc_head_base", - "hc_head.hc_scale": "hc_head_scale", - }, - ) - - -class DeepseekV4ForCausalLM(nn.Module, SupportsPP): - model_cls = DeepseekV4Model - - # Default mapper assumes the original FP4-expert checkpoint layout. - # Overridden per-instance in __init__ when expert_dtype != "fp4". - hf_to_vllm_mapper = _make_deepseek_v4_weights_mapper("fp4") - - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): - super().__init__() - - config = vllm_config.model_config.hf_config - self.config = config - expert_dtype = getattr(config, "expert_dtype", "fp4") - quant_config = vllm_config.quant_config - # Select weight mapper based on quantization method. - # NVFP4 (modelopt_fp4) checkpoints use different key naming - # than the default MXFP4 format. - if (quant_config is not None - and quant_config.get_name() == "modelopt_fp4"): - self.hf_to_vllm_mapper = _make_deepseek_v4_nvfp4_weights_mapper() - elif expert_dtype != "fp4": - self.hf_to_vllm_mapper = _make_deepseek_v4_weights_mapper( - expert_dtype) - - self.model = self.model_cls( - vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model") - ) - if get_pp_group().is_last_rank: - self.lm_head = ParallelLMHead( - config.vocab_size, - config.hidden_size, - prefix=maybe_prefix(prefix, "lm_head"), - ) - else: - self.lm_head = PPMissingLayer() - self.logits_processor = LogitsProcessor(config.vocab_size) - self.make_empty_intermediate_tensors = ( - self.model.make_empty_intermediate_tensors - ) - - def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: - return self.model.embed_input_ids(input_ids) - - def compute_logits( - self, - hidden_states: torch.Tensor, - ) -> torch.Tensor | None: - logits = self.logits_processor(self.lm_head, hidden_states) - return logits - - def forward( - self, - input_ids: torch.Tensor, - positions: torch.Tensor, - intermediate_tensors: IntermediateTensors | None = None, - inputs_embeds: torch.Tensor | None = None, - ) -> torch.Tensor | IntermediateTensors: - hidden_states = self.model( - input_ids, positions, intermediate_tensors, inputs_embeds - ) - return hidden_states - - def get_mtp_target_hidden_states(self) -> torch.Tensor | None: - """Pre-hc_head residual stream buffer (max_num_batched_tokens, - hc_mult * hidden_size) for the MTP draft model. Populated by - forward(); valid after each target step.""" - return getattr(self.model, "_mtp_hidden_buffer", None) - - def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: - loader = AutoWeightsLoader(self, skip_substrs=["mtp."]) - loaded_params = loader.load_weights(weights, mapper=self.hf_to_vllm_mapper) - self.model.finalize_mega_moe_weights() - return loaded_params - - def get_expert_mapping(self) -> list[tuple[str, str, int, str]]: - return self.model.get_expert_mapping() diff --git a/vllm/patches/deepseek_v4_attention.py b/vllm/patches/deepseek_v4_attention.py deleted file mode 100644 index 7b064240..00000000 --- a/vllm/patches/deepseek_v4_attention.py +++ /dev/null @@ -1,1537 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -""" -DeepseekV4 MLA Attention Layer -""" - -from collections.abc import Callable -from dataclasses import dataclass -from typing import TYPE_CHECKING, Any, cast - -import torch -import torch.nn as nn -import torch.nn.functional as F -from transformers import DeepseekV2Config, DeepseekV3Config - -import vllm.envs as envs -from vllm.model_executor.layers.linear import ( - ReplicatedLinear, -) -from vllm.model_executor.layers.sparse_attn_indexer import SparseAttnIndexer -from vllm.utils.deep_gemm import fp8_einsum -from vllm.utils.torch_utils import direct_register_custom_op -from vllm.v1.attention.ops.deepseek_v4_ops import ( - combine_topk_swa_indices, - compute_global_topk_indices_and_lens, - dequantize_and_gather_k_cache, - fused_indexer_q_rope_quant, - fused_inv_rope_fp8_quant, - fused_q_kv_rmsnorm, -) -from vllm.v1.attention.ops.rocm_aiter_mla_sparse import rocm_inv_rope_einsum - -if TYPE_CHECKING: - from vllm.v1.attention.backends.mla.sparse_swa import ( - DeepseekSparseSWAMetadata, - ) - -from vllm.config import ( - CacheConfig, - VllmConfig, - get_current_vllm_config, -) -from vllm.distributed import (get_tensor_model_parallel_world_size, - tensor_model_parallel_all_gather) -from vllm.forward_context import ForwardContext, get_forward_context -from vllm.logger import init_logger -from vllm.model_executor.custom_op import PluggableLayer -from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase -from vllm.model_executor.layers.deepseek_compressor import DeepseekCompressor -from vllm.model_executor.layers.layernorm import LayerNorm, RMSNorm -from vllm.model_executor.layers.quantization import QuantizationConfig - -from vllm.platforms import current_platform -from vllm.utils.multi_stream_utils import ( - execute_in_parallel, - maybe_execute_in_parallel, -) -from vllm.v1.attention.backend import AttentionBackend, AttentionMetadata -from vllm.v1.attention.backends.mla.flashmla_sparse import ( - DeepseekV4FlashMLASparseBackend, - FlashMLASparseBackend, - FlashMLASparseMetadata, -) -from vllm.v1.attention.backends.mla.indexer import ( - DeepseekV4IndexerBackend, - get_max_prefill_buffer_size, -) -from vllm.v1.attention.backends.mla.sparse_swa import DeepseekV4SWACache -from vllm.v1.attention.ops.flashmla import ( - flash_mla_sparse_fwd, - flash_mla_with_kvcache, -) -from vllm.v1.kv_cache_interface import KVCacheSpec, MLAAttentionSpec -from vllm.v1.worker.workspace import current_workspace_manager - -logger = init_logger(__name__) - -# Prefill is processed in fixed-size chunks; this bounds the bf16 kv-gather -# workspace allocated at _forward_prefill (and the matching profile-time -# reservation in attention_impl's dummy-run branch). -PREFILL_CHUNK_SIZE = 4 - - -@dataclass -class DeepseekV4MLAModules: - """Modules used in DeepseekV4 MLA.""" - - vllm_config: VllmConfig - fused_wqa_wkv: torch.nn.Module - q_norm: torch.nn.Module - wq_b: torch.nn.Module - kv_norm: torch.nn.Module - wo_a: torch.nn.Module - wo_b: torch.nn.Module - attn_sink: torch.nn.Module - rotary_emb: torch.nn.Module - indexer: torch.nn.Module | None - indexer_rotary_emb: torch.nn.Module - topk_indices_buffer: torch.Tensor | None - aux_stream_list: list[torch.cuda.Stream] | None = None - - -# --8<-- [start:multi_head_latent_attention] -@PluggableLayer.register("deepseek_v4_multi_head_latent_attention") -class DeepseekV4MultiHeadLatentAttentionWrapper(PluggableLayer): - """Pluggable MLA layer which allows OOT backends to add - custom implementations of the outer MLA layer (including rope & o_proj). - Note that currently oot platforms can still use CustomOp.register_oot to - replace MLA layer entirely, although we use PluggableLayer to register - this layer now. - - This class takes positions and hidden_states as input. - The input tensors can either contain prefill tokens or decode tokens. - The class does the following: - - 1. MLA Preprocess. - 2. Perform multi-head attention to prefill tokens and - multi-query attention to decode tokens separately. - 3. Return the output tensor. - """ - - # --8<-- [end:multi_head_latent_attention] - - def __init__( - self, - hidden_size: int, - num_heads: int, - head_dim: int, - scale: float, - qk_nope_head_dim: int, - qk_rope_head_dim: int, - v_head_dim: int, - q_lora_rank: int | None, - kv_lora_rank: int, - o_lora_rank: int | None, - mla_modules: DeepseekV4MLAModules, - window_size: int, - compress_ratio: int | None, - cache_config: CacheConfig | None = None, - quant_config: QuantizationConfig | None = None, - prefix: str = "", - ) -> None: - super().__init__() - self.hidden_size = hidden_size - self.n_local_heads = num_heads - self.head_dim = head_dim - self.scale = scale - - # FlashMLA sparse kernel only supports 64 or 128 heads; pad up to the - # next supported size. Must match DeepseekV4MLAAttention.padded_heads. - if num_heads <= 64: - self.padded_heads = 64 - elif num_heads <= 128: - self.padded_heads = 128 - else: - raise ValueError( - f"DeepseekV4 attention does not support {num_heads} heads " - "(must be <= 128)." - ) - - self.q_lora_rank = q_lora_rank - self.kv_lora_rank = kv_lora_rank - self.window_size = window_size - self.compress_ratio = compress_ratio if compress_ratio is not None else 1 - self.prefix = prefix - - # Extract config from vllm_config - config = mla_modules.vllm_config.model_config.hf_config - tp_size = get_tensor_model_parallel_world_size() - - # DeepseekV4-specific attributes (num_heads is already TP-adjusted) - self.eps = config.rms_norm_eps - self.rope_head_dim = config.qk_rope_head_dim - self.nope_head_dim = head_dim - self.rope_head_dim - self.n_local_groups = config.o_groups // tp_size - self.o_lora_rank = config.o_lora_rank - - # Store projection modules - self.fused_wqa_wkv = mla_modules.fused_wqa_wkv - self.q_norm = mla_modules.q_norm - self.wq_b = mla_modules.wq_b - - self.kv_norm = mla_modules.kv_norm - self.wo_a = mla_modules.wo_a - self.wo_b = mla_modules.wo_b - - # Pick fp8_einsum recipe based on GPU arch: - # SM90: FP32 block scales stay [g, r/128, d/128] → sfb_gran_mn=128 - # SM100: INT32 packed scales become [g, r, ...] → sfb_gran_mn=1 - cap = current_platform.get_device_capability() - assert cap is not None, "DeepseekV4 attention requires a CUDA device" - self._einsum_recipe = (1, 128, 128) if cap.major <= 9 else (1, 1, 128) - self._tma_aligned_scales = cap.major >= 10 - - self.rotary_emb = mla_modules.rotary_emb - self.indexer_rotary_emb = mla_modules.indexer_rotary_emb - self.topk_indices_buffer = mla_modules.topk_indices_buffer - - self.indexer = mla_modules.indexer - - # Per-head RMS normalization for Q (no learnable weights) - self.q_head_norm = RMSNorm(head_dim, eps=self.eps, has_weight=False) - - # TODO(yifan): currently hardcoded for FP8 sparse, make it more generic - head_bytes = ( - self.nope_head_dim # 448 fp8 NoPE - + self.rope_head_dim * 2 # 64 bf16 RoPE - + self.nope_head_dim // 64 # 7B scale factors - + 1 # 1B pad - ) - - # Will be None on ROCm for now. - self.aux_stream_list = mla_modules.aux_stream_list - # [0]: GEMM start / post-GEMM event0. [1..3]: GEMM done events; - # [1] doubles as post-GEMM event1. Reuse is safe: GEMM fully joins - # before post-GEMM starts. - self.ln_events = [torch.cuda.Event() for _ in range(4)] - - assert cache_config is not None, "DeepseekV4 attention requires cache_config" - self.swa_cache_layer = DeepseekV4SWACache( - head_dim=self.head_dim, - window_size=self.window_size, - dtype=torch.uint8, - prefix=f"{prefix}.swa_cache", - cache_config=cache_config, - ) - - self.mla_attn = DeepseekV4MLAAttention( - num_heads=self.n_local_heads, - head_dim=self.head_dim, - scale=self.scale, - qk_nope_head_dim=self.nope_head_dim, - qk_rope_head_dim=self.rope_head_dim, - q_lora_rank=self.q_lora_rank, - kv_lora_rank=self.kv_lora_rank, - compress_ratio=self.compress_ratio, - window_size=self.window_size, - head_bytes=head_bytes, - swa_cache_layer=self.swa_cache_layer, - attn_sink=mla_modules.attn_sink, # already padded with -inf - cache_config=cache_config, - quant_config=quant_config, - prefix=prefix, - indexer=self.indexer, - topk_indices_buffer=self.topk_indices_buffer, - ) - # Register this layer in the compilation config's static forward context - # This allows the custom op to retrieve the layer during execution - compilation_config = mla_modules.vllm_config.compilation_config - # HACK - self.layer_name = prefix + ".deepseek_v4_multi_head_latent_attention" - if self.layer_name in compilation_config.static_forward_context: - raise ValueError(f"Duplicate layer name: {self.layer_name}") - compilation_config.static_forward_context[self.layer_name] = self - - # Create the compressor for layers with compress_ratio > 1; after - # creating the DeepseekV4MLAAttention layer to get its cache. - self.compressor = None - if self.compress_ratio > 1: - self.compressor = DeepseekCompressor( - vllm_config=mla_modules.vllm_config, - compress_ratio=self.compress_ratio, - hidden_size=self.hidden_size, - head_dim=self.head_dim, - rotate=True, - prefix=f"{prefix}.compressor", - k_cache_prefix=self.mla_attn.prefix, - ) - - def forward( - self, - positions: torch.Tensor, - hidden_states: torch.Tensor, - llama_4_scaling: torch.Tensor | None = None, - ) -> torch.Tensor: - # Pre-allocate attention output with FlashMLA-padded head count. - # The op writes into `o_padded`; we slice to n_local_heads after. - num_tokens = hidden_states.shape[0] - o_padded = torch.empty( - (num_tokens, self.padded_heads, self.head_dim), - dtype=hidden_states.dtype, - device=hidden_states.device, - ) - - # Attention (inside custom op for torch.compile boundary) - torch.ops.vllm.deepseek_v4_attention( - hidden_states, - positions, - o_padded, - self.layer_name, - ) - o = o_padded[:, : self.n_local_heads, :] - - # Keep ROCm on the BF16 reference wo_a path util kernel ready. - if current_platform.is_rocm(): - z = rocm_inv_rope_einsum( - self.rotary_emb, - o, - positions, - self.rope_head_dim, - self.n_local_groups, - self.o_lora_rank, - self.wo_a, - ) - return self.wo_b(z.flatten(1)) - - # Detect if wo_a has FP8 weights (weight_scale_inv attribute). - # NVFP4 checkpoints leave wo_a as BF16 (no quantization scales), - # so we use inverse RoPE in BF16 + regular matmul instead of - # the FP8 einsum path (which crashes on Blackwell SM100). - has_fp8_weights = hasattr(self.wo_a, 'weight_scale_inv') - - if not has_fp8_weights: - # BF16 wo_a path: inverse RoPE in BF16, then per-group BMM - # wo_a is a ColumnParallelLinear with is_bmm=True, meaning it - # operates per o-group. The FP8 path uses einsum "bhr,hdr->bhd" - # where h=n_local_groups. We must do the same grouping here. - o_inv = _apply_inv_rope_bf16( - o, positions, - self.rotary_emb.cos_sin_cache.to(torch.float32), - nope_dim=self.nope_head_dim, - rope_dim=self.rope_head_dim, - ) - heads_per_group = self.n_local_heads // self.n_local_groups - # o_inv: (num_tokens, n_local_heads, head_dim) - # -> (n_local_groups, num_tokens, heads_per_group * head_dim) - o_inv = o_inv.view( - num_tokens, self.n_local_groups, heads_per_group * self.head_dim - ).permute(1, 0, 2) - # wo_a weight is sharded by TP along output dim. - # Shape: (n_local_groups * o_lora_rank // tp, heads_per_group * head_dim) - # For BMM, we need weight shaped as (n_local_groups, o_lora_rank // tp, heads_per_group * head_dim) - wo_a_w = self.wo_a.weight.view( - self.n_local_groups, -1, heads_per_group * self.head_dim - ) - # BMM: (n_local_groups, num_tokens, in) @ (n_local_groups, in, out) -> (n_local_groups, num_tokens, out) - z = torch.bmm( - o_inv, - wo_a_w.transpose(1, 2), - ) - # -> (num_tokens, n_local_groups, o_lora_rank // tp) - z = z.permute(1, 0, 2) - # All-gather wo_a output across TP ranks, then flatten groups - if self.wo_a.gather_output and self.wo_a.tp_size > 1: - z = tensor_model_parallel_all_gather(z) - z = z.reshape(num_tokens, self.n_local_groups * self.o_lora_rank) - return self.wo_b(z) - - # FP8 wo_a path: fused inverse RoPE + FP8 quant + einsum - o_fp8, o_scale = fused_inv_rope_fp8_quant( - o, - positions, - self.rotary_emb.cos_sin_cache, - n_groups=self.n_local_groups, - heads_per_group=self.n_local_heads // self.n_local_groups, - nope_dim=self.nope_head_dim, - rope_dim=self.rope_head_dim, - tma_aligned_scales=self._tma_aligned_scales, - ) - - wo_a_fp8 = self.wo_a.weight - wo_a_scale = self.wo_a.weight_scale_inv - - z = torch.empty( - (num_tokens, self.n_local_groups, self.o_lora_rank), - device=o.device, - dtype=torch.bfloat16, - ) - torch.ops.vllm.deepseek_v4_fp8_einsum( - o_fp8, - o_scale, - wo_a_fp8, - wo_a_scale, - z, - "bhr,hdr->bhd", - list(self._einsum_recipe), - ) - - return self.wo_b(z.flatten(1)) - - def attn_gemm_parallel_execute(self, hidden_states) -> tuple[Any, ...]: - aux_streams = self.aux_stream_list - if aux_streams is not None: - assert len(aux_streams) >= 3 - aux_streams = aux_streams[:3] - - # fused_wqa_wkv (heaviest) on default; the three lighter input GEMMs - # on aux streams 0..2 when their owning module exists. ln_events[0] - # is the fan-out start event; ln_events[1..3] are per-aux done events. - # On ROCm, aux_streams is None and execute_in_parallel runs serially. - aux_fns: list[Callable[[], Any] | None] = [None, None, None] - - if self.compressor is not None: - # Local ref so the closure keeps a non-None type for mypy. - compressor = self.compressor - - def compressor_kv_score() -> torch.Tensor: - return torch.mm( - hidden_states, - compressor.fused_wkv_wgate.weight.T, - out_dtype=torch.float32, - ) - - aux_fns[0] = compressor_kv_score - - if self.indexer is not None: - indexer = self.indexer - - def indexer_weights_proj() -> torch.Tensor: - # ReplicatedLinear returns (output, bias); bias is None. - weights, _ = indexer.weights_proj(hidden_states) - return weights - - def indexer_compressor_kv_score() -> torch.Tensor: - return torch.mm( - hidden_states, - indexer.compressor.fused_wkv_wgate.weight.T, - out_dtype=torch.float32, - ) - - aux_fns[1] = indexer_weights_proj - aux_fns[2] = indexer_compressor_kv_score - - def fused_wqa_wkv() -> torch.Tensor: - # MergedColumnParallelLinear returns (output, bias); bias is None. - qr_kv, _ = self.fused_wqa_wkv(hidden_states) - return qr_kv - - qr_kv, (kv_score, indexer_weights, indexer_kv_score) = execute_in_parallel( - fused_wqa_wkv, - aux_fns, - self.ln_events[0], - self.ln_events[1:4], - aux_streams, - enable=hidden_states.shape[0] - <= envs.VLLM_MULTI_STREAM_GEMM_TOKEN_THRESHOLD, - ) - - return qr_kv, kv_score, indexer_kv_score, indexer_weights - - def attention_impl( - self, - hidden_states: torch.Tensor, - positions: torch.Tensor, - out: torch.Tensor, # [num_tokens, padded_heads, head_dim], written in place - ) -> None: - # ── Blackwell (SM100+) path ────────────────────────────────── - # FlashMLA and fused CUDA kernels don't work on SM100. - # Use CSA/SDPA attention with pure PyTorch instead. - cap = current_platform.get_device_capability() - if cap is not None and cap.major >= 10: - self._attention_impl_blackwell(hidden_states, positions, out) - return - - # ── Original path (SM90 and below) ─────────────────────────── - forward_context = get_forward_context() - attn_metadata = forward_context.attn_metadata - - qr_kv, kv_score, indexer_kv_score, indexer_weights = ( - self.attn_gemm_parallel_execute(hidden_states) - ) - - qr, kv = qr_kv.split([self.q_lora_rank, self.head_dim], dim=-1) - qr, kv = fused_q_kv_rmsnorm( - qr, - kv, - self.q_norm.weight.data, - self.kv_norm.weight.data, - self.eps, - ) - - # wq_b + kv_insert (+ MLA compressor when an indexer is present) ride - # on the default stream so q stays on its consumer stream (mla_attn - # downstream reads q on default). Indexer/compressor go on aux for - # overlap with default's GEMM + cache write. - if self.indexer is not None: - aux_stream = ( - self.aux_stream_list[0] if self.aux_stream_list is not None else None - ) - indexer = self.indexer - # Local ref so the closure keeps a non-None type for mypy. - assert self.compressor is not None - compressor = self.compressor - - def wq_b_kv_insert_and_compress() -> torch.Tensor: - q = self.wq_b(qr).view(-1, self.n_local_heads, self.head_dim) - self._fused_qnorm_rope_kv_insert(q, kv, positions, attn_metadata) - compressor(kv_score, positions, self.rotary_emb) - return q - - q, _ = maybe_execute_in_parallel( - wq_b_kv_insert_and_compress, - lambda: indexer( - hidden_states, - qr, - indexer_kv_score, - indexer_weights, - positions, - self.indexer_rotary_emb, - ), - self.ln_events[0], - self.ln_events[1], - aux_stream, - ) - elif self.compressor is not None: - # wq_b + kv_insert on default, compressor on aux. - aux_stream = ( - self.aux_stream_list[0] if self.aux_stream_list is not None else None - ) - compressor = self.compressor - - def wq_b_kv_insert() -> torch.Tensor: - q = self.wq_b(qr).view(-1, self.n_local_heads, self.head_dim) - self._fused_qnorm_rope_kv_insert(q, kv, positions, attn_metadata) - return q - - q, _ = maybe_execute_in_parallel( - wq_b_kv_insert, - lambda: compressor(kv_score, positions, self.rotary_emb), - self.ln_events[0], - self.ln_events[1], - aux_stream, - ) - else: - # SWA-only layer: no compressor, no overlap. - q = self.wq_b(qr).view(-1, self.n_local_heads, self.head_dim) - self._fused_qnorm_rope_kv_insert(q, kv, positions, attn_metadata) - - # Handle dummy run (no metadata). - if not isinstance(attn_metadata, dict): - # Reserve _forward_prefill's bf16-gather workspace; the dummy - # run returns before mla_attn runs, so without this the shared - # workspace locks below the real prefill size. - sub = self.mla_attn - swa_only = sub.compress_ratio <= 1 - N = ( - 0 - if swa_only - else (sub.max_model_len + sub.compress_ratio - 1) // sub.compress_ratio - ) - M = N + sub.window_size + sub.max_num_batched_tokens - current_workspace_manager().get_simultaneous( - ((PREFILL_CHUNK_SIZE, M, q.shape[-1]), torch.bfloat16), - ) - out.zero_() - return - - # Pad q to FlashMLA-required head count (64 or 128) - if self.n_local_heads < self.padded_heads: - pad_size = self.padded_heads - self.n_local_heads - q = F.pad(q, (0, 0, 0, pad_size), value=0.0) - - # MLA attention writes into the pre-allocated `out` buffer - # ([num_tokens, padded_heads, head_dim]). - self.mla_attn(q, kv, positions, output=out) - - def _attention_impl_blackwell( - self, - hidden_states: torch.Tensor, - positions: torch.Tensor, - out: torch.Tensor, - ) -> None: - """Blackwell (SM100+) attention: KV cache-based, no FlashMLA. - - FIXED: Now writes KV to the paged cache and reads from it during decode. - Supports SWA, CSA (C4A), and HCA (C128A) attention. - - Pipeline: - 1. Project Q and KV (same as original) - 2. Apply RoPE to Q (in-place) - 3. Write KV to SWA paged cache (RoPE + fp8 quantize + insert) - 4. Run compressor (Triton, works on Blackwell) → compressed KV cache - 5. Run indexer (Triton, works on Blackwell) → topk_indices - 6. SWA layers: full decode attention with KV cache - 7. CSA/HCA layers: sparse attention on compressed KV + SWA + sink merge - """ - from vllm.model_executor.layers.csa_attention import ( - fused_qnorm_rope_kv_insert_py, - blackwell_attention_kv_write, - blackwell_attention_decode, - blackwell_csa_decode_attention, - causal_prefill_attention, - csa_sparse_prefill_attention, - ) - - forward_context = get_forward_context() - attn_metadata = forward_context.attn_metadata - - # Debug: check input for NaN - import sys as _sys - _hs_nan = torch.isnan(hidden_states).any().item() - if _hs_nan: - print(f"[BLACKWELL] INPUT NaN: cr={self.compress_ratio}", file=_sys.stderr, flush=True) - - qr_kv, kv_score, indexer_kv_score, indexer_weights = ( - self.attn_gemm_parallel_execute(hidden_states) - ) - qr, kv = qr_kv.split([self.q_lora_rank, self.head_dim], dim=-1) - qr, kv = fused_q_kv_rmsnorm( - qr, kv, - self.q_norm.weight.data, - self.kv_norm.weight.data, - self.eps, - ) - - # wq_b - q = self.wq_b(qr).view(-1, self.n_local_heads, self.head_dim) - - # Run compressor on default stream (it's Triton-based, works on Blackwell) - if self.compressor is not None: - self.compressor(kv_score, positions, self.rotary_emb) - - # Run indexer if present - if self.indexer is not None: - self.indexer( - hidden_states, qr, indexer_kv_score, indexer_weights, - positions, self.indexer_rotary_emb, - ) - - # Get metadata - if not isinstance(attn_metadata, dict): - fused_qnorm_rope_kv_insert_py( - q, kv, None, None, positions.to(torch.int64), - self.rotary_emb.cos_sin_cache, self.eps, 0, - nope_dim=self.nope_head_dim, - rope_dim=self.rope_head_dim, - ) - out.zero_() - return - - from vllm.v1.attention.backends.mla.sparse_swa import ( - DeepseekSparseSWAMetadata, - ) - swa_metadata = cast( - "DeepseekSparseSWAMetadata | None", - attn_metadata.get(self.swa_cache_layer.prefix), - ) - if swa_metadata is None: - fused_qnorm_rope_kv_insert_py( - q, kv, None, None, positions.to(torch.int64), - self.rotary_emb.cos_sin_cache, self.eps, 0, - nope_dim=self.nope_head_dim, - rope_dim=self.rope_head_dim, - ) - out.zero_() - return - - # Apply per-head RMS norm + RoPE on Q (in-place) - swa_kv_cache = self.swa_cache_layer.kv_cache - swa_kv_cache_2d = swa_kv_cache.view(swa_kv_cache.shape[0], -1) - fused_qnorm_rope_kv_insert_py( - q, kv, swa_kv_cache_2d, - swa_metadata.slot_mapping, - positions.to(torch.int64), - self.rotary_emb.cos_sin_cache, - self.eps, - swa_metadata.block_size, - nope_dim=self.nope_head_dim, - rope_dim=self.rope_head_dim, - ) - - # Split prefill and decode - num_decode_tokens = swa_metadata.num_decode_tokens - num_prefills = swa_metadata.num_prefill_tokens - swa_only = self.compress_ratio <= 1 - - # CRITICAL FIX: Write KV to paged cache (RoPE + fp8 quant + insert) - if not hasattr(self, '_swa_inv_scale_cache'): - max_slots = swa_kv_cache.shape[0] * swa_kv_cache.shape[1] - self._swa_inv_scale_cache = torch.zeros( - max_slots, 1, dtype=torch.bfloat16, device=kv.device, - ) - import sys - print(f"[BLACKWELL] swa_kv_cache shape: {swa_kv_cache.shape}, " - f"block_size: {swa_metadata.block_size}, " - f"num_decode_tokens: {num_decode_tokens}, " - f"num_prefills: {num_prefills}, " - f"compress_ratio: {self.compress_ratio}, " - f"slot_mapping shape: {swa_metadata.slot_mapping.shape}, " - f"positions shape: {positions.shape}, " - f"kv shape: {kv.shape}", file=sys.stderr, flush=True) - blackwell_attention_kv_write( - kv, positions, swa_kv_cache, self._swa_inv_scale_cache, - swa_metadata.slot_mapping, swa_metadata.block_size, - self.rotary_emb.cos_sin_cache, - nope_dim=self.nope_head_dim, - rope_dim=self.rope_head_dim, - ) - - # Get compressed KV cache and indexer metadata for CSA/HCA - flashmla_metadata = None - if not swa_only: - flashmla_metadata = cast( - "FlashMLASparseMetadata | None", - attn_metadata.get(self.prefix), - ) - - o = torch.zeros( - hidden_states.shape[0], self.n_local_heads, self.head_dim, - dtype=torch.bfloat16, device=hidden_states.device, - ) - - # ── Decode attention ────────────────────────────────────── - if num_decode_tokens > 0: - import sys - print(f"[BLACKWELL] DECODE: {num_decode_tokens} tokens, swa_only={swa_only}", file=sys.stderr, flush=True) - if swa_only: - # SWA-only layers: full decode attention with KV cache - q_decode = q[:num_decode_tokens] - pos_decode = positions[:num_decode_tokens] - for t in range(num_decode_tokens): - o[t] = blackwell_attention_decode( - q_decode[t:t+1], pos_decode[t:t+1], - swa_kv_cache, self._swa_inv_scale_cache, - swa_metadata.slot_mapping[t:t+1], - swa_metadata.block_size, - self.scale, - self.window_size, - swa_indices=swa_metadata.decode_swa_indices, - swa_lens=swa_metadata.decode_swa_lens, - decode_token_idx=t, - ).squeeze(0) - else: - # CSA/HCA layers: sparse attention + SWA + sink merge - o[:num_decode_tokens] = blackwell_csa_decode_attention( - q[:num_decode_tokens], - positions[:num_decode_tokens], - swa_kv_cache, - self._swa_inv_scale_cache, - swa_metadata, - flashmla_metadata, - self.mla_attn.kv_cache if not swa_only else None, - self.compress_ratio, - self.scale, - self.window_size, - self.nope_head_dim, - self.rope_head_dim, - self.rotary_emb.cos_sin_cache, - self.mla_attn.attn_sink, - self.mla_attn.max_model_len, - ) - - # ── Prefill attention ───────────────────────────────────── - if num_prefills > 0: - import sys - print(f"[BLACKWELL] PREFILL: {num_prefills} tokens, swa_only={swa_only}", file=sys.stderr, flush=True) - q_prefill = q[num_decode_tokens:] - kv_rope_prefill = self._apply_rope_kv( - kv[num_decode_tokens:], positions[num_decode_tokens:], - ) - # Debug: check attention inputs - import sys as _sys - _q_nan = torch.isnan(q_prefill).any().item() - _kv_nan = torch.isnan(kv_rope_prefill).any().item() - if _q_nan or _kv_nan: - print(f"[BLACKWELL] PREFILL INPUTS NaN: q_nan={_q_nan} kv_nan={_kv_nan} cr={self.compress_ratio}", file=_sys.stderr, flush=True) - if swa_only: - o[num_decode_tokens:] = causal_prefill_attention( - q_prefill, kv_rope_prefill, self.scale, - ) - else: - # CSA/HCA prefill: sparse + SWA (fallback to full causal for now) - o[num_decode_tokens:] = causal_prefill_attention( - q_prefill, kv_rope_prefill, self.scale, - ) - # Debug: check attention output - import sys as _sys - _amax = o[num_decode_tokens:].amax().item() - _nan = torch.isnan(o[num_decode_tokens:]).any().item() - _std = o[num_decode_tokens:].float().std().item() - if _amax > 100 or _nan or _std < 0.001: - print(f"[BLACKWELL] PREFILL CHECK: amax={_amax:.4f} NaN={_nan} std={_std:.6f} cr={self.compress_ratio}", file=_sys.stderr, flush=True) - - # Write into the output buffer - if self.n_local_heads < self.padded_heads: - out[:, :self.n_local_heads, :] = o - out[:, self.n_local_heads:, :] = 0 - else: - out.copy_(o) - - def _apply_rope_q(self, q, positions): - """Apply GPT-J RoPE to Q in-place (fallback when no SWA metadata).""" - half = self.rope_head_dim // 2 - cos_q = self.rotary_emb.cos_sin_cache[positions, :half].unsqueeze(1).to(q.dtype) - sin_q = self.rotary_emb.cos_sin_cache[positions, half:].unsqueeze(1).to(q.dtype) - q_rope = q[:, :, self.nope_head_dim:].clone() - q[:, :, self.nope_head_dim:][:, :, 0::2] = q_rope[:, :, 0::2] * cos_q - q_rope[:, :, 1::2] * sin_q - q[:, :, self.nope_head_dim:][:, :, 1::2] = q_rope[:, :, 0::2] * sin_q + q_rope[:, :, 1::2] * cos_q - - def _apply_rope_kv(self, kv, positions): - """Apply GPT-J RoPE to KV latent and return the result.""" - half = self.rope_head_dim // 2 - cos = self.rotary_emb.cos_sin_cache[positions, :half].to(kv.dtype) - sin = self.rotary_emb.cos_sin_cache[positions, half:2*half].to(kv.dtype) - # kv: (T, HD) — apply RoPE to the rope portion (after nope_dim) - kv_rope = kv[:, self.nope_head_dim:].clone() - even = kv_rope[:, 0::2] - odd = kv_rope[:, 1::2] - out = kv.clone() - out[:, self.nope_head_dim:][:, 0::2] = even * cos - odd * sin - out[:, self.nope_head_dim:][:, 1::2] = even * sin + odd * cos - return out - - def _fused_qnorm_rope_kv_insert( - self, - q: torch.Tensor, - kv: torch.Tensor, - positions: torch.Tensor, - attn_metadata: ( - dict[str, AttentionMetadata] | list[dict[str, AttentionMetadata]] | None - ), - ) -> None: - if not isinstance(attn_metadata, dict): - return - - swa_metadata = cast( - "DeepseekSparseSWAMetadata | None", - attn_metadata.get(self.swa_cache_layer.prefix), - ) - assert swa_metadata is not None - - swa_kv_cache = self.swa_cache_layer.kv_cache - swa_kv_cache_2d = swa_kv_cache.view(swa_kv_cache.shape[0], -1) - - # Horizontally fused: - # Q side: q_head_norm (per-head RMSNorm, no weight) + GPT-J RoPE - # KV side: GPT-J RoPE + UE8M0 FP8 quant + paged cache insert - # kv is unchanged; mla_attn reads kv solely via swa_kv_cache. - torch.ops._C.fused_deepseek_v4_qnorm_rope_kv_rope_quant_insert( - q, - kv, - swa_kv_cache_2d, - swa_metadata.slot_mapping, - positions.to(torch.int64), - self.rotary_emb.cos_sin_cache, - self.eps, - swa_metadata.block_size, - ) - - -def _apply_inv_rope_bf16( - o: torch.Tensor, - positions: torch.Tensor, - cos_sin_cache: torch.Tensor, - nope_dim: int, - rope_dim: int, -) -> torch.Tensor: - """Apply inverse RoPE to attention output in BF16. - - Inverse RoPE is just RoPE with sin -> -sin. - Uses GPT-J style (interleaved) rotary embedding. - """ - if rope_dim == 0 or o.numel() == 0: - return o - half_rot = rope_dim // 2 - o_f32 = o.to(torch.float32) - cache = cos_sin_cache.index_select(0, positions.to(torch.long)) - cos = cache[:, :half_rot].to(torch.float32) - sin = cache[:, half_rot : 2 * half_rot].to(torch.float32) - view_shape = (positions.shape[0], 1, half_rot) - cos = cos.view(view_shape) - sin = sin.view(view_shape) - rope = o_f32[..., nope_dim:] - y_even = rope[..., 0::2] - y_odd = rope[..., 1::2] - # Inverse: sin → -sin (swap signs on cross terms) - rope_out = torch.stack( - (y_even * cos + y_odd * sin, y_odd * cos - y_even * sin), - dim=-1, - ).flatten(-2) - o_f32 = o_f32.clone() - o_f32[..., nope_dim:] = rope_out - return o_f32.to(o.dtype) - - -def deepseek_v4_attention( - hidden_states: torch.Tensor, - positions: torch.Tensor, - out: torch.Tensor, - layer_name: str, -) -> None: - forward_context: ForwardContext = get_forward_context() - self = forward_context.no_compile_layers[layer_name] - self.attention_impl(hidden_states, positions, out) - - -def deepseek_v4_attention_fake( - hidden_states: torch.Tensor, - positions: torch.Tensor, - out: torch.Tensor, - layer_name: str, -) -> None: - return None - - -direct_register_custom_op( - op_name="deepseek_v4_attention", - op_func=deepseek_v4_attention, - mutates_args=["out"], - fake_impl=deepseek_v4_attention_fake, -) - - -def deepseek_v4_fp8_einsum( - a: torch.Tensor, - a_scale: torch.Tensor, - b: torch.Tensor, - b_scale: torch.Tensor, - out: torch.Tensor, - equation: str, - recipe: list[int], -) -> None: - fp8_einsum(equation, (a, a_scale), (b, b_scale), out, recipe=tuple(recipe)) - - -def deepseek_v4_fp8_einsum_fake( - a: torch.Tensor, - a_scale: torch.Tensor, - b: torch.Tensor, - b_scale: torch.Tensor, - out: torch.Tensor, - equation: str, - recipe: list[int], -) -> None: - return None - - -direct_register_custom_op( - op_name="deepseek_v4_fp8_einsum", - op_func=deepseek_v4_fp8_einsum, - mutates_args=["out"], - fake_impl=deepseek_v4_fp8_einsum_fake, -) - - -class DeepseekV4MLAAttention(nn.Module, AttentionLayerBase): - # FlashMLA FP8 sparse only supports 64 or 128 heads - SUPPORTED_HEAD_COUNTS = (64, 128) - - def __init__( - self, - num_heads: int, - head_dim: int, - scale: float, - qk_nope_head_dim: int, - qk_rope_head_dim: int, - q_lora_rank: int | None, - kv_lora_rank: int, - compress_ratio: int, - window_size: int, - head_bytes: int, - swa_cache_layer: DeepseekV4SWACache, - attn_sink: torch.Tensor, - cache_config: CacheConfig | None = None, - quant_config: QuantizationConfig | None = None, - prefix: str = "", - # Sparse MLA Args - indexer: object | None = None, - topk_indices_buffer: torch.Tensor | None = None, - aux_stream: torch.cuda.Stream | None = None, - **extra_impl_args, - ) -> None: - super().__init__() - self.num_heads = num_heads - self.num_kv_heads = 1 - self.head_dim = head_dim - self.scale = scale - self.window_size = window_size - self.head_bytes = head_bytes - self.compress_ratio = compress_ratio - self.q_lora_rank = q_lora_rank - self.kv_lora_rank = kv_lora_rank - self.nope_head_dim = qk_nope_head_dim - self.rope_head_dim = qk_rope_head_dim - self.indexer = indexer - self.topk_indices_buffer = topk_indices_buffer - - self.prefix = prefix # Alias for compatibility with compressor - - self.aux_stream = aux_stream - self.ln_events = [torch.cuda.Event(), torch.cuda.Event()] - - # Determine padded head count for FlashMLA - if num_heads not in self.SUPPORTED_HEAD_COUNTS: - if num_heads < 64: - self.padded_heads = 64 - elif num_heads < 128: - self.padded_heads = 128 - else: - raise ValueError( - f"DeepseekV4MLAAttention does not support {num_heads} heads. " - f"Supported: <= 128 (will be padded to 64 or 128)" - ) - else: - self.padded_heads = num_heads - - # Store attention sink - assert attn_sink is not None - self.attn_sink: torch.Tensor = attn_sink - # Store SWA cache - assert swa_cache_layer is not None - self.swa_cache_layer: DeepseekV4SWACache = swa_cache_layer - - # Get vllm config for cache setup - vllm_config = get_current_vllm_config() - self.max_num_batched_tokens = ( - vllm_config.scheduler_config.max_num_batched_tokens - ) - self.max_model_len = vllm_config.model_config.max_model_len - # DeepseekV4 only supports fp8 kv-cache format for now. - kv_cache_dtype = cache_config.cache_dtype if cache_config is not None else "fp8" - - assert kv_cache_dtype.startswith("fp8"), ( - f"DeepseekV4 only supports fp8 kv-cache format for now, " - f"got {kv_cache_dtype}" - ) - # On Blackwell (SM100+), FlashMLA kernels don't work. - # We use our own CSA/SDPA attention path. - _is_blackwell = ( - current_platform.get_device_capability() is not None - and current_platform.get_device_capability().major >= 10 - ) - if not _is_blackwell: - assert issubclass(self.get_attn_backend(), FlashMLASparseBackend), ( - "Only FlashMLA Sparse Attention backend is supported for DeepseekV4 for now" - ) - # FlashMLA Sparse Attention fp8 backend uses "fp8_ds_mla" kv-cache format - # Automatically convert fp8 kv-cache format to "fp8_ds_mla" - # On Blackwell, we use our own attention path, so keep standard fp8 - if not _is_blackwell and ( - issubclass(self.get_attn_backend(), FlashMLASparseBackend) - and kv_cache_dtype.startswith("fp8") - and kv_cache_dtype != "fp8_ds_mla" - ): - assert cache_config is not None - cache_config.cache_dtype = "fp8_ds_mla" - kv_cache_dtype = "fp8_ds_mla" - logger.info_once("Using DeepSeek's fp8_ds_mla KV cache format.") - - self.kv_cache_dtype = kv_cache_dtype - - # Register with compilation context for metadata lookup - compilation_config = vllm_config.compilation_config - if prefix and prefix in compilation_config.static_forward_context: - raise ValueError(f"Duplicate layer name: {prefix}") - if prefix: - compilation_config.static_forward_context[prefix] = self - - self.kv_cache = torch.tensor([]) - - def get_attn_backend(self) -> type[AttentionBackend]: - cap = current_platform.get_device_capability() - if cap is not None and cap.major >= 10: - # Blackwell: FlashMLA doesn't work. Use our CSA/SDPA path. - # Return the base class so KV cache setup doesn't force fp8_ds_mla. - from vllm.v1.attention.backends.mla.sparse_swa import ( - DeepseekSparseSWABackend, - ) - return DeepseekSparseSWABackend - if current_platform.is_rocm(): - from vllm.v1.attention.backends.mla.rocm_aiter_mla_sparse_dsv4 import ( - DeepseekV4ROCMAiterMLASparseBackend, - ) - return DeepseekV4ROCMAiterMLASparseBackend - return DeepseekV4FlashMLASparseBackend - - def get_kv_cache_spec(self, vllm_config: VllmConfig) -> KVCacheSpec | None: - if ( - self.compress_ratio <= 1 - ): # SWA part. Allocated separately as DeepseekV4SWACache. - return None - cap = current_platform.get_device_capability() - _is_blackwell = cap is not None and cap.major >= 10 - if _is_blackwell: - # Blackwell: no FlashMLA, use standard fp8_e4m3 KV cache - # No 576B FlashMLA alignment needed - return MLAAttentionSpec( - block_size=vllm_config.cache_config.block_size, - num_kv_heads=1, - head_size=self.head_dim, - dtype=torch.uint8, - compress_ratio=self.compress_ratio, - cache_dtype_str=self.kv_cache_dtype, # "fp8" (not fp8_ds_mla) - alignment=None, - model_version="deepseek_v4", - ) - return MLAAttentionSpec( - block_size=vllm_config.cache_config.block_size, - num_kv_heads=1, - head_size=self.head_dim, - dtype=torch.uint8, - compress_ratio=self.compress_ratio, - cache_dtype_str=self.kv_cache_dtype, - alignment=576, # NOTE: FlashMLA requires 576B alignment - model_version="deepseek_v4", - ) - - def forward( - self, - q: torch.Tensor, - kv: torch.Tensor, - positions: torch.Tensor, - output: torch.Tensor, - ) -> None: - assert output.shape == q.shape, ( - f"output buffer shape {output.shape} must match q shape {q.shape}" - ) - assert output.dtype == q.dtype, ( - f"output buffer dtype {output.dtype} must match q dtype {q.dtype}" - ) - - if current_platform.is_rocm(): - from vllm.v1.attention.backends.mla.rocm_aiter_mla_sparse_dsv4 import ( - DeepseekV4ROCMAiterMLASparseImpl, - ) - - DeepseekV4ROCMAiterMLASparseImpl.forward(self, q, kv, positions, output) - return - - # Get SWA and indexer metadata from forward context - forward_context = get_forward_context() - attn_metadata = forward_context.attn_metadata - assert isinstance(attn_metadata, dict) - flashmla_metadata = cast( - FlashMLASparseMetadata | None, attn_metadata.get(self.prefix) - ) - swa_metadata = cast( - "DeepseekSparseSWAMetadata | None", - attn_metadata.get(self.swa_cache_layer.prefix), - ) - assert swa_metadata is not None - - swa_only = self.compress_ratio <= 1 - # SWA-only layers (compress_ratio <= 1) don't have their own KV cache - # allocation, so self.kv_cache may be empty after profiling cleanup. - self_kv_cache = self.kv_cache if not swa_only else None - swa_kv_cache = self.swa_cache_layer.kv_cache - - # Split prefill and decode - num_decodes = swa_metadata.num_decodes - num_prefills = swa_metadata.num_prefills - num_decode_tokens = swa_metadata.num_decode_tokens - - if num_prefills > 0: - self._forward_prefill( - q=q[num_decode_tokens:], - positions=positions[num_decode_tokens:], - compressed_k_cache=self_kv_cache, - swa_k_cache=swa_kv_cache, - output=output[num_decode_tokens:], - attn_metadata=flashmla_metadata, - swa_metadata=swa_metadata, - ) - if num_decodes > 0: - self._forward_decode( - q=q[:num_decode_tokens], - kv_cache=self_kv_cache, - swa_metadata=swa_metadata, - attn_metadata=flashmla_metadata, - swa_only=swa_only, - output=output[:num_decode_tokens], - ) - - def _forward_decode( - self, - q: torch.Tensor, - kv_cache: torch.Tensor | None, # Only used when compress_ratio > 1 - swa_metadata: "DeepseekSparseSWAMetadata", - attn_metadata: FlashMLASparseMetadata | None, - swa_only: bool, - output: torch.Tensor, - ) -> None: - num_decodes = swa_metadata.num_decodes - num_decode_tokens = swa_metadata.num_decode_tokens - - topk_indices = None - topk_lens = None - if not swa_only: - assert attn_metadata is not None - assert swa_metadata.is_valid_token is not None - block_size = attn_metadata.block_size // self.compress_ratio - is_valid = swa_metadata.is_valid_token[:num_decode_tokens] - if self.compress_ratio == 4: - # C4A: local indices differ per layer (filled by Indexer). - assert self.topk_indices_buffer is not None - global_indices, topk_lens = compute_global_topk_indices_and_lens( - self.topk_indices_buffer[:num_decode_tokens], - swa_metadata.token_to_req_indices, - attn_metadata.block_table[:num_decodes], - block_size, - is_valid, - ) - topk_indices = global_indices.view(num_decode_tokens, 1, -1) - else: - # C128A: pre-computed during metadata build. - topk_indices = attn_metadata.c128a_global_decode_topk_indices - topk_lens = attn_metadata.c128a_decode_topk_lens - - swa_indices = swa_metadata.decode_swa_indices - swa_lens = swa_metadata.decode_swa_lens - - # We treat queries in the same seq as different queries - # and later we only attend by generated indices. - # q arrives pre-padded to self.padded_heads by the outer wrapper. - q = q.unsqueeze(1) - - # Prepare SWA cache (num_blocks, swa_block_size, 1, head_bytes) - # Use unsqueeze to preserve strides (handles padded blocks correctly) - swa_cache = self.swa_cache_layer.kv_cache.unsqueeze(-2) - # Reshape KV cache to (num_blocks, block_size, 1, head_bytes) - if kv_cache is not None: - kv_cache = kv_cache.unsqueeze(-2) - - # One FlashMLASchedMeta per layer type, shared across all same-type - # layers within this decode step. The first forward call per type - # triggers the in-kernel planner (allocating tile_scheduler_metadata - # and num_splits via PyTorch's graph-aware allocator so CUDA graph - # capture reuses the same addresses on replay); subsequent same-type - # layers see have_initialized=True and skip the planner. - if self.compress_ratio <= 1: - tile_metadata = swa_metadata.tile_sched_swaonly - elif self.compress_ratio == 4: - tile_metadata = swa_metadata.tile_sched_c4a - elif self.compress_ratio == 128: - tile_metadata = swa_metadata.tile_sched_c128a - else: - raise ValueError( - f"Unsupported compress_ratio={self.compress_ratio}; " - "expected 1, 4, or 128." - ) - assert tile_metadata is not None, ( - "swa_metadata missing tile_sched entry for " - f"compress_ratio={self.compress_ratio}; " - "DeepseekSparseSWAMetadataBuilder.build_tile_scheduler did not " - "allocate one for this layer type." - ) - - out, _ = flash_mla_with_kvcache( - q=q, - k_cache=swa_cache, - block_table=None, - head_dim_v=512, - tile_scheduler_metadata=tile_metadata, - cache_seqlens=None, - is_fp8_kvcache=True, - indices=swa_indices, - topk_length=swa_lens, - softmax_scale=self.scale, - attn_sink=self.attn_sink, - extra_k_cache=kv_cache if not swa_only else None, - extra_indices_in_kvcache=topk_indices, - extra_topk_length=topk_lens, - out=output.unsqueeze(1), - ) - - def _forward_prefill( - self, - q: torch.Tensor, - positions: torch.Tensor, - compressed_k_cache: torch.Tensor | None, # Only used when compress_ratio > 1 - swa_k_cache: torch.Tensor, - output: torch.Tensor, - attn_metadata: FlashMLASparseMetadata | None, - swa_metadata: "DeepseekSparseSWAMetadata", - ) -> None: - swa_only = attn_metadata is None - - num_prefills = swa_metadata.num_prefills - num_prefill_tokens = swa_metadata.num_prefill_tokens - num_decodes = swa_metadata.num_decodes - num_decode_tokens = swa_metadata.num_decode_tokens - - # Use pre-computed prefill metadata. - seq_lens = swa_metadata.prefill_seq_lens - gather_lens = swa_metadata.prefill_gather_lens - assert seq_lens is not None - assert gather_lens is not None - - # Derive prefill-local token offsets from the full query_start_loc_cpu. - query_start_loc_cpu = swa_metadata.query_start_loc_cpu - query_start_loc = swa_metadata.query_start_loc - assert query_start_loc_cpu is not None - assert query_start_loc is not None - prefill_token_base = query_start_loc_cpu[num_decodes] - - if not swa_only: - if self.compress_ratio == 4: - assert self.topk_indices_buffer is not None - topk_indices = self.topk_indices_buffer[num_decode_tokens:] - topk_indices = topk_indices[:num_prefill_tokens] - else: - # C128A: pre-computed during metadata build. - assert attn_metadata is not None - topk_indices = attn_metadata.c128a_prefill_topk_indices - top_k = topk_indices.shape[-1] - # Compressed region must fit the full compressed pool (seq_len // - # compress_ratio), not just top_k. top_k bounds how many indices - # the indexer selects, not the pool size it indexes into. - N = (self.max_model_len + self.compress_ratio - 1) // self.compress_ratio - else: - # NOTE(woosuk): topk_indices will not be used for SWA-only layers. - assert self.topk_indices_buffer is not None - topk_indices = self.topk_indices_buffer[num_decode_tokens:] - top_k = 0 - N = 0 - - M = N + self.window_size + self.max_num_batched_tokens - num_chunks = (num_prefills + PREFILL_CHUNK_SIZE - 1) // PREFILL_CHUNK_SIZE - - workspace_manager = current_workspace_manager() - kv = workspace_manager.get_simultaneous( - ((PREFILL_CHUNK_SIZE, M, q.shape[-1]), torch.bfloat16), - )[0] - for chunk_idx in range(num_chunks): - chunk_start = chunk_idx * PREFILL_CHUNK_SIZE - chunk_end = min(chunk_start + PREFILL_CHUNK_SIZE, num_prefills) - chunk_size = chunk_end - chunk_start - if not swa_only: - # Gather compressed KV - assert attn_metadata is not None - block_table = attn_metadata.block_table[num_decodes:] - dequantize_and_gather_k_cache( - kv[:chunk_size], - compressed_k_cache, - seq_lens=seq_lens[chunk_start:chunk_end] // self.compress_ratio, - gather_lens=None, - block_table=block_table[chunk_start:chunk_end], - block_size=attn_metadata.block_size // self.compress_ratio, - offset=0, - ) - - # Gather SWA KV - swa_block_table = swa_metadata.block_table[num_decodes:] - dequantize_and_gather_k_cache( - kv[:chunk_size], - swa_k_cache, - seq_lens=seq_lens[chunk_start:chunk_end], - gather_lens=gather_lens[chunk_start:chunk_end], - block_table=swa_block_table[chunk_start:chunk_end], - block_size=swa_metadata.block_size, - offset=N, - ) - - # Combine the topk indices and SWA indices for gathered KV cache - query_start = ( - query_start_loc_cpu[num_decodes + chunk_start] - prefill_token_base - ) - query_end = ( - query_start_loc_cpu[num_decodes + chunk_end] - prefill_token_base - ) - - combined_indices, combined_lens = combine_topk_swa_indices( - topk_indices[query_start:query_end], - query_start_loc[ - num_decodes + chunk_start : num_decodes + chunk_end + 1 - ], - seq_lens[chunk_start:chunk_end], - gather_lens[chunk_start:chunk_end], - self.window_size, - self.compress_ratio, - top_k, - M, - N, - ) - flash_mla_sparse_fwd( - q=q[query_start:query_end], - kv=kv.view(-1, 1, q.shape[-1]), - indices=combined_indices.unsqueeze(1), - sm_scale=self.scale, - attn_sink=self.attn_sink, - topk_length=combined_lens, - out=output[query_start:query_end], - ) - - -class DeepseekV4IndexerCache(torch.nn.Module, AttentionLayerBase): - def __init__( - self, - head_dim: int, - dtype: torch.dtype, - prefix: str, - cache_config: CacheConfig, - compress_ratio: int = 1, - ): - super().__init__() - self.kv_cache = torch.tensor([]) - self.head_dim = head_dim - self.prefix = prefix - self.cache_config = cache_config - self.dtype = dtype - self.compress_ratio = compress_ratio - compilation_config = get_current_vllm_config().compilation_config - if prefix in compilation_config.static_forward_context: - raise ValueError(f"Duplicate layer name: {prefix}") - compilation_config.static_forward_context[prefix] = self - - def get_kv_cache_spec(self, vllm_config: VllmConfig) -> KVCacheSpec: - # head_dim already carries the fp8 scale padding - # compress_ratio=1 for V3.2, >1 for DeepseekV4; both use the same cache layout. - return MLAAttentionSpec( - block_size=self.cache_config.block_size, - num_kv_heads=1, - head_size=self.head_dim, - dtype=self.dtype, - compress_ratio=self.compress_ratio, - # DeepseekV4 aligns indexer pages to FlashMLA's 576B so they can pack with - # the indexer's compressor state cache. V3.2 keeps the legacy layout. - alignment=576, - ) - - def forward(self): ... - - def get_attn_backend(self) -> type[AttentionBackend]: - return DeepseekV4IndexerBackend - - -class DeepseekV4Indexer(nn.Module): - def __init__( - self, - vllm_config: VllmConfig, - config: DeepseekV2Config | DeepseekV3Config, - hidden_size: int, - q_lora_rank: int, - quant_config: QuantizationConfig | None, - cache_config: CacheConfig | None, - topk_indices_buffer: torch.Tensor | None, - compress_ratio: int = 1, - prefix: str = "", - ): - super().__init__() - self.vllm_config = vllm_config - self.config = config - self.quant_config = quant_config - # self.indexer_cfg = config.attn_module_list_cfg[0]["attn_index"] - self.topk_tokens = config.index_topk - self.n_head = config.index_n_heads # 64 - self.head_dim = config.index_head_dim # 128 - self.rope_dim = config.qk_rope_head_dim # 64 - self.q_lora_rank = q_lora_rank # 1536 - self.compress_ratio = compress_ratio - self.use_fp4_kv = self.vllm_config.attention_config.use_fp4_indexer_cache - logger.info_once( - "Using %s indexer cache for Lightning Indexer.", - "MXFP4" if self.use_fp4_kv else "FP8", - ) - - # no tensor parallel, just replicated - self.wq_b = ReplicatedLinear( - self.q_lora_rank, - self.head_dim * self.n_head, - bias=False, - quant_config=quant_config, - prefix=f"{prefix}.wq_b", - ) - self.weights_proj = ReplicatedLinear( - hidden_size, - self.n_head, - bias=False, - quant_config=quant_config, - prefix=f"{prefix}.weights_proj", - ) - self.k_norm = LayerNorm(self.head_dim, eps=1e-6) - self.softmax_scale = self.head_dim**-0.5 - - self.scale_fmt = "ue8m0" - self.quant_block_size = 128 # TODO: get from config - self.topk_indices_buffer = topk_indices_buffer - - self.max_model_len = ( - vllm_config.model_config.max_model_len // self.compress_ratio - ) - self.prefix = prefix - - self.max_total_seq_len = ( - get_max_prefill_buffer_size(vllm_config) // self.compress_ratio - ) - - assert cache_config is not None, "Deepseek V4 indexer requires cache_config" - # NOTE(yifan): FP8 indxer cache use the same layout as V3.2: - # head_dim bytes = 128 fp8 + 4 fp32 scale = 132. - # For FP4 indexer cache, we still allocate the same amount of memory as FP8, - # but only use the first half of the memory. - k_cache_head_dim = self.head_dim + self.head_dim // self.quant_block_size * 4 - self.k_cache = DeepseekV4IndexerCache( - head_dim=k_cache_head_dim, - dtype=torch.uint8, - prefix=f"{prefix}.k_cache", - cache_config=cache_config, - compress_ratio=self.compress_ratio, - ) - self.compressor = DeepseekCompressor( - vllm_config=vllm_config, - compress_ratio=self.compress_ratio, - hidden_size=hidden_size, - head_dim=self.head_dim, - rotate=True, - prefix=f"{prefix}.compressor", - k_cache_prefix=self.k_cache.prefix, - use_fp4_cache=self.use_fp4_kv, - ) - - self.indexer_op = SparseAttnIndexer( - self.k_cache, - self.quant_block_size, - self.scale_fmt, - self.topk_tokens, - self.head_dim, - self.max_model_len, - self.max_total_seq_len, - self.topk_indices_buffer, - skip_k_cache_insert=True, - use_fp4_cache=self.use_fp4_kv, - ) - - def forward( - self, - hidden_states: torch.Tensor, - qr: torch.Tensor, - compressed_kv_score: torch.Tensor, - indexer_weights: torch.Tensor, - positions: torch.Tensor, - rotary_emb: nn.Module, - ) -> torch.Tensor: - # ReplicatedLinear returns (output, bias); bias is None. - q, _ = self.wq_b(qr) - q = q.view(-1, self.n_head, self.head_dim) - k = self.compressor(compressed_kv_score, positions, rotary_emb) - q_quant, weights = fused_indexer_q_rope_quant( - positions, - q, - rotary_emb.cos_sin_cache, - indexer_weights, - self.softmax_scale, - self.n_head**-0.5, - use_fp4=self.use_fp4_kv, - ) - return self.indexer_op(hidden_states, q_quant, k, weights) diff --git a/vllm/patches/fused_moe/experts/cutedsl_moe.py b/vllm/patches/fused_moe/experts/cutedsl_moe.py deleted file mode 100644 index a2b07ad1..00000000 --- a/vllm/patches/fused_moe/experts/cutedsl_moe.py +++ /dev/null @@ -1,308 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -"""CuTeDSL NVFP4 MoE experts for DeepSeek-V4. - -Integrates the CuTeDSL NVFP4 grouped GEMM kernel into vLLM's FusedMoE -modular kernel framework. This is the proper integration path — no -monkey-patching, no post-load surgery. - -The CuTeDSL kernel is a Python-based CUTLASS kernel compiled via MLIR → PTX. -It handles: - - L1 GEMM (gate + up projections) - - SiLU activation with optional swiglu_limit clamping - - L2 GEMM (down projection) - - All with NVFP4 (float8_e4m3fn block scales + float32 global scales) - -CUDA-graph-safe: all intermediate buffers pre-allocated, no CPU-GPU syncs, -no dynamic shapes. -""" - -import torch - -import vllm.model_executor.layers.fused_moe.modular_kernel as mk -from vllm.model_executor.layers.fused_moe.activation import MoEActivation -from vllm.model_executor.layers.fused_moe.config import ( - FusedMoEConfig, - FusedMoEParallelConfig, - FusedMoEQuantConfig, -) -from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import ( - TopKWeightAndReduceNoOP, -) -from vllm.model_executor.layers.quantization.utils.quant_utils import ( - QuantKey, - kNvfp4Dynamic, - kNvfp4Static, -) -from vllm.platforms import current_platform - -from cutedsl.runner import CuTeDSLMoERunner - - -class CuTeDSLMoEExperts(mk.FusedMoEExpertsModular): - """CuTeDSL NVFP4 MoE experts using the custom CuTeDSL grouped GEMM kernel. - - Uses Standard activation format (non-batched). Handles input quantization - internally (expects_unquantized_inputs=True). - - Supports expert parallelism: remaps global→local expert IDs. - """ - - def __init__( - self, - moe_config: FusedMoEConfig, - quant_config: FusedMoEQuantConfig, - ): - super().__init__( - moe_config=moe_config, - quant_config=quant_config, - ) - assert quant_config.quant_dtype == "nvfp4", ( - "CuTeDSL MoE only supports nvfp4 quantization, " - f"got {quant_config.quant_dtype}" - ) - self.out_dtype = moe_config.in_dtype - self.hidden_dim = moe_config.hidden_dim - self.intermediate_size_per_partition = ( - moe_config.intermediate_size_per_partition - ) - self.topk = moe_config.experts_per_token - self.local_num_experts = moe_config.num_local_experts - self.global_num_experts = moe_config.num_experts - self.ep_rank = moe_config.moe_parallel_config.ep_rank - self.local_expert_offset = self.ep_rank * self.local_num_experts - # max_num_tokens from scheduler config (for buffer pre-allocation) - self.max_num_tokens = getattr(moe_config, 'max_num_tokens', 8192) - - # swiglu_limit: read from the FusedMoE layer in process_weights_after_loading - self._swiglu_limit = None - - # Runner is created in process_weights_after_loading - self._runner: CuTeDSLMoERunner | None = None - - def process_weights_after_loading(self, layer: torch.nn.Module) -> None: - """Convert NVFP4 MoE weights into CuTeDSL kernel format. - - Reads w13/w2 weight tensors from the FusedMoE layer, converts them - to the CuTeDSL runner's expected format, and creates the runner. - Also folds the activation global scale (input_scale) into the - weight global scale (weight_scale_2) as the runner's alpha. - """ - num_experts = layer.w13_weight.shape[0] - hidden_size = self.hidden_dim - intermediate_size = self.intermediate_size_per_partition - device = layer.w13_weight.device - - # NOTE: For the CuTeDSL kernel, we do NOT fold input_scale into - # weight_scale_2. The CuTeDSL runner uses weight global scale - # (weight_scale_2) and activation global scale separately. - # The activation global scale is computed via warmup before first inference. - # - # Also, convert_to_nvfp4_moe_kernel_format already inverted input_scale - # (1.0 / a13_scale) for the quant config. We undo that inversion here - # to get the original input_scale, then use it as initial activation gs. - if layer.w13_input_scale is not None and not isinstance(layer.w13_input_scale, float): - # input_scale was inverted in convert_to_nvfp4_moe_kernel_format - # Original: input_scale = amax / (6.0 * 448.0) - # Inverted: 1.0 / input_scale = 6.0 * 448.0 / amax - # We need the original for activation gs - w13_input_scale_orig = 1.0 / layer.w13_input_scale - else: - w13_input_scale_orig = None - if layer.w2_input_scale is not None and not isinstance(layer.w2_input_scale, float): - w2_input_scale_orig = 1.0 / layer.w2_input_scale - else: - w2_input_scale_orig = None - - # Extract weights from the layer — checkpoint format, no copies yet. - # w13_weight: (E, 2*intermediate, hidden//2) uint8 — gate + up fused - # w2_weight: (E, hidden, intermediate//2) uint8 — down - # w13_weight_scale: (E, 2*intermediate, hidden//16) fp8 - # w2_weight_scale: (E, hidden, intermediate//16) fp8 - w13_uint8 = layer.w13_weight.data # (E, 2*inter, hidden//2) - w2_uint8 = layer.w2_weight.data # (E, hidden, intermediate//2) - w13_sf = layer.w13_weight_scale.data # (E, 2*inter, hidden//16) = (E, N, K_sf) - w2_sf = layer.w2_weight_scale.data # (E, hidden, intermediate//16) = (E, N, K_sf) - w13_gs = layer.w13_weight_scale_2.data # (E,) or (E, 2) - w2_gs = layer.w2_weight_scale_2.data # (E,) or (E, 2) - - # View as fp4 — byte-preserving, NO copy - l1_fp4 = w13_uint8.view(torch.float4_e2m1fn_x2) # (E, N, K_packed) - l2_fp4 = w2_uint8.view(torch.float4_e2m1fn_x2) # (E, N, K_packed) - - # Ensure scales are float8_e4m3fn (no copy if already correct dtype) - if w13_sf.dtype != torch.float8_e4m3fn: - w13_sf = w13_sf.to(torch.float8_e4m3fn) - if w2_sf.dtype != torch.float8_e4m3fn: - w2_sf = w2_sf.to(torch.float8_e4m3fn) - - # Global scales - l1_gs_list = w13_gs.tolist() - l2_gs_list = w2_gs.tolist() - - # Free original weight tensors IMMEDIATELY. - # We have views into the same memory (l1_fp4, l2_fp4), but the runner - # will create its own copies in _ensure_stacked. Free the layer refs - # now so the memory can be reclaimed when the views are no longer held. - # NOTE: The modular kernel framework reads w1.shape[0] in its outer - # apply() before delegating to our expert impl, so we can't set the - # weights to None. Instead, replace with a shape-preserving dummy on CPU - # to free GPU memory while keeping the shape metadata accessible. - # Free the large weight tensors — they're now in the runner. - # Keep the scale tensors (small) because the framework's warmup - # and quant config construction needs them. - layer.w13_weight = torch.nn.Parameter(torch.empty( - num_experts, 2 * intermediate_size, hidden_size // 2, - device='cpu', dtype=torch.uint8), requires_grad=False) - layer.w2_weight = torch.nn.Parameter(torch.empty( - num_experts, hidden_size, intermediate_size // 2, - device='cpu', dtype=torch.uint8), requires_grad=False) - - # Create the CuTeDSL runner - self._runner = CuTeDSLMoERunner( - num_experts=num_experts, - hidden_size=hidden_size, - intermediate_size=intermediate_size, - max_num_tokens=self.max_num_tokens, - top_k=self.topk, - device=str(device), - experts_start_idx=self.local_expert_offset, - ) - # Pass stacked tensors in checkpoint format (E, N, K) — no copies needed - self._runner.prepare_weights_from_stacked( - l1_fp4, w13_sf, l1_gs_list, - l2_fp4, w2_sf, l2_gs_list, - ) - if self._swiglu_limit is not None: - self._runner.set_swiglu_limit(float(self._swiglu_limit)) - - # Read swiglu_limit from the FusedMoE layer (set by DeepseekV4MoE) - swiglu_limit = getattr(layer, 'swiglu_limit', None) - if swiglu_limit is not None: - self._swiglu_limit = swiglu_limit - self._runner.set_swiglu_limit(float(swiglu_limit)) - - # Set initial activation global scales from checkpoint input_scale. - # After undoing the inversion from convert_to_nvfp4_moe_kernel_format, - # w13_input_scale_orig = amax / (6.0 * 448.0), which IS the activation - # global scale that quantize_activation_nvfp4 expects. - # The warmup step (compute_activation_global_scales) will override - # this with an empirically computed value before the first inference. - if w13_input_scale_orig is not None: - # w13_input_scale_orig = amax / (6.0 * 448.0) = activation gs - # Mean across experts (they should be similar) - mean_l1_gs = float(w13_input_scale_orig.mean().item()) - if mean_l1_gs > 0: - self._runner._l1_activation_global_scale = mean_l1_gs - if w2_input_scale_orig is not None: - mean_l2_gs = float(w2_input_scale_orig.mean().item()) - if mean_l2_gs > 0: - self._runner._l2_activation_global_scale = mean_l2_gs - - # Note: activation global scale warmup must be done after - # process_weights_after_loading, before the first inference. - # This is handled by the model's load_weights or a separate warmup step. - - @property - def runner(self) -> CuTeDSLMoERunner | None: - return self._runner - - @staticmethod - def activation_format() -> mk.FusedMoEActivationFormat: - return mk.FusedMoEActivationFormat.Standard - - @staticmethod - def _supports_current_device() -> bool: - # CuTeDSL requires CUDA SM100 (Blackwell) - p = current_platform - return p.is_cuda() and p.is_device_capability_family(100) - - @staticmethod - def _supports_no_act_and_mul() -> bool: - return False - - @staticmethod - def _supports_quant_scheme( - weight_key: QuantKey | None, - activation_key: QuantKey | None, - ) -> bool: - SUPPORTED_W_A = [ - (kNvfp4Static, kNvfp4Dynamic), - ] - return (weight_key, activation_key) in SUPPORTED_W_A - - @staticmethod - def _supports_activation(activation: MoEActivation) -> bool: - # We handle SiLU + swiglu_limit internally - return activation == MoEActivation.SILU - - @staticmethod - def _supports_parallel_config( - moe_parallel_config: FusedMoEParallelConfig, - ) -> bool: - return True - - def supports_expert_map(self) -> bool: - return False - - @property - def expects_unquantized_inputs(self) -> bool: - # Our runner handles activation quantization internally - return True - - def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce: - return TopKWeightAndReduceNoOP() - - def workspace_shapes( - self, - M: int, - N: int, - K: int, - topk: int, - global_num_experts: int, - local_num_experts: int, - expert_tokens_meta: mk.ExpertTokensMetadata | None, - activation: MoEActivation, - ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]: - # Our runner manages its own workspace internally (pre-allocated buffers) - workspace1 = (0,) - workspace2 = (0,) - # The output of the full 2-stage MoE pipeline is hidden_dim. - # K comes from hidden_states.size(-1) (full BF16 dimension, not packed). - output = (M, self.hidden_dim) - return (workspace1, workspace2, output) - - def apply( - self, - output: torch.Tensor, - hidden_states: torch.Tensor, - w1: torch.Tensor, - w2: torch.Tensor, - topk_weights: torch.Tensor, - topk_ids: torch.Tensor, - activation: MoEActivation, - global_num_experts: int, - expert_map: torch.Tensor | None, - a1q_scale: torch.Tensor | None, - a2_scale: torch.Tensor | None, - workspace13: torch.Tensor | None, - workspace2: torch.Tensor | None, - expert_tokens_meta: mk.ExpertTokensMetadata | None, - apply_router_weight_on_input: bool | None, - ): - assert self._runner is not None, ( - "CuTeDSL runner not initialized. " - "Call process_weights_after_loading first." - ) - - # Our runner expects topk_ids as global expert IDs. - # The modular kernel framework may pass local IDs with expert_map. - # We handle remapping internally via experts_start_idx. - result = self._runner.run( - hidden_states=hidden_states, - topk_weights=topk_weights, - topk_ids=topk_ids, - ) - - # Copy result into output tensor - output.copy_(result) diff --git a/vllm/patches/fused_moe/oracle/nvfp4.py b/vllm/patches/fused_moe/oracle/nvfp4.py deleted file mode 100644 index e29fc746..00000000 --- a/vllm/patches/fused_moe/oracle/nvfp4.py +++ /dev/null @@ -1,535 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from enum import Enum - -import torch - -import vllm.envs as envs -import vllm.model_executor.layers.fused_moe.modular_kernel as mk -from vllm.config.kernel import MoEBackend -from vllm.logger import init_logger -from vllm.model_executor.layers.fused_moe.all2all_utils import ( - maybe_make_prepare_finalize, -) -from vllm.model_executor.layers.fused_moe.config import ( - FusedMoEConfig, - FusedMoEQuantConfig, - nvfp4_moe_quant_config, - nvfp4_w4a16_moe_quant_config, -) -from vllm.model_executor.layers.quantization.utils.flashinfer_fp4_moe import ( - prepare_nvfp4_moe_layer_for_fi_or_cutlass, - prepare_nvfp4_moe_layer_for_flashinfer_cutedsl, -) -from vllm.model_executor.layers.quantization.utils.flashinfer_utils import ( - FlashinferMoeBackend, - get_flashinfer_moe_backend, -) -from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import ( - prepare_nvfp4_moe_layer_for_marlin, -) -from vllm.model_executor.layers.quantization.utils.nvfp4_emulation_utils import ( - kE2M1ToFloat_handle, -) -from vllm.model_executor.layers.quantization.utils.quant_utils import ( - QuantKey, -) - -logger = init_logger(__name__) - - -class NvFp4MoeBackend(Enum): - FLASHINFER_TRTLLM = "FLASHINFER_TRTLLM" - FLASHINFER_CUTLASS = "FLASHINFER_CUTLASS" - FLASHINFER_CUTEDSL = "FLASHINFER_CUTEDSL" - FLASHINFER_CUTEDSL_BATCHED = "FLASHINFER_CUTEDSL_BATCHED" - VLLM_CUTLASS = "VLLM_CUTLASS" - MARLIN = "MARLIN" - CUTEDSL = "CUTEDSL" - EMULATION = "EMULATION" - - -FLASHINFER_NVFP4_MOE_BACKENDS = [ - NvFp4MoeBackend.FLASHINFER_TRTLLM, - NvFp4MoeBackend.FLASHINFER_CUTLASS, - NvFp4MoeBackend.FLASHINFER_CUTEDSL, - NvFp4MoeBackend.FLASHINFER_CUTEDSL_BATCHED, -] - -CUTEDSL_NVFP4_MOE_BACKENDS = [ - NvFp4MoeBackend.CUTEDSL, -] - -fi_2_vllm_backend_map: dict[FlashinferMoeBackend, NvFp4MoeBackend] = { - FlashinferMoeBackend.CUTLASS: NvFp4MoeBackend.FLASHINFER_CUTLASS, - FlashinferMoeBackend.TENSORRT_LLM: NvFp4MoeBackend.FLASHINFER_TRTLLM, - FlashinferMoeBackend.CUTEDSL: NvFp4MoeBackend.FLASHINFER_CUTEDSL, -} - - -def is_global_sf_supported_for_nvfp4_backend(backend: NvFp4MoeBackend) -> bool: - # Checks whether `backend` supports quantizing with scaling factors - # of all experts in Expert Parallel Mode when all experts are not - # on the same rank. - - return backend in FLASHINFER_NVFP4_MOE_BACKENDS or backend in CUTEDSL_NVFP4_MOE_BACKENDS - - -def backend_to_kernel_cls( - backend: NvFp4MoeBackend, -) -> list[type[mk.FusedMoEExperts]]: - if backend == NvFp4MoeBackend.FLASHINFER_TRTLLM: - from vllm.model_executor.layers.fused_moe.experts.trtllm_nvfp4_moe import ( - TrtLlmNvFp4ExpertsModular, - TrtLlmNvFp4ExpertsMonolithic, - ) - - # NOTE: prefer Monolthic > Modular, so return Monolithic first. - return [ - TrtLlmNvFp4ExpertsMonolithic, - TrtLlmNvFp4ExpertsModular, - ] - - elif backend == NvFp4MoeBackend.FLASHINFER_CUTLASS: - from vllm.model_executor.layers.fused_moe.experts.flashinfer_cutlass_moe import ( # noqa: E501 - FlashInferExperts, - ) - - return [FlashInferExperts] - - elif backend == NvFp4MoeBackend.FLASHINFER_CUTEDSL: - from vllm.model_executor.layers.fused_moe.experts.flashinfer_cutedsl_moe import ( # noqa: E501 - FlashInferCuteDSLExperts, - ) - - return [FlashInferCuteDSLExperts] - - elif backend == NvFp4MoeBackend.FLASHINFER_CUTEDSL_BATCHED: - from vllm.model_executor.layers.fused_moe.experts.flashinfer_cutedsl_batched_moe import ( # noqa: E501 - FlashInferCuteDSLBatchedExperts, - ) - - return [FlashInferCuteDSLBatchedExperts] - - elif backend == NvFp4MoeBackend.CUTEDSL: - from vllm.model_executor.layers.fused_moe.experts.cutedsl_moe import ( # noqa: E501 - CuTeDSLMoEExperts, - ) - - return [CuTeDSLMoEExperts] - - elif backend == NvFp4MoeBackend.VLLM_CUTLASS: - from vllm.model_executor.layers.fused_moe.experts.cutlass_moe import ( - CutlassExpertsFp4, - ) - - return [CutlassExpertsFp4] - - elif backend == NvFp4MoeBackend.MARLIN: - from vllm.model_executor.layers.fused_moe.experts.marlin_moe import ( - MarlinExperts, - ) - - return [MarlinExperts] - elif backend == NvFp4MoeBackend.EMULATION: - from vllm.model_executor.layers.fused_moe.experts.nvfp4_emulation_moe import ( - Nvfp4QuantizationEmulationTritonExperts, - ) - - return [Nvfp4QuantizationEmulationTritonExperts] - else: - raise ValueError(f"Unknown NvFP4 MoE backend: {backend.value}") - - -def map_nvfp4_backend(runner_backend: MoEBackend) -> NvFp4MoeBackend: - """Map user's MoEBackend to NvFp4MoeBackend.""" - mapping = { - "cutlass": NvFp4MoeBackend.VLLM_CUTLASS, - "flashinfer_trtllm": NvFp4MoeBackend.FLASHINFER_TRTLLM, - "flashinfer_cutlass": NvFp4MoeBackend.FLASHINFER_CUTLASS, - "flashinfer_cutedsl": NvFp4MoeBackend.FLASHINFER_CUTEDSL, - "cutedsl": NvFp4MoeBackend.CUTEDSL, - "marlin": NvFp4MoeBackend.MARLIN, - "emulation": NvFp4MoeBackend.EMULATION, - } - if backend := mapping.get(runner_backend): - return backend - raise ValueError( - f"moe_backend='{runner_backend}' is not supported for NvFP4 MoE. " - f"Expected one of {list(mapping.keys())}." - ) - - -def select_nvfp4_moe_backend( - config: FusedMoEConfig, - weight_key: QuantKey | None, - activation_key: QuantKey | None, -) -> tuple[NvFp4MoeBackend, type[mk.FusedMoEExperts]]: - """ - Select the primary NvFP4 MoE backend - Note: Shape-specific fallbacks may still occur at runtime. - """ - - # NOTE: the kernels are selected in the following order. - AVAILABLE_BACKENDS = [ - NvFp4MoeBackend.FLASHINFER_TRTLLM, - NvFp4MoeBackend.FLASHINFER_CUTEDSL, - NvFp4MoeBackend.FLASHINFER_CUTEDSL_BATCHED, - NvFp4MoeBackend.CUTEDSL, - NvFp4MoeBackend.FLASHINFER_CUTLASS, - NvFp4MoeBackend.VLLM_CUTLASS, - NvFp4MoeBackend.MARLIN, - NvFp4MoeBackend.EMULATION, - ] - - use_batched = config.moe_parallel_config.use_batched_activation_format - activation_format = ( - mk.FusedMoEActivationFormat.BatchedExperts - if use_batched - else mk.FusedMoEActivationFormat.Standard - ) - - def _make_log_backend(backend: NvFp4MoeBackend): - available_backend_strs = [b.value for b in AVAILABLE_BACKENDS] - return ( - f"Using '{backend.value}' NvFp4 MoE backend out " - f"of potential backends: {available_backend_strs}." - ) - - def _make_log_unsupported(backend: NvFp4MoeBackend, reason: str | None) -> str: - if reason: - return ( - f"NvFp4 MoE backend '{backend.value}' does not support the " - f"deployment configuration since {reason}." - ) - else: - return ( - f"NvFp4 MoE backend '{backend.value}' does not support the " - "deployment configuration." - ) - - def _return_or_raise( - backend: NvFp4MoeBackend, - config: FusedMoEConfig, - weight_key: QuantKey | None, - activation_key: QuantKey | None, - activation_format: mk.FusedMoEActivationFormat, - ) -> tuple[NvFp4MoeBackend, type[mk.FusedMoEExperts]]: - for k_cls in backend_to_kernel_cls(backend): - supported, reason = k_cls.is_supported_config( - k_cls, config, weight_key, activation_key, activation_format - ) - if supported: - logger.info_once(_make_log_backend(backend)) - return backend, k_cls - - raise ValueError(_make_log_unsupported(backend, reason)) - - # Handle explicit moe_backend from user. - runner_backend = config.moe_backend - if runner_backend != "auto": - requested_backend = map_nvfp4_backend(runner_backend) - # For batched activation format, use batched variant if available. - if ( - activation_format == mk.FusedMoEActivationFormat.BatchedExperts - and requested_backend == NvFp4MoeBackend.FLASHINFER_CUTEDSL - ): - requested_backend = NvFp4MoeBackend.FLASHINFER_CUTEDSL_BATCHED - return _return_or_raise( - requested_backend, config, weight_key, activation_key, activation_format - ) - - if envs.is_set("VLLM_USE_FLASHINFER_MOE_FP4"): - if not envs.VLLM_USE_FLASHINFER_MOE_FP4: - # If the user rejects FlashInfer remove those backends. - for b in FLASHINFER_NVFP4_MOE_BACKENDS: - AVAILABLE_BACKENDS.remove(b) - - elif envs.is_set("VLLM_FLASHINFER_MOE_BACKEND"): - # If user is explicit about backend, validate it. - backend = fi_2_vllm_backend_map[get_flashinfer_moe_backend()] - return _return_or_raise( - backend, config, weight_key, activation_key, activation_format - ) - else: - # If the user is not explicit about the backend, try each. - for backend in FLASHINFER_NVFP4_MOE_BACKENDS: - for k_cls in backend_to_kernel_cls(backend): - supported, reason = k_cls.is_supported_config( - k_cls, - config, - weight_key, - activation_key, - activation_format, - ) - if supported: - logger.info_once(_make_log_backend(backend)) - return backend, k_cls - else: - logger.debug_once(_make_log_unsupported(backend, reason)) - - raise NotImplementedError( - "Found VLLM_USE_FLASHINFER_MOE_FP4=1, but no " - "FlashInfer NVFP4 MoE backend supports the configuration." - ) - - if envs.VLLM_TEST_FORCE_FP8_MARLIN: - backend = NvFp4MoeBackend.MARLIN - return _return_or_raise( - backend, config, weight_key, activation_key, activation_format - ) - - # Select kernels in order of backend. - for backend in AVAILABLE_BACKENDS: - for k_cls in backend_to_kernel_cls(backend): - supported, reason = k_cls.is_supported_config( - k_cls, - config, - weight_key, - activation_key, - activation_format, - ) - if supported: - logger.info_once(_make_log_backend(backend)) - return backend, k_cls - else: - logger.debug_once(_make_log_unsupported(backend, reason)) - - raise NotImplementedError( - "No NvFp4 MoE backend supports the deployment configuration." - ) - - -def convert_to_nvfp4_moe_kernel_format( - nvfp4_backend: NvFp4MoeBackend, - layer: torch.nn.Module, - w13: torch.Tensor, - w13_scale: torch.Tensor, - w13_scale_2: torch.Tensor, - a13_scale: torch.Tensor | None, - w2: torch.Tensor, - w2_scale: torch.Tensor, - w2_scale_2: torch.Tensor, - a2_scale: torch.Tensor | None, - is_act_and_mul: bool, -) -> tuple[ - torch.Tensor, - torch.Tensor, - torch.Tensor, - torch.Tensor, - torch.Tensor, - torch.Tensor, - torch.Tensor, - torch.Tensor, -]: - if nvfp4_backend == NvFp4MoeBackend.CUTEDSL: - # CuTeDSL kernel handles weight conversion in its own - # process_weights_after_loading. Pass through raw weights. - # Compute inverse activation scales for the quant config. - if a13_scale is not None: - a13_scale = 1.0 / a13_scale - if a2_scale is not None: - a2_scale = 1.0 / a2_scale - elif nvfp4_backend == NvFp4MoeBackend.FLASHINFER_CUTEDSL: - ( - w13, - w13_scale, - w13_scale_2, - a13_scale, - w2, - w2_scale, - w2_scale_2, - a2_scale, - ) = prepare_nvfp4_moe_layer_for_flashinfer_cutedsl( - layer=layer, - w13=w13, - w13_scale=w13_scale, - w13_scale_2=w13_scale_2, - a13_scale=a13_scale, - w2=w2, - w2_scale=w2_scale, - w2_scale_2=w2_scale_2, - a2_scale=a2_scale, - ) - elif ( - nvfp4_backend in FLASHINFER_NVFP4_MOE_BACKENDS - or nvfp4_backend == NvFp4MoeBackend.VLLM_CUTLASS - ): - ( - w13, - w13_scale, - w13_scale_2, - a13_scale, - w2, - w2_scale, - w2_scale_2, - a2_scale, - ) = prepare_nvfp4_moe_layer_for_fi_or_cutlass( - backend=nvfp4_backend, - layer=layer, - w13=w13, - w13_scale=w13_scale, - w13_scale_2=w13_scale_2, - a13_scale=a13_scale, - w2=w2, - w2_scale=w2_scale, - w2_scale_2=w2_scale_2, - a2_scale=a2_scale, - is_act_and_mul=is_act_and_mul, - ) - elif nvfp4_backend == NvFp4MoeBackend.MARLIN: - a13_scale = None - a2_scale = None - ( - w13, - w13_scale, - w13_scale_2, - w2, - w2_scale, - w2_scale_2, - ) = prepare_nvfp4_moe_layer_for_marlin( - layer=layer, - w13=w13, - w13_scale=w13_scale, - w13_scale_2=w13_scale_2, - w2=w2, - w2_scale=w2_scale, - w2_scale_2=w2_scale_2, - is_act_and_mul=is_act_and_mul, - ) - elif nvfp4_backend == NvFp4MoeBackend.EMULATION: - # Move the E2M1 lookup table to the device now, because - # `.to(device)` is not allowed during CUDA graph capture. - kE2M1ToFloat_handle.val = kE2M1ToFloat_handle.val.to(w13.device) - - if a13_scale is None or a2_scale is None: - raise ValueError( - "Activation global scales should not be None, got" - f" a13_scale={a13_scale}, a2_scale={a2_scale}" - ) - - if torch.unique(a13_scale).numel() != 1 or torch.unique(a2_scale).numel() != 1: - logger.warning_once( - "In NVFP4 linear, the activation global scale for inputs are different" - " for MOE w13 (gate_up_proj) layer or MOE w2 (down_proj). Using" - " a13_scale = a13_scale.max() and a2_scale = a2_scale.max()." - ) - - # 1. We take the max following e.g. quantization/utils/flashinfer_fp4_moe.py. - # 2. moe_kernel_quantize_input -> ref_nvfp4_quant_dequant - # use the inverse scale directly (large global scale). - # NOTE: Before this point, `a13_scale` and `a2_scale` are such that: - # `FP8_MAX = activation[expert_id].abs().max() * global_scale[expert_id]`, - # and `global_scale[expert_id]` are small (~1e-4). - # Taking the largest global scale likely results in overflowing the FP8 range - # for other experts - other selection strategies may be used. - a13_scale = 1.0 / a13_scale.max().to(torch.float32) - a2_scale = 1.0 / a2_scale.max().to(torch.float32) - else: - raise ValueError(f"Unknown NvFp4 backend for MoE: {nvfp4_backend}") - - return ( - w13, - w13_scale, - w13_scale_2, - a13_scale, - w2, - w2_scale, - w2_scale_2, - a2_scale, - ) - - -def make_nvfp4_moe_quant_config( - backend: NvFp4MoeBackend, - w13_scale: torch.Tensor, - w2_scale: torch.Tensor, - w13_scale_2: torch.Tensor, - w2_scale_2: torch.Tensor, - a13_scale: torch.Tensor, - a2_scale: torch.Tensor, -) -> FusedMoEQuantConfig: - if backend == NvFp4MoeBackend.MARLIN: - return nvfp4_w4a16_moe_quant_config( - g1_alphas=w13_scale_2, - g2_alphas=w2_scale_2, - w1_scale=w13_scale, - w2_scale=w2_scale, - ) - elif backend == NvFp4MoeBackend.EMULATION: - return nvfp4_moe_quant_config( - g1_alphas=w13_scale_2, - g2_alphas=w2_scale_2, - a1_gscale=a13_scale, - a2_gscale=a2_scale, - w1_scale=w13_scale, - w2_scale=w2_scale, - ) - - # Pass w13_scale_2 / w2_scale_2 directly as g1/g2_alphas. - # The expert's process_weights_after_loading will fuse activation - # scales in-place. Since the quant config references the same tensor - # as the registered parameter, EPLB rearrangement stays in sync. - return nvfp4_moe_quant_config( - g1_alphas=w13_scale_2, - g2_alphas=w2_scale_2, - a1_gscale=(1.0 / a13_scale), - a2_gscale=(1.0 / a2_scale), - w1_scale=w13_scale, - w2_scale=w2_scale, - # NOTE(rob): this is a hack until the MoE kernels - # create their own quant configs. TRTLLM kernel - # does not accept swizzled input quant scales. - is_scale_swizzled=( - backend - not in ( - NvFp4MoeBackend.FLASHINFER_TRTLLM, - NvFp4MoeBackend.FLASHINFER_CUTEDSL, - NvFp4MoeBackend.CUTEDSL, - ) - ), - ) - - -def make_nvfp4_moe_kernel( - moe_quant_config: FusedMoEQuantConfig, - moe_config: FusedMoEConfig, - experts_cls: type[mk.FusedMoEExperts], - routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None, -) -> mk.FusedMoEKernel: - # Create Prepare/Finalize. - prepare_finalize = maybe_make_prepare_finalize( - moe=moe_config, - quant_config=moe_quant_config, - routing_tables=routing_tables, - allow_new_interface=True, - use_monolithic=issubclass(experts_cls, mk.FusedMoEExpertsMonolithic), - ) - assert prepare_finalize is not None - - logger.info_once("Using %s", prepare_finalize.__class__.__name__) - - # Create Experts. - if prepare_finalize.activation_format == mk.FusedMoEActivationFormat.BatchedExperts: - max_num_tokens = prepare_finalize.max_num_tokens_per_rank() - assert max_num_tokens is not None - experts = experts_cls( - moe_config=moe_config, - quant_config=moe_quant_config, - max_num_tokens=max_num_tokens, - num_dispatchers=prepare_finalize.num_dispatchers(), - ) - else: - experts = experts_cls( - moe_config=moe_config, - quant_config=moe_quant_config, - ) - - kernel = mk.FusedMoEKernel( - prepare_finalize, - experts, - inplace=False, - ) - - # TODO(rob): update inplace logic to be part of the kernel. - return kernel diff --git a/vllm/patches/kernel.py b/vllm/patches/kernel.py deleted file mode 100644 index dc862e73..00000000 --- a/vllm/patches/kernel.py +++ /dev/null @@ -1,216 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -import contextlib -from collections.abc import Callable -from dataclasses import asdict, fields -from typing import TYPE_CHECKING, Any, Literal - -from pydantic import Field, field_validator - -from vllm.config.utils import config, get_hash_factors, hash_factors -from vllm.logger import init_logger - -if TYPE_CHECKING: - from vllm.config import VllmConfig - -logger = init_logger(__name__) - - -@config -class IrOpPriorityConfig: - """ - Configuration for vLLM IR op priority for dispatching/lowering during the - forward pass. Each member is a list of strings, which will be passed to - vllm.ir.ops..set_priority() for the duration of the forward pass. - A single comma-separated string is accepted as well, - - If specified manually, platform defaults will be appended to the lists. - See KernelConfig.set_platform_defaults(). - """ - - rms_norm: list[str] = Field(default_factory=list) - """Priority list for vllm.ir.ops.rms_norm""" - - fused_add_rms_norm: list[str] = Field(default_factory=list) - """Priority list for vllm.ir.ops.fused_add_rms_norm""" - - def compute_hash(self) -> str: - """ - Produces a hash unique to the pass configuration. - Any new fields that affect compilation should be added to the hash. - Any future fields that don't affect compilation should be excluded. - - Also, manually add IR op impl UUIDs to make sure they affect the compile cache. - """ - factors = get_hash_factors(self, set()) - - # Implementations are hidden from Dynamo, - # so they don't show up in the traced files list. - from vllm.ir.op import IrOp - - assert "_impls" not in factors - factors["_impls"] = { - name: { - provider: IrOp.registry[name].impls[provider].uuid() for provider in p - } - for name, p in asdict(self).items() # type: ignore[call-overload] - } - - return hash_factors(factors) - - @field_validator("*", mode="before") - @classmethod - def _to_list_str(cls, value: str | list[str]): - if isinstance(value, str): - value = value.replace(" ", "").split(",") - - assert all(isinstance(v, str) for v in value) - return value - - @contextlib.contextmanager - def set_priority(self): - """ - Context manager to set the IR op priority for all op members. - It also imports IR kernel implementations for the current platform - to ensure all implementations are made available. - """ - from vllm.ir.op import IrOp - from vllm.platforms import current_platform - - current_platform.import_ir_kernels() - - with contextlib.ExitStack() as stack: - for field in fields(self): # type: ignore[arg-type] - op_priority = getattr(self, field.name) - assert op_priority is not None, ( - f"IR op priority for {field.name} must be set" - ) - logger.debug( - "Setting IR op priority for %s to %s", field.name, op_priority - ) - ir_op = IrOp.registry[field.name] - stack.enter_context(ir_op.set_priority(op_priority)) - - yield - - @classmethod - def with_default( - cls, default: list[str], /, **kwargs: list[str] - ) -> "IrOpPriorityConfig": - """ - A helper to create an IrOpPriorityConfig where fields not specified in kwargs - use the given default list. - """ - for field in fields(cls): # type: ignore[arg-type] - if field.name not in kwargs: - kwargs[field.name] = list(default) - - return cls(**kwargs) - - -MoEBackend = Literal[ - "auto", - "triton", - "deep_gemm", - "deep_gemm_mega_moe", - "cutlass", - "flashinfer_trtllm", - "flashinfer_cutlass", - "flashinfer_cutedsl", - "marlin", - "humming", - "triton_unfused", - "aiter", - "cutedsl", - "emulation", -] - - -@config -class KernelConfig: - """Configuration for kernel selection and warmup behavior.""" - - ir_op_priority: IrOpPriorityConfig = Field(default_factory=IrOpPriorityConfig) - """ - vLLM IR op priority for dispatching/lowering during the forward pass. - Platform defaults appended automatically during VllmConfig.__post_init__. - """ - - enable_flashinfer_autotune: bool = None # type: ignore[assignment] - """If True, run FlashInfer autotuning during kernel warmup.""" - - moe_backend: MoEBackend = "auto" - """Backend for MoE expert computation kernels. Available options: - - - "auto": Automatically select the best backend based on model and hardware - - "triton": Use Triton-based fused MoE kernels - - "deep_gemm": Use DeepGEMM kernels (FP8 block-quantized only) - - "deep_gemm_mega_moe": Use DeepGEMM mega MoE kernels - - "cutlass": Use vLLM CUTLASS kernels - - "flashinfer_trtllm": Use FlashInfer with TRTLLM-GEN kernels - - "flashinfer_cutlass": Use FlashInfer with CUTLASS kernels - - "flashinfer_cutedsl": Use FlashInfer with CuteDSL kernels (FP4 only) - - "marlin": Use Marlin kernels (weight-only quantization) - - "humming": Use Humming Mixed Precision kernels - - "triton_unfused": Use Triton unfused MoE kernels - - "aiter": Use AMD AITer kernels (ROCm only) - - "emulation": use BF16/FP16 GEMM, dequantizing weights and - running QDQ on activations. - """ - - @field_validator("moe_backend", mode="before") - @classmethod - def _normalize_moe_backend(cls, value: Any) -> Any: - if isinstance(value, str): - return value.lower().replace("-", "_") - return value - - def compute_hash(self) -> str: - """ - Produces a hash unique to the pass configuration. - Any new fields that affect compilation should be added to the hash. - Any future fields that don't affect compilation should be excluded. - """ - ignored_factors = { - "enable_flashinfer_autotune", - "ir_op_priority", # handled separately below - } - factors = get_hash_factors(self, ignored_factors) - factors["ir_op_priority"] = self.ir_op_priority.compute_hash() - return hash_factors(factors) - - @field_validator("enable_flashinfer_autotune", mode="wrap") - @classmethod - def _skip_none_validation(cls, value: Any, handler: Callable) -> Any: - """Skip validation if the value is `None` when initialization is delayed.""" - if value is None: - return value - return handler(value) - - def set_platform_defaults(self, vllm_config: "VllmConfig") -> None: - """Set platform-specific defaults for the kernel config.""" - from vllm.platforms import current_platform - - platform_op_priority = current_platform.get_default_ir_op_priority(vllm_config) - logger.debug( - "Setting platform-specific IR op priority defaults: %s, user-defined: %s", - platform_op_priority, - self.ir_op_priority, - ) - for op_name, op_priority in asdict(platform_op_priority).items(): - current_op_priority: list[str] = getattr(self.ir_op_priority, op_name) - if current_op_priority is None: - setattr(self.ir_op_priority, op_name, op_priority) - else: - # Append platform-specific priorities - # Must be idempotent because vllm_config.set_platform_defaults() may be - # called multiple times (due to VllmConfig.__post_init__ manual call). - unique_op_priority = [ - op for op in op_priority if op not in current_op_priority - ] - current_op_priority.extend(unique_op_priority) - - logger.info( - "Final IR op priority after setting platform defaults: %s", - self.ir_op_priority, - ) diff --git a/vllm/patches/layers/csa_attention.py b/vllm/patches/layers/csa_attention.py deleted file mode 100644 index 2147f336..00000000 --- a/vllm/patches/layers/csa_attention.py +++ /dev/null @@ -1,356 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -""" -CSA/HCA attention for Blackwell (SM100+). - -Replaces vLLM's FlashMLA + fused CUDA kernels with our own KV cache-based -attention pipeline. The previous version used raw KV for attention (no cache), -which produced garbage during decode because the KV cache was never written. - -Key changes from the broken version: -1. fused_qnorm_rope_kv_insert_py NOW WRITES KV to the paged cache (fp8) -2. full_sdpa_attention is replaced with cache-aware attention -3. KV is quantized to fp8 with per-token scale, RoPE applied before caching - -Architecture: -- KV latent: (T, HD=512) single head, shared across 128 Q heads -- KV Cache: fp8_e4m3 paged cache with per-token inverse scale -- Attention: BF16 (NVFP4 too lossy for Q×K^T) -""" - -import torch -import torch.nn.functional as F - - -def apply_gptj_rope(x, cos, sin, nope_dim): - out = x.clone() - even = x[..., nope_dim:][..., 0::2] - odd = x[..., nope_dim:][..., 1::2] - out[..., nope_dim:][..., 0::2] = even * cos - odd * sin - out[..., nope_dim:][..., 1::2] = even * sin + odd * cos - return out - - -def apply_inv_gptj_rope(x, cos, sin, nope_dim): - out = x.clone() - even = x[..., nope_dim:][..., 0::2] - odd = x[..., nope_dim:][..., 1::2] - out[..., nope_dim:][..., 0::2] = even * cos + odd * sin - out[..., nope_dim:][..., 1::2] = -even * sin + odd * cos - return out - - -# ── KV Cache Operations ────────────────────────────────────────────── - -def kv_quantize_fp8(kv_bf16): - """BF16 KV → fp8_e4m3 with per-token inverse scale.""" - amax = kv_bf16.float().abs().amax(dim=-1, keepdim=True).clamp(min=1e-12) - fp8_max = torch.tensor(448.0, dtype=torch.float32, device=kv_bf16.device) - scale = fp8_max / amax - kv_fp8 = (kv_bf16.float() * scale).to(torch.float8_e4m3fn) - inv_scale = (amax / fp8_max).to(torch.bfloat16) - return kv_fp8, inv_scale - - -def kv_dequantize_fp8(kv_fp8, inv_scale): - """fp8 KV → BF16.""" - return (kv_fp8.to(torch.bfloat16) * inv_scale).to(torch.bfloat16) - - -def paged_kv_write(kv_data, slot_mapping, cache, block_size): - """Write KV into paged cache. - - kv_data: (T, D) tensor to write (fp8 or bf16) - slot_mapping: (T,) slot indices - cache: (num_blocks, block_size, D) cache tensor (may be uint8) - """ - # Handle dtype mismatch: cache is uint8, kv_data is fp8 - if cache.dtype == torch.uint8 and kv_data.dtype == torch.float8_e4m3fn: - kv_to_write = kv_data.view(torch.uint8) - else: - kv_to_write = kv_data - - # Vectorized write using advanced indexing - block_indices = slot_mapping // block_size - offsets = slot_mapping % block_size - # Clamp to valid range (safety) - valid = (block_indices < cache.shape[0]) & (offsets < cache.shape[1]) - if valid.all(): - cache[block_indices, offsets] = kv_to_write - else: - # Fall back to per-token for partial writes - for t in range(kv_data.shape[0]): - bi = block_indices[t].item() - oi = offsets[t].item() - if bi < cache.shape[0] and oi < cache.shape[1]: - cache[bi, oi] = kv_to_write[t] - - -def paged_kv_read(slot_mapping, cache, block_size, num_tokens, head_dim): - """Read KV from paged cache. Returns fp8 or uint8. - - Vectorized version — uses advanced indexing instead of Python for loop. - """ - device = cache.device - # Compute block indices and offsets - slots = slot_mapping # (num_tokens,) - block_indices = slots // block_size - offsets = slots % block_size - - # Advanced indexing: cache[block_indices, offsets] -> (num_tokens, head_dim) - kv = cache[block_indices, offsets] - - # If cache is uint8, reinterpret as fp8 - if cache.dtype == torch.uint8: - kv = kv.view(torch.float8_e4m3fn) - return kv - - -# ── Attention ───────────────────────────────────────────────────────── - -def causal_prefill_attention(q, kv, scale): - """Full causal self-attention for prefill. q: (T, NH, HD), kv: (T, HD).""" - T, NH, HD = q.shape - q_t = q.permute(1, 0, 2) # (NH, T, HD) - kv_exp = kv.unsqueeze(0).expand(NH, -1, -1) # (NH, T, HD) - out = F.scaled_dot_product_attention(q_t, kv_exp, kv_exp, is_causal=True, scale=scale) - return out.permute(1, 0, 2) # (T, NH, HD) - - -def decode_attention(q, kv, scale): - """Decode attention: 1 query vs N cached KVs. - - q: (1, NH, HD) — single decode token - kv: (N, HD) — all cached KV (already with RoPE) - """ - NH = q.shape[1] - HD = q.shape[2] - q_t = q.permute(1, 0, 2) # (NH, 1, HD) - kv_exp = kv.unsqueeze(0).expand(NH, -1, -1) # (NH, N, HD) - out = F.scaled_dot_product_attention(q_t, kv_exp, kv_exp, is_causal=False, scale=scale) - return out.permute(1, 0, 2) # (1, NH, HD) - - -# ── Fused Q norm + RoPE + KV cache write ───────────────────────────── - -def fused_qnorm_rope_kv_insert_py( - q, # (T, num_heads, head_dim) — modified in-place - kv, # (T, head_dim) — not modified - swa_kv_cache_2d, # paged KV cache (2D view) - slot_mapping, - positions, - cos_sin_cache, - eps, - block_size, - nope_dim, - rope_dim, -) -> None: - """Pure PyTorch: RoPE on Q only. - - Q is already normed (by fused_q_kv_rmsnorm), so we only apply RoPE. - The original CUDA kernel also does KV cache insert, but we handle that - separately in blackwell_attention_kv_write. - """ - T = q.shape[0] - if T == 0: - return - - # GPT-J RoPE on Q only (Q is already normed) - half = rope_dim // 2 - cos_q = cos_sin_cache[positions, :half].unsqueeze(1).to(q.dtype) - sin_q = cos_sin_cache[positions, half:2*half].unsqueeze(1).to(q.dtype) - q_rope = q[:, :, nope_dim:].clone() - q[:, :, nope_dim:][:, :, 0::2] = q_rope[:, :, 0::2] * cos_q - q_rope[:, :, 1::2] * sin_q - q[:, :, nope_dim:][:, :, 1::2] = q_rope[:, :, 0::2] * sin_q + q_rope[:, :, 1::2] * cos_q - - -def blackwell_attention_kv_write( - kv, # (T, HD) kv_normed — NOT RoPE'd yet - positions, # (T,) absolute positions - swa_kv_cache, # (num_blocks, block_size, HD) fp8 paged cache - swa_inv_scale, # (max_slots, 1) per-token inv scale - slot_mapping, # (T,) slot indices - block_size, # tokens per block - cos_sin_cache, # (max_pos, rope_dim) for RoPE - nope_dim, # 448 - rope_dim, # 64 -) -> None: - """Write KV to paged cache: apply RoPE → fp8 quantize → insert. - - This is the function that vLLM's Blackwell path was missing. - Without this, the KV cache is never written, and decode attention - produces garbage because it can't access prior tokens' KV. - """ - T = kv.shape[0] - if T == 0: - return - - # Apply GPT-J RoPE to KV - half = rope_dim // 2 - cos = cos_sin_cache[positions, :half].to(kv.dtype) - sin = cos_sin_cache[positions, half:2 * half].to(kv.dtype) - # Must use original values for both even and odd before modifying - kv_rope_nope = kv[:, nope_dim:].clone() - even = kv_rope_nope[:, 0::2] - odd = kv_rope_nope[:, 1::2] - new_even = even * cos - odd * sin - new_odd = even * sin + odd * cos - kv_rope = kv.clone() - kv_rope[:, nope_dim:][:, 0::2] = new_even - kv_rope[:, nope_dim:][:, 1::2] = new_odd - - # Quantize to fp8 - kv_fp8, inv_scale = kv_quantize_fp8(kv_rope) - - # Write to paged cache - paged_kv_write(kv_fp8, slot_mapping, swa_kv_cache, block_size) - - # Write inv_scale to flat cache - for t in range(T): - slot = slot_mapping[t].item() - swa_inv_scale[slot] = inv_scale[t] - - -def blackwell_attention_decode( - q, # (1, NH, HD) single decode query with RoPE - positions, # (1,) absolute position - swa_kv_cache, # (num_blocks, block_size, HD) fp8 SWA cache (uint8) - swa_inv_scale, # (max_slots, 1) per-token inv scale - slot_mapping, # (1,) slot for the new token (already written) - block_size, # tokens per block - scale, # 1/sqrt(HD) - window_size, # 128 - swa_indices=None, # (num_decode_tokens, window_size) pre-computed paged indices - swa_lens=None, # (num_decode_tokens,) number of valid indices per token - decode_token_idx=0, # which decode token this is -) -> torch.Tensor: - """Decode attention: read cached KV using paged indices, attend. - - Uses pre-computed swa_indices from vLLM's metadata for correct paged access. - Returns: (1, NH, HD) attention output. - """ - NH = q.shape[1] - HD = q.shape[2] - device = q.device - - if swa_indices is not None and swa_lens is not None: - # Use pre-computed paged indices from vLLM - num_valid = swa_lens[decode_token_idx].item() - indices = swa_indices[decode_token_idx, :num_valid] - block_indices = indices // block_size - offsets = indices % block_size - kv_cached_raw = swa_kv_cache[block_indices, offsets] - if swa_kv_cache.dtype == torch.uint8: - kv_cached_raw = kv_cached_raw.view(torch.float8_e4m3fn) - # Dequantize with per-token inverse scale - inv_scales = swa_inv_scale[indices] - kv_cached = kv_dequantize_fp8(kv_cached_raw, inv_scales) - else: - # Fallback: sequential slot access - pos = positions[0].item() - all_slots = torch.arange(pos + 1, dtype=torch.int64, device=device) - kv_cached_raw = paged_kv_read(all_slots, swa_kv_cache, block_size, pos + 1, HD) - kv_inv_scales = swa_inv_scale[all_slots] - kv_cached = kv_dequantize_fp8(kv_cached_raw, kv_inv_scales) - window_start = max(0, pos - window_size + 1) - kv_cached = kv_cached[window_start:] - - q_t = q.permute(1, 0, 2) - kv_exp = kv_cached.unsqueeze(0).expand(NH, -1, -1) - out = F.scaled_dot_product_attention(q_t, kv_exp, kv_exp, is_causal=False, scale=scale) - return out.permute(1, 0, 2) - - -def full_sdpa_attention( - q: torch.Tensor, # (T, NH, HD) with RoPE - kv: torch.Tensor, # (T, HD) KV latent - scale: float, -) -> torch.Tensor: - """Full causal self-attention for PREFILL only. - - DEPRECATED: Use causal_prefill_attention instead. - Kept for backward compatibility with the existing vLLM patch. - """ - return causal_prefill_attention(q, kv, scale) - - -# ── CSA/HCA Decode Attention ───────────────────────────────────────── - -def blackwell_csa_decode_attention( - q, # (num_decode_tokens, NH, HD) with RoPE - positions, # (num_decode_tokens,) - swa_kv_cache, # (num_blocks, block_size, D) fp8 SWA cache - swa_inv_scale, # (max_slots, 1) per-token inv scale - swa_metadata, # DeepseekSparseSWAMetadata - flashmla_metadata, # FlashMLASparseMetadata (for topk_indices) - compressed_kv_cache, # (num_blocks, block_size, D) compressed KV cache - compress_ratio, # 4 or 128 - scale, # 1/sqrt(HD) - window_size, # 128 - nope_dim, # 448 - rope_dim, # 64 - cos_sin_cache, # (max_pos, rope_dim) - attn_sink, # (NH,) sink weights - max_model_len, # max sequence length -) -> torch.Tensor: - """CSA/HCA decode: sparse attention on compressed KV + SWA + sink merge. - - For each decode token: - 1. Get topk_indices from the indexer (already computed) - 2. Gather compressed KV at topk positions - 3. Sparse attention on gathered KV - 4. SWA attention from paged cache - 5. Merge with sink weights - """ - num_tokens, NH, HD = q.shape - device = q.device - block_size = swa_metadata.block_size - - output = torch.zeros(num_tokens, NH, HD, dtype=torch.bfloat16, device=device) - - # Get topk indices from the indexer - num_decodes = swa_metadata.num_decodes - is_valid = swa_metadata.is_valid_token[:num_tokens] - - if compress_ratio == 4: - # C4A: topk indices from indexer buffer - # These are computed by the indexer during this forward pass - # For now, we need to get them from the metadata - # The indexer fills topk_indices_buffer - pass - # C128A: pre-computed in the metadata - - # For now, fall back to SWA-only for CSA/HCA decode - # The sparse component will be added once we verify the SWA path works - for t in range(num_tokens): - output[t] = blackwell_attention_decode( - q[t:t+1], positions[t:t+1], - swa_kv_cache, swa_inv_scale, - swa_metadata.slot_mapping[t:t+1], - block_size, scale, window_size, - ).squeeze(0) - - return output - - -def csa_sparse_prefill_attention( - q, # (num_prefills, NH, HD) with RoPE - kv_rope, # (num_prefills, HD) KV with RoPE - compressed_kv_cache, # compressed KV cache - flashmla_metadata, # FlashMLASparseMetadata - swa_metadata, # DeepseekSparseSWAMetadata - compress_ratio, # 4 or 128 - scale, # 1/sqrt(HD) - window_size, # 128 - nope_dim, # 448 - rope_dim, # 64 - cos_sin_cache, # (max_pos, rope_dim) - attn_sink, # (NH,) sink weights - max_model_len, # max sequence length -) -> torch.Tensor: - """CSA/HCA prefill: sparse + SWA attention. - - For now, falls back to full causal attention (which is correct - but not optimal for long sequences). - """ - # Full causal attention is always correct for prefill - return causal_prefill_attention(q, kv_rope, scale) diff --git a/vllm/patches/layers/deepseek_compressor.py b/vllm/patches/layers/deepseek_compressor.py deleted file mode 100644 index 4bf89163..00000000 --- a/vllm/patches/layers/deepseek_compressor.py +++ /dev/null @@ -1,455 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -from dataclasses import dataclass -from typing import Any, ClassVar, cast - -import torch -from torch import nn - -from vllm.config import VllmConfig, get_current_vllm_config -from vllm.forward_context import get_forward_context -from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase -from vllm.model_executor.layers.layernorm import RMSNorm -from vllm.model_executor.layers.linear import ( - MergedColumnParallelLinear, -) -from vllm.platforms import current_platform - -# Check at module load time if we're on Blackwell -_IS_BLACKWELL = False -try: - _cap = current_platform.get_device_capability() - if _cap is not None and _cap.major >= 10: - _IS_BLACKWELL = True -except Exception: - pass -from vllm.triton_utils import tl, triton -from vllm.v1.attention.backend import ( - AttentionBackend, - AttentionCGSupport, - AttentionMetadataBuilder, - CommonAttentionMetadata, - MultipleOf, -) -from vllm.v1.attention.ops.deepseek_v4_ops.fused_compress_quant_cache import ( - _fused_kv_compress_norm_rope_insert_indexer_attn, - _fused_kv_compress_norm_rope_insert_indexer_mxfp4_attn, - _fused_kv_compress_norm_rope_insert_sparse_attn, -) -from vllm.v1.attention.ops.deepseek_v4_ops.fused_indexer_q import ( - MXFP4_BLOCK_SIZE, -) -from vllm.v1.kv_cache_interface import ( - KVCacheSpec, - MLAAttentionSpec, - SlidingWindowMLASpec, -) - - -class CompressorBackend(AttentionBackend): - def __init__(self): - super().__init__() - - @staticmethod - def get_name() -> str: - return "CompressorBackend" - - @staticmethod - def get_supported_kernel_block_sizes() -> list[int | MultipleOf]: - return [MultipleOf(1)] - - @classmethod - def get_supported_head_sizes(cls) -> list[int]: - return [512, 1024] - - @staticmethod - def get_builder_cls() -> type["CompressorMetadataBuilder"]: - return CompressorMetadataBuilder - - @staticmethod - def get_kv_cache_shape( - num_blocks: int, - block_size: int, - num_kv_heads: int, - head_size: int, - cache_dtype_str: str = "auto", - ) -> tuple[int, ...]: - assert num_kv_heads == 1 - return (num_blocks, block_size, head_size) - - @staticmethod - def get_kv_cache_stride_order( - include_num_layers_dimension: bool = False, - ) -> tuple[int, ...]: - if include_num_layers_dimension: - return (0, 1, 2, 3) - return (0, 1, 2) - - -@dataclass -class CompressorMetadata: - block_table: torch.Tensor - slot_mapping: torch.Tensor - block_size: int - - token_to_req_indices: torch.Tensor | None = None # [num_tokens] - - -class CompressorMetadataBuilder(AttentionMetadataBuilder): - _cudagraph_support: ClassVar[AttentionCGSupport] = AttentionCGSupport.ALWAYS - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - assert isinstance(self.kv_cache_spec, SlidingWindowMLASpec | MLAAttentionSpec) - mla_spec = cast(SlidingWindowMLASpec | MLAAttentionSpec, self.kv_cache_spec) - self.block_size = mla_spec.block_size - - self.token_to_req_indices = torch.zeros( - self.vllm_config.scheduler_config.max_num_batched_tokens, - dtype=torch.int32, - device=self.device, - ) - - def build( - self, - common_prefix_len: int, - common_attn_metadata: CommonAttentionMetadata, - fast_build: bool = False, - ) -> CompressorMetadata: - query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu - num_reqs = common_attn_metadata.num_reqs - query_lens = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1] - x = torch.repeat_interleave(torch.arange(num_reqs), query_lens).pin_memory() - token_to_req_indices = self.token_to_req_indices[: x.shape[0]] - token_to_req_indices.copy_(x, non_blocking=True) - return CompressorMetadata( - block_table=common_attn_metadata.block_table_tensor.clamp_(min=0), - slot_mapping=common_attn_metadata.slot_mapping, - block_size=self.block_size, - token_to_req_indices=token_to_req_indices, - ) - - -class CompressorStateCache(torch.nn.Module, AttentionLayerBase): - def __init__( - self, - state_dim: int, - dtype: torch.dtype, - compress_ratio: int, - prefix: str, - ): - super().__init__() - self.state_dim = state_dim - self.dtype = dtype - self.prefix = prefix - self.kv_cache = torch.tensor([]) - compilation_config = get_current_vllm_config().compilation_config - if prefix in compilation_config.static_forward_context: - raise ValueError(f"Duplicate layer name: {prefix}") - compilation_config.static_forward_context[prefix] = self - - assert self.dtype == torch.float32 - assert compress_ratio in [4, 128] - coff = 1 + (compress_ratio == 4) - self.sliding_window = coff * compress_ratio - # Block size is constrained by tensor sharing between compressor states - # and KV blocks. Since compressor states share the same physical tensor - # as KV blocks, they must use the same page size. - # The KV block shape [256//4, head_dim] = [64, 584] determines: - # - C4 compressor block shape [4, 2*512*2*4] -> block_size = 4 - # - C128 compressor block shape [8, 512*2*4] -> block_size = 8 - # TODO(yifan): make block size automatically determined and configurable. - if compress_ratio == 4: - self.block_size = 4 - elif compress_ratio == 128: - self.block_size = 8 - else: - raise ValueError(f"Invalid compress ratio: {compress_ratio}") - - def get_kv_cache_spec(self, vllm_config: VllmConfig) -> KVCacheSpec: - return SlidingWindowMLASpec( # only has one vector instead of K + V - block_size=self.block_size, - num_kv_heads=1, - head_size=self.state_dim, - dtype=self.dtype, - sliding_window=self.sliding_window, - alignment=576, # NOTE: FlashMLA requires 576B alignment - ) - - def forward(self): ... - - def get_attn_backend(self) -> type[AttentionBackend]: - return CompressorBackend - - -class DeepseekCompressor(nn.Module): - def __init__( - self, - vllm_config: VllmConfig, - compress_ratio: int, - hidden_size: int, - head_dim: int, - rotate: bool = False, - prefix: str = "", - k_cache_prefix="", - use_fp4_cache: bool = False, - ): - super().__init__() - self.compress_ratio = compress_ratio - self.hidden_size = hidden_size - self.head_dim = head_dim - self.rotate = rotate - self.prefix = prefix - self.k_cache_prefix = k_cache_prefix - self.use_fp4_cache = use_fp4_cache - - config = vllm_config.model_config.hf_config - self.rope_head_dim = config.qk_rope_head_dim - self.nope_head_dim = self.head_dim - self.rope_head_dim - self.rms_norm_eps = config.rms_norm_eps - self.device = current_platform.device_type - self.max_num_reqs = vllm_config.scheduler_config.max_num_seqs - self.max_model_len = vllm_config.model_config.max_model_len - - self.overlap = compress_ratio == 4 - self.coff = 1 + self.overlap - - state_dtype = torch.float32 - self.ape = nn.Parameter( - torch.empty( - (compress_ratio, self.coff * self.head_dim), - dtype=state_dtype, - device=self.device, - ), - requires_grad=False, - ) - - quant_config = vllm_config.quant_config - - self.fused_wkv_wgate = MergedColumnParallelLinear( - self.hidden_size, - [self.coff * self.head_dim, self.coff * self.head_dim], - bias=False, - return_bias=False, - quant_config=quant_config, - disable_tp=True, - prefix=f"{prefix}.fused_wkv_wgate", - ) - self.norm = RMSNorm(self.head_dim, self.rms_norm_eps) - - self.state_cache = CompressorStateCache( - state_dim=2 * self.coff * self.head_dim, # kv_state + score_state - dtype=state_dtype, - compress_ratio=compress_ratio, - prefix=f"{prefix}.state_cache", - ) - - # Save reference to static_forward_context for forward-time KV cache lookup. - # get_current_vllm_config() is only available during __init__, not forward. - self._static_forward_context = ( - vllm_config.compilation_config.static_forward_context - ) - - if self.head_dim == 512: - assert not use_fp4_cache, ( - "MXFP4 cache is only supported for indexer (head=128)" - ) - self._fused_kernel = _fused_kv_compress_norm_rope_insert_sparse_attn - self._quant_block = 64 - self._token_stride = self.nope_head_dim + self.rope_head_dim * 2 - self._scale_dim = self.nope_head_dim // 64 + 1 # 7 real + 1 pad - self._num_warps = 4 - elif self.head_dim == 128: - if use_fp4_cache: - self._fused_kernel = ( - _fused_kv_compress_norm_rope_insert_indexer_mxfp4_attn - ) - self._quant_block = MXFP4_BLOCK_SIZE - self._token_stride = self.head_dim // 2 - self._scale_dim = self.head_dim // MXFP4_BLOCK_SIZE - else: - self._fused_kernel = _fused_kv_compress_norm_rope_insert_indexer_attn - self._quant_block = 128 - self._token_stride = self.head_dim - self._scale_dim = 4 # single float32 scale - self._num_warps = 1 - else: - raise ValueError( - f"Unsupported head_dim for fused quant+cache: {self.head_dim}" - ) - - def forward( - self, - # [num_tokens, 2 * self.coff * self.head_dim] - kv_score: torch.Tensor, - # [num_tokens] - positions: torch.Tensor, - rotary_emb, - ) -> None: - # Each of shape [num_tokens, coff * self.head_dim] - # input bf16, output are fp32 - kv, score = kv_score.split( - [self.coff * self.head_dim, self.coff * self.head_dim], dim=-1 - ) - - # Get the metadata and handle dummy profiling run. - attn_metadata = get_forward_context().attn_metadata - if not isinstance(attn_metadata, dict): - return - - state_metadata = cast( - CompressorMetadata, attn_metadata[self.state_cache.prefix] - ) - token_to_req_indices = state_metadata.token_to_req_indices - slot_mapping = state_metadata.slot_mapping - num_actual = slot_mapping.shape[0] - block_table = state_metadata.block_table - block_size = state_metadata.block_size - - # [num_blocks, block_size, kv_dim+score_dim], where kv_dim == score_dim - state_cache = self.state_cache.kv_cache - # kv_state stored in first half, score_state stored in second half - state_width = state_cache.shape[-1] // 2 - pdl_kwargs = {} if current_platform.is_rocm() else {"launch_pdl": False} - - # Store the KV and score (with fused APE addition) in the state. - # NOTE: PDL is disabled — both this kernel and _fused_kernel below - # depend on preceding kernel outputs (kv/score from the cublas GEMM; - # state_cache from this kernel) but neither emits/waits on PDL grid - # dependency primitives, so launch_pdl=True caused a read-after-write - # race and non-deterministic output. - _save_partial_states_kernel[(num_actual,)]( - kv, - kv.stride(0), - score, - score.stride(0), - self.ape, - self.ape.stride(0), - positions, - state_cache, - state_cache.stride(0), - state_cache.stride(1), - slot_mapping, - block_size, - HEAD_SIZE=kv.shape[-1], - TRITON_BLOCK_SIZE=triton.next_power_of_2(kv.shape[-1]), - STATE_WIDTH=state_width, - COMPRESS_RATIO=self.compress_ratio, - **pdl_kwargs, - ) - - # Fused: compress → RMSNorm → RoPE → FP8 quant → KV cache write. - # RoPE requirements (kernel applies forward GPT-J style rotation): - # - is_neox_style=False (interleaved pairs, NOT split-half) - # - cos_sin_cache layout: [max_pos, rope_head_dim] with first half cos, - # second half sin (per-pair, length rope_head_dim // 2 each) - # - applied to LAST rope_head_dim elements of head_dim - # - position used: (positions // compress_ratio) * compress_ratio - - # On Blackwell (SM100+), skip the fused kernel because: - # 1. The fused kernel does attention using FlashMLA which doesn't work on SM100 - # 2. Our Blackwell attention path handles everything separately - # Instead, we just save the state (done above) and let the attention - # path handle compression + RoPE + cache write + attention. - if _IS_BLACKWELL: - # Blackwell: state is already saved, skip fused kernel - return - - cos_sin_cache = rotary_emb.cos_sin_cache - k_cache_metadata = cast(Any, attn_metadata[self.k_cache_prefix]) - kv_cache = self._static_forward_context[self.k_cache_prefix].kv_cache - - self._fused_kernel[(num_actual,)]( - # state cache - state_cache, - state_cache.stride(0), - state_cache.stride(1), - # metadata - token_to_req_indices, - positions, - slot_mapping, - block_table, - block_table.stride(0), - block_size, - # RMSNorm - self.norm.weight, - self.rms_norm_eps, - # RoPE - cos_sin_cache, - cos_sin_cache.stride(0), - # KV cache - kv_cache, - k_cache_metadata.slot_mapping, - kv_cache.shape[1], # paged KV cache block size (tokens per block) - # constexprs - HEAD_SIZE=self.head_dim, - TRITON_BLOCK_SIZE=triton.next_power_of_2(self.head_dim), - STATE_WIDTH=state_width, - COMPRESS_RATIO=self.compress_ratio, - OVERLAP=self.overlap, - ROPE_HEAD_DIM=self.rope_head_dim, - FP8_MAX=448.0, - QUANT_BLOCK=self._quant_block, - TOKEN_STRIDE=self._token_stride, - SCALE_DIM=self._scale_dim, - KV_BLOCK_STRIDE=kv_cache.stride(0), - num_warps=self._num_warps, - **pdl_kwargs, - ) - - -@triton.jit -def _save_partial_states_kernel( - kv_ptr, - kv_stride, - score_ptr, - score_stride, - ape_ptr, - ape_stride, - positions_ptr, - state_cache_ptr, - state_cache_stride0, - state_cache_stride1, - slot_mapping_ptr, - block_size, - HEAD_SIZE: tl.constexpr, - TRITON_BLOCK_SIZE: tl.constexpr, - # state_cache last dim packs [kv_state, score_state], each STATE_WIDTH wide. - STATE_WIDTH: tl.constexpr, - COMPRESS_RATIO: tl.constexpr, -): - token_idx = tl.program_id(0) - slot_id = tl.load(slot_mapping_ptr + token_idx) - - # Skip padded / invalid tokens (slot_id == -1 is the PAD sentinel used - # by vLLM). During CUDA graph replay the batch may contain padding - # tokens whose slot_mapping is -1; writing to kv_state[-1] would be an - # illegal memory access. - if slot_id < 0: - return - - block_idx = slot_id // block_size - pos_in_block = slot_id % block_size - base_ptr = ( - state_cache_ptr - + block_idx * state_cache_stride0 - + pos_in_block * state_cache_stride1 - ) - - block = tl.arange(0, TRITON_BLOCK_SIZE) - mask = block < HEAD_SIZE - - kv = tl.load(kv_ptr + token_idx * kv_stride + block, mask=mask) - tl.store(base_ptr + block, kv, mask=mask) - - # Fused: score += ape[position % compress_ratio] - position = tl.load(positions_ptr + token_idx) - ape_row = position % COMPRESS_RATIO - ape = tl.load(ape_ptr + ape_row * ape_stride + block, mask=mask) - score = tl.load(score_ptr + token_idx * score_stride + block, mask=mask) - tl.store( - base_ptr + STATE_WIDTH + block, - score + ape, - mask=mask, - ) diff --git a/vllm/patches/layers/mhc.py b/vllm/patches/layers/mhc.py deleted file mode 100644 index e46e3fb5..00000000 --- a/vllm/patches/layers/mhc.py +++ /dev/null @@ -1,195 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# Patched MHC layer — pure PyTorch, no TileLang. -# Replaces the TileLang-based mhc.py from the vLLM nightly image. -# The original imports tilelang at the top level and JIT-compiles kernels -# which don't work correctly on Blackwell (SM100). - -import torch -from vllm.utils.torch_utils import direct_register_custom_op - - -# ── Pure PyTorch MHC implementations ────────────────────────────────── - -def mhc_pre( - residual: torch.Tensor, - fn: torch.Tensor, - hc_scale: torch.Tensor, - hc_base: torch.Tensor, - rms_eps: float, - hc_pre_eps: float, - hc_sinkhorn_eps: float, - hc_post_mult_value: float, - sinkhorn_repeat: int, - n_splits: int = 1, -) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - assert residual.dtype == torch.bfloat16 - assert fn.dtype == torch.float32 - assert hc_scale.dtype == torch.float32 - assert hc_base.dtype == torch.float32 - - hc_mult = residual.shape[-2] - hidden_size = residual.shape[-1] - hc_mult2 = hc_mult * hc_mult - hc_mult3 = hc_mult * 2 + hc_mult2 - hc_hidden_size = hc_mult * hidden_size - outer_shape = residual.shape[:-2] - - residual_flat = residual.view(-1, hc_mult, hidden_size) - num_tokens = residual_flat.shape[0] - - x = residual_flat.view(num_tokens, hc_hidden_size).to(torch.float32) - mixes = torch.matmul(x, fn.t()) - sqrsum = x.square().sum(dim=-1, keepdim=True) - mixes = mixes * torch.rsqrt(sqrsum / hc_hidden_size + rms_eps) - - pre_logits = mixes[:, :hc_mult] * hc_scale[0] + hc_base[:hc_mult] - pre_mix = torch.sigmoid(pre_logits) + hc_pre_eps - - post_logits = mixes[:, hc_mult:2 * hc_mult] * hc_scale[1] + hc_base[hc_mult:2 * hc_mult] - post_mix = torch.sigmoid(post_logits) * hc_post_mult_value - - comb_logits = (mixes[:, 2 * hc_mult:] - .view(num_tokens, hc_mult, hc_mult) - * hc_scale[2] - + hc_base[2 * hc_mult:].view(1, hc_mult, hc_mult)) - comb_mix = torch.softmax(comb_logits, dim=-1) + hc_sinkhorn_eps - comb_mix = comb_mix / (comb_mix.sum(dim=-2, keepdim=True) + hc_sinkhorn_eps) - for _ in range(sinkhorn_repeat - 1): - comb_mix = comb_mix / (comb_mix.sum(dim=-1, keepdim=True) + hc_sinkhorn_eps) - comb_mix = comb_mix / (comb_mix.sum(dim=-2, keepdim=True) + hc_sinkhorn_eps) - - layer_input = torch.sum( - pre_mix.unsqueeze(-1) * residual_flat.to(torch.float32), dim=1 - ).to(torch.bfloat16) - - return ( - post_mix.view(*outer_shape, hc_mult, 1), - comb_mix.view(*outer_shape, hc_mult, hc_mult), - layer_input.view(*outer_shape, hidden_size), - ) - - -def _mhc_pre_fake( - residual, fn, hc_scale, hc_base, rms_eps, hc_pre_eps, - hc_sinkhorn_eps, hc_post_mult_value, sinkhorn_repeat, n_splits=1, -): - hc_mult = residual.shape[-2] - hidden_size = residual.shape[-1] - outer_shape = residual.shape[:-2] - return ( - torch.empty(*outer_shape, hc_mult, 1, dtype=torch.float32, device=residual.device), - torch.empty(*outer_shape, hc_mult, hc_mult, dtype=torch.float32, device=residual.device), - torch.empty(*outer_shape, hidden_size, dtype=torch.bfloat16, device=residual.device), - ) - - -def mhc_post( - x: torch.Tensor, - residual: torch.Tensor, - post_layer_mix: torch.Tensor, - comb_res_mix: torch.Tensor, -) -> torch.Tensor: - mixed_residual = torch.einsum( - "...ij,...ih->...jh", - comb_res_mix.to(torch.float32), - residual.to(torch.float32), - ) - post_term = post_layer_mix.to(torch.float32) * x.unsqueeze(-2).to(torch.float32) - return (mixed_residual + post_term).to(residual.dtype) - - -def _mhc_post_fake(x, residual, post_layer_mix, comb_res_mix): - return torch.empty_like(residual) - - -def mhc_fused_post_pre( - x: torch.Tensor, - residual: torch.Tensor, - post_layer_mix: torch.Tensor, - comb_res_mix: torch.Tensor, - fn: torch.Tensor, - hc_scale: torch.Tensor, - hc_base: torch.Tensor, - rms_eps: float, - hc_pre_eps: float, - hc_sinkhorn_eps: float, - hc_post_mult_value: float, - sinkhorn_repeat: int, - n_splits: int = 1, - tile_n: int = 1, -) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: - new_residual = mhc_post(x, residual, post_layer_mix, comb_res_mix) - post_mix, res_mix, layer_input = mhc_pre( - new_residual, fn, hc_scale, hc_base, - rms_eps, hc_pre_eps, hc_sinkhorn_eps, - hc_post_mult_value, sinkhorn_repeat, n_splits, - ) - return new_residual, post_mix, res_mix, layer_input - - -def _mhc_fused_post_pre_fake( - x, residual, post_layer_mix, comb_res_mix, fn, hc_scale, hc_base, - rms_eps, hc_pre_eps, hc_sinkhorn_eps, hc_post_mult_value, - sinkhorn_repeat, n_splits=1, tile_n=1, -): - hc_mult = residual.shape[-2] - hidden_size = residual.shape[-1] - outer_shape = residual.shape[:-2] - return ( - torch.empty_like(residual), - torch.empty(*outer_shape, hc_mult, 1, dtype=torch.float32, device=residual.device), - torch.empty(*outer_shape, hc_mult, hc_mult, dtype=torch.float32, device=residual.device), - torch.empty(*outer_shape, hidden_size, dtype=torch.bfloat16, device=residual.device), - ) - - -def _hc_head_fused_kernel( - hs_flat: torch.Tensor, - fn: torch.Tensor, - hc_scale: torch.Tensor, - hc_base: torch.Tensor, - out: torch.Tensor, - hidden_size: int, - rms_eps: float, - hc_eps: float, - hc_mult: int, -) -> None: - if hs_flat.shape[0] == 0: - return - x_flat = hs_flat.reshape(hs_flat.shape[0], hc_mult * hidden_size).to(torch.float32) - mixes = torch.matmul(x_flat, fn.t()) - sqrsum = x_flat.square().sum(dim=-1, keepdim=True) - rsqrt = torch.rsqrt(sqrsum / (hc_mult * hidden_size) + rms_eps) - pre_mix = torch.sigmoid(mixes * rsqrt * hc_scale[0] + hc_base) + hc_eps - result = torch.sum(pre_mix.unsqueeze(-1) * hs_flat.to(torch.float32), dim=1).to(out.dtype) - out.copy_(result) - - -# ── Register as torch custom ops ────────────────────────────────────── - -direct_register_custom_op( - op_name="mhc_pre", - op_func=mhc_pre, - mutates_args=[], - fake_impl=_mhc_pre_fake, -) - -direct_register_custom_op( - op_name="mhc_post", - op_func=mhc_post, - mutates_args=[], - fake_impl=_mhc_post_fake, -) - -direct_register_custom_op( - op_name="mhc_fused_post_pre", - op_func=mhc_fused_post_pre, - mutates_args=[], - fake_impl=_mhc_fused_post_pre_fake, -) - -direct_register_custom_op( - op_name="hc_head_fused_kernel", - op_func=_hc_head_fused_kernel, - mutates_args=["out"], -) diff --git a/vllm/patches/patch_compressor_cache.py b/vllm/patches/patch_compressor_cache.py deleted file mode 100644 index d278ae59..00000000 --- a/vllm/patches/patch_compressor_cache.py +++ /dev/null @@ -1,48 +0,0 @@ -#!/usr/bin/env python3 -"""Patch DeepseekCompressor cache for Blackwell: remove FlashMLA alignment.""" -import sys - -def patch(path): - with open(path, 'r') as f: - content = f.read() - - if "CLAWMINE_PATCH_COMPRESSOR" in content: - print("Already patched, skipping") - return - - old = """ return SlidingWindowMLASpec( # only has one vector instead of K + V - block_size=self.block_size, - num_kv_heads=1, - head_size=self.state_dim, - dtype=self.dtype, - sliding_window=self.sliding_window, - alignment=576, # NOTE: FlashMLA requires 576B alignment - )""" - - new = """ # CLAWMINE_PATCH_COMPRESSOR: No FlashMLA alignment on Blackwell - from vllm.platforms import current_platform - _is_blackwell = ( - current_platform.get_device_capability() is not None - and current_platform.get_device_capability().major >= 10 - ) - return SlidingWindowMLASpec( # only has one vector instead of K + V - block_size=self.block_size, - num_kv_heads=1, - head_size=self.state_dim, - dtype=self.dtype, - sliding_window=self.sliding_window, - alignment=None if _is_blackwell else 576, - )""" - - if old not in content: - print("ERROR: Could not find the code to patch in " + path) - sys.exit(1) - - content = content.replace(old, new) - - with open(path, 'w') as f: - f.write(content) - print("Patched DeepseekCompressor for Blackwell") - -if __name__ == "__main__": - patch(sys.argv[1]) diff --git a/vllm/patches/patch_debug_layers.py b/vllm/patches/patch_debug_layers.py deleted file mode 100644 index 4af91637..00000000 --- a/vllm/patches/patch_debug_layers.py +++ /dev/null @@ -1,39 +0,0 @@ -#!/usr/bin/env python3 -"""Patch _allocate_kv_cache_tensors to print the layer name mismatch.""" -import sys - -def patch(path): - with open(path, 'r') as f: - content = f.read() - - if "CLAWMINE_DEBUG_LAYERS" in content: - print("Already patched, skipping") - return - - old = """ assert layer_names == set(kv_cache_raw_tensors.keys()), ( - "Some layers are not correctly initialized" - )""" - - new = """ # CLAWMINE_DEBUG_LAYERS: print mismatch instead of asserting - missing = layer_names - set(kv_cache_raw_tensors.keys()) - extra = set(kv_cache_raw_tensors.keys()) - layer_names - if missing or extra: - print(f"CLAWMINE DEBUG: missing layers ({len(missing)}): {sorted(missing)[:20]}") - print(f"CLAWMINE DEBUG: extra layers ({len(extra)}): {sorted(extra)[:20]}") - print(f"CLAWMINE DEBUG: expected ({len(layer_names)}), got ({len(kv_cache_raw_tensors.keys())})") - assert layer_names == set(kv_cache_raw_tensors.keys()), ( - "Some layers are not correctly initialized" - )""" - - if old not in content: - print("ERROR: Could not find the code to patch") - sys.exit(1) - - content = content.replace(old, new) - - with open(path, 'w') as f: - f.write(content) - print("Patched gpu_model_runner.py for debug layer names") - -if __name__ == "__main__": - patch(sys.argv[1]) diff --git a/vllm/patches/patch_indexer_cache.py b/vllm/patches/patch_indexer_cache.py deleted file mode 100644 index 4c001f28..00000000 --- a/vllm/patches/patch_indexer_cache.py +++ /dev/null @@ -1,58 +0,0 @@ -#!/usr/bin/env python3 -"""Patch DeepseekV4IndexerCache on Blackwell: remove FlashMLA alignment. - -Same as patch_swa_cache but for the indexer cache class. -""" -import sys - -def patch(path): - with open(path, 'r') as f: - content = f.read() - - if "CLAWMINE_PATCH_INDEXER_CACHE" in content: - print("Already patched, skipping") - return - - # Patch the indexer cache's get_kv_cache_spec to remove FlashMLA alignment on Blackwell - old = """ def get_kv_cache_spec(self, vllm_config: VllmConfig) -> KVCacheSpec: - # head_dim already carries the fp8 scale padding - # compress_ratio=1 for V3.2, >1 for DeepseekV4; both use the same cache layout. - return MLAAttentionSpec( - block_size=self.cache_config.block_size, - num_kv_heads=1, - head_size=self.head_dim, - dtype=self.dtype, - compress_ratio=self.compress_ratio, - # DeepseekV4 aligns indexer pages to FlashMLA's 576B so they can pack with - # the indexer's compressor state cache. V3.2 keeps the legacy layout. - alignment=576, - )""" - - new = """ def get_kv_cache_spec(self, vllm_config: VllmConfig) -> KVCacheSpec: - # CLAWMINE_PATCH_INDEXER_CACHE: No FlashMLA alignment on Blackwell - from vllm.platforms import current_platform - _is_blackwell = ( - current_platform.get_device_capability() is not None - and current_platform.get_device_capability().major >= 10 - ) - return MLAAttentionSpec( - block_size=self.cache_config.block_size, - num_kv_heads=1, - head_size=self.head_dim, - dtype=self.dtype, - compress_ratio=self.compress_ratio, - alignment=None if _is_blackwell else 576, - )""" - - if old not in content: - print("ERROR: Could not find the code to patch in " + path) - sys.exit(1) - - content = content.replace(old, new) - - with open(path, 'w') as f: - f.write(content) - print("Patched DeepseekV4IndexerCache for Blackwell") - -if __name__ == "__main__": - patch(sys.argv[1]) diff --git a/vllm/patches/patch_kv_cache_utils.py b/vllm/patches/patch_kv_cache_utils.py deleted file mode 100644 index bd7faab6..00000000 --- a/vllm/patches/patch_kv_cache_utils.py +++ /dev/null @@ -1,135 +0,0 @@ -#!/usr/bin/env python3 -"""Patch vLLM kv_cache_utils.py to handle DeepseekV4 SWA page sizes. - -DeepseekV4 has three cache types: -- C128A (HCA): compress_ratio=128, very small page size -- C4A (CSA): compress_ratio=4, medium page size -- SWA: compress_ratio=1, large page size - -The upstream code assumes SWA page sizes <= MLA page sizes and pads -SWA pages to match MLA. This breaks when SWA pages are LARGER than -MLA pages (which is always the case for DeepseekV4). - -Our fix: when SWA pages exceed MLA pages, put them in their own -separate cache group without padding. -""" -import sys - -def patch(path): - with open(path, 'r') as f: - content = f.read() - - if "CLAWMINE_PATCH_KV_CACHE" in content: - print("Already patched, skipping") - return - - # The old code: asserts SWA pages <= MLA pages, then pads SWA to MLA - old = """ assert max(sm_page_sizes) <= max(all_page_sizes) - - # Unify page size by padding layers' page_size to the nearest larger page_size. - # Compute candidate (nearest larger page_size) for each unique page size. - size_to_candidate: dict[int, int] = {} - for ps in sm_page_sizes: - size_to_candidate[ps] = min(x for x in all_page_sizes if x >= ps) - # Pad and collect layer names per page size. - for layer_name, layer_spec in sm_spec.kv_cache_specs.items(): - current_size = layer_spec.page_size_bytes - candidate = size_to_candidate[current_size] - if current_size < candidate: - object.__setattr__(layer_spec, "page_size_padded", candidate) - layers_per_size[candidate].append(layer_name) - # NOTE(yifan): for now, inside a UniformKV group, each page_size should - # have the same number of layers. This also means we don't need to pad layers - # inside a partial-full layer tuple. - assert len(set(len(layers) for layers in layers_per_size.values())) == 1 - num_layers_per_size = len(next(iter(layers_per_size.values()))) - - # Split layers inside each UniformKV group for aligned #(layers). - # See `_get_kv_cache_groups_uniform_page_size` for more details. - num_tuple_groups = cdiv(num_layers_per_size, num_layer_tuples) - layer_tuples = list(zip(*layers_per_size.values())) - for i in range(num_tuple_groups): - group_layer_tuples = layer_tuples[i::num_tuple_groups] - # Flatten tuples and build dict for from_specs - group_layer_names = [ - name for layer_tuple in group_layer_tuples for name in layer_tuple - ] - group_layer_specs = { - name: sm_spec.kv_cache_specs[name] for name in group_layer_names - } - sub_sm_spec = UniformTypeKVCacheSpecs.from_specs(group_layer_specs) - assert sub_sm_spec is not None - swa_mla_groups.append( - KVCacheGroupSpec( - layer_names=group_layer_names, - kv_cache_spec=sub_sm_spec, - ) - )""" - - # The new code: handle both cases - new = """ # CLAWMINE_PATCH_KV_CACHE: Handle DeepseekV4 where SWA page sizes - # can be larger than MLA page sizes. Two cases: - # 1. All SWA pages <= some MLA page: original padding logic - # 2. Some SWA pages > all MLA pages: separate cache group, no padding - max_mla_page = max(all_page_sizes) - can_pad = max(sm_page_sizes) <= max_mla_page - - if can_pad: - # Original logic: pad SWA pages to nearest MLA page - size_to_candidate: dict[int, int] = {} - for ps in sm_page_sizes: - size_to_candidate[ps] = min(x for x in all_page_sizes if x >= ps) - for layer_name, layer_spec in sm_spec.kv_cache_specs.items(): - current_size = layer_spec.page_size_bytes - candidate = size_to_candidate[current_size] - if current_size < candidate: - object.__setattr__(layer_spec, "page_size_padded", candidate) - layers_per_size[candidate].append(layer_name) - assert len(set(len(layers) for layers in layers_per_size.values())) == 1 - num_layers_per_size = len(next(iter(layers_per_size.values()))) - num_tuple_groups = cdiv(num_layers_per_size, num_layer_tuples) - layer_tuples = list(zip(*layers_per_size.values())) - for i in range(num_tuple_groups): - group_layer_tuples = layer_tuples[i::num_tuple_groups] - group_layer_names = [ - name for layer_tuple in group_layer_tuples for name in layer_tuple - ] - group_layer_specs = { - name: sm_spec.kv_cache_specs[name] for name in group_layer_names - } - sub_sm_spec = UniformTypeKVCacheSpecs.from_specs(group_layer_specs) - assert sub_sm_spec is not None - swa_mla_groups.append( - KVCacheGroupSpec( - layer_names=group_layer_names, - kv_cache_spec=sub_sm_spec, - ) - ) - else: - # SWA pages are larger than MLA pages. - # Put each SWA layer in its own cache group (no padding needed). - # This is the DeepseekV4 Blackwell case where compress_ratio=1 - # layers have much larger pages than compressed layers. - for layer_name, layer_spec in sm_spec.kv_cache_specs.items(): - group_layer_specs = {layer_name: layer_spec} - sub_sm_spec = UniformTypeKVCacheSpecs.from_specs(group_layer_specs) - if sub_sm_spec is not None: - swa_mla_groups.append( - KVCacheGroupSpec( - layer_names=[layer_name], - kv_cache_spec=sub_sm_spec, - ) - )""" - - if old not in content: - print("ERROR: Could not find the code to patch") - sys.exit(1) - - content = content.replace(old, new) - - with open(path, 'w') as f: - f.write(content) - print("Patched kv_cache_utils.py for DeepseekV4 SWA page sizes") - -if __name__ == "__main__": - patch(sys.argv[1]) diff --git a/vllm/patches/patch_swa_cache.py b/vllm/patches/patch_swa_cache.py deleted file mode 100644 index 2fa6ff10..00000000 --- a/vllm/patches/patch_swa_cache.py +++ /dev/null @@ -1,71 +0,0 @@ -#!/usr/bin/env python3 -"""Patch DeepseekV4SWACache on Blackwell: remove FlashMLA alignment and model_version. - -On Blackwell (SM100+), FlashMLA doesn't work. We use our own CSA/SDPA attention. -The SWA cache should use standard fp8 format (not fp8_ds_mla) and no FlashMLA alignment. -""" -import sys - -def patch(path): - with open(path, 'r') as f: - content = f.read() - - if "CLAWMINE_PATCH_SWA_CACHE" in content: - print("Already patched, skipping") - return - - # Patch the get_kv_cache_spec method to remove FlashMLA-specific values on Blackwell - old = """ def get_kv_cache_spec(self, vllm_config: VllmConfig) -> KVCacheSpec: - return SlidingWindowMLASpec( - block_size=self.block_size, - num_kv_heads=1, - head_size=self.head_dim, - dtype=self.dtype, - sliding_window=self.window_size, - cache_dtype_str=self.cache_config.cache_dtype, - alignment=576, # NOTE: FlashMLA requires 576B alignment - model_version="deepseek_v4", - )""" - - new = """ def get_kv_cache_spec(self, vllm_config: VllmConfig) -> KVCacheSpec: - # CLAWMINE_PATCH_SWA_CACHE: On Blackwell, no FlashMLA = no 576B alignment - # Use standard fp8 format (not fp8_ds_mla), no model_version override - from vllm.platforms import current_platform - _is_blackwell = ( - current_platform.get_device_capability() is not None - and current_platform.get_device_capability().major >= 10 - ) - if _is_blackwell: - return SlidingWindowMLASpec( - block_size=self.block_size, - num_kv_heads=1, - head_size=self.head_dim, - dtype=self.dtype, - sliding_window=self.window_size, - cache_dtype_str=self.cache_config.cache_dtype, - alignment=None, # No FlashMLA alignment on Blackwell - model_version=None, # Don't use 584B deepseek_v4 format - ) - return SlidingWindowMLASpec( - block_size=self.block_size, - num_kv_heads=1, - head_size=self.head_dim, - dtype=self.dtype, - sliding_window=self.window_size, - cache_dtype_str=self.cache_config.cache_dtype, - alignment=576, # NOTE: FlashMLA requires 576B alignment - model_version="deepseek_v4", - )""" - - if old not in content: - print("ERROR: Could not find the code to patch in " + path) - sys.exit(1) - - content = content.replace(old, new) - - with open(path, 'w') as f: - f.write(content) - print("Patched DeepseekV4SWACache for Blackwell") - -if __name__ == "__main__": - patch(sys.argv[1]) diff --git a/vllm/patches/register_cutedsl_kernel.py b/vllm/patches/register_cutedsl_kernel.py deleted file mode 100644 index 681fc34e..00000000 --- a/vllm/patches/register_cutedsl_kernel.py +++ /dev/null @@ -1,46 +0,0 @@ -#!/usr/bin -# Patch vLLM's linear kernel __init__.py to register the CuTeDSL NVFP4 kernel. -# This inserts our kernel at the TOP of the _POSSIBLE_NVFP4_KERNELS list, -# so it gets selected first on Blackwell GPUs. - -import sys - -def patch_init(path): - with open(path, 'r') as f: - content = f.read() - - # Add import after the existing flashinfer import block - import_line = ( - "from vllm.model_executor.kernels.linear.nvfp4.cutedsl import (\n" - " CuTeDSLNvFp4LinearKernel,\n" - ")\n" - ) - # Insert after the marlin import block - marker = "from vllm.model_executor.kernels.linear.nvfp4.marlin import (" - if "CuTeDSLNvFp4LinearKernel" in content: - print("CuTeDSL kernel already registered, skipping") - return - idx = content.find(marker) - if idx == -1: - print("ERROR: Could not find marlin import marker") - sys.exit(1) - # Find end of marlin import block - end = content.find("\n\n", idx) - content = content[:end] + "\n" + import_line + content[end:] - - # Insert CuTeDSLNvFp4LinearKernel at TOP of _POSSIBLE_NVFP4_KERNELS CUDA list - old = " PlatformEnum.CUDA: [\n FlashInferCutlassNvFp4LinearKernel," - new = " PlatformEnum.CUDA: [\n CuTeDSLNvFp4LinearKernel,\n FlashInferCutlassNvFp4LinearKernel," - content = content.replace(old, new) - - # Also add to _NVFP4_BACKEND_TO_KERNEL so VLLM_NVFP4_GEMM_BACKEND=cutedsl works - old_backend = ' "emulation": EmulationNvFp4LinearKernel,\n}' - new_backend = ' "emulation": EmulationNvFp4LinearKernel,\n "cutedsl": CuTeDSLNvFp4LinearKernel,\n}' - content = content.replace(old_backend, new_backend) - - with open(path, 'w') as f: - f.write(content) - print("Patched CuTeDSL NVFP4 kernel into", path) - -if __name__ == "__main__": - patch_init(sys.argv[1])