From c043a11bccd1083cd154c229d78c5d570d7710f2 Mon Sep 17 00:00:00 2001 From: biondizzle Date: Tue, 19 May 2026 00:44:44 +0000 Subject: [PATCH] Register CuTeDSL as proper NvFp4LinearKernel for NVFP4 linear layers MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Create CuTeDSLNvFp4LinearKernel extending NvFp4LinearKernel base class - Register it via init_nvfp4_linear_kernel() selection mechanism (inserted at top of _POSSIBLE_NVFP4_KERNELS, before FlashInfer) - process_weights_after_loading: uint8→FP4, permute, create CuTeDSL runner - apply_weights: route through CuTeDSL GEMM - Update Dockerfile: copy kernel + registration script - Fix attention: always use forward() for quantized compressor/indexer layers (dtype check was fragile after kernel swaps weights to dummy BF16) --- Dockerfile | 9 ++ vllm/kernels/linear/nvfp4/cutedsl.py | 149 ++++++++++++++++++++++++ vllm/patches/deepseek_v4_attention.py | 40 ++----- vllm/patches/register_cutedsl_kernel.py | 41 +++++++ 4 files changed, 211 insertions(+), 28 deletions(-) create mode 100644 vllm/kernels/linear/nvfp4/cutedsl.py create mode 100644 vllm/patches/register_cutedsl_kernel.py diff --git a/Dockerfile b/Dockerfile index 69742eb6..9529c401 100644 --- a/Dockerfile +++ b/Dockerfile @@ -39,6 +39,15 @@ COPY vllm/patches/deepseek_v4.py ${VLLM_MODELS_DIR}/deepseek_v4.py COPY vllm/patches/deepseek_v4_attention.py ${VLLM_LAYERS_DIR}/deepseek_v4_attention.py COPY vllm/patches/layers/deepseek_compressor.py ${VLLM_LAYERS_DIR}/deepseek_compressor.py +# CuTeDSL NVFP4 linear kernel (registered as NvFp4LinearKernel) +ARG VLLM_NVFP4_DIR=/usr/local/lib/python3.12/dist-packages/vllm/model_executor/kernels/linear/nvfp4 +COPY vllm/kernels/linear/nvfp4/cutedsl.py ${VLLM_NVFP4_DIR}/cutedsl.py + +# Register CuTeDSL kernel in vLLM's linear kernel selection +ARG VLLM_LINEAR_DIR=/usr/local/lib/python3.12/dist-packages/vllm/model_executor/kernels/linear +COPY vllm/patches/register_cutedsl_kernel.py /tmp/register_cutedsl_kernel.py +RUN python3 /tmp/register_cutedsl_kernel.py ${VLLM_LINEAR_DIR}/__init__.py && rm /tmp/register_cutedsl_kernel.py + # Config patches (add cutedsl to MoEBackend) ARG VLLM_CONFIG_DIR=/usr/local/lib/python3.12/dist-packages/vllm/config COPY vllm/patches/kernel.py ${VLLM_CONFIG_DIR}/kernel.py diff --git a/vllm/kernels/linear/nvfp4/cutedsl.py b/vllm/kernels/linear/nvfp4/cutedsl.py new file mode 100644 index 00000000..1dbf72b4 --- /dev/null +++ b/vllm/kernels/linear/nvfp4/cutedsl.py @@ -0,0 +1,149 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""CuTeDSL NVFP4 Linear Kernel for vLLM. + +Registers as an NvFp4LinearKernel so that vLLM's kernel selection +mechanism (init_nvfp4_linear_kernel) picks it up on Blackwell GPUs. +Routes NVFP4 GEMM through the CuTeDSL framework, which uses MLIR-compiled +grouped GEMM kernels with Blackwell-specific TMA + wgmma instructions. + +CUDA-graph-compatible: all intermediate buffers are pre-allocated, +no CPU-GPU syncs, no dynamic shapes. +""" + +import torch + +from vllm.logger import init_logger +from vllm.platforms import current_platform + +from .base import NvFp4LinearKernel, NvFp4LinearLayerConfig + +logger = init_logger(__name__) + + +class CuTeDSLNvFp4LinearKernel(NvFp4LinearKernel): + """NVFP4 GEMM via the CuTeDSL framework (Blackwell SM100+). + + Uses CuTeDSL's ScaledGroupedGemmKernel with num_groups=1 for + single linear layers. Weight processing: + - uint8 packed FP4 → float4_e2m1fn_x2, permuted to (K, N) + - FP8 block scales permuted to (K_sf, N) + - Global scale stored as float32 + + Activation quantization is done internally (NVFP4 W4A4). + """ + + @classmethod + def is_supported( + cls, compute_capability: int | None = None + ) -> tuple[bool, str | None]: + cap = compute_capability or current_platform.get_device_capability() + if cap is not None and cap.major >= 10: + return True, None + return False, "CuTeDSL NVFP4 requires SM100+ (Blackwell)" + + @classmethod + def can_implement(cls, config: NvFp4LinearLayerConfig) -> tuple[bool, str | None]: + return True, None + + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + """Convert NVFP4 weights into CuTeDSL kernel format. + + Reads the layer's weight (uint8), weight_scale (fp8), and + weight_global_scale (float32) — all set up by + ModelOptNvFp4LinearMethod.process_weights_before our call. + Creates a CuTeDSLNvfp4Linear runner and stores it on the layer. + """ + from cutedsl.nvfp4_linear import CuTeDSLNvfp4Linear + + w_uint8 = layer.weight.data # (out, in//2) uint8 packed E2M1 + device = w_uint8.device + out_features = w_uint8.shape[0] + in_features = w_uint8.shape[1] * 2 # 2 FP4 values per uint8 + + # Convert uint8 → float4_e2m1fn_x2, then permute to (K_packed, N) + w_fp4 = w_uint8.view(torch.float4_e2m1fn_x2).permute(1, 0).contiguous() + + # Block scales: (N, K_sf) → (K_sf, N) for CuTeDSL + sf = layer.weight_scale.data + if sf.dtype != torch.float8_e4m3fn: + sf = sf.to(torch.float8_e4m3fn) + sf = sf.permute(1, 0).contiguous() + + # Global scale (set by ModelOptNvFp4LinearMethod.process_weights_after_loading) + gs = layer.weight_global_scale.data.item() + + # Handle fused projections (MergedColumnParallelLinear with dual gs). + # When weight_global_scale has 2 elements (e.g. fused_wqa_wkv), + # normalize to max(gs1, gs2) and fold ratio into block scales. + if layer.weight_global_scale.numel() == 2: + gs0 = layer.weight_global_scale[0].item() + gs1 = layer.weight_global_scale[1].item() + gs = max(gs0, gs1) + if gs0 != gs1: + sf_f32 = sf.float() + logical_widths = getattr(layer, 'logical_widths', None) + if logical_widths is not None and len(logical_widths) == 2: + split_point = logical_widths[0] + else: + split_point = out_features // 2 + sf_f32[:, :split_point] *= (gs0 / gs) + sf_f32[:, split_point:] *= (gs1 / gs) + sf = sf_f32.to(torch.float8_e4m3fn) + + # Create CuTeDSL runner + runner = CuTeDSLNvfp4Linear( + in_features=in_features, + out_features=out_features, + device=str(device), + ) + runner.fp4 = [w_fp4] + runner.sf = [sf] + runner.gs = [gs] + runner.finalize_weights() + + # Compute activation global scale from input_global_scale_inv. + # ModelOptNvFp4LinearMethod sets: + # input_global_scale = input_scale.max() = amax/448 (small) + # input_global_scale_inv = 1/input_global_scale = 448/amax (large) + # Our quantize_activation_nvfp4(x, global_scale) normalizes: + # x_norm = x / global_scale + # So global_scale = amax/448 = input_global_scale = 1/inv. + if hasattr(layer, 'input_global_scale_inv') and layer.input_global_scale_inv is not None: + inv = layer.input_global_scale_inv.data.item() + if inv != 0: + runner._activation_global_scale = 1.0 / inv + + # Store runner on the layer + layer._cutedsl_runner = runner + + # Replace weight with dummy BF16 (vLLM module introspection may need it) + layer.weight = torch.nn.Parameter( + torch.zeros(out_features, in_features, dtype=torch.bfloat16, + device=device), + requires_grad=False, + ) + + # Clean up NVFP4 params that are now in the runner. + # Keep output_size_per_partition, logical_widths, input_size_per_partition + # which may be referenced by the layer's forward path. + for attr in ("weight_scale", "weight_global_scale", + "input_global_scale", "input_global_scale_inv", + "alpha", "weights_padding_cols", "weight_scale_2", + "input_scale"): + if hasattr(layer, attr): + try: + delattr(layer, attr) + except Exception: + pass + + def apply_weights( + self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: torch.Tensor | None = None, + ) -> torch.Tensor: + result = layer._cutedsl_runner(x) + if bias is not None: + result = result + bias + return result diff --git a/vllm/patches/deepseek_v4_attention.py b/vllm/patches/deepseek_v4_attention.py index 53fdcb07..b27e248d 100644 --- a/vllm/patches/deepseek_v4_attention.py +++ b/vllm/patches/deepseek_v4_attention.py @@ -366,23 +366,14 @@ class DeepseekV4MultiHeadLatentAttentionWrapper(PluggableLayer): compressor = self.compressor def compressor_kv_score() -> torch.Tensor: - # For NVFP4-quantized weights, we can't do a raw torch.mm - # with packed uint8 weights. Use the layer's forward() - # which handles dequantization properly. - wkv_wgate_weight = compressor.fused_wkv_wgate.weight - if wkv_wgate_weight.dtype == torch.uint8: - # NVFP4 packed weights — use forward() for dequant+matmul - result = compressor.fused_wkv_wgate(hidden_states) - # MergedColumnParallelLinear may return (output, bias) or - # just output depending on quantization method. - if isinstance(result, tuple): - result = result[0] - return result.to(torch.float32) - return torch.mm( - hidden_states, - wkv_wgate_weight.T, - out_dtype=torch.float32, - ) + # Use forward() for quantized layers (NVFP4, FP8, etc.) + # — raw torch.mm doesn't work with packed/dequantized weights. + # MergedColumnParallelLinear with return_bias=False returns + # a tensor directly. + result = compressor.fused_wkv_wgate(hidden_states) + if isinstance(result, tuple): + result = result[0] + return result.to(torch.float32) aux_fns[0] = compressor_kv_score @@ -395,17 +386,10 @@ class DeepseekV4MultiHeadLatentAttentionWrapper(PluggableLayer): return weights def indexer_compressor_kv_score() -> torch.Tensor: - wkv_wgate_weight = indexer.compressor.fused_wkv_wgate.weight - if wkv_wgate_weight.dtype == torch.uint8: - result = indexer.compressor.fused_wkv_wgate(hidden_states) - if isinstance(result, tuple): - result = result[0] - return result.to(torch.float32) - return torch.mm( - hidden_states, - wkv_wgate_weight.T, - out_dtype=torch.float32, - ) + result = indexer.compressor.fused_wkv_wgate(hidden_states) + if isinstance(result, tuple): + result = result[0] + return result.to(torch.float32) aux_fns[1] = indexer_weights_proj aux_fns[2] = indexer_compressor_kv_score diff --git a/vllm/patches/register_cutedsl_kernel.py b/vllm/patches/register_cutedsl_kernel.py new file mode 100644 index 00000000..8951949a --- /dev/null +++ b/vllm/patches/register_cutedsl_kernel.py @@ -0,0 +1,41 @@ +#!/usr/bin +# Patch vLLM's linear kernel __init__.py to register the CuTeDSL NVFP4 kernel. +# This inserts our kernel at the TOP of the _POSSIBLE_NVFP4_KERNELS list, +# so it gets selected first on Blackwell GPUs. + +import sys + +def patch_init(path): + with open(path, 'r') as f: + content = f.read() + + # Add import after the existing flashinfer import block + import_line = ( + "from vllm.model_executor.kernels.linear.nvfp4.cutedsl import (\n" + " CuTeDSLNvFp4LinearKernel,\n" + ")\n" + ) + # Insert after the marlin import block + marker = "from vllm.model_executor.kernels.linear.nvfp4.marlin import (" + if "CuTeDSLNvFp4LinearKernel" in content: + print("CuTeDSL kernel already registered, skipping") + return + idx = content.find(marker) + if idx == -1: + print("ERROR: Could not find marlin import marker") + sys.exit(1) + # Find end of marlin import block + end = content.find("\n\n", idx) + content = content[:end] + "\n" + import_line + content[end:] + + # Insert CuTeDSLNvFp4LinearKernel at TOP of _POSSIBLE_NVFP4_KERNELS CUDA list + old = " PlatformEnum.CUDA: [\n FlashInferCutlassNvFp4LinearKernel," + new = " PlatformEnum.CUDA: [\n CuTeDSLNvFp4LinearKernel,\n FlashInferCutlassNvFp4LinearKernel," + content = content.replace(old, new) + + with open(path, 'w') as f: + f.write(content) + print("Patched CuTeDSL NVFP4 kernel into", path) + +if __name__ == "__main__": + patch_init(sys.argv[1])