diff --git a/vllm/kernels/linear/nvfp4/cutedsl.py b/vllm/kernels/linear/nvfp4/cutedsl.py index 51d737e9..d013e39e 100644 --- a/vllm/kernels/linear/nvfp4/cutedsl.py +++ b/vllm/kernels/linear/nvfp4/cutedsl.py @@ -6,9 +6,14 @@ Registers as an NvFp4LinearKernel so that vLLM kernel selection (init_nvfp4_linear_kernel) picks it up on Blackwell GPUs. Routes NVFP4 GEMM through CuTeDSL's MLIR-compiled grouped GEMM. -Uses torch.autograd.Function + torch._dynamo.allow_in_graph to make -the GEMM opaque to torch.compile/Dynamo (CuTeDSL internals use +The GEMM is registered as a torch.library.custom_op so that +torch.compile/Dynamo treats it as opaque (CuTeDSL internals use Path.cwd, JIT compilation, etc. which Dynamo cannot trace). + +The custom op only takes tensor arguments. The runner's pre-assembled +weight tensors (mat_b, scale_b, global_scale_b) are stored on the +layer and passed directly. Activation quantization and scale assembly +are done inside the custom op. """ import torch @@ -21,25 +26,84 @@ from .base import NvFp4LinearKernel, NvFp4LinearLayerConfig logger = init_logger(__name__) -class _CuTeDSLNvfp4LinearFn(torch.autograd.Function): - """Custom autograd function to make CuTeDSL NVFP4 GEMM opaque to Dynamo. +@torch.library.custom_op("cutedsl::nvfp4_linear", mutates_args=()) +def _cutedsl_nvfp4_linear( + x: torch.Tensor, + mat_b: torch.Tensor, + scale_b: torch.Tensor, + global_scale_b: torch.Tensor, + activation_global_scale: float, +) -> torch.Tensor: + """Run a single-group NVFP4 GEMM via CuTeDSL. - Without this, Dynamo in fullgraph mode tries to trace through - CuTeDSL internals (JIT compilation uses Path.cwd, etc.) and crashes. - allow_in_graph tells Dynamo this is a known, opaque kernel call. + All args are tensors (or scalars) — Dynamo-compatible. + The weight tensors come from the runner's finalize_weights: + mat_b: (1, K_padded, N_packed) float4_e2m1fn_x2 + scale_b: (1, K_sf_padded, N_sf_packed) fp8 + global_scale_b: (1,) float32 """ - @staticmethod - def forward(ctx, x, runner): - return runner._run_impl(x) + from cutedsl.bridge import (quantize_activation_nvfp4, + run_nvfp4_grouped_gemm) + from cutedsl.nvfp4_linear import CuTeDSLNvfp4Linear, cutedsl_ceil_div + from cutedsl.kernel.moe.torch_scaled_grouped_mm import pad_and_swizzle_single - @staticmethod - def backward(ctx, grad_output): - raise NotImplementedError( - "CuTeDSL NVFP4 linear does not support backward") + num_tokens = x.shape[0] + out_features = mat_b.shape[2] # packed N in float4 elements + + # Quantize activation: x → (x_fp4, x_sf) + x_fp4, x_sf = quantize_activation_nvfp4(x, activation_global_scale) + + # Pad activation to 128-row alignment for TMA + padded_rows = cutedsl_ceil_div(num_tokens, 128) * 128 + if num_tokens < padded_rows: + # Can't torch.zeros with float4 dtype — allocate as uint8 then view + x_fp4_padded = torch.zeros(padded_rows, x_fp4.shape[1], + dtype=torch.uint8, device=x.device + ).view(torch.float4_e2m1fn_x2) + x_fp4_padded[:num_tokens] = x_fp4 + else: + x_fp4_padded = x_fp4 + + # Assemble A-side scales: pad + swizzle for CuTeDSL layout + num_rows_sf, num_cols_sf = x_sf.shape + padded_rows_sf = cutedsl_ceil_div(num_rows_sf, 128) * 128 + padded_cols_sf = cutedsl_ceil_div(num_cols_sf, 4) * 4 + sf_buf = torch.zeros(padded_rows_sf, padded_cols_sf, + dtype=torch.float8_e4m3fn, device=x_sf.device) + sf_buf[:num_rows_sf, :num_cols_sf] = x_sf + scale_a = pad_and_swizzle_single(sf_buf).unsqueeze(0) # (1, ...) + + # Expert offsets for 1 group (int32 — CuTeDSL requires int32) + expert_offsets = torch.tensor([padded_rows], dtype=torch.int32, device=x.device) + + # Global scale for activation + global_scale_a = torch.tensor([activation_global_scale], dtype=torch.float32, device=x.device) + + # Run the CuTeDSL grouped GEMM (1 group) + out = run_nvfp4_grouped_gemm( + mat_a=x_fp4_padded, + mat_b=mat_b, + scale_a=scale_a, + scale_b=scale_b, + expert_offsets=expert_offsets, + global_scale_a=global_scale_a, + global_scale_b=global_scale_b, + ) + + return out[:num_tokens] -# Tell Dynamo: this autograd function is an opaque op, don't trace inside. -torch._dynamo.allow_in_graph(_CuTeDSLNvfp4LinearFn) +@_cutedsl_nvfp4_linear.register_fake +def _cutedsl_nvfp4_linear_fake( + x: torch.Tensor, + mat_b: torch.Tensor, + scale_b: torch.Tensor, + global_scale_b: torch.Tensor, + activation_global_scale: float, +) -> torch.Tensor: + out_features = mat_b.shape[2] + return torch.empty((*x.shape[:-1], out_features), dtype=torch.bfloat16, + device=x.device) class CuTeDSLNvFp4LinearKernel(NvFp4LinearKernel): @@ -119,10 +183,12 @@ class CuTeDSLNvFp4LinearKernel(NvFp4LinearKernel): inv = layer.input_global_scale_inv.data.item() if inv != 0: activation_global_scale = 1.0 / inv - runner._activation_global_scale = activation_global_scale - # Store runner on the layer - layer._cutedsl_runner = runner + # Store pre-assembled weight tensors on the layer for the custom op. + layer._cutedsl_mat_b = runner._mat_b + layer._cutedsl_scale_b = runner._scale_b + layer._cutedsl_global_scale_b = runner._gsb + layer._cutedsl_activation_global_scale = activation_global_scale # Replace weight with dummy BF16 (vLLM module introspection may need it) layer.weight = torch.nn.Parameter( @@ -131,7 +197,7 @@ class CuTeDSLNvFp4LinearKernel(NvFp4LinearKernel): requires_grad=False, ) - # Clean up NVFP4 params that are now handled by the runner. + # Clean up NVFP4 params that are now handled by the custom op. for attr in ("weight_scale", "weight_global_scale", "input_global_scale", "input_global_scale_inv", "alpha", "weights_padding_cols", "weight_scale_2", @@ -148,7 +214,13 @@ class CuTeDSLNvFp4LinearKernel(NvFp4LinearKernel): x: torch.Tensor, bias: torch.Tensor | None = None, ) -> torch.Tensor: - result = _CuTeDSLNvfp4LinearFn.apply(x, layer._cutedsl_runner) + result = torch.ops.cutedsl.nvfp4_linear( + x, + layer._cutedsl_mat_b, + layer._cutedsl_scale_b, + layer._cutedsl_global_scale_b, + layer._cutedsl_activation_global_scale, + ) if bias is not None: result = result + bias return result