diff --git a/vllm/kernels/linear/nvfp4/cutedsl.py b/vllm/kernels/linear/nvfp4/cutedsl.py index 45aa119f..51d737e9 100644 --- a/vllm/kernels/linear/nvfp4/cutedsl.py +++ b/vllm/kernels/linear/nvfp4/cutedsl.py @@ -6,8 +6,8 @@ Registers as an NvFp4LinearKernel so that vLLM kernel selection (init_nvfp4_linear_kernel) picks it up on Blackwell GPUs. Routes NVFP4 GEMM through CuTeDSL's MLIR-compiled grouped GEMM. -The GEMM is registered as a torch.library.custom_op so that -torch.compile/Dynamo treats it as opaque (CuTeDSL internals use +Uses torch.autograd.Function + torch._dynamo.allow_in_graph to make +the GEMM opaque to torch.compile/Dynamo (CuTeDSL internals use Path.cwd, JIT compilation, etc. which Dynamo cannot trace). """ @@ -21,88 +21,25 @@ from .base import NvFp4LinearKernel, NvFp4LinearLayerConfig logger = init_logger(__name__) -@torch.library.custom_op("cutedsl::nvfp4_linear", mutates_args=()) -def _cutedsl_nvfp4_linear( - x: torch.Tensor, - mat_b: torch.Tensor, - scale_b: torch.Tensor, - global_scale_b: torch.Tensor, - activation_global_scale: float, -) -> torch.Tensor: - """Run a single-group NVFP4 GEMM via CuTeDSL. +class _CuTeDSLNvfp4LinearFn(torch.autograd.Function): + """Custom autograd function to make CuTeDSL NVFP4 GEMM opaque to Dynamo. - This is registered as a torch.library.custom_op so Dynamo treats - it as opaque. The real implementation calls into CuTeDSL's - pre-assembled weight tensors and grouped GEMM kernel. - - Args: - x: BF16 input (M, K) - mat_b: Pre-assembled FP4 weight (1, K_padded, N_packed) float4_e2m1fn_x2 - scale_b: Pre-assembled block scales (1, K_sf_padded, N_sf_packed) fp8 - global_scale_b: Weight global scale (1,) float32 - activation_global_scale: Activation global scale (float) + Without this, Dynamo in fullgraph mode tries to trace through + CuTeDSL internals (JIT compilation uses Path.cwd, etc.) and crashes. + allow_in_graph tells Dynamo this is a known, opaque kernel call. """ - from cutedsl.bridge import (quantize_activation_nvfp4, - run_nvfp4_grouped_gemm) - from cutedsl.kernel.moe.torch_scaled_grouped_mm import pad_and_swizzle_single - from cutedsl.nvfp4_linear import cutedsl_ceil_div + @staticmethod + def forward(ctx, x, runner): + return runner._run_impl(x) - num_tokens = x.shape[0] - in_features = x.shape[1] - out_features = mat_b.shape[2] # packed N in float4 elements - - # Quantize activation - x_fp4, x_sf = quantize_activation_nvfp4(x, activation_global_scale) - - # Pad activation to 128-row alignment for TMA - padded_rows = cutedsl_ceil_div(num_tokens, 128) * 128 - if num_tokens < padded_rows: - x_fp4_padded = torch.zeros(padded_rows, x_fp4.shape[1], - dtype=torch.uint8, device=x.device - ).view(torch.float4_e2m1fn_x2) - x_fp4_padded[:num_tokens] = x_fp4 - else: - x_fp4_padded = x_fp4 - - # Assemble A-side scales: pad + swizzle for single group - num_rows, num_cols = x_sf.shape - padded_rows_sf = cutedsl_ceil_div(num_rows, 128) * 128 - padded_cols_sf = cutedsl_ceil_div(num_cols, 4) * 4 - buf = torch.zeros(padded_rows_sf, padded_cols_sf, dtype=torch.float8_e4m3fn, device=x_sf.device) - buf[:num_rows, :num_cols] = x_sf - scale_a = pad_and_swizzle_single(buf).unsqueeze(0) # (1, padded_rows, padded_cols) - - # Expert offsets for 1 group (int32 — CuTeDSL expects int32, not int64) - expert_offsets = torch.tensor([padded_rows], dtype=torch.int32, device=x.device) - - # Global scale for activation - global_scale_a = torch.tensor([activation_global_scale], dtype=torch.float32, device=x.device) - - # Run the CuTeDSL grouped GEMM (1 group) - out = run_nvfp4_grouped_gemm( - mat_a=x_fp4_padded, - mat_b=mat_b, - scale_a=scale_a, - scale_b=scale_b, - expert_offsets=expert_offsets, - global_scale_a=global_scale_a, - global_scale_b=global_scale_b, - ) - - return out[:num_tokens] + @staticmethod + def backward(ctx, grad_output): + raise NotImplementedError( + "CuTeDSL NVFP4 linear does not support backward") -@_cutedsl_nvfp4_linear.register_fake -def _cutedsl_nvfp4_linear_fake( - x: torch.Tensor, - mat_b: torch.Tensor, - scale_b: torch.Tensor, - global_scale_b: torch.Tensor, - activation_global_scale: float, -) -> torch.Tensor: - out_features = mat_b.shape[2] - return torch.empty((*x.shape[:-1], out_features), dtype=torch.bfloat16, - device=x.device) +# Tell Dynamo: this autograd function is an opaque op, don't trace inside. +torch._dynamo.allow_in_graph(_CuTeDSLNvfp4LinearFn) class CuTeDSLNvFp4LinearKernel(NvFp4LinearKernel): @@ -122,7 +59,7 @@ class CuTeDSLNvFp4LinearKernel(NvFp4LinearKernel): return False, "CuTeDSL NVFP4 requires SM100+ (Blackwell)" @classmethod - def can_implement(cls, config: NvFp4LinearLayerConfig) -> tuple[ bool, str | None]: + def can_implement(cls, config: NvFp4LinearLayerConfig) -> tuple[bool, str | None]: return True, None def process_weights_after_loading(self, layer: torch.nn.Module) -> None: @@ -174,20 +111,18 @@ class CuTeDSLNvFp4LinearKernel(NvFp4LinearKernel): runner.finalize_weights() # Compute activation global scale from input_global_scale_inv. + # quantize_activation_nvfp4(x, global_scale) normalizes: + # x_norm = x / global_scale + # global_scale = amax/448 = input_global_scale = 1/inv. activation_global_scale = 1.0 / 2688.0 # default fallback if hasattr(layer, 'input_global_scale_inv') and layer.input_global_scale_inv is not None: inv = layer.input_global_scale_inv.data.item() if inv != 0: activation_global_scale = 1.0 / inv + runner._activation_global_scale = activation_global_scale - # Store pre-assembled weight tensors on the layer for the custom op. - # mat_b shape: (1, K_padded, N_packed) float4_e2m1fn_x2 - # scale_b shape: (1, K_sf_padded, N_sf_packed) fp8 - # gsb shape: (1,) float32 - layer._cutedsl_mat_b = runner._mat_b - layer._cutedsl_scale_b = runner._scale_b - layer._cutedsl_global_scale_b = runner._gsb - layer._cutedsl_activation_global_scale = activation_global_scale + # Store runner on the layer + layer._cutedsl_runner = runner # Replace weight with dummy BF16 (vLLM module introspection may need it) layer.weight = torch.nn.Parameter( @@ -196,7 +131,7 @@ class CuTeDSLNvFp4LinearKernel(NvFp4LinearKernel): requires_grad=False, ) - # Clean up NVFP4 params that are now handled by our custom op. + # Clean up NVFP4 params that are now handled by the runner. for attr in ("weight_scale", "weight_global_scale", "input_global_scale", "input_global_scale_inv", "alpha", "weights_padding_cols", "weight_scale_2", @@ -213,13 +148,7 @@ class CuTeDSLNvFp4LinearKernel(NvFp4LinearKernel): x: torch.Tensor, bias: torch.Tensor | None = None, ) -> torch.Tensor: - result = torch.ops.cutedsl.nvfp4_linear( - x, - layer._cutedsl_mat_b, - layer._cutedsl_scale_b, - layer._cutedsl_global_scale_b, - layer._cutedsl_activation_global_scale, - ) + result = _CuTeDSLNvfp4LinearFn.apply(x, layer._cutedsl_runner) if bias is not None: result = result + bias return result