nuke vllm because this keep confusing people

This commit is contained in:
2026-05-19 23:04:36 +00:00
parent 02b57071be
commit 02b9c1ac20
18 changed files with 0 additions and 6677 deletions

View File

@@ -26,72 +26,3 @@ RUN cd /root/nvfp4-megamoe-kernel && pip install -e .
COPY cutedsl/ /root/nvfp4-megamoe-kernel/cutedsl/
ENV PYTHONPATH="/root/nvfp4-megamoe-kernel:${PYTHONPATH}"
# Patch vLLM — overwrite model files and register architecture
ARG VLLM_MODELS_DIR=/usr/local/lib/python3.12/dist-packages/vllm/model_executor/models
ARG VLLM_LAYERS_DIR=/usr/local/lib/python3.12/dist-packages/vllm/model_executor/layers
ARG VLLM_QUANT_DIR=/usr/local/lib/python3.12/dist-packages/vllm/model_executor/layers/quantization
ARG VLLM_FUSED_MOE_DIR=/usr/local/lib/python3.12/dist-packages/vllm/model_executor/layers/fused_moe
ARG VLLM_LOADER_DIR=/usr/local/lib/python3.12/dist-packages/vllm/model_executor/model_loader
# Core model patches
COPY vllm/patches/deepseek_v4.py ${VLLM_MODELS_DIR}/deepseek_v4.py
COPY vllm/patches/deepseek_v4_attention.py ${VLLM_LAYERS_DIR}/deepseek_v4_attention.py
COPY vllm/patches/layers/deepseek_compressor.py ${VLLM_LAYERS_DIR}/deepseek_compressor.py
# Replace MHC TileLang kernels with pure PyTorch (avoids TileLang JIT on Blackwell)
# The nightly image has all MHC in layers/mhc.py (imports tilelang at top level).
# Our replacement is pure PyTorch — no tilelang dependency at all.
COPY vllm/patches/layers/mhc.py ${VLLM_LAYERS_DIR}/mhc.py
# CSA/HCA attention kernel (replaces FlashMLA on Blackwell)
COPY vllm/patches/layers/csa_attention.py ${VLLM_LAYERS_DIR}/csa_attention.py
# CuTeDSL NVFP4 linear kernel (registered as NvFp4LinearKernel)
ARG VLLM_NVFP4_DIR=/usr/local/lib/python3.12/dist-packages/vllm/model_executor/kernels/linear/nvfp4
COPY vllm/kernels/linear/nvfp4/cutedsl.py ${VLLM_NVFP4_DIR}/cutedsl.py
# Patch KV cache utils to handle DeepseekV4 SWA page sizes > MLA page sizes
# (SWA layers have larger page sizes than compressed MLA layers on Blackwell)
ARG VLLM_CORE_DIR=/usr/local/lib/python3.12/dist-packages/vllm/v1/core
COPY vllm/patches/patch_kv_cache_utils.py /tmp/patch_kv_cache_utils.py
RUN python3 /tmp/patch_kv_cache_utils.py ${VLLM_CORE_DIR}/kv_cache_utils.py && rm /tmp/patch_kv_cache_utils.py
# Patch SWA cache and Indexer cache for Blackwell (no FlashMLA alignment)
ARG VLLM_SPARSE_SWA_DIR=/usr/local/lib/python3.12/dist-packages/vllm/v1/attention/backends/mla
ARG VLLM_LAYERS_DIR2=/usr/local/lib/python3.12/dist-packages/vllm/model_executor/layers
COPY vllm/patches/patch_swa_cache.py /tmp/patch_swa_cache.py
RUN python3 /tmp/patch_swa_cache.py ${VLLM_SPARSE_SWA_DIR}/sparse_swa.py && rm /tmp/patch_swa_cache.py
COPY vllm/patches/patch_indexer_cache.py /tmp/patch_indexer_cache.py
RUN python3 /tmp/patch_indexer_cache.py ${VLLM_LAYERS_DIR2}/deepseek_v4_attention.py && rm /tmp/patch_indexer_cache.py
COPY vllm/patches/patch_compressor_cache.py /tmp/patch_compressor_cache.py
RUN python3 /tmp/patch_compressor_cache.py ${VLLM_LAYERS_DIR2}/deepseek_compressor.py && rm /tmp/patch_compressor_cache.py
# Debug: print layer name mismatch
ARG VLLM_WORKER_DIR=/usr/local/lib/python3.12/dist-packages/vllm/v1/worker
COPY vllm/patches/patch_debug_layers.py /tmp/patch_debug_layers.py
RUN python3 /tmp/patch_debug_layers.py ${VLLM_WORKER_DIR}/gpu_model_runner.py && rm /tmp/patch_debug_layers.py
# Register CuTeDSL kernel in vLLM's linear kernel selection
ARG VLLM_LINEAR_DIR=/usr/local/lib/python3.12/dist-packages/vllm/model_executor/kernels/linear
COPY vllm/patches/register_cutedsl_kernel.py /tmp/register_cutedsl_kernel.py
RUN python3 /tmp/register_cutedsl_kernel.py ${VLLM_LINEAR_DIR}/__init__.py && rm /tmp/register_cutedsl_kernel.py
# Config patches (add cutedsl to MoEBackend)
ARG VLLM_CONFIG_DIR=/usr/local/lib/python3.12/dist-packages/vllm/config
COPY vllm/patches/kernel.py ${VLLM_CONFIG_DIR}/kernel.py
# NVFP4 MoE backend registration
COPY vllm/patches/fused_moe/oracle/nvfp4.py ${VLLM_FUSED_MOE_DIR}/oracle/nvfp4.py
COPY vllm/patches/fused_moe/experts/cutedsl_moe.py ${VLLM_FUSED_MOE_DIR}/experts/cutedsl_moe.py
# Register DeepseekV4ForCausalLM model architecture (if not already in upstream)
RUN grep -q '"DeepseekV4ForCausalLM"' ${VLLM_MODELS_DIR}/registry.py || \
sed -i 's/"DeepseekV32ForCausalLM": ("deepseek_v2", "DeepseekV3ForCausalLM"),/"DeepseekV32ForCausalLM": ("deepseek_v2", "DeepseekV3ForCausalLM"),\n "DeepseekV4ForCausalLM": ("deepseek_v4", "DeepseekV4ForCausalLM"),/' \
${VLLM_MODELS_DIR}/registry.py
# Verify
RUN python3 -c "import torch; print(f'PyTorch {torch.__version__} OK')" && \
python3 -c "import vllm; print('vLLM OK')" && \
python3 -c "import nvfp4_megamoe_kernel; print('NVFP4 kernel OK')" && \
python3 -c "import cutlass; print('CuTeDSL OK')"

View File

@@ -1,157 +0,0 @@
"""CuTeDSL NVFP4 Quantization Method for vLLM
Replaces the broken FlashInferCutlassNvFp4LinearKernel with CuTeDSL GEMM.
After process_weights_after_loading, the module's quant_method is swapped
to CuTeDSLNvfp4LinearMethod which routes forward() through CuTeDSL
via torch.library.custom_op (opaque to torch.compile).
"""
import torch
from vllm.model_executor.layers.linear import LinearMethodBase
from cutedsl.custom_ops import register_runner, nvfp4_linear_gemm
class CuTeDSLNvfp4Method(LinearMethodBase):
"""Pre-processing quant method that sets up CuTeDSL runners.
Installed on NVFP4 linear layers before process_weights_after_loading.
When vLLM calls process_weights_after_loading, this method:
1. Reads NVFP4 weights (uint8, float8 block scales, float32 global scales)
2. Converts to CuTeDSL format
3. Creates CuTeDSLNvfp4Linear runner
4. Stores runner on the module
5. Frees original weight/scale params
6. Replaces quant_method with CuTeDSLNvfp4LinearMethod
"""
def __init__(self, is_fused: bool = False):
"""
Args:
is_fused: True for MergedColumnParallelLinear with two sub-projections
(e.g., fused_wqa_wkv with q_a + kv, or gate_up_proj with gate + up).
Handles dual weight_scale_2 the same way as MoE L1:
normalize to max(gs1, gs2), fold ratio into block scales.
"""
self.is_fused = is_fused
def create_weights(self, layer, input_size_per_partition,
output_partition_sizes, input_size, output_size,
params_dtype, **extra_weight_attrs):
# We don't create weights — ModelOptNvFp4LinearMethod already did that.
# This method is only installed after weight loading.
pass
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
from cutedsl.nvfp4_linear import CuTeDSLNvfp4Linear
w_uint8 = layer.weight.data # (out, in//2) uint8 packed E2M1
device = w_uint8.device
out_features = w_uint8.shape[0]
in_features = w_uint8.shape[1] * 2 # 2 FP4 values per uint8
# Convert uint8 → float4_e2m1fn_x2, then permute to (K_packed, N)
w_fp4 = w_uint8.view(torch.float4_e2m1fn_x2).permute(1, 0).contiguous()
# Block scales: (N, K_sf) → (K_sf, N)
sf = layer.weight_scale.data
if sf.dtype != torch.float8_e4m3fn:
sf = sf.to(torch.float8_e4m3fn)
sf = sf.permute(1, 0).contiguous()
# Global scale
weight_scale_2 = layer.weight_scale_2.data
if self.is_fused and weight_scale_2.numel() == 2:
# Dual global scales (fused_wqa_wkv: q_a + kv, gate_up: gate + up)
gs1 = weight_scale_2[0].item()
gs2 = weight_scale_2[1].item()
gs = max(gs1, gs2)
# Fold ratio into block scales via float32 round-trip
if gs1 != gs2:
sf_f32 = sf.float()
# After permute to (K_sf, N): first sub-projection's output
# columns, then second sub-projection's output columns
logical_widths = getattr(layer, 'logical_widths', None)
if logical_widths is not None and len(logical_widths) == 2:
split_point = logical_widths[0]
else:
split_point = out_features // 2
sf_f32[:, :split_point] *= (gs1 / gs)
sf_f32[:, split_point:] *= (gs2 / gs)
sf = sf_f32.to(torch.float8_e4m3fn)
else:
gs = weight_scale_2.max().item()
# Create CuTeDSL runner
runner = CuTeDSLNvfp4Linear(
in_features=in_features,
out_features=out_features,
device=device,
)
runner.fp4 = [w_fp4]
runner.sf = [sf]
runner.gs = [gs]
runner.finalize_weights()
# Register runner in global registry (for torch.library.custom_op)
layer._cutedsl_runner_id = register_runner(runner)
layer._cutedsl_out_features = out_features
# Warmup: compute activation global scale from sample data.
# The checkpoint's input_scale is a calibration-time value that does NOT
# match what quantize_activation_nvfp4 expects at runtime. Using it
# produces garbage output (empty EOS tokens). The correct approach is
# a warmup forward pass that measures the actual activation distribution.
# Use only 1 token to minimize GPU memory overhead during weight loading.
with torch.no_grad():
sample = torch.randn(1, in_features,
dtype=torch.bfloat16, device=device) * 2.0
runner.compute_activation_global_scale(sample)
del sample
torch.cuda.empty_cache()
# Replace weight with dummy BF16 (needed by vLLM module introspection)
# Replace weight with a GPU dummy (some vLLM code paths like
# torch.mm(compressor.weight.T) expect weight on GPU).
layer.weight = torch.nn.Parameter(
torch.zeros(out_features, in_features, dtype=torch.bfloat16,
device=device),
requires_grad=False,
)
# Free original NVFP4 params
for attr in ("weight_scale", "weight_scale_2", "input_scale",
"input_global_scale", "input_global_scale_inv",
"weight_global_scale", "alpha", "weight_scale_inv"):
if hasattr(layer, attr):
try:
delattr(layer, attr)
except Exception:
pass
# Swap quant method to the forward-only one
layer.quant_method = CuTeDSLNvfp4LinearMethod()
def apply(self, layer, x, bias=None):
raise NotImplementedError(
"CuTeDSLNvfp4Method should be replaced by "
"CuTeDSLNvfp4LinearMethod after process_weights_after_loading"
)
class CuTeDSLNvfp4LinearMethod(LinearMethodBase):
"""Forward path: BF16 input → CuTeDSL NVFP4 GEMM → BF16 output."""
def create_weights(self, layer, input_size_per_partition,
output_partition_sizes, input_size, output_size,
params_dtype, **extra_weight_attrs):
pass
def apply(self, layer, x: torch.Tensor, bias=None) -> torch.Tensor:
result = nvfp4_linear_gemm(
x, layer._cutedsl_runner_id, layer._cutedsl_out_features,
)
if bias is not None:
result = result + bias
return result

View File

@@ -1,133 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""CuTeDSL NVFP4 Linear Kernel for vLLM.
Registers as an NvFp4LinearKernel so that vLLM kernel selection
(init_nvfp4_linear_kernel) picks it up on Blackwell GPUs.
Routes NVFP4 GEMM through CuTeDSL's MLIR-compiled grouped GEMM.
Uses torch.library.custom_op to make Dynamo (torch.compile) treat the
GEMM as opaque. The runner's _run_impl is already cudagraph-safe.
"""
import torch
from vllm.logger import init_logger
from vllm.platforms import current_platform
from .base import NvFp4LinearKernel, NvFp4LinearLayerConfig
from cutedsl.custom_ops import register_runner, nvfp4_linear_gemm
logger = init_logger(__name__)
class CuTeDSLNvFp4LinearKernel(NvFp4LinearKernel):
"""NVFP4 GEMM via the CuTeDSL framework (Blackwell SM100+)."""
@classmethod
def is_supported(
cls, compute_capability: int | None = None
) -> tuple[bool, str | None]:
cap = compute_capability or current_platform.get_device_capability()
if cap is not None and cap.major >= 10:
return True, None
return False, "CuTeDSL NVFP4 requires SM100+ (Blackwell)"
@classmethod
def can_implement(cls, config: NvFp4LinearLayerConfig) -> tuple[bool, str | None]:
return True, None
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
"""Convert NVFP4 weights into CuTeDSL kernel format."""
from cutedsl.nvfp4_linear import CuTeDSLNvfp4Linear
w_uint8 = layer.weight.data
device = w_uint8.device
out_features = w_uint8.shape[0]
in_features = w_uint8.shape[1] * 2
w_fp4 = w_uint8.view(torch.float4_e2m1fn_x2).permute(1, 0).contiguous()
sf = layer.weight_scale.data
if sf.dtype != torch.float8_e4m3fn:
sf = sf.to(torch.float8_e4m3fn)
sf = sf.permute(1, 0).contiguous()
gs = layer.weight_global_scale.data.item()
if layer.weight_global_scale.numel() == 2:
gs0 = layer.weight_global_scale[0].item()
gs1 = layer.weight_global_scale[1].item()
gs = max(gs0, gs1)
if gs0 != gs1:
sf_f32 = sf.float()
logical_widths = getattr(layer, 'logical_widths', None)
if logical_widths is not None and len(logical_widths) == 2:
split_point = logical_widths[0]
else:
split_point = out_features // 2
sf_f32[:, :split_point] *= (gs0 / gs)
sf_f32[:, split_point:] *= (gs1 / gs)
sf = sf_f32.to(torch.float8_e4m3fn)
runner = CuTeDSLNvfp4Linear(
in_features=in_features,
out_features=out_features,
device=str(device),
)
runner.fp4 = [w_fp4]
runner.sf = [sf]
runner.gs = [gs]
runner.finalize_weights()
# Warmup: compute activation global scale from sample data.
# The checkpoint's input_scale is a calibration-time value that does NOT
# match what quantize_activation_nvfp4 expects at runtime. Using it
# produces garbage output (empty EOS tokens). The correct approach is
# a warmup forward pass that measures the actual activation distribution.
# Use only 1 token to minimize GPU memory overhead during weight loading.
with torch.no_grad():
sample = torch.randn(
1, in_features,
dtype=torch.bfloat16, device=str(device),
) * 2.0
runner.compute_activation_global_scale(sample)
del sample
torch.cuda.empty_cache()
# Register the runner and store the ID (not the runner itself)
layer._cutedsl_runner_id = register_runner(runner)
layer._cutedsl_out_features = out_features
# Replace weight with a GPU dummy (some vLLM code paths like
# torch.mm(compressor.weight.T) expect weight on GPU).
layer.weight = torch.nn.Parameter(
torch.zeros(out_features, in_features, dtype=torch.bfloat16,
device=device),
requires_grad=False,
)
for attr in ("weight_scale", "weight_global_scale",
"input_global_scale", "input_global_scale_inv",
"alpha", "weights_padding_cols", "weight_scale_2",
"input_scale"):
if hasattr(layer, attr):
try:
delattr(layer, attr)
except Exception:
pass
def apply_weights(
self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: torch.Tensor | None = None,
) -> torch.Tensor:
result = nvfp4_linear_gemm(
x,
layer._cutedsl_runner_id,
layer._cutedsl_out_features,
)
if bias is not None:
result = result + bias
return result

View File

@@ -1,536 +0,0 @@
"""
vLLM integration for the CuTeDSL NVFP4 MoE kernel.
CUDA-graph-compatible design:
- All intermediate buffers pre-allocated at max_num_tokens * top_k size
- No .item(), .tolist(), .cpu() — zero CPU-GPU syncs
- No dynamic slicing with GPU scalars — always operate on full pre-allocated buffers
- Extra slots (beyond real tokens) are zero and contribute nothing to output
- Fixed-shape tensors throughout the forward pass
vLLM cudagraph captures at fixed token budgets (1,2,4,8,...,8192).
During capture, num_tokens equals the budget — all shapes are fixed.
During replay, inputs are padded to the budget size. Our runner always
processes max_slots = budget * top_k rows; padding rows are zeros.
Dynamo compatibility: uses torch.library.custom_op via cutedsl.custom_ops
so torch.compile (fullgraph mode) treats the GEMM as an opaque black box.
The runner's _run_impl is already cudagraph-safe.
"""
import torch
from cutedsl.bridge import (
quantize_activation_nvfp4,
quantize_weight_to_nvfp4,
quantize_to_nvfp4,
make_b_k_major,
assemble_scales_3d_side,
run_nvfp4_grouped_gemm,
)
from cutedsl.custom_ops import register_runner, nvfp4_moe_gemm
from cutedsl.kernel.moe.torch_scaled_grouped_mm import (
ceil_div as cutedsl_ceil_div,
pad_and_swizzle_single,
)
class CuTeDSLMoERunner:
"""Manages NVFP4 MoE execution via the CuTeDSL kernel.
CUDA-graph-compatible: all buffers pre-allocated, no CPU-GPU syncs,
no dynamic shapes. Always computes at max_num_tokens * top_k capacity.
"""
def __init__(self, num_experts, hidden_size, intermediate_size,
max_num_tokens=8192, top_k=8, device="cuda",
experts_start_idx=0):
self.num_experts = num_experts
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.max_num_tokens = max_num_tokens
self.top_k = top_k
self.device = device
self.experts_start_idx = experts_start_idx
self._swiglu_limit = None # Set via set_swiglu_limit()
# Weight storage (set before _ensure_stacked)
self.l1_fp4 = None
self.l1_sf = None
self.l1_gs = None
self.l2_fp4 = None
self.l2_sf = None
self.l2_gs = None
# Stacked weight tensors (set in _ensure_stacked)
self._l1_mat_b = None
self._l2_mat_b = None
self._l1_scale_b = None
self._l2_scale_b = None
self._l1_gsb = None
self._l2_gsb = None
# Default: 1/2688 ≈ 0.000372 (amax=1 → gs=1/2688)
# Overridden in finalize_weights with checkpoint input_scale or warmup value
self._l1_activation_global_scale = 1.0 / (6.0 * 448.0)
self._l2_activation_global_scale = 1.0 / (6.0 * 448.0)
# Pre-allocated cudagraph buffers (set in _allocate_buffers)
self._token_indices = None
self._expert_id_range = None
self._expert_offsets_buf = None
self._per_expert_scale_bufs_l1 = None
self._per_expert_scale_bufs_l2 = None
self._padded_x_sf_buf_l1 = None
self._padded_x_sf_buf_l2 = None
self._l1_gsa_buf = None
self._l2_gsa_buf = None
self._output_buf = None
self._row_indices_buf = None
self._padded_hidden_buf = None
self._padded_activated_buf = None # unused, using shared
self._padded_expert_offsets_buf = None
self._max_chunks_per_expert = cutedsl_ceil_div(
self.max_num_tokens * self.top_k, self.num_experts * 128
)
self._buffers_allocated = False
def set_swiglu_limit(self, limit: float | None):
"""Set the swiglu_limit for activation clamping."""
self._swiglu_limit = limit
def _fill_token_indices(self):
"""Fill _token_indices with [0,0,..0, 1,1,..1, ...] (each token repeated top_k times).
Builds on CPU first, then copies to GPU, to ensure correctness
regardless of CuTeDSL JIT GPU memory corruption.
"""
src = torch.arange(self.max_num_tokens, dtype=torch.int32)
cpu_indices = src.unsqueeze(1).expand(-1, self.top_k).contiguous().view(-1)
self._token_indices.copy_(cpu_indices)
def _allocate_buffers(self):
"""Pre-allocate scale buffers at max size for cudagraph compatibility."""
# Per-expert scale buffers: separate L1/L2 since K_sf differs
K_sf_l1 = cutedsl_ceil_div(self.hidden_size, 16)
padded_cols_l1 = cutedsl_ceil_div(K_sf_l1, 4) * 4
K_sf_l2 = cutedsl_ceil_div(self.intermediate_size, 16)
padded_cols_l2 = cutedsl_ceil_div(K_sf_l2, 4) * 4
self._per_expert_scale_bufs_l1 = [
torch.zeros(128, padded_cols_l1, dtype=torch.float16, device=self.device).to(torch.float8_e4m3fn)
for _ in range(self.num_experts)
]
self._per_expert_scale_bufs_l2 = [
torch.zeros(128, padded_cols_l2, dtype=torch.float16, device=self.device).to(torch.float8_e4m3fn)
for _ in range(self.num_experts)
]
# Initialize shared buffers dict (if not already)
device_key = str(self.device)
if not hasattr(CuTeDSLMoERunner, '_shared_padded_bufs'):
CuTeDSLMoERunner._shared_padded_bufs = {}
if device_key not in CuTeDSLMoERunner._shared_padded_bufs:
CuTeDSLMoERunner._shared_padded_bufs[device_key] = {}
# Padded x_sf buffers: SHARED across all runners (not per-layer)
max_sf_rows = self.num_experts * self._max_chunks_per_expert * 128
if 'xsf_l1' not in CuTeDSLMoERunner._shared_padded_bufs[device_key]:
CuTeDSLMoERunner._shared_padded_bufs[device_key].update({
'xsf_l1': torch.zeros(
max_sf_rows, padded_cols_l1, dtype=torch.float16, device=self.device
).to(torch.float8_e4m3fn),
'xsf_l2': torch.zeros(
max_sf_rows, padded_cols_l2, dtype=torch.float16, device=self.device
).to(torch.float8_e4m3fn),
'output': torch.zeros(
self.max_num_tokens, self.hidden_size, dtype=torch.bfloat16, device=self.device
),
})
self._padded_x_sf_buf_l1 = CuTeDSLMoERunner._shared_padded_bufs[device_key]['xsf_l1']
self._padded_x_sf_buf_l2 = CuTeDSLMoERunner._shared_padded_bufs[device_key]['xsf_l2']
self._output_buf = CuTeDSLMoERunner._shared_padded_bufs[device_key]['output']
# Pre-allocated global_scale_a buffers (filled via .fill_(), no torch.full during capture)
self._l1_gsa_buf = torch.zeros(self.num_experts, dtype=torch.float32, device=self.device)
self._l2_gsa_buf = torch.zeros(self.num_experts, dtype=torch.float32, device=self.device)
# Row indices for scale assembly (max_num_tokens * top_k slots)
self._row_indices_buf = torch.arange(
self.max_num_tokens * self.top_k, device=self.device
)
# Padded hidden/activated: SHARED across all runners (not per-layer)
max_rows_per_expert = self._max_chunks_per_expert * 128
padded_max_slots = self.num_experts * max_rows_per_expert
if 'hidden' not in CuTeDSLMoERunner._shared_padded_bufs[device_key]:
CuTeDSLMoERunner._shared_padded_bufs[device_key].update({
'hidden': torch.zeros(
padded_max_slots, self.hidden_size, dtype=torch.bfloat16, device=self.device
),
'hidden_fp4': torch.zeros(
padded_max_slots, self.hidden_size // 2, dtype=torch.uint8, device=self.device
).view(torch.float4_e2m1fn_x2),
'activated': torch.zeros(
padded_max_slots, self.intermediate_size, dtype=torch.bfloat16, device=self.device
),
'activated_fp4': torch.zeros(
padded_max_slots, self.intermediate_size // 2, dtype=torch.uint8, device=self.device
).view(torch.float4_e2m1fn_x2),
})
self._shared_bufs = CuTeDSLMoERunner._shared_padded_bufs[device_key]
# Padded expert offsets buffer: [0, max_rows, 2*max_rows, ...] (fixed)
self._padded_expert_offsets_buf = torch.zeros(
self.num_experts + 1, dtype=torch.int32, device=self.device
)
max_rows_per_expert = self._max_chunks_per_expert * 128
self._padded_expert_offsets_buf[1:] = torch.arange(
1, self.num_experts + 1, dtype=torch.int32, device=self.device
) * max_rows_per_expert
self._buffers_allocated = True
def _ensure_stacked(self):
if self._l1_mat_b is not None:
return
# Stack and prepare weight tensors FIRST (triggers CuTeDSL JIT compilation)
self._l1_mat_b = make_b_k_major(torch.stack(self.l1_fp4))
self._l2_mat_b = make_b_k_major(torch.stack(self.l2_fp4))
self._l1_scale_b = assemble_scales_3d_side(self.l1_sf)
self._l2_scale_b = assemble_scales_3d_side(self.l2_sf)
self._l1_gsb = torch.tensor(self.l1_gs, dtype=torch.float32, device=self.device)
self._l2_gsb = torch.tensor(self.l2_gs, dtype=torch.float32, device=self.device)
self.l1_fp4 = None
self.l1_sf = None
self.l1_gs = None
self.l2_fp4 = None
self.l2_sf = None
self.l2_gs = None
# Allocate buffers AFTER JIT compilation
# (CuTeDSL's cute.compile corrupts GPU memory during JIT;
# tensors allocated before/during compilation may be zeroed)
#
# _token_indices: GPU tensor for cudagraph compatibility.
# CuTeDSL JIT may corrupt GPU memory, so we fill AFTER stacking
# (which triggers the weight JIT). The GEMM JIT in run_nvfp4_grouped_gemm
# is triggered on the first run() call; we refill _token_indices after
# that first call via the _needs_token_refill flag.
self._token_indices = torch.zeros(
self.max_num_tokens * self.top_k, dtype=torch.int32, device=self.device
)
self._fill_token_indices()
self._needs_token_refill = True # GEMM JIT may corrupt; refill after first run
self._expert_id_range = torch.arange(
self.num_experts, dtype=torch.int32
).to(self.device)
self._expert_offsets_buf = torch.zeros(
self.num_experts + 1, dtype=torch.int32, device=self.device
)
self._allocate_buffers()
def prepare_weights_direct(self, l1_fp4, l1_sf, l1_gs, l2_fp4, l2_sf, l2_gs):
self.l1_fp4 = l1_fp4
self.l1_sf = l1_sf
self.l1_gs = l1_gs
self.l2_fp4 = l2_fp4
self.l2_sf = l2_sf
self.l2_gs = l2_gs
self._l1_mat_b = None
def prepare_weights_from_dequantized(self, l1_weights_bf16, l2_weights_bf16):
self.l1_fp4, self.l1_sf, self.l1_gs = [], [], []
self.l2_fp4, self.l2_sf, self.l2_gs = [], [], []
for l1_w, l2_w in zip(l1_weights_bf16, l2_weights_bf16):
l1_w_t = l1_w.T
w_fp4, w_sf, w_gs = quantize_weight_to_nvfp4(l1_w_t)
self.l1_fp4.append(w_fp4)
self.l1_sf.append(w_sf)
self.l1_gs.append(w_gs)
l2_w_t = l2_w.T
w_fp4, w_sf, w_gs = quantize_weight_to_nvfp4(l2_w_t)
self.l2_fp4.append(w_fp4)
self.l2_sf.append(w_sf)
self.l2_gs.append(w_gs)
self._l1_mat_b = None
def _assemble_scales_cudagraph_safe(self, x_sf, expert_offsets,
padded_expert_offsets,
padded_x_sf_buf, per_expert_bufs):
"""Assemble 2D-side activation scales (cudagraph-safe, NO CPU syncs).
Phase 1: Scatter x_sf into padded per-expert sections (GPU-only).
Phase 2: Apply full-buffer Blackwell 32_4_4 swizzle (no Python loops).
The buffer is 128-row aligned per expert (from padded_expert_offsets),
so the full-buffer swizzle produces the correct layout. The GEMM reads
scale_a using padded_expert_offsets, matching the scatter layout.
"""
K_sf = x_sf.shape[1]
padded_x_sf = padded_x_sf_buf
padded_x_sf.zero_()
# Phase 1: Scatter x_sf into padded per-expert sections (GPU-only)
total_rows = x_sf.shape[0]
row_indices = self._row_indices_buf[:total_rows]
expert_assign = torch.searchsorted(
expert_offsets[1:], row_indices, right=True
).clamp(max=self.num_experts - 1)
local_row = row_indices - expert_offsets[expert_assign]
dst_rows = padded_expert_offsets[expert_assign] + local_row
padded_x_sf[dst_rows, :K_sf] = x_sf
# Phase 2: Full-buffer swizzle (no CPU sync, no Python loops)
# padded_x_sf is 128-row aligned per expert and 4-col aligned.
# to_blocked: (rows, cols) → view(R, 128, C, 4) → permute(0,2,1,3)
# → reshape(-1, 4, 32, 4) → transpose(1,2) → reshape(-1, 32, 16) → flatten
rows = padded_x_sf.shape[0]
cols = padded_x_sf.shape[1]
R = rows // 128
C = cols // 4
blocks = padded_x_sf.view(R, 128, C, 4).permute(0, 2, 1, 3)
rearranged = blocks.reshape(-1, 4, 32, 4).transpose(1, 2).reshape(-1, 32, 16)
swizzled = rearranged.flatten().view(torch.float8_e4m3fn)
return swizzled.reshape(rows, cols)
def compute_activation_global_scales(self, hidden_states_sample, topk_weights, topk_ids):
"""Compute activation global scales from a warmup forward pass.
Called BEFORE cudagraph capture. Uses the SAME padded GEMM path as run()
to ensure kernel JIT happens with the same layout, and L2 gs is computed
from actual L1 output (not an approximation).
"""
self._ensure_stacked()
device = hidden_states_sample.device
num_tokens = hidden_states_sample.shape[0]
top_k = topk_ids.shape[1]
with torch.no_grad():
# Build slot mapping (same as run())
flat_ids = topk_ids.reshape(-1)
num_slots = num_tokens * top_k
token_indices = self._token_indices[:num_slots]
sort_idx = flat_ids.argsort(stable=True)
sorted_ids = flat_ids[sort_idx]
sorted_token_ids = token_indices[sort_idx]
slot_hidden = hidden_states_sample[sorted_token_ids]
# L1: get exact gs from quantize_to_nvfp4
_, _, l1_gs = quantize_to_nvfp4(slot_hidden)
# Quantize slot_hidden for GEMM
slot_x_fp4, slot_x_sf = quantize_activation_nvfp4(slot_hidden, l1_gs)
expert_id_range = self._expert_id_range
tokens_per_expert = (sorted_ids.unsqueeze(1) == expert_id_range.unsqueeze(0)).sum(dim=0).int()
expert_offsets = self._expert_offsets_buf
expert_offsets.zero_()
expert_offsets[1:self.num_experts + 1] = tokens_per_expert.cumsum(0)
padded_tokens_per_expert = ((tokens_per_expert + 127) // 128) * 128
padded_expert_offsets = self._padded_expert_offsets_buf
padded_expert_offsets.zero_()
padded_expert_offsets[1:self.num_experts + 1] = padded_tokens_per_expert.cumsum(0)
# Compute padded_dst (same as run())
row_indices = self._row_indices_buf[:num_slots]
expert_assign = torch.searchsorted(
expert_offsets[1:], row_indices, right=True
).clamp(max=self.num_experts - 1)
local_row = row_indices - expert_offsets[expert_assign]
padded_dst = padded_expert_offsets[expert_assign] + local_row
# Scatter x_fp4 into padded layout
padded_x_fp4 = self._shared_bufs['hidden_fp4']
padded_x_fp4.view(torch.uint8).zero_()
padded_x_fp4.view(torch.uint8)[padded_dst] = slot_x_fp4.view(torch.uint8)
l1_scale_a = self._assemble_scales_cudagraph_safe(
slot_x_sf, expert_offsets[:self.num_experts + 1],
padded_expert_offsets,
self._padded_x_sf_buf_l1, self._per_expert_scale_bufs_l1
)
l1_gsa = torch.full((self.num_experts,), l1_gs, dtype=torch.float32, device=device)
l1_out = run_nvfp4_grouped_gemm(
mat_a=padded_x_fp4, mat_b=self._l1_mat_b,
scale_a=l1_scale_a, scale_b=self._l1_scale_b,
expert_offsets=padded_expert_offsets[1:self.num_experts + 1],
global_scale_a=l1_gsa, global_scale_b=self._l1_gsb,
)
# Extract real token outputs
l1_out_real = l1_out[padded_dst]
# L2: get exact gs from SiLU(gate)*up
gate = l1_out_real[:, :self.intermediate_size]
up = l1_out_real[:, self.intermediate_size:]
gate_silu = torch.nn.functional.silu(gate)
if self._swiglu_limit is not None:
gate_silu = gate_silu.clamp(max=self._swiglu_limit)
up = up.clamp(min=-self._swiglu_limit, max=self._swiglu_limit)
activated = gate_silu * up
_, _, l2_gs = quantize_to_nvfp4(activated)
self._l1_activation_global_scale = l1_gs
self._l2_activation_global_scale = l2_gs
def run(self, hidden_states, topk_weights, topk_ids, expert_indices=None):
"""Forward: route tokens to experts, GEMM, combine.
Uses torch.library.custom_op (nvfp4::moe_gemm) so torch.compile
treats this as an opaque op. The custom op calls _run_impl internally.
"""
if not hasattr(self, '_runner_id'):
self._runner_id = register_runner(self)
return nvfp4_moe_gemm(
hidden_states, topk_weights, topk_ids,
self._runner_id, self.hidden_size,
)
def _run_impl(self, hidden_states, topk_weights, topk_ids, expert_indices=None):
"""Run the NVFP4 MoE forward pass.
Handles global→local expert ID remapping for expert parallelism.
Fully cudagraph-safe: no CPU-GPU syncs, no dynamic shapes.
Each expert's slots are padded to multiples of 128 for the GEMM.
expert_offsets is [0, padded_e0, padded_e0+padded_e1, ...].
scale_a is produced at those same offsets.
"""
num_tokens = hidden_states.shape[0]
top_k = topk_ids.shape[1]
device = hidden_states.device
self._ensure_stacked()
# -- Remap global expert IDs to local IDs --
local_ids = topk_ids - self.experts_start_idx
local_mask = (local_ids >= 0) & (local_ids < self.num_experts)
safe_ids = local_ids.clamp(0, self.num_experts - 1)
safe_weights = topk_weights * local_mask.float()
# -- Build slot mapping --
flat_ids = safe_ids.reshape(-1)
flat_weights = safe_weights.reshape(-1)
num_slots = num_tokens * top_k
token_indices = self._token_indices[:num_slots]
sort_idx = flat_ids.argsort(stable=True)
sorted_ids = flat_ids[sort_idx]
sorted_weights = flat_weights[sort_idx]
sorted_token_ids = token_indices[sort_idx]
# Expert offsets (real token counts)
expert_id_range = self._expert_id_range
tokens_per_expert = (sorted_ids.unsqueeze(1) == expert_id_range.unsqueeze(0)).sum(dim=0).int()
expert_offsets = self._expert_offsets_buf
expert_offsets.zero_()
expert_offsets[1:self.num_experts + 1] = tokens_per_expert.cumsum(0)
# Pad each expert to 128-row alignment (GPU-only computation)
padded_tokens_per_expert = ((tokens_per_expert + 127) // 128) * 128
padded_expert_offsets = self._padded_expert_offsets_buf
padded_expert_offsets.zero_()
padded_expert_offsets[1:self.num_experts + 1] = padded_tokens_per_expert.cumsum(0)
total_padded_slots = padded_expert_offsets[self.num_experts]
# -- Gather hidden states into slot order, compute padded_dst --
slot_hidden = hidden_states[sorted_token_ids]
row_indices = self._row_indices_buf[:num_slots]
expert_assign = torch.searchsorted(
expert_offsets[1:], row_indices, right=True
).clamp(max=self.num_experts - 1)
local_row = row_indices - expert_offsets[expert_assign]
padded_dst = padded_expert_offsets[expert_assign] + local_row
# === L1: gate + up ===
# Quantize slot_hidden (sorted tokens), NOT padded_hidden.
# padded_hidden is padded with zeros; quantizing it produces
# x_sf rows at padded positions, but x_sf[:num_slots] would
# only get scales for the first num_slots PADDED rows (expert 0),
# not the scattered token positions. Quantizing slot_hidden
# gives x_sf with num_slots rows (one per token), which the
# scale assembly correctly scatters into padded layout.
slot_x_fp4, slot_x_sf = quantize_activation_nvfp4(
slot_hidden, self._l1_activation_global_scale
)
# Scatter x_fp4 into padded layout for the GEMM
# Must scatter as uint8 (float4_e2m1fn_x2 doesn't support index_put)
padded_x_fp4 = self._shared_bufs['hidden_fp4']
padded_x_fp4.view(torch.uint8).zero_()
padded_x_fp4.view(torch.uint8)[padded_dst] = slot_x_fp4.view(torch.uint8)
l1_scale_a = self._assemble_scales_cudagraph_safe(
slot_x_sf, expert_offsets[:self.num_experts + 1],
padded_expert_offsets,
self._padded_x_sf_buf_l1, self._per_expert_scale_bufs_l1
)
l1_gsa = self._l1_gsa_buf.fill_(self._l1_activation_global_scale)
l1_out = run_nvfp4_grouped_gemm(
mat_a=padded_x_fp4, mat_b=self._l1_mat_b,
scale_a=l1_scale_a, scale_b=self._l1_scale_b,
expert_offsets=padded_expert_offsets[1:self.num_experts + 1],
global_scale_a=l1_gsa, global_scale_b=self._l1_gsb,
)
# Extract real token outputs from padded GEMM output
l1_out_real = l1_out[padded_dst]
# === SiLU(gate) * up (with swiglu_limit clamp) ===
gate = l1_out_real[:, :self.intermediate_size]
up = l1_out_real[:, self.intermediate_size:]
gate_silu = torch.nn.functional.silu(gate)
# Apply DeepSeek-V4 swiglu_limit: clamp both silu(gate) and up
if self._swiglu_limit is not None:
gate_silu = gate_silu.clamp(max=self._swiglu_limit)
up = up.clamp(min=-self._swiglu_limit, max=self._swiglu_limit)
activated = gate_silu * up
# === L2: down ===
# Quantize activated (per-token), scatter into padded FP4 buffer
slot_l2_x_fp4, slot_l2_x_sf = quantize_activation_nvfp4(
activated, self._l2_activation_global_scale
)
padded_activated_fp4 = self._shared_bufs['activated_fp4']
padded_activated_fp4.view(torch.uint8).zero_()
padded_activated_fp4.view(torch.uint8)[padded_dst] = slot_l2_x_fp4.view(torch.uint8)
l2_scale_a = self._assemble_scales_cudagraph_safe(
slot_l2_x_sf, expert_offsets[:self.num_experts + 1],
padded_expert_offsets,
self._padded_x_sf_buf_l2, self._per_expert_scale_bufs_l2
)
l2_gsa = self._l2_gsa_buf.fill_(self._l2_activation_global_scale)
l2_out = run_nvfp4_grouped_gemm(
mat_a=padded_activated_fp4, mat_b=self._l2_mat_b,
scale_a=l2_scale_a, scale_b=self._l2_scale_b,
expert_offsets=padded_expert_offsets[1:self.num_experts + 1],
global_scale_a=l2_gsa, global_scale_b=self._l2_gsb,
)
l2_out_real = l2_out[padded_dst]
# === Scatter -> final output ===
y = self._output_buf[:num_tokens]
y.zero_()
weighted_out = l2_out_real * sorted_weights.unsqueeze(1).to(l2_out_real.dtype)
y.scatter_add_(
0,
sorted_token_ids.unsqueeze(1).expand(-1, self.hidden_size),
weighted_out,
)
# Refill _token_indices after GEMM JIT on first call
# (CuTeDSL's cute.compile may corrupt GPU memory during first GEMM)
if self._needs_token_refill:
self._fill_token_indices()
self._needs_token_refill = False
return y

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -1,308 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""CuTeDSL NVFP4 MoE experts for DeepSeek-V4.
Integrates the CuTeDSL NVFP4 grouped GEMM kernel into vLLM's FusedMoE
modular kernel framework. This is the proper integration path — no
monkey-patching, no post-load surgery.
The CuTeDSL kernel is a Python-based CUTLASS kernel compiled via MLIR → PTX.
It handles:
- L1 GEMM (gate + up projections)
- SiLU activation with optional swiglu_limit clamping
- L2 GEMM (down projection)
- All with NVFP4 (float8_e4m3fn block scales + float32 global scales)
CUDA-graph-safe: all intermediate buffers pre-allocated, no CPU-GPU syncs,
no dynamic shapes.
"""
import torch
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm.model_executor.layers.fused_moe.activation import MoEActivation
from vllm.model_executor.layers.fused_moe.config import (
FusedMoEConfig,
FusedMoEParallelConfig,
FusedMoEQuantConfig,
)
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
TopKWeightAndReduceNoOP,
)
from vllm.model_executor.layers.quantization.utils.quant_utils import (
QuantKey,
kNvfp4Dynamic,
kNvfp4Static,
)
from vllm.platforms import current_platform
from cutedsl.runner import CuTeDSLMoERunner
class CuTeDSLMoEExperts(mk.FusedMoEExpertsModular):
"""CuTeDSL NVFP4 MoE experts using the custom CuTeDSL grouped GEMM kernel.
Uses Standard activation format (non-batched). Handles input quantization
internally (expects_unquantized_inputs=True).
Supports expert parallelism: remaps global→local expert IDs.
"""
def __init__(
self,
moe_config: FusedMoEConfig,
quant_config: FusedMoEQuantConfig,
):
super().__init__(
moe_config=moe_config,
quant_config=quant_config,
)
assert quant_config.quant_dtype == "nvfp4", (
"CuTeDSL MoE only supports nvfp4 quantization, "
f"got {quant_config.quant_dtype}"
)
self.out_dtype = moe_config.in_dtype
self.hidden_dim = moe_config.hidden_dim
self.intermediate_size_per_partition = (
moe_config.intermediate_size_per_partition
)
self.topk = moe_config.experts_per_token
self.local_num_experts = moe_config.num_local_experts
self.global_num_experts = moe_config.num_experts
self.ep_rank = moe_config.moe_parallel_config.ep_rank
self.local_expert_offset = self.ep_rank * self.local_num_experts
# max_num_tokens from scheduler config (for buffer pre-allocation)
self.max_num_tokens = getattr(moe_config, 'max_num_tokens', 8192)
# swiglu_limit: read from the FusedMoE layer in process_weights_after_loading
self._swiglu_limit = None
# Runner is created in process_weights_after_loading
self._runner: CuTeDSLMoERunner | None = None
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
"""Convert NVFP4 MoE weights into CuTeDSL kernel format.
Reads w13/w2 weight tensors from the FusedMoE layer, converts them
to the CuTeDSL runner's expected format, and creates the runner.
Also folds the activation global scale (input_scale) into the
weight global scale (weight_scale_2) as the runner's alpha.
"""
num_experts = layer.w13_weight.shape[0]
hidden_size = self.hidden_dim
intermediate_size = self.intermediate_size_per_partition
device = layer.w13_weight.device
# NOTE: For the CuTeDSL kernel, we do NOT fold input_scale into
# weight_scale_2. The CuTeDSL runner uses weight global scale
# (weight_scale_2) and activation global scale separately.
# The activation global scale is computed via warmup before first inference.
#
# Also, convert_to_nvfp4_moe_kernel_format already inverted input_scale
# (1.0 / a13_scale) for the quant config. We undo that inversion here
# to get the original input_scale, then use it as initial activation gs.
if layer.w13_input_scale is not None and not isinstance(layer.w13_input_scale, float):
# input_scale was inverted in convert_to_nvfp4_moe_kernel_format
# Original: input_scale = amax / (6.0 * 448.0)
# Inverted: 1.0 / input_scale = 6.0 * 448.0 / amax
# We need the original for activation gs
w13_input_scale_orig = 1.0 / layer.w13_input_scale
else:
w13_input_scale_orig = None
if layer.w2_input_scale is not None and not isinstance(layer.w2_input_scale, float):
w2_input_scale_orig = 1.0 / layer.w2_input_scale
else:
w2_input_scale_orig = None
# Extract weights from the layer — checkpoint format, no copies yet.
# w13_weight: (E, 2*intermediate, hidden//2) uint8 — gate + up fused
# w2_weight: (E, hidden, intermediate//2) uint8 — down
# w13_weight_scale: (E, 2*intermediate, hidden//16) fp8
# w2_weight_scale: (E, hidden, intermediate//16) fp8
w13_uint8 = layer.w13_weight.data # (E, 2*inter, hidden//2)
w2_uint8 = layer.w2_weight.data # (E, hidden, intermediate//2)
w13_sf = layer.w13_weight_scale.data # (E, 2*inter, hidden//16) = (E, N, K_sf)
w2_sf = layer.w2_weight_scale.data # (E, hidden, intermediate//16) = (E, N, K_sf)
w13_gs = layer.w13_weight_scale_2.data # (E,) or (E, 2)
w2_gs = layer.w2_weight_scale_2.data # (E,) or (E, 2)
# View as fp4 — byte-preserving, NO copy
l1_fp4 = w13_uint8.view(torch.float4_e2m1fn_x2) # (E, N, K_packed)
l2_fp4 = w2_uint8.view(torch.float4_e2m1fn_x2) # (E, N, K_packed)
# Ensure scales are float8_e4m3fn (no copy if already correct dtype)
if w13_sf.dtype != torch.float8_e4m3fn:
w13_sf = w13_sf.to(torch.float8_e4m3fn)
if w2_sf.dtype != torch.float8_e4m3fn:
w2_sf = w2_sf.to(torch.float8_e4m3fn)
# Global scales
l1_gs_list = w13_gs.tolist()
l2_gs_list = w2_gs.tolist()
# Free original weight tensors IMMEDIATELY.
# We have views into the same memory (l1_fp4, l2_fp4), but the runner
# will create its own copies in _ensure_stacked. Free the layer refs
# now so the memory can be reclaimed when the views are no longer held.
# NOTE: The modular kernel framework reads w1.shape[0] in its outer
# apply() before delegating to our expert impl, so we can't set the
# weights to None. Instead, replace with a shape-preserving dummy on CPU
# to free GPU memory while keeping the shape metadata accessible.
# Free the large weight tensors — they're now in the runner.
# Keep the scale tensors (small) because the framework's warmup
# and quant config construction needs them.
layer.w13_weight = torch.nn.Parameter(torch.empty(
num_experts, 2 * intermediate_size, hidden_size // 2,
device='cpu', dtype=torch.uint8), requires_grad=False)
layer.w2_weight = torch.nn.Parameter(torch.empty(
num_experts, hidden_size, intermediate_size // 2,
device='cpu', dtype=torch.uint8), requires_grad=False)
# Create the CuTeDSL runner
self._runner = CuTeDSLMoERunner(
num_experts=num_experts,
hidden_size=hidden_size,
intermediate_size=intermediate_size,
max_num_tokens=self.max_num_tokens,
top_k=self.topk,
device=str(device),
experts_start_idx=self.local_expert_offset,
)
# Pass stacked tensors in checkpoint format (E, N, K) — no copies needed
self._runner.prepare_weights_from_stacked(
l1_fp4, w13_sf, l1_gs_list,
l2_fp4, w2_sf, l2_gs_list,
)
if self._swiglu_limit is not None:
self._runner.set_swiglu_limit(float(self._swiglu_limit))
# Read swiglu_limit from the FusedMoE layer (set by DeepseekV4MoE)
swiglu_limit = getattr(layer, 'swiglu_limit', None)
if swiglu_limit is not None:
self._swiglu_limit = swiglu_limit
self._runner.set_swiglu_limit(float(swiglu_limit))
# Set initial activation global scales from checkpoint input_scale.
# After undoing the inversion from convert_to_nvfp4_moe_kernel_format,
# w13_input_scale_orig = amax / (6.0 * 448.0), which IS the activation
# global scale that quantize_activation_nvfp4 expects.
# The warmup step (compute_activation_global_scales) will override
# this with an empirically computed value before the first inference.
if w13_input_scale_orig is not None:
# w13_input_scale_orig = amax / (6.0 * 448.0) = activation gs
# Mean across experts (they should be similar)
mean_l1_gs = float(w13_input_scale_orig.mean().item())
if mean_l1_gs > 0:
self._runner._l1_activation_global_scale = mean_l1_gs
if w2_input_scale_orig is not None:
mean_l2_gs = float(w2_input_scale_orig.mean().item())
if mean_l2_gs > 0:
self._runner._l2_activation_global_scale = mean_l2_gs
# Note: activation global scale warmup must be done after
# process_weights_after_loading, before the first inference.
# This is handled by the model's load_weights or a separate warmup step.
@property
def runner(self) -> CuTeDSLMoERunner | None:
return self._runner
@staticmethod
def activation_format() -> mk.FusedMoEActivationFormat:
return mk.FusedMoEActivationFormat.Standard
@staticmethod
def _supports_current_device() -> bool:
# CuTeDSL requires CUDA SM100 (Blackwell)
p = current_platform
return p.is_cuda() and p.is_device_capability_family(100)
@staticmethod
def _supports_no_act_and_mul() -> bool:
return False
@staticmethod
def _supports_quant_scheme(
weight_key: QuantKey | None,
activation_key: QuantKey | None,
) -> bool:
SUPPORTED_W_A = [
(kNvfp4Static, kNvfp4Dynamic),
]
return (weight_key, activation_key) in SUPPORTED_W_A
@staticmethod
def _supports_activation(activation: MoEActivation) -> bool:
# We handle SiLU + swiglu_limit internally
return activation == MoEActivation.SILU
@staticmethod
def _supports_parallel_config(
moe_parallel_config: FusedMoEParallelConfig,
) -> bool:
return True
def supports_expert_map(self) -> bool:
return False
@property
def expects_unquantized_inputs(self) -> bool:
# Our runner handles activation quantization internally
return True
def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce:
return TopKWeightAndReduceNoOP()
def workspace_shapes(
self,
M: int,
N: int,
K: int,
topk: int,
global_num_experts: int,
local_num_experts: int,
expert_tokens_meta: mk.ExpertTokensMetadata | None,
activation: MoEActivation,
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
# Our runner manages its own workspace internally (pre-allocated buffers)
workspace1 = (0,)
workspace2 = (0,)
# The output of the full 2-stage MoE pipeline is hidden_dim.
# K comes from hidden_states.size(-1) (full BF16 dimension, not packed).
output = (M, self.hidden_dim)
return (workspace1, workspace2, output)
def apply(
self,
output: torch.Tensor,
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
activation: MoEActivation,
global_num_experts: int,
expert_map: torch.Tensor | None,
a1q_scale: torch.Tensor | None,
a2_scale: torch.Tensor | None,
workspace13: torch.Tensor | None,
workspace2: torch.Tensor | None,
expert_tokens_meta: mk.ExpertTokensMetadata | None,
apply_router_weight_on_input: bool | None,
):
assert self._runner is not None, (
"CuTeDSL runner not initialized. "
"Call process_weights_after_loading first."
)
# Our runner expects topk_ids as global expert IDs.
# The modular kernel framework may pass local IDs with expert_map.
# We handle remapping internally via experts_start_idx.
result = self._runner.run(
hidden_states=hidden_states,
topk_weights=topk_weights,
topk_ids=topk_ids,
)
# Copy result into output tensor
output.copy_(result)

View File

@@ -1,535 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from enum import Enum
import torch
import vllm.envs as envs
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm.config.kernel import MoEBackend
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe.all2all_utils import (
maybe_make_prepare_finalize,
)
from vllm.model_executor.layers.fused_moe.config import (
FusedMoEConfig,
FusedMoEQuantConfig,
nvfp4_moe_quant_config,
nvfp4_w4a16_moe_quant_config,
)
from vllm.model_executor.layers.quantization.utils.flashinfer_fp4_moe import (
prepare_nvfp4_moe_layer_for_fi_or_cutlass,
prepare_nvfp4_moe_layer_for_flashinfer_cutedsl,
)
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
FlashinferMoeBackend,
get_flashinfer_moe_backend,
)
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import (
prepare_nvfp4_moe_layer_for_marlin,
)
from vllm.model_executor.layers.quantization.utils.nvfp4_emulation_utils import (
kE2M1ToFloat_handle,
)
from vllm.model_executor.layers.quantization.utils.quant_utils import (
QuantKey,
)
logger = init_logger(__name__)
class NvFp4MoeBackend(Enum):
FLASHINFER_TRTLLM = "FLASHINFER_TRTLLM"
FLASHINFER_CUTLASS = "FLASHINFER_CUTLASS"
FLASHINFER_CUTEDSL = "FLASHINFER_CUTEDSL"
FLASHINFER_CUTEDSL_BATCHED = "FLASHINFER_CUTEDSL_BATCHED"
VLLM_CUTLASS = "VLLM_CUTLASS"
MARLIN = "MARLIN"
CUTEDSL = "CUTEDSL"
EMULATION = "EMULATION"
FLASHINFER_NVFP4_MOE_BACKENDS = [
NvFp4MoeBackend.FLASHINFER_TRTLLM,
NvFp4MoeBackend.FLASHINFER_CUTLASS,
NvFp4MoeBackend.FLASHINFER_CUTEDSL,
NvFp4MoeBackend.FLASHINFER_CUTEDSL_BATCHED,
]
CUTEDSL_NVFP4_MOE_BACKENDS = [
NvFp4MoeBackend.CUTEDSL,
]
fi_2_vllm_backend_map: dict[FlashinferMoeBackend, NvFp4MoeBackend] = {
FlashinferMoeBackend.CUTLASS: NvFp4MoeBackend.FLASHINFER_CUTLASS,
FlashinferMoeBackend.TENSORRT_LLM: NvFp4MoeBackend.FLASHINFER_TRTLLM,
FlashinferMoeBackend.CUTEDSL: NvFp4MoeBackend.FLASHINFER_CUTEDSL,
}
def is_global_sf_supported_for_nvfp4_backend(backend: NvFp4MoeBackend) -> bool:
# Checks whether `backend` supports quantizing with scaling factors
# of all experts in Expert Parallel Mode when all experts are not
# on the same rank.
return backend in FLASHINFER_NVFP4_MOE_BACKENDS or backend in CUTEDSL_NVFP4_MOE_BACKENDS
def backend_to_kernel_cls(
backend: NvFp4MoeBackend,
) -> list[type[mk.FusedMoEExperts]]:
if backend == NvFp4MoeBackend.FLASHINFER_TRTLLM:
from vllm.model_executor.layers.fused_moe.experts.trtllm_nvfp4_moe import (
TrtLlmNvFp4ExpertsModular,
TrtLlmNvFp4ExpertsMonolithic,
)
# NOTE: prefer Monolthic > Modular, so return Monolithic first.
return [
TrtLlmNvFp4ExpertsMonolithic,
TrtLlmNvFp4ExpertsModular,
]
elif backend == NvFp4MoeBackend.FLASHINFER_CUTLASS:
from vllm.model_executor.layers.fused_moe.experts.flashinfer_cutlass_moe import ( # noqa: E501
FlashInferExperts,
)
return [FlashInferExperts]
elif backend == NvFp4MoeBackend.FLASHINFER_CUTEDSL:
from vllm.model_executor.layers.fused_moe.experts.flashinfer_cutedsl_moe import ( # noqa: E501
FlashInferCuteDSLExperts,
)
return [FlashInferCuteDSLExperts]
elif backend == NvFp4MoeBackend.FLASHINFER_CUTEDSL_BATCHED:
from vllm.model_executor.layers.fused_moe.experts.flashinfer_cutedsl_batched_moe import ( # noqa: E501
FlashInferCuteDSLBatchedExperts,
)
return [FlashInferCuteDSLBatchedExperts]
elif backend == NvFp4MoeBackend.CUTEDSL:
from vllm.model_executor.layers.fused_moe.experts.cutedsl_moe import ( # noqa: E501
CuTeDSLMoEExperts,
)
return [CuTeDSLMoEExperts]
elif backend == NvFp4MoeBackend.VLLM_CUTLASS:
from vllm.model_executor.layers.fused_moe.experts.cutlass_moe import (
CutlassExpertsFp4,
)
return [CutlassExpertsFp4]
elif backend == NvFp4MoeBackend.MARLIN:
from vllm.model_executor.layers.fused_moe.experts.marlin_moe import (
MarlinExperts,
)
return [MarlinExperts]
elif backend == NvFp4MoeBackend.EMULATION:
from vllm.model_executor.layers.fused_moe.experts.nvfp4_emulation_moe import (
Nvfp4QuantizationEmulationTritonExperts,
)
return [Nvfp4QuantizationEmulationTritonExperts]
else:
raise ValueError(f"Unknown NvFP4 MoE backend: {backend.value}")
def map_nvfp4_backend(runner_backend: MoEBackend) -> NvFp4MoeBackend:
"""Map user's MoEBackend to NvFp4MoeBackend."""
mapping = {
"cutlass": NvFp4MoeBackend.VLLM_CUTLASS,
"flashinfer_trtllm": NvFp4MoeBackend.FLASHINFER_TRTLLM,
"flashinfer_cutlass": NvFp4MoeBackend.FLASHINFER_CUTLASS,
"flashinfer_cutedsl": NvFp4MoeBackend.FLASHINFER_CUTEDSL,
"cutedsl": NvFp4MoeBackend.CUTEDSL,
"marlin": NvFp4MoeBackend.MARLIN,
"emulation": NvFp4MoeBackend.EMULATION,
}
if backend := mapping.get(runner_backend):
return backend
raise ValueError(
f"moe_backend='{runner_backend}' is not supported for NvFP4 MoE. "
f"Expected one of {list(mapping.keys())}."
)
def select_nvfp4_moe_backend(
config: FusedMoEConfig,
weight_key: QuantKey | None,
activation_key: QuantKey | None,
) -> tuple[NvFp4MoeBackend, type[mk.FusedMoEExperts]]:
"""
Select the primary NvFP4 MoE backend
Note: Shape-specific fallbacks may still occur at runtime.
"""
# NOTE: the kernels are selected in the following order.
AVAILABLE_BACKENDS = [
NvFp4MoeBackend.FLASHINFER_TRTLLM,
NvFp4MoeBackend.FLASHINFER_CUTEDSL,
NvFp4MoeBackend.FLASHINFER_CUTEDSL_BATCHED,
NvFp4MoeBackend.CUTEDSL,
NvFp4MoeBackend.FLASHINFER_CUTLASS,
NvFp4MoeBackend.VLLM_CUTLASS,
NvFp4MoeBackend.MARLIN,
NvFp4MoeBackend.EMULATION,
]
use_batched = config.moe_parallel_config.use_batched_activation_format
activation_format = (
mk.FusedMoEActivationFormat.BatchedExperts
if use_batched
else mk.FusedMoEActivationFormat.Standard
)
def _make_log_backend(backend: NvFp4MoeBackend):
available_backend_strs = [b.value for b in AVAILABLE_BACKENDS]
return (
f"Using '{backend.value}' NvFp4 MoE backend out "
f"of potential backends: {available_backend_strs}."
)
def _make_log_unsupported(backend: NvFp4MoeBackend, reason: str | None) -> str:
if reason:
return (
f"NvFp4 MoE backend '{backend.value}' does not support the "
f"deployment configuration since {reason}."
)
else:
return (
f"NvFp4 MoE backend '{backend.value}' does not support the "
"deployment configuration."
)
def _return_or_raise(
backend: NvFp4MoeBackend,
config: FusedMoEConfig,
weight_key: QuantKey | None,
activation_key: QuantKey | None,
activation_format: mk.FusedMoEActivationFormat,
) -> tuple[NvFp4MoeBackend, type[mk.FusedMoEExperts]]:
for k_cls in backend_to_kernel_cls(backend):
supported, reason = k_cls.is_supported_config(
k_cls, config, weight_key, activation_key, activation_format
)
if supported:
logger.info_once(_make_log_backend(backend))
return backend, k_cls
raise ValueError(_make_log_unsupported(backend, reason))
# Handle explicit moe_backend from user.
runner_backend = config.moe_backend
if runner_backend != "auto":
requested_backend = map_nvfp4_backend(runner_backend)
# For batched activation format, use batched variant if available.
if (
activation_format == mk.FusedMoEActivationFormat.BatchedExperts
and requested_backend == NvFp4MoeBackend.FLASHINFER_CUTEDSL
):
requested_backend = NvFp4MoeBackend.FLASHINFER_CUTEDSL_BATCHED
return _return_or_raise(
requested_backend, config, weight_key, activation_key, activation_format
)
if envs.is_set("VLLM_USE_FLASHINFER_MOE_FP4"):
if not envs.VLLM_USE_FLASHINFER_MOE_FP4:
# If the user rejects FlashInfer remove those backends.
for b in FLASHINFER_NVFP4_MOE_BACKENDS:
AVAILABLE_BACKENDS.remove(b)
elif envs.is_set("VLLM_FLASHINFER_MOE_BACKEND"):
# If user is explicit about backend, validate it.
backend = fi_2_vllm_backend_map[get_flashinfer_moe_backend()]
return _return_or_raise(
backend, config, weight_key, activation_key, activation_format
)
else:
# If the user is not explicit about the backend, try each.
for backend in FLASHINFER_NVFP4_MOE_BACKENDS:
for k_cls in backend_to_kernel_cls(backend):
supported, reason = k_cls.is_supported_config(
k_cls,
config,
weight_key,
activation_key,
activation_format,
)
if supported:
logger.info_once(_make_log_backend(backend))
return backend, k_cls
else:
logger.debug_once(_make_log_unsupported(backend, reason))
raise NotImplementedError(
"Found VLLM_USE_FLASHINFER_MOE_FP4=1, but no "
"FlashInfer NVFP4 MoE backend supports the configuration."
)
if envs.VLLM_TEST_FORCE_FP8_MARLIN:
backend = NvFp4MoeBackend.MARLIN
return _return_or_raise(
backend, config, weight_key, activation_key, activation_format
)
# Select kernels in order of backend.
for backend in AVAILABLE_BACKENDS:
for k_cls in backend_to_kernel_cls(backend):
supported, reason = k_cls.is_supported_config(
k_cls,
config,
weight_key,
activation_key,
activation_format,
)
if supported:
logger.info_once(_make_log_backend(backend))
return backend, k_cls
else:
logger.debug_once(_make_log_unsupported(backend, reason))
raise NotImplementedError(
"No NvFp4 MoE backend supports the deployment configuration."
)
def convert_to_nvfp4_moe_kernel_format(
nvfp4_backend: NvFp4MoeBackend,
layer: torch.nn.Module,
w13: torch.Tensor,
w13_scale: torch.Tensor,
w13_scale_2: torch.Tensor,
a13_scale: torch.Tensor | None,
w2: torch.Tensor,
w2_scale: torch.Tensor,
w2_scale_2: torch.Tensor,
a2_scale: torch.Tensor | None,
is_act_and_mul: bool,
) -> tuple[
torch.Tensor,
torch.Tensor,
torch.Tensor,
torch.Tensor,
torch.Tensor,
torch.Tensor,
torch.Tensor,
torch.Tensor,
]:
if nvfp4_backend == NvFp4MoeBackend.CUTEDSL:
# CuTeDSL kernel handles weight conversion in its own
# process_weights_after_loading. Pass through raw weights.
# Compute inverse activation scales for the quant config.
if a13_scale is not None:
a13_scale = 1.0 / a13_scale
if a2_scale is not None:
a2_scale = 1.0 / a2_scale
elif nvfp4_backend == NvFp4MoeBackend.FLASHINFER_CUTEDSL:
(
w13,
w13_scale,
w13_scale_2,
a13_scale,
w2,
w2_scale,
w2_scale_2,
a2_scale,
) = prepare_nvfp4_moe_layer_for_flashinfer_cutedsl(
layer=layer,
w13=w13,
w13_scale=w13_scale,
w13_scale_2=w13_scale_2,
a13_scale=a13_scale,
w2=w2,
w2_scale=w2_scale,
w2_scale_2=w2_scale_2,
a2_scale=a2_scale,
)
elif (
nvfp4_backend in FLASHINFER_NVFP4_MOE_BACKENDS
or nvfp4_backend == NvFp4MoeBackend.VLLM_CUTLASS
):
(
w13,
w13_scale,
w13_scale_2,
a13_scale,
w2,
w2_scale,
w2_scale_2,
a2_scale,
) = prepare_nvfp4_moe_layer_for_fi_or_cutlass(
backend=nvfp4_backend,
layer=layer,
w13=w13,
w13_scale=w13_scale,
w13_scale_2=w13_scale_2,
a13_scale=a13_scale,
w2=w2,
w2_scale=w2_scale,
w2_scale_2=w2_scale_2,
a2_scale=a2_scale,
is_act_and_mul=is_act_and_mul,
)
elif nvfp4_backend == NvFp4MoeBackend.MARLIN:
a13_scale = None
a2_scale = None
(
w13,
w13_scale,
w13_scale_2,
w2,
w2_scale,
w2_scale_2,
) = prepare_nvfp4_moe_layer_for_marlin(
layer=layer,
w13=w13,
w13_scale=w13_scale,
w13_scale_2=w13_scale_2,
w2=w2,
w2_scale=w2_scale,
w2_scale_2=w2_scale_2,
is_act_and_mul=is_act_and_mul,
)
elif nvfp4_backend == NvFp4MoeBackend.EMULATION:
# Move the E2M1 lookup table to the device now, because
# `.to(device)` is not allowed during CUDA graph capture.
kE2M1ToFloat_handle.val = kE2M1ToFloat_handle.val.to(w13.device)
if a13_scale is None or a2_scale is None:
raise ValueError(
"Activation global scales should not be None, got"
f" a13_scale={a13_scale}, a2_scale={a2_scale}"
)
if torch.unique(a13_scale).numel() != 1 or torch.unique(a2_scale).numel() != 1:
logger.warning_once(
"In NVFP4 linear, the activation global scale for inputs are different"
" for MOE w13 (gate_up_proj) layer or MOE w2 (down_proj). Using"
" a13_scale = a13_scale.max() and a2_scale = a2_scale.max()."
)
# 1. We take the max following e.g. quantization/utils/flashinfer_fp4_moe.py.
# 2. moe_kernel_quantize_input -> ref_nvfp4_quant_dequant
# use the inverse scale directly (large global scale).
# NOTE: Before this point, `a13_scale` and `a2_scale` are such that:
# `FP8_MAX = activation[expert_id].abs().max() * global_scale[expert_id]`,
# and `global_scale[expert_id]` are small (~1e-4).
# Taking the largest global scale likely results in overflowing the FP8 range
# for other experts - other selection strategies may be used.
a13_scale = 1.0 / a13_scale.max().to(torch.float32)
a2_scale = 1.0 / a2_scale.max().to(torch.float32)
else:
raise ValueError(f"Unknown NvFp4 backend for MoE: {nvfp4_backend}")
return (
w13,
w13_scale,
w13_scale_2,
a13_scale,
w2,
w2_scale,
w2_scale_2,
a2_scale,
)
def make_nvfp4_moe_quant_config(
backend: NvFp4MoeBackend,
w13_scale: torch.Tensor,
w2_scale: torch.Tensor,
w13_scale_2: torch.Tensor,
w2_scale_2: torch.Tensor,
a13_scale: torch.Tensor,
a2_scale: torch.Tensor,
) -> FusedMoEQuantConfig:
if backend == NvFp4MoeBackend.MARLIN:
return nvfp4_w4a16_moe_quant_config(
g1_alphas=w13_scale_2,
g2_alphas=w2_scale_2,
w1_scale=w13_scale,
w2_scale=w2_scale,
)
elif backend == NvFp4MoeBackend.EMULATION:
return nvfp4_moe_quant_config(
g1_alphas=w13_scale_2,
g2_alphas=w2_scale_2,
a1_gscale=a13_scale,
a2_gscale=a2_scale,
w1_scale=w13_scale,
w2_scale=w2_scale,
)
# Pass w13_scale_2 / w2_scale_2 directly as g1/g2_alphas.
# The expert's process_weights_after_loading will fuse activation
# scales in-place. Since the quant config references the same tensor
# as the registered parameter, EPLB rearrangement stays in sync.
return nvfp4_moe_quant_config(
g1_alphas=w13_scale_2,
g2_alphas=w2_scale_2,
a1_gscale=(1.0 / a13_scale),
a2_gscale=(1.0 / a2_scale),
w1_scale=w13_scale,
w2_scale=w2_scale,
# NOTE(rob): this is a hack until the MoE kernels
# create their own quant configs. TRTLLM kernel
# does not accept swizzled input quant scales.
is_scale_swizzled=(
backend
not in (
NvFp4MoeBackend.FLASHINFER_TRTLLM,
NvFp4MoeBackend.FLASHINFER_CUTEDSL,
NvFp4MoeBackend.CUTEDSL,
)
),
)
def make_nvfp4_moe_kernel(
moe_quant_config: FusedMoEQuantConfig,
moe_config: FusedMoEConfig,
experts_cls: type[mk.FusedMoEExperts],
routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None,
) -> mk.FusedMoEKernel:
# Create Prepare/Finalize.
prepare_finalize = maybe_make_prepare_finalize(
moe=moe_config,
quant_config=moe_quant_config,
routing_tables=routing_tables,
allow_new_interface=True,
use_monolithic=issubclass(experts_cls, mk.FusedMoEExpertsMonolithic),
)
assert prepare_finalize is not None
logger.info_once("Using %s", prepare_finalize.__class__.__name__)
# Create Experts.
if prepare_finalize.activation_format == mk.FusedMoEActivationFormat.BatchedExperts:
max_num_tokens = prepare_finalize.max_num_tokens_per_rank()
assert max_num_tokens is not None
experts = experts_cls(
moe_config=moe_config,
quant_config=moe_quant_config,
max_num_tokens=max_num_tokens,
num_dispatchers=prepare_finalize.num_dispatchers(),
)
else:
experts = experts_cls(
moe_config=moe_config,
quant_config=moe_quant_config,
)
kernel = mk.FusedMoEKernel(
prepare_finalize,
experts,
inplace=False,
)
# TODO(rob): update inplace logic to be part of the kernel.
return kernel

View File

@@ -1,216 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import contextlib
from collections.abc import Callable
from dataclasses import asdict, fields
from typing import TYPE_CHECKING, Any, Literal
from pydantic import Field, field_validator
from vllm.config.utils import config, get_hash_factors, hash_factors
from vllm.logger import init_logger
if TYPE_CHECKING:
from vllm.config import VllmConfig
logger = init_logger(__name__)
@config
class IrOpPriorityConfig:
"""
Configuration for vLLM IR op priority for dispatching/lowering during the
forward pass. Each member is a list of strings, which will be passed to
vllm.ir.ops.<op_name>.set_priority() for the duration of the forward pass.
A single comma-separated string is accepted as well,
If specified manually, platform defaults will be appended to the lists.
See KernelConfig.set_platform_defaults().
"""
rms_norm: list[str] = Field(default_factory=list)
"""Priority list for vllm.ir.ops.rms_norm"""
fused_add_rms_norm: list[str] = Field(default_factory=list)
"""Priority list for vllm.ir.ops.fused_add_rms_norm"""
def compute_hash(self) -> str:
"""
Produces a hash unique to the pass configuration.
Any new fields that affect compilation should be added to the hash.
Any future fields that don't affect compilation should be excluded.
Also, manually add IR op impl UUIDs to make sure they affect the compile cache.
"""
factors = get_hash_factors(self, set())
# Implementations are hidden from Dynamo,
# so they don't show up in the traced files list.
from vllm.ir.op import IrOp
assert "_impls" not in factors
factors["_impls"] = {
name: {
provider: IrOp.registry[name].impls[provider].uuid() for provider in p
}
for name, p in asdict(self).items() # type: ignore[call-overload]
}
return hash_factors(factors)
@field_validator("*", mode="before")
@classmethod
def _to_list_str(cls, value: str | list[str]):
if isinstance(value, str):
value = value.replace(" ", "").split(",")
assert all(isinstance(v, str) for v in value)
return value
@contextlib.contextmanager
def set_priority(self):
"""
Context manager to set the IR op priority for all op members.
It also imports IR kernel implementations for the current platform
to ensure all implementations are made available.
"""
from vllm.ir.op import IrOp
from vllm.platforms import current_platform
current_platform.import_ir_kernels()
with contextlib.ExitStack() as stack:
for field in fields(self): # type: ignore[arg-type]
op_priority = getattr(self, field.name)
assert op_priority is not None, (
f"IR op priority for {field.name} must be set"
)
logger.debug(
"Setting IR op priority for %s to %s", field.name, op_priority
)
ir_op = IrOp.registry[field.name]
stack.enter_context(ir_op.set_priority(op_priority))
yield
@classmethod
def with_default(
cls, default: list[str], /, **kwargs: list[str]
) -> "IrOpPriorityConfig":
"""
A helper to create an IrOpPriorityConfig where fields not specified in kwargs
use the given default list.
"""
for field in fields(cls): # type: ignore[arg-type]
if field.name not in kwargs:
kwargs[field.name] = list(default)
return cls(**kwargs)
MoEBackend = Literal[
"auto",
"triton",
"deep_gemm",
"deep_gemm_mega_moe",
"cutlass",
"flashinfer_trtllm",
"flashinfer_cutlass",
"flashinfer_cutedsl",
"marlin",
"humming",
"triton_unfused",
"aiter",
"cutedsl",
"emulation",
]
@config
class KernelConfig:
"""Configuration for kernel selection and warmup behavior."""
ir_op_priority: IrOpPriorityConfig = Field(default_factory=IrOpPriorityConfig)
"""
vLLM IR op priority for dispatching/lowering during the forward pass.
Platform defaults appended automatically during VllmConfig.__post_init__.
"""
enable_flashinfer_autotune: bool = None # type: ignore[assignment]
"""If True, run FlashInfer autotuning during kernel warmup."""
moe_backend: MoEBackend = "auto"
"""Backend for MoE expert computation kernels. Available options:
- "auto": Automatically select the best backend based on model and hardware
- "triton": Use Triton-based fused MoE kernels
- "deep_gemm": Use DeepGEMM kernels (FP8 block-quantized only)
- "deep_gemm_mega_moe": Use DeepGEMM mega MoE kernels
- "cutlass": Use vLLM CUTLASS kernels
- "flashinfer_trtllm": Use FlashInfer with TRTLLM-GEN kernels
- "flashinfer_cutlass": Use FlashInfer with CUTLASS kernels
- "flashinfer_cutedsl": Use FlashInfer with CuteDSL kernels (FP4 only)
- "marlin": Use Marlin kernels (weight-only quantization)
- "humming": Use Humming Mixed Precision kernels
- "triton_unfused": Use Triton unfused MoE kernels
- "aiter": Use AMD AITer kernels (ROCm only)
- "emulation": use BF16/FP16 GEMM, dequantizing weights and
running QDQ on activations.
"""
@field_validator("moe_backend", mode="before")
@classmethod
def _normalize_moe_backend(cls, value: Any) -> Any:
if isinstance(value, str):
return value.lower().replace("-", "_")
return value
def compute_hash(self) -> str:
"""
Produces a hash unique to the pass configuration.
Any new fields that affect compilation should be added to the hash.
Any future fields that don't affect compilation should be excluded.
"""
ignored_factors = {
"enable_flashinfer_autotune",
"ir_op_priority", # handled separately below
}
factors = get_hash_factors(self, ignored_factors)
factors["ir_op_priority"] = self.ir_op_priority.compute_hash()
return hash_factors(factors)
@field_validator("enable_flashinfer_autotune", mode="wrap")
@classmethod
def _skip_none_validation(cls, value: Any, handler: Callable) -> Any:
"""Skip validation if the value is `None` when initialization is delayed."""
if value is None:
return value
return handler(value)
def set_platform_defaults(self, vllm_config: "VllmConfig") -> None:
"""Set platform-specific defaults for the kernel config."""
from vllm.platforms import current_platform
platform_op_priority = current_platform.get_default_ir_op_priority(vllm_config)
logger.debug(
"Setting platform-specific IR op priority defaults: %s, user-defined: %s",
platform_op_priority,
self.ir_op_priority,
)
for op_name, op_priority in asdict(platform_op_priority).items():
current_op_priority: list[str] = getattr(self.ir_op_priority, op_name)
if current_op_priority is None:
setattr(self.ir_op_priority, op_name, op_priority)
else:
# Append platform-specific priorities
# Must be idempotent because vllm_config.set_platform_defaults() may be
# called multiple times (due to VllmConfig.__post_init__ manual call).
unique_op_priority = [
op for op in op_priority if op not in current_op_priority
]
current_op_priority.extend(unique_op_priority)
logger.info(
"Final IR op priority after setting platform defaults: %s",
self.ir_op_priority,
)

View File

@@ -1,356 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
"""
CSA/HCA attention for Blackwell (SM100+).
Replaces vLLM's FlashMLA + fused CUDA kernels with our own KV cache-based
attention pipeline. The previous version used raw KV for attention (no cache),
which produced garbage during decode because the KV cache was never written.
Key changes from the broken version:
1. fused_qnorm_rope_kv_insert_py NOW WRITES KV to the paged cache (fp8)
2. full_sdpa_attention is replaced with cache-aware attention
3. KV is quantized to fp8 with per-token scale, RoPE applied before caching
Architecture:
- KV latent: (T, HD=512) single head, shared across 128 Q heads
- KV Cache: fp8_e4m3 paged cache with per-token inverse scale
- Attention: BF16 (NVFP4 too lossy for Q×K^T)
"""
import torch
import torch.nn.functional as F
def apply_gptj_rope(x, cos, sin, nope_dim):
out = x.clone()
even = x[..., nope_dim:][..., 0::2]
odd = x[..., nope_dim:][..., 1::2]
out[..., nope_dim:][..., 0::2] = even * cos - odd * sin
out[..., nope_dim:][..., 1::2] = even * sin + odd * cos
return out
def apply_inv_gptj_rope(x, cos, sin, nope_dim):
out = x.clone()
even = x[..., nope_dim:][..., 0::2]
odd = x[..., nope_dim:][..., 1::2]
out[..., nope_dim:][..., 0::2] = even * cos + odd * sin
out[..., nope_dim:][..., 1::2] = -even * sin + odd * cos
return out
# ── KV Cache Operations ──────────────────────────────────────────────
def kv_quantize_fp8(kv_bf16):
"""BF16 KV → fp8_e4m3 with per-token inverse scale."""
amax = kv_bf16.float().abs().amax(dim=-1, keepdim=True).clamp(min=1e-12)
fp8_max = torch.tensor(448.0, dtype=torch.float32, device=kv_bf16.device)
scale = fp8_max / amax
kv_fp8 = (kv_bf16.float() * scale).to(torch.float8_e4m3fn)
inv_scale = (amax / fp8_max).to(torch.bfloat16)
return kv_fp8, inv_scale
def kv_dequantize_fp8(kv_fp8, inv_scale):
"""fp8 KV → BF16."""
return (kv_fp8.to(torch.bfloat16) * inv_scale).to(torch.bfloat16)
def paged_kv_write(kv_data, slot_mapping, cache, block_size):
"""Write KV into paged cache.
kv_data: (T, D) tensor to write (fp8 or bf16)
slot_mapping: (T,) slot indices
cache: (num_blocks, block_size, D) cache tensor (may be uint8)
"""
# Handle dtype mismatch: cache is uint8, kv_data is fp8
if cache.dtype == torch.uint8 and kv_data.dtype == torch.float8_e4m3fn:
kv_to_write = kv_data.view(torch.uint8)
else:
kv_to_write = kv_data
# Vectorized write using advanced indexing
block_indices = slot_mapping // block_size
offsets = slot_mapping % block_size
# Clamp to valid range (safety)
valid = (block_indices < cache.shape[0]) & (offsets < cache.shape[1])
if valid.all():
cache[block_indices, offsets] = kv_to_write
else:
# Fall back to per-token for partial writes
for t in range(kv_data.shape[0]):
bi = block_indices[t].item()
oi = offsets[t].item()
if bi < cache.shape[0] and oi < cache.shape[1]:
cache[bi, oi] = kv_to_write[t]
def paged_kv_read(slot_mapping, cache, block_size, num_tokens, head_dim):
"""Read KV from paged cache. Returns fp8 or uint8.
Vectorized version — uses advanced indexing instead of Python for loop.
"""
device = cache.device
# Compute block indices and offsets
slots = slot_mapping # (num_tokens,)
block_indices = slots // block_size
offsets = slots % block_size
# Advanced indexing: cache[block_indices, offsets] -> (num_tokens, head_dim)
kv = cache[block_indices, offsets]
# If cache is uint8, reinterpret as fp8
if cache.dtype == torch.uint8:
kv = kv.view(torch.float8_e4m3fn)
return kv
# ── Attention ─────────────────────────────────────────────────────────
def causal_prefill_attention(q, kv, scale):
"""Full causal self-attention for prefill. q: (T, NH, HD), kv: (T, HD)."""
T, NH, HD = q.shape
q_t = q.permute(1, 0, 2) # (NH, T, HD)
kv_exp = kv.unsqueeze(0).expand(NH, -1, -1) # (NH, T, HD)
out = F.scaled_dot_product_attention(q_t, kv_exp, kv_exp, is_causal=True, scale=scale)
return out.permute(1, 0, 2) # (T, NH, HD)
def decode_attention(q, kv, scale):
"""Decode attention: 1 query vs N cached KVs.
q: (1, NH, HD) — single decode token
kv: (N, HD) — all cached KV (already with RoPE)
"""
NH = q.shape[1]
HD = q.shape[2]
q_t = q.permute(1, 0, 2) # (NH, 1, HD)
kv_exp = kv.unsqueeze(0).expand(NH, -1, -1) # (NH, N, HD)
out = F.scaled_dot_product_attention(q_t, kv_exp, kv_exp, is_causal=False, scale=scale)
return out.permute(1, 0, 2) # (1, NH, HD)
# ── Fused Q norm + RoPE + KV cache write ─────────────────────────────
def fused_qnorm_rope_kv_insert_py(
q, # (T, num_heads, head_dim) — modified in-place
kv, # (T, head_dim) — not modified
swa_kv_cache_2d, # paged KV cache (2D view)
slot_mapping,
positions,
cos_sin_cache,
eps,
block_size,
nope_dim,
rope_dim,
) -> None:
"""Pure PyTorch: RoPE on Q only.
Q is already normed (by fused_q_kv_rmsnorm), so we only apply RoPE.
The original CUDA kernel also does KV cache insert, but we handle that
separately in blackwell_attention_kv_write.
"""
T = q.shape[0]
if T == 0:
return
# GPT-J RoPE on Q only (Q is already normed)
half = rope_dim // 2
cos_q = cos_sin_cache[positions, :half].unsqueeze(1).to(q.dtype)
sin_q = cos_sin_cache[positions, half:2*half].unsqueeze(1).to(q.dtype)
q_rope = q[:, :, nope_dim:].clone()
q[:, :, nope_dim:][:, :, 0::2] = q_rope[:, :, 0::2] * cos_q - q_rope[:, :, 1::2] * sin_q
q[:, :, nope_dim:][:, :, 1::2] = q_rope[:, :, 0::2] * sin_q + q_rope[:, :, 1::2] * cos_q
def blackwell_attention_kv_write(
kv, # (T, HD) kv_normed — NOT RoPE'd yet
positions, # (T,) absolute positions
swa_kv_cache, # (num_blocks, block_size, HD) fp8 paged cache
swa_inv_scale, # (max_slots, 1) per-token inv scale
slot_mapping, # (T,) slot indices
block_size, # tokens per block
cos_sin_cache, # (max_pos, rope_dim) for RoPE
nope_dim, # 448
rope_dim, # 64
) -> None:
"""Write KV to paged cache: apply RoPE → fp8 quantize → insert.
This is the function that vLLM's Blackwell path was missing.
Without this, the KV cache is never written, and decode attention
produces garbage because it can't access prior tokens' KV.
"""
T = kv.shape[0]
if T == 0:
return
# Apply GPT-J RoPE to KV
half = rope_dim // 2
cos = cos_sin_cache[positions, :half].to(kv.dtype)
sin = cos_sin_cache[positions, half:2 * half].to(kv.dtype)
# Must use original values for both even and odd before modifying
kv_rope_nope = kv[:, nope_dim:].clone()
even = kv_rope_nope[:, 0::2]
odd = kv_rope_nope[:, 1::2]
new_even = even * cos - odd * sin
new_odd = even * sin + odd * cos
kv_rope = kv.clone()
kv_rope[:, nope_dim:][:, 0::2] = new_even
kv_rope[:, nope_dim:][:, 1::2] = new_odd
# Quantize to fp8
kv_fp8, inv_scale = kv_quantize_fp8(kv_rope)
# Write to paged cache
paged_kv_write(kv_fp8, slot_mapping, swa_kv_cache, block_size)
# Write inv_scale to flat cache
for t in range(T):
slot = slot_mapping[t].item()
swa_inv_scale[slot] = inv_scale[t]
def blackwell_attention_decode(
q, # (1, NH, HD) single decode query with RoPE
positions, # (1,) absolute position
swa_kv_cache, # (num_blocks, block_size, HD) fp8 SWA cache (uint8)
swa_inv_scale, # (max_slots, 1) per-token inv scale
slot_mapping, # (1,) slot for the new token (already written)
block_size, # tokens per block
scale, # 1/sqrt(HD)
window_size, # 128
swa_indices=None, # (num_decode_tokens, window_size) pre-computed paged indices
swa_lens=None, # (num_decode_tokens,) number of valid indices per token
decode_token_idx=0, # which decode token this is
) -> torch.Tensor:
"""Decode attention: read cached KV using paged indices, attend.
Uses pre-computed swa_indices from vLLM's metadata for correct paged access.
Returns: (1, NH, HD) attention output.
"""
NH = q.shape[1]
HD = q.shape[2]
device = q.device
if swa_indices is not None and swa_lens is not None:
# Use pre-computed paged indices from vLLM
num_valid = swa_lens[decode_token_idx].item()
indices = swa_indices[decode_token_idx, :num_valid]
block_indices = indices // block_size
offsets = indices % block_size
kv_cached_raw = swa_kv_cache[block_indices, offsets]
if swa_kv_cache.dtype == torch.uint8:
kv_cached_raw = kv_cached_raw.view(torch.float8_e4m3fn)
# Dequantize with per-token inverse scale
inv_scales = swa_inv_scale[indices]
kv_cached = kv_dequantize_fp8(kv_cached_raw, inv_scales)
else:
# Fallback: sequential slot access
pos = positions[0].item()
all_slots = torch.arange(pos + 1, dtype=torch.int64, device=device)
kv_cached_raw = paged_kv_read(all_slots, swa_kv_cache, block_size, pos + 1, HD)
kv_inv_scales = swa_inv_scale[all_slots]
kv_cached = kv_dequantize_fp8(kv_cached_raw, kv_inv_scales)
window_start = max(0, pos - window_size + 1)
kv_cached = kv_cached[window_start:]
q_t = q.permute(1, 0, 2)
kv_exp = kv_cached.unsqueeze(0).expand(NH, -1, -1)
out = F.scaled_dot_product_attention(q_t, kv_exp, kv_exp, is_causal=False, scale=scale)
return out.permute(1, 0, 2)
def full_sdpa_attention(
q: torch.Tensor, # (T, NH, HD) with RoPE
kv: torch.Tensor, # (T, HD) KV latent
scale: float,
) -> torch.Tensor:
"""Full causal self-attention for PREFILL only.
DEPRECATED: Use causal_prefill_attention instead.
Kept for backward compatibility with the existing vLLM patch.
"""
return causal_prefill_attention(q, kv, scale)
# ── CSA/HCA Decode Attention ─────────────────────────────────────────
def blackwell_csa_decode_attention(
q, # (num_decode_tokens, NH, HD) with RoPE
positions, # (num_decode_tokens,)
swa_kv_cache, # (num_blocks, block_size, D) fp8 SWA cache
swa_inv_scale, # (max_slots, 1) per-token inv scale
swa_metadata, # DeepseekSparseSWAMetadata
flashmla_metadata, # FlashMLASparseMetadata (for topk_indices)
compressed_kv_cache, # (num_blocks, block_size, D) compressed KV cache
compress_ratio, # 4 or 128
scale, # 1/sqrt(HD)
window_size, # 128
nope_dim, # 448
rope_dim, # 64
cos_sin_cache, # (max_pos, rope_dim)
attn_sink, # (NH,) sink weights
max_model_len, # max sequence length
) -> torch.Tensor:
"""CSA/HCA decode: sparse attention on compressed KV + SWA + sink merge.
For each decode token:
1. Get topk_indices from the indexer (already computed)
2. Gather compressed KV at topk positions
3. Sparse attention on gathered KV
4. SWA attention from paged cache
5. Merge with sink weights
"""
num_tokens, NH, HD = q.shape
device = q.device
block_size = swa_metadata.block_size
output = torch.zeros(num_tokens, NH, HD, dtype=torch.bfloat16, device=device)
# Get topk indices from the indexer
num_decodes = swa_metadata.num_decodes
is_valid = swa_metadata.is_valid_token[:num_tokens]
if compress_ratio == 4:
# C4A: topk indices from indexer buffer
# These are computed by the indexer during this forward pass
# For now, we need to get them from the metadata
# The indexer fills topk_indices_buffer
pass
# C128A: pre-computed in the metadata
# For now, fall back to SWA-only for CSA/HCA decode
# The sparse component will be added once we verify the SWA path works
for t in range(num_tokens):
output[t] = blackwell_attention_decode(
q[t:t+1], positions[t:t+1],
swa_kv_cache, swa_inv_scale,
swa_metadata.slot_mapping[t:t+1],
block_size, scale, window_size,
).squeeze(0)
return output
def csa_sparse_prefill_attention(
q, # (num_prefills, NH, HD) with RoPE
kv_rope, # (num_prefills, HD) KV with RoPE
compressed_kv_cache, # compressed KV cache
flashmla_metadata, # FlashMLASparseMetadata
swa_metadata, # DeepseekSparseSWAMetadata
compress_ratio, # 4 or 128
scale, # 1/sqrt(HD)
window_size, # 128
nope_dim, # 448
rope_dim, # 64
cos_sin_cache, # (max_pos, rope_dim)
attn_sink, # (NH,) sink weights
max_model_len, # max sequence length
) -> torch.Tensor:
"""CSA/HCA prefill: sparse + SWA attention.
For now, falls back to full causal attention (which is correct
but not optimal for long sequences).
"""
# Full causal attention is always correct for prefill
return causal_prefill_attention(q, kv_rope, scale)

View File

@@ -1,455 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from dataclasses import dataclass
from typing import Any, ClassVar, cast
import torch
from torch import nn
from vllm.config import VllmConfig, get_current_vllm_config
from vllm.forward_context import get_forward_context
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (
MergedColumnParallelLinear,
)
from vllm.platforms import current_platform
# Check at module load time if we're on Blackwell
_IS_BLACKWELL = False
try:
_cap = current_platform.get_device_capability()
if _cap is not None and _cap.major >= 10:
_IS_BLACKWELL = True
except Exception:
pass
from vllm.triton_utils import tl, triton
from vllm.v1.attention.backend import (
AttentionBackend,
AttentionCGSupport,
AttentionMetadataBuilder,
CommonAttentionMetadata,
MultipleOf,
)
from vllm.v1.attention.ops.deepseek_v4_ops.fused_compress_quant_cache import (
_fused_kv_compress_norm_rope_insert_indexer_attn,
_fused_kv_compress_norm_rope_insert_indexer_mxfp4_attn,
_fused_kv_compress_norm_rope_insert_sparse_attn,
)
from vllm.v1.attention.ops.deepseek_v4_ops.fused_indexer_q import (
MXFP4_BLOCK_SIZE,
)
from vllm.v1.kv_cache_interface import (
KVCacheSpec,
MLAAttentionSpec,
SlidingWindowMLASpec,
)
class CompressorBackend(AttentionBackend):
def __init__(self):
super().__init__()
@staticmethod
def get_name() -> str:
return "CompressorBackend"
@staticmethod
def get_supported_kernel_block_sizes() -> list[int | MultipleOf]:
return [MultipleOf(1)]
@classmethod
def get_supported_head_sizes(cls) -> list[int]:
return [512, 1024]
@staticmethod
def get_builder_cls() -> type["CompressorMetadataBuilder"]:
return CompressorMetadataBuilder
@staticmethod
def get_kv_cache_shape(
num_blocks: int,
block_size: int,
num_kv_heads: int,
head_size: int,
cache_dtype_str: str = "auto",
) -> tuple[int, ...]:
assert num_kv_heads == 1
return (num_blocks, block_size, head_size)
@staticmethod
def get_kv_cache_stride_order(
include_num_layers_dimension: bool = False,
) -> tuple[int, ...]:
if include_num_layers_dimension:
return (0, 1, 2, 3)
return (0, 1, 2)
@dataclass
class CompressorMetadata:
block_table: torch.Tensor
slot_mapping: torch.Tensor
block_size: int
token_to_req_indices: torch.Tensor | None = None # [num_tokens]
class CompressorMetadataBuilder(AttentionMetadataBuilder):
_cudagraph_support: ClassVar[AttentionCGSupport] = AttentionCGSupport.ALWAYS
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
assert isinstance(self.kv_cache_spec, SlidingWindowMLASpec | MLAAttentionSpec)
mla_spec = cast(SlidingWindowMLASpec | MLAAttentionSpec, self.kv_cache_spec)
self.block_size = mla_spec.block_size
self.token_to_req_indices = torch.zeros(
self.vllm_config.scheduler_config.max_num_batched_tokens,
dtype=torch.int32,
device=self.device,
)
def build(
self,
common_prefix_len: int,
common_attn_metadata: CommonAttentionMetadata,
fast_build: bool = False,
) -> CompressorMetadata:
query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu
num_reqs = common_attn_metadata.num_reqs
query_lens = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1]
x = torch.repeat_interleave(torch.arange(num_reqs), query_lens).pin_memory()
token_to_req_indices = self.token_to_req_indices[: x.shape[0]]
token_to_req_indices.copy_(x, non_blocking=True)
return CompressorMetadata(
block_table=common_attn_metadata.block_table_tensor.clamp_(min=0),
slot_mapping=common_attn_metadata.slot_mapping,
block_size=self.block_size,
token_to_req_indices=token_to_req_indices,
)
class CompressorStateCache(torch.nn.Module, AttentionLayerBase):
def __init__(
self,
state_dim: int,
dtype: torch.dtype,
compress_ratio: int,
prefix: str,
):
super().__init__()
self.state_dim = state_dim
self.dtype = dtype
self.prefix = prefix
self.kv_cache = torch.tensor([])
compilation_config = get_current_vllm_config().compilation_config
if prefix in compilation_config.static_forward_context:
raise ValueError(f"Duplicate layer name: {prefix}")
compilation_config.static_forward_context[prefix] = self
assert self.dtype == torch.float32
assert compress_ratio in [4, 128]
coff = 1 + (compress_ratio == 4)
self.sliding_window = coff * compress_ratio
# Block size is constrained by tensor sharing between compressor states
# and KV blocks. Since compressor states share the same physical tensor
# as KV blocks, they must use the same page size.
# The KV block shape [256//4, head_dim] = [64, 584] determines:
# - C4 compressor block shape [4, 2*512*2*4] -> block_size = 4
# - C128 compressor block shape [8, 512*2*4] -> block_size = 8
# TODO(yifan): make block size automatically determined and configurable.
if compress_ratio == 4:
self.block_size = 4
elif compress_ratio == 128:
self.block_size = 8
else:
raise ValueError(f"Invalid compress ratio: {compress_ratio}")
def get_kv_cache_spec(self, vllm_config: VllmConfig) -> KVCacheSpec:
return SlidingWindowMLASpec( # only has one vector instead of K + V
block_size=self.block_size,
num_kv_heads=1,
head_size=self.state_dim,
dtype=self.dtype,
sliding_window=self.sliding_window,
alignment=576, # NOTE: FlashMLA requires 576B alignment
)
def forward(self): ...
def get_attn_backend(self) -> type[AttentionBackend]:
return CompressorBackend
class DeepseekCompressor(nn.Module):
def __init__(
self,
vllm_config: VllmConfig,
compress_ratio: int,
hidden_size: int,
head_dim: int,
rotate: bool = False,
prefix: str = "",
k_cache_prefix="",
use_fp4_cache: bool = False,
):
super().__init__()
self.compress_ratio = compress_ratio
self.hidden_size = hidden_size
self.head_dim = head_dim
self.rotate = rotate
self.prefix = prefix
self.k_cache_prefix = k_cache_prefix
self.use_fp4_cache = use_fp4_cache
config = vllm_config.model_config.hf_config
self.rope_head_dim = config.qk_rope_head_dim
self.nope_head_dim = self.head_dim - self.rope_head_dim
self.rms_norm_eps = config.rms_norm_eps
self.device = current_platform.device_type
self.max_num_reqs = vllm_config.scheduler_config.max_num_seqs
self.max_model_len = vllm_config.model_config.max_model_len
self.overlap = compress_ratio == 4
self.coff = 1 + self.overlap
state_dtype = torch.float32
self.ape = nn.Parameter(
torch.empty(
(compress_ratio, self.coff * self.head_dim),
dtype=state_dtype,
device=self.device,
),
requires_grad=False,
)
quant_config = vllm_config.quant_config
self.fused_wkv_wgate = MergedColumnParallelLinear(
self.hidden_size,
[self.coff * self.head_dim, self.coff * self.head_dim],
bias=False,
return_bias=False,
quant_config=quant_config,
disable_tp=True,
prefix=f"{prefix}.fused_wkv_wgate",
)
self.norm = RMSNorm(self.head_dim, self.rms_norm_eps)
self.state_cache = CompressorStateCache(
state_dim=2 * self.coff * self.head_dim, # kv_state + score_state
dtype=state_dtype,
compress_ratio=compress_ratio,
prefix=f"{prefix}.state_cache",
)
# Save reference to static_forward_context for forward-time KV cache lookup.
# get_current_vllm_config() is only available during __init__, not forward.
self._static_forward_context = (
vllm_config.compilation_config.static_forward_context
)
if self.head_dim == 512:
assert not use_fp4_cache, (
"MXFP4 cache is only supported for indexer (head=128)"
)
self._fused_kernel = _fused_kv_compress_norm_rope_insert_sparse_attn
self._quant_block = 64
self._token_stride = self.nope_head_dim + self.rope_head_dim * 2
self._scale_dim = self.nope_head_dim // 64 + 1 # 7 real + 1 pad
self._num_warps = 4
elif self.head_dim == 128:
if use_fp4_cache:
self._fused_kernel = (
_fused_kv_compress_norm_rope_insert_indexer_mxfp4_attn
)
self._quant_block = MXFP4_BLOCK_SIZE
self._token_stride = self.head_dim // 2
self._scale_dim = self.head_dim // MXFP4_BLOCK_SIZE
else:
self._fused_kernel = _fused_kv_compress_norm_rope_insert_indexer_attn
self._quant_block = 128
self._token_stride = self.head_dim
self._scale_dim = 4 # single float32 scale
self._num_warps = 1
else:
raise ValueError(
f"Unsupported head_dim for fused quant+cache: {self.head_dim}"
)
def forward(
self,
# [num_tokens, 2 * self.coff * self.head_dim]
kv_score: torch.Tensor,
# [num_tokens]
positions: torch.Tensor,
rotary_emb,
) -> None:
# Each of shape [num_tokens, coff * self.head_dim]
# input bf16, output are fp32
kv, score = kv_score.split(
[self.coff * self.head_dim, self.coff * self.head_dim], dim=-1
)
# Get the metadata and handle dummy profiling run.
attn_metadata = get_forward_context().attn_metadata
if not isinstance(attn_metadata, dict):
return
state_metadata = cast(
CompressorMetadata, attn_metadata[self.state_cache.prefix]
)
token_to_req_indices = state_metadata.token_to_req_indices
slot_mapping = state_metadata.slot_mapping
num_actual = slot_mapping.shape[0]
block_table = state_metadata.block_table
block_size = state_metadata.block_size
# [num_blocks, block_size, kv_dim+score_dim], where kv_dim == score_dim
state_cache = self.state_cache.kv_cache
# kv_state stored in first half, score_state stored in second half
state_width = state_cache.shape[-1] // 2
pdl_kwargs = {} if current_platform.is_rocm() else {"launch_pdl": False}
# Store the KV and score (with fused APE addition) in the state.
# NOTE: PDL is disabled — both this kernel and _fused_kernel below
# depend on preceding kernel outputs (kv/score from the cublas GEMM;
# state_cache from this kernel) but neither emits/waits on PDL grid
# dependency primitives, so launch_pdl=True caused a read-after-write
# race and non-deterministic output.
_save_partial_states_kernel[(num_actual,)](
kv,
kv.stride(0),
score,
score.stride(0),
self.ape,
self.ape.stride(0),
positions,
state_cache,
state_cache.stride(0),
state_cache.stride(1),
slot_mapping,
block_size,
HEAD_SIZE=kv.shape[-1],
TRITON_BLOCK_SIZE=triton.next_power_of_2(kv.shape[-1]),
STATE_WIDTH=state_width,
COMPRESS_RATIO=self.compress_ratio,
**pdl_kwargs,
)
# Fused: compress → RMSNorm → RoPE → FP8 quant → KV cache write.
# RoPE requirements (kernel applies forward GPT-J style rotation):
# - is_neox_style=False (interleaved pairs, NOT split-half)
# - cos_sin_cache layout: [max_pos, rope_head_dim] with first half cos,
# second half sin (per-pair, length rope_head_dim // 2 each)
# - applied to LAST rope_head_dim elements of head_dim
# - position used: (positions // compress_ratio) * compress_ratio
# On Blackwell (SM100+), skip the fused kernel because:
# 1. The fused kernel does attention using FlashMLA which doesn't work on SM100
# 2. Our Blackwell attention path handles everything separately
# Instead, we just save the state (done above) and let the attention
# path handle compression + RoPE + cache write + attention.
if _IS_BLACKWELL:
# Blackwell: state is already saved, skip fused kernel
return
cos_sin_cache = rotary_emb.cos_sin_cache
k_cache_metadata = cast(Any, attn_metadata[self.k_cache_prefix])
kv_cache = self._static_forward_context[self.k_cache_prefix].kv_cache
self._fused_kernel[(num_actual,)](
# state cache
state_cache,
state_cache.stride(0),
state_cache.stride(1),
# metadata
token_to_req_indices,
positions,
slot_mapping,
block_table,
block_table.stride(0),
block_size,
# RMSNorm
self.norm.weight,
self.rms_norm_eps,
# RoPE
cos_sin_cache,
cos_sin_cache.stride(0),
# KV cache
kv_cache,
k_cache_metadata.slot_mapping,
kv_cache.shape[1], # paged KV cache block size (tokens per block)
# constexprs
HEAD_SIZE=self.head_dim,
TRITON_BLOCK_SIZE=triton.next_power_of_2(self.head_dim),
STATE_WIDTH=state_width,
COMPRESS_RATIO=self.compress_ratio,
OVERLAP=self.overlap,
ROPE_HEAD_DIM=self.rope_head_dim,
FP8_MAX=448.0,
QUANT_BLOCK=self._quant_block,
TOKEN_STRIDE=self._token_stride,
SCALE_DIM=self._scale_dim,
KV_BLOCK_STRIDE=kv_cache.stride(0),
num_warps=self._num_warps,
**pdl_kwargs,
)
@triton.jit
def _save_partial_states_kernel(
kv_ptr,
kv_stride,
score_ptr,
score_stride,
ape_ptr,
ape_stride,
positions_ptr,
state_cache_ptr,
state_cache_stride0,
state_cache_stride1,
slot_mapping_ptr,
block_size,
HEAD_SIZE: tl.constexpr,
TRITON_BLOCK_SIZE: tl.constexpr,
# state_cache last dim packs [kv_state, score_state], each STATE_WIDTH wide.
STATE_WIDTH: tl.constexpr,
COMPRESS_RATIO: tl.constexpr,
):
token_idx = tl.program_id(0)
slot_id = tl.load(slot_mapping_ptr + token_idx)
# Skip padded / invalid tokens (slot_id == -1 is the PAD sentinel used
# by vLLM). During CUDA graph replay the batch may contain padding
# tokens whose slot_mapping is -1; writing to kv_state[-1] would be an
# illegal memory access.
if slot_id < 0:
return
block_idx = slot_id // block_size
pos_in_block = slot_id % block_size
base_ptr = (
state_cache_ptr
+ block_idx * state_cache_stride0
+ pos_in_block * state_cache_stride1
)
block = tl.arange(0, TRITON_BLOCK_SIZE)
mask = block < HEAD_SIZE
kv = tl.load(kv_ptr + token_idx * kv_stride + block, mask=mask)
tl.store(base_ptr + block, kv, mask=mask)
# Fused: score += ape[position % compress_ratio]
position = tl.load(positions_ptr + token_idx)
ape_row = position % COMPRESS_RATIO
ape = tl.load(ape_ptr + ape_row * ape_stride + block, mask=mask)
score = tl.load(score_ptr + token_idx * score_stride + block, mask=mask)
tl.store(
base_ptr + STATE_WIDTH + block,
score + ape,
mask=mask,
)

View File

@@ -1,195 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
# Patched MHC layer — pure PyTorch, no TileLang.
# Replaces the TileLang-based mhc.py from the vLLM nightly image.
# The original imports tilelang at the top level and JIT-compiles kernels
# which don't work correctly on Blackwell (SM100).
import torch
from vllm.utils.torch_utils import direct_register_custom_op
# ── Pure PyTorch MHC implementations ──────────────────────────────────
def mhc_pre(
residual: torch.Tensor,
fn: torch.Tensor,
hc_scale: torch.Tensor,
hc_base: torch.Tensor,
rms_eps: float,
hc_pre_eps: float,
hc_sinkhorn_eps: float,
hc_post_mult_value: float,
sinkhorn_repeat: int,
n_splits: int = 1,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
assert residual.dtype == torch.bfloat16
assert fn.dtype == torch.float32
assert hc_scale.dtype == torch.float32
assert hc_base.dtype == torch.float32
hc_mult = residual.shape[-2]
hidden_size = residual.shape[-1]
hc_mult2 = hc_mult * hc_mult
hc_mult3 = hc_mult * 2 + hc_mult2
hc_hidden_size = hc_mult * hidden_size
outer_shape = residual.shape[:-2]
residual_flat = residual.view(-1, hc_mult, hidden_size)
num_tokens = residual_flat.shape[0]
x = residual_flat.view(num_tokens, hc_hidden_size).to(torch.float32)
mixes = torch.matmul(x, fn.t())
sqrsum = x.square().sum(dim=-1, keepdim=True)
mixes = mixes * torch.rsqrt(sqrsum / hc_hidden_size + rms_eps)
pre_logits = mixes[:, :hc_mult] * hc_scale[0] + hc_base[:hc_mult]
pre_mix = torch.sigmoid(pre_logits) + hc_pre_eps
post_logits = mixes[:, hc_mult:2 * hc_mult] * hc_scale[1] + hc_base[hc_mult:2 * hc_mult]
post_mix = torch.sigmoid(post_logits) * hc_post_mult_value
comb_logits = (mixes[:, 2 * hc_mult:]
.view(num_tokens, hc_mult, hc_mult)
* hc_scale[2]
+ hc_base[2 * hc_mult:].view(1, hc_mult, hc_mult))
comb_mix = torch.softmax(comb_logits, dim=-1) + hc_sinkhorn_eps
comb_mix = comb_mix / (comb_mix.sum(dim=-2, keepdim=True) + hc_sinkhorn_eps)
for _ in range(sinkhorn_repeat - 1):
comb_mix = comb_mix / (comb_mix.sum(dim=-1, keepdim=True) + hc_sinkhorn_eps)
comb_mix = comb_mix / (comb_mix.sum(dim=-2, keepdim=True) + hc_sinkhorn_eps)
layer_input = torch.sum(
pre_mix.unsqueeze(-1) * residual_flat.to(torch.float32), dim=1
).to(torch.bfloat16)
return (
post_mix.view(*outer_shape, hc_mult, 1),
comb_mix.view(*outer_shape, hc_mult, hc_mult),
layer_input.view(*outer_shape, hidden_size),
)
def _mhc_pre_fake(
residual, fn, hc_scale, hc_base, rms_eps, hc_pre_eps,
hc_sinkhorn_eps, hc_post_mult_value, sinkhorn_repeat, n_splits=1,
):
hc_mult = residual.shape[-2]
hidden_size = residual.shape[-1]
outer_shape = residual.shape[:-2]
return (
torch.empty(*outer_shape, hc_mult, 1, dtype=torch.float32, device=residual.device),
torch.empty(*outer_shape, hc_mult, hc_mult, dtype=torch.float32, device=residual.device),
torch.empty(*outer_shape, hidden_size, dtype=torch.bfloat16, device=residual.device),
)
def mhc_post(
x: torch.Tensor,
residual: torch.Tensor,
post_layer_mix: torch.Tensor,
comb_res_mix: torch.Tensor,
) -> torch.Tensor:
mixed_residual = torch.einsum(
"...ij,...ih->...jh",
comb_res_mix.to(torch.float32),
residual.to(torch.float32),
)
post_term = post_layer_mix.to(torch.float32) * x.unsqueeze(-2).to(torch.float32)
return (mixed_residual + post_term).to(residual.dtype)
def _mhc_post_fake(x, residual, post_layer_mix, comb_res_mix):
return torch.empty_like(residual)
def mhc_fused_post_pre(
x: torch.Tensor,
residual: torch.Tensor,
post_layer_mix: torch.Tensor,
comb_res_mix: torch.Tensor,
fn: torch.Tensor,
hc_scale: torch.Tensor,
hc_base: torch.Tensor,
rms_eps: float,
hc_pre_eps: float,
hc_sinkhorn_eps: float,
hc_post_mult_value: float,
sinkhorn_repeat: int,
n_splits: int = 1,
tile_n: int = 1,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
new_residual = mhc_post(x, residual, post_layer_mix, comb_res_mix)
post_mix, res_mix, layer_input = mhc_pre(
new_residual, fn, hc_scale, hc_base,
rms_eps, hc_pre_eps, hc_sinkhorn_eps,
hc_post_mult_value, sinkhorn_repeat, n_splits,
)
return new_residual, post_mix, res_mix, layer_input
def _mhc_fused_post_pre_fake(
x, residual, post_layer_mix, comb_res_mix, fn, hc_scale, hc_base,
rms_eps, hc_pre_eps, hc_sinkhorn_eps, hc_post_mult_value,
sinkhorn_repeat, n_splits=1, tile_n=1,
):
hc_mult = residual.shape[-2]
hidden_size = residual.shape[-1]
outer_shape = residual.shape[:-2]
return (
torch.empty_like(residual),
torch.empty(*outer_shape, hc_mult, 1, dtype=torch.float32, device=residual.device),
torch.empty(*outer_shape, hc_mult, hc_mult, dtype=torch.float32, device=residual.device),
torch.empty(*outer_shape, hidden_size, dtype=torch.bfloat16, device=residual.device),
)
def _hc_head_fused_kernel(
hs_flat: torch.Tensor,
fn: torch.Tensor,
hc_scale: torch.Tensor,
hc_base: torch.Tensor,
out: torch.Tensor,
hidden_size: int,
rms_eps: float,
hc_eps: float,
hc_mult: int,
) -> None:
if hs_flat.shape[0] == 0:
return
x_flat = hs_flat.reshape(hs_flat.shape[0], hc_mult * hidden_size).to(torch.float32)
mixes = torch.matmul(x_flat, fn.t())
sqrsum = x_flat.square().sum(dim=-1, keepdim=True)
rsqrt = torch.rsqrt(sqrsum / (hc_mult * hidden_size) + rms_eps)
pre_mix = torch.sigmoid(mixes * rsqrt * hc_scale[0] + hc_base) + hc_eps
result = torch.sum(pre_mix.unsqueeze(-1) * hs_flat.to(torch.float32), dim=1).to(out.dtype)
out.copy_(result)
# ── Register as torch custom ops ──────────────────────────────────────
direct_register_custom_op(
op_name="mhc_pre",
op_func=mhc_pre,
mutates_args=[],
fake_impl=_mhc_pre_fake,
)
direct_register_custom_op(
op_name="mhc_post",
op_func=mhc_post,
mutates_args=[],
fake_impl=_mhc_post_fake,
)
direct_register_custom_op(
op_name="mhc_fused_post_pre",
op_func=mhc_fused_post_pre,
mutates_args=[],
fake_impl=_mhc_fused_post_pre_fake,
)
direct_register_custom_op(
op_name="hc_head_fused_kernel",
op_func=_hc_head_fused_kernel,
mutates_args=["out"],
)

View File

@@ -1,48 +0,0 @@
#!/usr/bin/env python3
"""Patch DeepseekCompressor cache for Blackwell: remove FlashMLA alignment."""
import sys
def patch(path):
with open(path, 'r') as f:
content = f.read()
if "CLAWMINE_PATCH_COMPRESSOR" in content:
print("Already patched, skipping")
return
old = """ return SlidingWindowMLASpec( # only has one vector instead of K + V
block_size=self.block_size,
num_kv_heads=1,
head_size=self.state_dim,
dtype=self.dtype,
sliding_window=self.sliding_window,
alignment=576, # NOTE: FlashMLA requires 576B alignment
)"""
new = """ # CLAWMINE_PATCH_COMPRESSOR: No FlashMLA alignment on Blackwell
from vllm.platforms import current_platform
_is_blackwell = (
current_platform.get_device_capability() is not None
and current_platform.get_device_capability().major >= 10
)
return SlidingWindowMLASpec( # only has one vector instead of K + V
block_size=self.block_size,
num_kv_heads=1,
head_size=self.state_dim,
dtype=self.dtype,
sliding_window=self.sliding_window,
alignment=None if _is_blackwell else 576,
)"""
if old not in content:
print("ERROR: Could not find the code to patch in " + path)
sys.exit(1)
content = content.replace(old, new)
with open(path, 'w') as f:
f.write(content)
print("Patched DeepseekCompressor for Blackwell")
if __name__ == "__main__":
patch(sys.argv[1])

View File

@@ -1,39 +0,0 @@
#!/usr/bin/env python3
"""Patch _allocate_kv_cache_tensors to print the layer name mismatch."""
import sys
def patch(path):
with open(path, 'r') as f:
content = f.read()
if "CLAWMINE_DEBUG_LAYERS" in content:
print("Already patched, skipping")
return
old = """ assert layer_names == set(kv_cache_raw_tensors.keys()), (
"Some layers are not correctly initialized"
)"""
new = """ # CLAWMINE_DEBUG_LAYERS: print mismatch instead of asserting
missing = layer_names - set(kv_cache_raw_tensors.keys())
extra = set(kv_cache_raw_tensors.keys()) - layer_names
if missing or extra:
print(f"CLAWMINE DEBUG: missing layers ({len(missing)}): {sorted(missing)[:20]}")
print(f"CLAWMINE DEBUG: extra layers ({len(extra)}): {sorted(extra)[:20]}")
print(f"CLAWMINE DEBUG: expected ({len(layer_names)}), got ({len(kv_cache_raw_tensors.keys())})")
assert layer_names == set(kv_cache_raw_tensors.keys()), (
"Some layers are not correctly initialized"
)"""
if old not in content:
print("ERROR: Could not find the code to patch")
sys.exit(1)
content = content.replace(old, new)
with open(path, 'w') as f:
f.write(content)
print("Patched gpu_model_runner.py for debug layer names")
if __name__ == "__main__":
patch(sys.argv[1])

View File

@@ -1,58 +0,0 @@
#!/usr/bin/env python3
"""Patch DeepseekV4IndexerCache on Blackwell: remove FlashMLA alignment.
Same as patch_swa_cache but for the indexer cache class.
"""
import sys
def patch(path):
with open(path, 'r') as f:
content = f.read()
if "CLAWMINE_PATCH_INDEXER_CACHE" in content:
print("Already patched, skipping")
return
# Patch the indexer cache's get_kv_cache_spec to remove FlashMLA alignment on Blackwell
old = """ def get_kv_cache_spec(self, vllm_config: VllmConfig) -> KVCacheSpec:
# head_dim already carries the fp8 scale padding
# compress_ratio=1 for V3.2, >1 for DeepseekV4; both use the same cache layout.
return MLAAttentionSpec(
block_size=self.cache_config.block_size,
num_kv_heads=1,
head_size=self.head_dim,
dtype=self.dtype,
compress_ratio=self.compress_ratio,
# DeepseekV4 aligns indexer pages to FlashMLA's 576B so they can pack with
# the indexer's compressor state cache. V3.2 keeps the legacy layout.
alignment=576,
)"""
new = """ def get_kv_cache_spec(self, vllm_config: VllmConfig) -> KVCacheSpec:
# CLAWMINE_PATCH_INDEXER_CACHE: No FlashMLA alignment on Blackwell
from vllm.platforms import current_platform
_is_blackwell = (
current_platform.get_device_capability() is not None
and current_platform.get_device_capability().major >= 10
)
return MLAAttentionSpec(
block_size=self.cache_config.block_size,
num_kv_heads=1,
head_size=self.head_dim,
dtype=self.dtype,
compress_ratio=self.compress_ratio,
alignment=None if _is_blackwell else 576,
)"""
if old not in content:
print("ERROR: Could not find the code to patch in " + path)
sys.exit(1)
content = content.replace(old, new)
with open(path, 'w') as f:
f.write(content)
print("Patched DeepseekV4IndexerCache for Blackwell")
if __name__ == "__main__":
patch(sys.argv[1])

View File

@@ -1,135 +0,0 @@
#!/usr/bin/env python3
"""Patch vLLM kv_cache_utils.py to handle DeepseekV4 SWA page sizes.
DeepseekV4 has three cache types:
- C128A (HCA): compress_ratio=128, very small page size
- C4A (CSA): compress_ratio=4, medium page size
- SWA: compress_ratio=1, large page size
The upstream code assumes SWA page sizes <= MLA page sizes and pads
SWA pages to match MLA. This breaks when SWA pages are LARGER than
MLA pages (which is always the case for DeepseekV4).
Our fix: when SWA pages exceed MLA pages, put them in their own
separate cache group without padding.
"""
import sys
def patch(path):
with open(path, 'r') as f:
content = f.read()
if "CLAWMINE_PATCH_KV_CACHE" in content:
print("Already patched, skipping")
return
# The old code: asserts SWA pages <= MLA pages, then pads SWA to MLA
old = """ assert max(sm_page_sizes) <= max(all_page_sizes)
# Unify page size by padding layers' page_size to the nearest larger page_size.
# Compute candidate (nearest larger page_size) for each unique page size.
size_to_candidate: dict[int, int] = {}
for ps in sm_page_sizes:
size_to_candidate[ps] = min(x for x in all_page_sizes if x >= ps)
# Pad and collect layer names per page size.
for layer_name, layer_spec in sm_spec.kv_cache_specs.items():
current_size = layer_spec.page_size_bytes
candidate = size_to_candidate[current_size]
if current_size < candidate:
object.__setattr__(layer_spec, "page_size_padded", candidate)
layers_per_size[candidate].append(layer_name)
# NOTE(yifan): for now, inside a UniformKV group, each page_size should
# have the same number of layers. This also means we don't need to pad layers
# inside a partial-full layer tuple.
assert len(set(len(layers) for layers in layers_per_size.values())) == 1
num_layers_per_size = len(next(iter(layers_per_size.values())))
# Split layers inside each UniformKV group for aligned #(layers).
# See `_get_kv_cache_groups_uniform_page_size` for more details.
num_tuple_groups = cdiv(num_layers_per_size, num_layer_tuples)
layer_tuples = list(zip(*layers_per_size.values()))
for i in range(num_tuple_groups):
group_layer_tuples = layer_tuples[i::num_tuple_groups]
# Flatten tuples and build dict for from_specs
group_layer_names = [
name for layer_tuple in group_layer_tuples for name in layer_tuple
]
group_layer_specs = {
name: sm_spec.kv_cache_specs[name] for name in group_layer_names
}
sub_sm_spec = UniformTypeKVCacheSpecs.from_specs(group_layer_specs)
assert sub_sm_spec is not None
swa_mla_groups.append(
KVCacheGroupSpec(
layer_names=group_layer_names,
kv_cache_spec=sub_sm_spec,
)
)"""
# The new code: handle both cases
new = """ # CLAWMINE_PATCH_KV_CACHE: Handle DeepseekV4 where SWA page sizes
# can be larger than MLA page sizes. Two cases:
# 1. All SWA pages <= some MLA page: original padding logic
# 2. Some SWA pages > all MLA pages: separate cache group, no padding
max_mla_page = max(all_page_sizes)
can_pad = max(sm_page_sizes) <= max_mla_page
if can_pad:
# Original logic: pad SWA pages to nearest MLA page
size_to_candidate: dict[int, int] = {}
for ps in sm_page_sizes:
size_to_candidate[ps] = min(x for x in all_page_sizes if x >= ps)
for layer_name, layer_spec in sm_spec.kv_cache_specs.items():
current_size = layer_spec.page_size_bytes
candidate = size_to_candidate[current_size]
if current_size < candidate:
object.__setattr__(layer_spec, "page_size_padded", candidate)
layers_per_size[candidate].append(layer_name)
assert len(set(len(layers) for layers in layers_per_size.values())) == 1
num_layers_per_size = len(next(iter(layers_per_size.values())))
num_tuple_groups = cdiv(num_layers_per_size, num_layer_tuples)
layer_tuples = list(zip(*layers_per_size.values()))
for i in range(num_tuple_groups):
group_layer_tuples = layer_tuples[i::num_tuple_groups]
group_layer_names = [
name for layer_tuple in group_layer_tuples for name in layer_tuple
]
group_layer_specs = {
name: sm_spec.kv_cache_specs[name] for name in group_layer_names
}
sub_sm_spec = UniformTypeKVCacheSpecs.from_specs(group_layer_specs)
assert sub_sm_spec is not None
swa_mla_groups.append(
KVCacheGroupSpec(
layer_names=group_layer_names,
kv_cache_spec=sub_sm_spec,
)
)
else:
# SWA pages are larger than MLA pages.
# Put each SWA layer in its own cache group (no padding needed).
# This is the DeepseekV4 Blackwell case where compress_ratio=1
# layers have much larger pages than compressed layers.
for layer_name, layer_spec in sm_spec.kv_cache_specs.items():
group_layer_specs = {layer_name: layer_spec}
sub_sm_spec = UniformTypeKVCacheSpecs.from_specs(group_layer_specs)
if sub_sm_spec is not None:
swa_mla_groups.append(
KVCacheGroupSpec(
layer_names=[layer_name],
kv_cache_spec=sub_sm_spec,
)
)"""
if old not in content:
print("ERROR: Could not find the code to patch")
sys.exit(1)
content = content.replace(old, new)
with open(path, 'w') as f:
f.write(content)
print("Patched kv_cache_utils.py for DeepseekV4 SWA page sizes")
if __name__ == "__main__":
patch(sys.argv[1])

View File

@@ -1,71 +0,0 @@
#!/usr/bin/env python3
"""Patch DeepseekV4SWACache on Blackwell: remove FlashMLA alignment and model_version.
On Blackwell (SM100+), FlashMLA doesn't work. We use our own CSA/SDPA attention.
The SWA cache should use standard fp8 format (not fp8_ds_mla) and no FlashMLA alignment.
"""
import sys
def patch(path):
with open(path, 'r') as f:
content = f.read()
if "CLAWMINE_PATCH_SWA_CACHE" in content:
print("Already patched, skipping")
return
# Patch the get_kv_cache_spec method to remove FlashMLA-specific values on Blackwell
old = """ def get_kv_cache_spec(self, vllm_config: VllmConfig) -> KVCacheSpec:
return SlidingWindowMLASpec(
block_size=self.block_size,
num_kv_heads=1,
head_size=self.head_dim,
dtype=self.dtype,
sliding_window=self.window_size,
cache_dtype_str=self.cache_config.cache_dtype,
alignment=576, # NOTE: FlashMLA requires 576B alignment
model_version="deepseek_v4",
)"""
new = """ def get_kv_cache_spec(self, vllm_config: VllmConfig) -> KVCacheSpec:
# CLAWMINE_PATCH_SWA_CACHE: On Blackwell, no FlashMLA = no 576B alignment
# Use standard fp8 format (not fp8_ds_mla), no model_version override
from vllm.platforms import current_platform
_is_blackwell = (
current_platform.get_device_capability() is not None
and current_platform.get_device_capability().major >= 10
)
if _is_blackwell:
return SlidingWindowMLASpec(
block_size=self.block_size,
num_kv_heads=1,
head_size=self.head_dim,
dtype=self.dtype,
sliding_window=self.window_size,
cache_dtype_str=self.cache_config.cache_dtype,
alignment=None, # No FlashMLA alignment on Blackwell
model_version=None, # Don't use 584B deepseek_v4 format
)
return SlidingWindowMLASpec(
block_size=self.block_size,
num_kv_heads=1,
head_size=self.head_dim,
dtype=self.dtype,
sliding_window=self.window_size,
cache_dtype_str=self.cache_config.cache_dtype,
alignment=576, # NOTE: FlashMLA requires 576B alignment
model_version="deepseek_v4",
)"""
if old not in content:
print("ERROR: Could not find the code to patch in " + path)
sys.exit(1)
content = content.replace(old, new)
with open(path, 'w') as f:
f.write(content)
print("Patched DeepseekV4SWACache for Blackwell")
if __name__ == "__main__":
patch(sys.argv[1])

View File

@@ -1,46 +0,0 @@
#!/usr/bin
# Patch vLLM's linear kernel __init__.py to register the CuTeDSL NVFP4 kernel.
# This inserts our kernel at the TOP of the _POSSIBLE_NVFP4_KERNELS list,
# so it gets selected first on Blackwell GPUs.
import sys
def patch_init(path):
with open(path, 'r') as f:
content = f.read()
# Add import after the existing flashinfer import block
import_line = (
"from vllm.model_executor.kernels.linear.nvfp4.cutedsl import (\n"
" CuTeDSLNvFp4LinearKernel,\n"
")\n"
)
# Insert after the marlin import block
marker = "from vllm.model_executor.kernels.linear.nvfp4.marlin import ("
if "CuTeDSLNvFp4LinearKernel" in content:
print("CuTeDSL kernel already registered, skipping")
return
idx = content.find(marker)
if idx == -1:
print("ERROR: Could not find marlin import marker")
sys.exit(1)
# Find end of marlin import block
end = content.find("\n\n", idx)
content = content[:end] + "\n" + import_line + content[end:]
# Insert CuTeDSLNvFp4LinearKernel at TOP of _POSSIBLE_NVFP4_KERNELS CUDA list
old = " PlatformEnum.CUDA: [\n FlashInferCutlassNvFp4LinearKernel,"
new = " PlatformEnum.CUDA: [\n CuTeDSLNvFp4LinearKernel,\n FlashInferCutlassNvFp4LinearKernel,"
content = content.replace(old, new)
# Also add to _NVFP4_BACKEND_TO_KERNEL so VLLM_NVFP4_GEMM_BACKEND=cutedsl works
old_backend = ' "emulation": EmulationNvFp4LinearKernel,\n}'
new_backend = ' "emulation": EmulationNvFp4LinearKernel,\n "cutedsl": CuTeDSLNvFp4LinearKernel,\n}'
content = content.replace(old_backend, new_backend)
with open(path, 'w') as f:
f.write(content)
print("Patched CuTeDSL NVFP4 kernel into", path)
if __name__ == "__main__":
patch_init(sys.argv[1])