Register CuTeDSL as proper NvFp4LinearKernel for NVFP4 linear layers
- Create CuTeDSLNvFp4LinearKernel extending NvFp4LinearKernel base class - Register it via init_nvfp4_linear_kernel() selection mechanism (inserted at top of _POSSIBLE_NVFP4_KERNELS, before FlashInfer) - process_weights_after_loading: uint8→FP4, permute, create CuTeDSL runner - apply_weights: route through CuTeDSL GEMM - Update Dockerfile: copy kernel + registration script - Fix attention: always use forward() for quantized compressor/indexer layers (dtype check was fragile after kernel swaps weights to dummy BF16)
This commit is contained in:
@@ -39,6 +39,15 @@ COPY vllm/patches/deepseek_v4.py ${VLLM_MODELS_DIR}/deepseek_v4.py
|
||||
COPY vllm/patches/deepseek_v4_attention.py ${VLLM_LAYERS_DIR}/deepseek_v4_attention.py
|
||||
COPY vllm/patches/layers/deepseek_compressor.py ${VLLM_LAYERS_DIR}/deepseek_compressor.py
|
||||
|
||||
# CuTeDSL NVFP4 linear kernel (registered as NvFp4LinearKernel)
|
||||
ARG VLLM_NVFP4_DIR=/usr/local/lib/python3.12/dist-packages/vllm/model_executor/kernels/linear/nvfp4
|
||||
COPY vllm/kernels/linear/nvfp4/cutedsl.py ${VLLM_NVFP4_DIR}/cutedsl.py
|
||||
|
||||
# Register CuTeDSL kernel in vLLM's linear kernel selection
|
||||
ARG VLLM_LINEAR_DIR=/usr/local/lib/python3.12/dist-packages/vllm/model_executor/kernels/linear
|
||||
COPY vllm/patches/register_cutedsl_kernel.py /tmp/register_cutedsl_kernel.py
|
||||
RUN python3 /tmp/register_cutedsl_kernel.py ${VLLM_LINEAR_DIR}/__init__.py && rm /tmp/register_cutedsl_kernel.py
|
||||
|
||||
# Config patches (add cutedsl to MoEBackend)
|
||||
ARG VLLM_CONFIG_DIR=/usr/local/lib/python3.12/dist-packages/vllm/config
|
||||
COPY vllm/patches/kernel.py ${VLLM_CONFIG_DIR}/kernel.py
|
||||
|
||||
149
vllm/kernels/linear/nvfp4/cutedsl.py
Normal file
149
vllm/kernels/linear/nvfp4/cutedsl.py
Normal file
@@ -0,0 +1,149 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""CuTeDSL NVFP4 Linear Kernel for vLLM.
|
||||
|
||||
Registers as an NvFp4LinearKernel so that vLLM's kernel selection
|
||||
mechanism (init_nvfp4_linear_kernel) picks it up on Blackwell GPUs.
|
||||
Routes NVFP4 GEMM through the CuTeDSL framework, which uses MLIR-compiled
|
||||
grouped GEMM kernels with Blackwell-specific TMA + wgmma instructions.
|
||||
|
||||
CUDA-graph-compatible: all intermediate buffers are pre-allocated,
|
||||
no CPU-GPU syncs, no dynamic shapes.
|
||||
"""
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
from .base import NvFp4LinearKernel, NvFp4LinearLayerConfig
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class CuTeDSLNvFp4LinearKernel(NvFp4LinearKernel):
|
||||
"""NVFP4 GEMM via the CuTeDSL framework (Blackwell SM100+).
|
||||
|
||||
Uses CuTeDSL's ScaledGroupedGemmKernel with num_groups=1 for
|
||||
single linear layers. Weight processing:
|
||||
- uint8 packed FP4 → float4_e2m1fn_x2, permuted to (K, N)
|
||||
- FP8 block scales permuted to (K_sf, N)
|
||||
- Global scale stored as float32
|
||||
|
||||
Activation quantization is done internally (NVFP4 W4A4).
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def is_supported(
|
||||
cls, compute_capability: int | None = None
|
||||
) -> tuple[bool, str | None]:
|
||||
cap = compute_capability or current_platform.get_device_capability()
|
||||
if cap is not None and cap.major >= 10:
|
||||
return True, None
|
||||
return False, "CuTeDSL NVFP4 requires SM100+ (Blackwell)"
|
||||
|
||||
@classmethod
|
||||
def can_implement(cls, config: NvFp4LinearLayerConfig) -> tuple[bool, str | None]:
|
||||
return True, None
|
||||
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
"""Convert NVFP4 weights into CuTeDSL kernel format.
|
||||
|
||||
Reads the layer's weight (uint8), weight_scale (fp8), and
|
||||
weight_global_scale (float32) — all set up by
|
||||
ModelOptNvFp4LinearMethod.process_weights_before our call.
|
||||
Creates a CuTeDSLNvfp4Linear runner and stores it on the layer.
|
||||
"""
|
||||
from cutedsl.nvfp4_linear import CuTeDSLNvfp4Linear
|
||||
|
||||
w_uint8 = layer.weight.data # (out, in//2) uint8 packed E2M1
|
||||
device = w_uint8.device
|
||||
out_features = w_uint8.shape[0]
|
||||
in_features = w_uint8.shape[1] * 2 # 2 FP4 values per uint8
|
||||
|
||||
# Convert uint8 → float4_e2m1fn_x2, then permute to (K_packed, N)
|
||||
w_fp4 = w_uint8.view(torch.float4_e2m1fn_x2).permute(1, 0).contiguous()
|
||||
|
||||
# Block scales: (N, K_sf) → (K_sf, N) for CuTeDSL
|
||||
sf = layer.weight_scale.data
|
||||
if sf.dtype != torch.float8_e4m3fn:
|
||||
sf = sf.to(torch.float8_e4m3fn)
|
||||
sf = sf.permute(1, 0).contiguous()
|
||||
|
||||
# Global scale (set by ModelOptNvFp4LinearMethod.process_weights_after_loading)
|
||||
gs = layer.weight_global_scale.data.item()
|
||||
|
||||
# Handle fused projections (MergedColumnParallelLinear with dual gs).
|
||||
# When weight_global_scale has 2 elements (e.g. fused_wqa_wkv),
|
||||
# normalize to max(gs1, gs2) and fold ratio into block scales.
|
||||
if layer.weight_global_scale.numel() == 2:
|
||||
gs0 = layer.weight_global_scale[0].item()
|
||||
gs1 = layer.weight_global_scale[1].item()
|
||||
gs = max(gs0, gs1)
|
||||
if gs0 != gs1:
|
||||
sf_f32 = sf.float()
|
||||
logical_widths = getattr(layer, 'logical_widths', None)
|
||||
if logical_widths is not None and len(logical_widths) == 2:
|
||||
split_point = logical_widths[0]
|
||||
else:
|
||||
split_point = out_features // 2
|
||||
sf_f32[:, :split_point] *= (gs0 / gs)
|
||||
sf_f32[:, split_point:] *= (gs1 / gs)
|
||||
sf = sf_f32.to(torch.float8_e4m3fn)
|
||||
|
||||
# Create CuTeDSL runner
|
||||
runner = CuTeDSLNvfp4Linear(
|
||||
in_features=in_features,
|
||||
out_features=out_features,
|
||||
device=str(device),
|
||||
)
|
||||
runner.fp4 = [w_fp4]
|
||||
runner.sf = [sf]
|
||||
runner.gs = [gs]
|
||||
runner.finalize_weights()
|
||||
|
||||
# Compute activation global scale from input_global_scale_inv.
|
||||
# ModelOptNvFp4LinearMethod sets:
|
||||
# input_global_scale = input_scale.max() = amax/448 (small)
|
||||
# input_global_scale_inv = 1/input_global_scale = 448/amax (large)
|
||||
# Our quantize_activation_nvfp4(x, global_scale) normalizes:
|
||||
# x_norm = x / global_scale
|
||||
# So global_scale = amax/448 = input_global_scale = 1/inv.
|
||||
if hasattr(layer, 'input_global_scale_inv') and layer.input_global_scale_inv is not None:
|
||||
inv = layer.input_global_scale_inv.data.item()
|
||||
if inv != 0:
|
||||
runner._activation_global_scale = 1.0 / inv
|
||||
|
||||
# Store runner on the layer
|
||||
layer._cutedsl_runner = runner
|
||||
|
||||
# Replace weight with dummy BF16 (vLLM module introspection may need it)
|
||||
layer.weight = torch.nn.Parameter(
|
||||
torch.zeros(out_features, in_features, dtype=torch.bfloat16,
|
||||
device=device),
|
||||
requires_grad=False,
|
||||
)
|
||||
|
||||
# Clean up NVFP4 params that are now in the runner.
|
||||
# Keep output_size_per_partition, logical_widths, input_size_per_partition
|
||||
# which may be referenced by the layer's forward path.
|
||||
for attr in ("weight_scale", "weight_global_scale",
|
||||
"input_global_scale", "input_global_scale_inv",
|
||||
"alpha", "weights_padding_cols", "weight_scale_2",
|
||||
"input_scale"):
|
||||
if hasattr(layer, attr):
|
||||
try:
|
||||
delattr(layer, attr)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def apply_weights(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
result = layer._cutedsl_runner(x)
|
||||
if bias is not None:
|
||||
result = result + bias
|
||||
return result
|
||||
@@ -366,23 +366,14 @@ class DeepseekV4MultiHeadLatentAttentionWrapper(PluggableLayer):
|
||||
compressor = self.compressor
|
||||
|
||||
def compressor_kv_score() -> torch.Tensor:
|
||||
# For NVFP4-quantized weights, we can't do a raw torch.mm
|
||||
# with packed uint8 weights. Use the layer's forward()
|
||||
# which handles dequantization properly.
|
||||
wkv_wgate_weight = compressor.fused_wkv_wgate.weight
|
||||
if wkv_wgate_weight.dtype == torch.uint8:
|
||||
# NVFP4 packed weights — use forward() for dequant+matmul
|
||||
result = compressor.fused_wkv_wgate(hidden_states)
|
||||
# MergedColumnParallelLinear may return (output, bias) or
|
||||
# just output depending on quantization method.
|
||||
if isinstance(result, tuple):
|
||||
result = result[0]
|
||||
return result.to(torch.float32)
|
||||
return torch.mm(
|
||||
hidden_states,
|
||||
wkv_wgate_weight.T,
|
||||
out_dtype=torch.float32,
|
||||
)
|
||||
# Use forward() for quantized layers (NVFP4, FP8, etc.)
|
||||
# — raw torch.mm doesn't work with packed/dequantized weights.
|
||||
# MergedColumnParallelLinear with return_bias=False returns
|
||||
# a tensor directly.
|
||||
result = compressor.fused_wkv_wgate(hidden_states)
|
||||
if isinstance(result, tuple):
|
||||
result = result[0]
|
||||
return result.to(torch.float32)
|
||||
|
||||
aux_fns[0] = compressor_kv_score
|
||||
|
||||
@@ -395,17 +386,10 @@ class DeepseekV4MultiHeadLatentAttentionWrapper(PluggableLayer):
|
||||
return weights
|
||||
|
||||
def indexer_compressor_kv_score() -> torch.Tensor:
|
||||
wkv_wgate_weight = indexer.compressor.fused_wkv_wgate.weight
|
||||
if wkv_wgate_weight.dtype == torch.uint8:
|
||||
result = indexer.compressor.fused_wkv_wgate(hidden_states)
|
||||
if isinstance(result, tuple):
|
||||
result = result[0]
|
||||
return result.to(torch.float32)
|
||||
return torch.mm(
|
||||
hidden_states,
|
||||
wkv_wgate_weight.T,
|
||||
out_dtype=torch.float32,
|
||||
)
|
||||
result = indexer.compressor.fused_wkv_wgate(hidden_states)
|
||||
if isinstance(result, tuple):
|
||||
result = result[0]
|
||||
return result.to(torch.float32)
|
||||
|
||||
aux_fns[1] = indexer_weights_proj
|
||||
aux_fns[2] = indexer_compressor_kv_score
|
||||
|
||||
41
vllm/patches/register_cutedsl_kernel.py
Normal file
41
vllm/patches/register_cutedsl_kernel.py
Normal file
@@ -0,0 +1,41 @@
|
||||
#!/usr/bin
|
||||
# Patch vLLM's linear kernel __init__.py to register the CuTeDSL NVFP4 kernel.
|
||||
# This inserts our kernel at the TOP of the _POSSIBLE_NVFP4_KERNELS list,
|
||||
# so it gets selected first on Blackwell GPUs.
|
||||
|
||||
import sys
|
||||
|
||||
def patch_init(path):
|
||||
with open(path, 'r') as f:
|
||||
content = f.read()
|
||||
|
||||
# Add import after the existing flashinfer import block
|
||||
import_line = (
|
||||
"from vllm.model_executor.kernels.linear.nvfp4.cutedsl import (\n"
|
||||
" CuTeDSLNvFp4LinearKernel,\n"
|
||||
")\n"
|
||||
)
|
||||
# Insert after the marlin import block
|
||||
marker = "from vllm.model_executor.kernels.linear.nvfp4.marlin import ("
|
||||
if "CuTeDSLNvFp4LinearKernel" in content:
|
||||
print("CuTeDSL kernel already registered, skipping")
|
||||
return
|
||||
idx = content.find(marker)
|
||||
if idx == -1:
|
||||
print("ERROR: Could not find marlin import marker")
|
||||
sys.exit(1)
|
||||
# Find end of marlin import block
|
||||
end = content.find("\n\n", idx)
|
||||
content = content[:end] + "\n" + import_line + content[end:]
|
||||
|
||||
# Insert CuTeDSLNvFp4LinearKernel at TOP of _POSSIBLE_NVFP4_KERNELS CUDA list
|
||||
old = " PlatformEnum.CUDA: [\n FlashInferCutlassNvFp4LinearKernel,"
|
||||
new = " PlatformEnum.CUDA: [\n CuTeDSLNvFp4LinearKernel,\n FlashInferCutlassNvFp4LinearKernel,"
|
||||
content = content.replace(old, new)
|
||||
|
||||
with open(path, 'w') as f:
|
||||
f.write(content)
|
||||
print("Patched CuTeDSL NVFP4 kernel into", path)
|
||||
|
||||
if __name__ == "__main__":
|
||||
patch_init(sys.argv[1])
|
||||
Reference in New Issue
Block a user