Register CuTeDSL as proper NvFp4LinearKernel for NVFP4 linear layers

- Create CuTeDSLNvFp4LinearKernel extending NvFp4LinearKernel base class
- Register it via init_nvfp4_linear_kernel() selection mechanism
  (inserted at top of _POSSIBLE_NVFP4_KERNELS, before FlashInfer)
- process_weights_after_loading: uint8→FP4, permute, create CuTeDSL runner
- apply_weights: route through CuTeDSL GEMM
- Update Dockerfile: copy kernel + registration script
- Fix attention: always use forward() for quantized compressor/indexer
  layers (dtype check was fragile after kernel swaps weights to dummy BF16)
This commit is contained in:
2026-05-19 00:44:44 +00:00
parent 358830925a
commit c043a11bcc
4 changed files with 211 additions and 28 deletions

View File

@@ -39,6 +39,15 @@ COPY vllm/patches/deepseek_v4.py ${VLLM_MODELS_DIR}/deepseek_v4.py
COPY vllm/patches/deepseek_v4_attention.py ${VLLM_LAYERS_DIR}/deepseek_v4_attention.py
COPY vllm/patches/layers/deepseek_compressor.py ${VLLM_LAYERS_DIR}/deepseek_compressor.py
# CuTeDSL NVFP4 linear kernel (registered as NvFp4LinearKernel)
ARG VLLM_NVFP4_DIR=/usr/local/lib/python3.12/dist-packages/vllm/model_executor/kernels/linear/nvfp4
COPY vllm/kernels/linear/nvfp4/cutedsl.py ${VLLM_NVFP4_DIR}/cutedsl.py
# Register CuTeDSL kernel in vLLM's linear kernel selection
ARG VLLM_LINEAR_DIR=/usr/local/lib/python3.12/dist-packages/vllm/model_executor/kernels/linear
COPY vllm/patches/register_cutedsl_kernel.py /tmp/register_cutedsl_kernel.py
RUN python3 /tmp/register_cutedsl_kernel.py ${VLLM_LINEAR_DIR}/__init__.py && rm /tmp/register_cutedsl_kernel.py
# Config patches (add cutedsl to MoEBackend)
ARG VLLM_CONFIG_DIR=/usr/local/lib/python3.12/dist-packages/vllm/config
COPY vllm/patches/kernel.py ${VLLM_CONFIG_DIR}/kernel.py

View File

@@ -0,0 +1,149 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""CuTeDSL NVFP4 Linear Kernel for vLLM.
Registers as an NvFp4LinearKernel so that vLLM's kernel selection
mechanism (init_nvfp4_linear_kernel) picks it up on Blackwell GPUs.
Routes NVFP4 GEMM through the CuTeDSL framework, which uses MLIR-compiled
grouped GEMM kernels with Blackwell-specific TMA + wgmma instructions.
CUDA-graph-compatible: all intermediate buffers are pre-allocated,
no CPU-GPU syncs, no dynamic shapes.
"""
import torch
from vllm.logger import init_logger
from vllm.platforms import current_platform
from .base import NvFp4LinearKernel, NvFp4LinearLayerConfig
logger = init_logger(__name__)
class CuTeDSLNvFp4LinearKernel(NvFp4LinearKernel):
"""NVFP4 GEMM via the CuTeDSL framework (Blackwell SM100+).
Uses CuTeDSL's ScaledGroupedGemmKernel with num_groups=1 for
single linear layers. Weight processing:
- uint8 packed FP4 → float4_e2m1fn_x2, permuted to (K, N)
- FP8 block scales permuted to (K_sf, N)
- Global scale stored as float32
Activation quantization is done internally (NVFP4 W4A4).
"""
@classmethod
def is_supported(
cls, compute_capability: int | None = None
) -> tuple[bool, str | None]:
cap = compute_capability or current_platform.get_device_capability()
if cap is not None and cap.major >= 10:
return True, None
return False, "CuTeDSL NVFP4 requires SM100+ (Blackwell)"
@classmethod
def can_implement(cls, config: NvFp4LinearLayerConfig) -> tuple[bool, str | None]:
return True, None
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
"""Convert NVFP4 weights into CuTeDSL kernel format.
Reads the layer's weight (uint8), weight_scale (fp8), and
weight_global_scale (float32) — all set up by
ModelOptNvFp4LinearMethod.process_weights_before our call.
Creates a CuTeDSLNvfp4Linear runner and stores it on the layer.
"""
from cutedsl.nvfp4_linear import CuTeDSLNvfp4Linear
w_uint8 = layer.weight.data # (out, in//2) uint8 packed E2M1
device = w_uint8.device
out_features = w_uint8.shape[0]
in_features = w_uint8.shape[1] * 2 # 2 FP4 values per uint8
# Convert uint8 → float4_e2m1fn_x2, then permute to (K_packed, N)
w_fp4 = w_uint8.view(torch.float4_e2m1fn_x2).permute(1, 0).contiguous()
# Block scales: (N, K_sf) → (K_sf, N) for CuTeDSL
sf = layer.weight_scale.data
if sf.dtype != torch.float8_e4m3fn:
sf = sf.to(torch.float8_e4m3fn)
sf = sf.permute(1, 0).contiguous()
# Global scale (set by ModelOptNvFp4LinearMethod.process_weights_after_loading)
gs = layer.weight_global_scale.data.item()
# Handle fused projections (MergedColumnParallelLinear with dual gs).
# When weight_global_scale has 2 elements (e.g. fused_wqa_wkv),
# normalize to max(gs1, gs2) and fold ratio into block scales.
if layer.weight_global_scale.numel() == 2:
gs0 = layer.weight_global_scale[0].item()
gs1 = layer.weight_global_scale[1].item()
gs = max(gs0, gs1)
if gs0 != gs1:
sf_f32 = sf.float()
logical_widths = getattr(layer, 'logical_widths', None)
if logical_widths is not None and len(logical_widths) == 2:
split_point = logical_widths[0]
else:
split_point = out_features // 2
sf_f32[:, :split_point] *= (gs0 / gs)
sf_f32[:, split_point:] *= (gs1 / gs)
sf = sf_f32.to(torch.float8_e4m3fn)
# Create CuTeDSL runner
runner = CuTeDSLNvfp4Linear(
in_features=in_features,
out_features=out_features,
device=str(device),
)
runner.fp4 = [w_fp4]
runner.sf = [sf]
runner.gs = [gs]
runner.finalize_weights()
# Compute activation global scale from input_global_scale_inv.
# ModelOptNvFp4LinearMethod sets:
# input_global_scale = input_scale.max() = amax/448 (small)
# input_global_scale_inv = 1/input_global_scale = 448/amax (large)
# Our quantize_activation_nvfp4(x, global_scale) normalizes:
# x_norm = x / global_scale
# So global_scale = amax/448 = input_global_scale = 1/inv.
if hasattr(layer, 'input_global_scale_inv') and layer.input_global_scale_inv is not None:
inv = layer.input_global_scale_inv.data.item()
if inv != 0:
runner._activation_global_scale = 1.0 / inv
# Store runner on the layer
layer._cutedsl_runner = runner
# Replace weight with dummy BF16 (vLLM module introspection may need it)
layer.weight = torch.nn.Parameter(
torch.zeros(out_features, in_features, dtype=torch.bfloat16,
device=device),
requires_grad=False,
)
# Clean up NVFP4 params that are now in the runner.
# Keep output_size_per_partition, logical_widths, input_size_per_partition
# which may be referenced by the layer's forward path.
for attr in ("weight_scale", "weight_global_scale",
"input_global_scale", "input_global_scale_inv",
"alpha", "weights_padding_cols", "weight_scale_2",
"input_scale"):
if hasattr(layer, attr):
try:
delattr(layer, attr)
except Exception:
pass
def apply_weights(
self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: torch.Tensor | None = None,
) -> torch.Tensor:
result = layer._cutedsl_runner(x)
if bias is not None:
result = result + bias
return result

View File

@@ -366,23 +366,14 @@ class DeepseekV4MultiHeadLatentAttentionWrapper(PluggableLayer):
compressor = self.compressor
def compressor_kv_score() -> torch.Tensor:
# For NVFP4-quantized weights, we can't do a raw torch.mm
# with packed uint8 weights. Use the layer's forward()
# which handles dequantization properly.
wkv_wgate_weight = compressor.fused_wkv_wgate.weight
if wkv_wgate_weight.dtype == torch.uint8:
# NVFP4 packed weights — use forward() for dequant+matmul
result = compressor.fused_wkv_wgate(hidden_states)
# MergedColumnParallelLinear may return (output, bias) or
# just output depending on quantization method.
if isinstance(result, tuple):
result = result[0]
return result.to(torch.float32)
return torch.mm(
hidden_states,
wkv_wgate_weight.T,
out_dtype=torch.float32,
)
# Use forward() for quantized layers (NVFP4, FP8, etc.)
# — raw torch.mm doesn't work with packed/dequantized weights.
# MergedColumnParallelLinear with return_bias=False returns
# a tensor directly.
result = compressor.fused_wkv_wgate(hidden_states)
if isinstance(result, tuple):
result = result[0]
return result.to(torch.float32)
aux_fns[0] = compressor_kv_score
@@ -395,17 +386,10 @@ class DeepseekV4MultiHeadLatentAttentionWrapper(PluggableLayer):
return weights
def indexer_compressor_kv_score() -> torch.Tensor:
wkv_wgate_weight = indexer.compressor.fused_wkv_wgate.weight
if wkv_wgate_weight.dtype == torch.uint8:
result = indexer.compressor.fused_wkv_wgate(hidden_states)
if isinstance(result, tuple):
result = result[0]
return result.to(torch.float32)
return torch.mm(
hidden_states,
wkv_wgate_weight.T,
out_dtype=torch.float32,
)
result = indexer.compressor.fused_wkv_wgate(hidden_states)
if isinstance(result, tuple):
result = result[0]
return result.to(torch.float32)
aux_fns[1] = indexer_weights_proj
aux_fns[2] = indexer_compressor_kv_score

View File

@@ -0,0 +1,41 @@
#!/usr/bin
# Patch vLLM's linear kernel __init__.py to register the CuTeDSL NVFP4 kernel.
# This inserts our kernel at the TOP of the _POSSIBLE_NVFP4_KERNELS list,
# so it gets selected first on Blackwell GPUs.
import sys
def patch_init(path):
with open(path, 'r') as f:
content = f.read()
# Add import after the existing flashinfer import block
import_line = (
"from vllm.model_executor.kernels.linear.nvfp4.cutedsl import (\n"
" CuTeDSLNvFp4LinearKernel,\n"
")\n"
)
# Insert after the marlin import block
marker = "from vllm.model_executor.kernels.linear.nvfp4.marlin import ("
if "CuTeDSLNvFp4LinearKernel" in content:
print("CuTeDSL kernel already registered, skipping")
return
idx = content.find(marker)
if idx == -1:
print("ERROR: Could not find marlin import marker")
sys.exit(1)
# Find end of marlin import block
end = content.find("\n\n", idx)
content = content[:end] + "\n" + import_line + content[end:]
# Insert CuTeDSLNvFp4LinearKernel at TOP of _POSSIBLE_NVFP4_KERNELS CUDA list
old = " PlatformEnum.CUDA: [\n FlashInferCutlassNvFp4LinearKernel,"
new = " PlatformEnum.CUDA: [\n CuTeDSLNvFp4LinearKernel,\n FlashInferCutlassNvFp4LinearKernel,"
content = content.replace(old, new)
with open(path, 'w') as f:
f.write(content)
print("Patched CuTeDSL NVFP4 kernel into", path)
if __name__ == "__main__":
patch_init(sys.argv[1])