diff --git a/vllm/kernels/linear/nvfp4/cutedsl.py b/vllm/kernels/linear/nvfp4/cutedsl.py index 72f55a52..53fb83c3 100644 --- a/vllm/kernels/linear/nvfp4/cutedsl.py +++ b/vllm/kernels/linear/nvfp4/cutedsl.py @@ -21,79 +21,73 @@ from .base import NvFp4LinearKernel, NvFp4LinearLayerConfig logger = init_logger(__name__) -def _cutedsl_nvfp4_gemm_impl( - x: torch.Tensor, - mat_b: torch.Tensor, - scale_b: torch.Tensor, - global_scale_b: torch.Tensor, - in_features: int, - out_features: int, - activation_global_scale: float, -) -> torch.Tensor: - """Run a single-group NVFP4 GEMM via CuTeDSL.""" - from cutedsl.bridge import (pad_and_swizzle_single, - quantize_activation_nvfp4, - run_nvfp4_grouped_gemm) - from cutedsl.nvfp4_linear import cutedsl_ceil_div - - num_tokens = x.shape[0] - padded_rows = cutedsl_ceil_div(num_tokens, 128) * 128 - - # Quantize activation - x_fp4, x_sf = quantize_activation_nvfp4(x, activation_global_scale) - - # Pad activation to 128-row alignment for TMA - if num_tokens < padded_rows: - x_fp4_padded = torch.zeros(padded_rows, x_fp4.shape[1], - dtype=x_fp4.dtype, device=x.device) - x_fp4_padded[:num_tokens] = x_fp4 - else: - x_fp4_padded = x_fp4 - - # Assemble A-side scales in CuTeDSL layout - scale_a = pad_and_swizzle_single(x_sf, num_tokens, x_sf.shape[1], x.device) - - # Expert offsets for 1 group (all tokens in one group) - expert_offsets = torch.tensor([padded_rows], dtype=torch.int64, device=x.device) - - # Global scale for activation - global_scale_a = torch.tensor([activation_global_scale], dtype=torch.float32, device=x.device) - - # Run the CuTeDSL grouped GEMM (1 group) - out = run_nvfp4_grouped_gemm( - mat_a=x_fp4_padded, - mat_b=mat_b, - scale_a=scale_a, - scale_b=scale_b, - expert_offsets=expert_offsets, - global_scale_a=global_scale_a, - global_scale_b=global_scale_b, - ) - - return out[:num_tokens] - - -def _cutedsl_nvfp4_gemm_fake( - x: torch.Tensor, - mat_b: torch.Tensor, - scale_b: torch.Tensor, - global_scale_b: torch.Tensor, - in_features: int, - out_features: int, - activation_global_scale: float, -) -> torch.Tensor: - return torch.empty((*x.shape[:-1], out_features), dtype=torch.bfloat16, - device=x.device) - - # Register custom op (idempotent — safe to import multiple times) if not hasattr(torch.ops, 'cutedsl') or not hasattr(torch.ops.cutedsl, 'nvfp4_gemm'): - _CUTEDSL_NVFP4_GEMM = torch.library.custom_op( - "cutedsl::nvfp4_gemm", - _cutedsl_nvfp4_gemm_impl, - mutates_args=(), - )(_cutedsl_nvfp4_gemm_impl) - _CUTEDSL_NVFP4_GEMM.register_fake(_cutedsl_nvfp4_gemm_fake) + + @torch.library.custom_op("cutedsl::nvfp4_gemm", mutates_args=()) + def _cutedsl_nvfp4_gemm( + x: torch.Tensor, + mat_b: torch.Tensor, + scale_b: torch.Tensor, + global_scale_b: torch.Tensor, + in_features: int, + out_features: int, + activation_global_scale: float, + ) -> torch.Tensor: + """Run a single-group NVFP4 GEMM via CuTeDSL.""" + from cutedsl.bridge import (pad_and_swizzle_single, + quantize_activation_nvfp4, + run_nvfp4_grouped_gemm) + from cutedsl.nvfp4_linear import cutedsl_ceil_div + + num_tokens = x.shape[0] + padded_rows = cutedsl_ceil_div(num_tokens, 128) * 128 + + # Quantize activation + x_fp4, x_sf = quantize_activation_nvfp4(x, activation_global_scale) + + # Pad activation to 128-row alignment for TMA + if num_tokens < padded_rows: + x_fp4_padded = torch.zeros(padded_rows, x_fp4.shape[1], + dtype=x_fp4.dtype, device=x.device) + x_fp4_padded[:num_tokens] = x_fp4 + else: + x_fp4_padded = x_fp4 + + # Assemble A-side scales in CuTeDSL layout + scale_a = pad_and_swizzle_single(x_sf, num_tokens, x_sf.shape[1], x.device) + + # Expert offsets for 1 group (all tokens in one group) + expert_offsets = torch.tensor([padded_rows], dtype=torch.int64, device=x.device) + + # Global scale for activation + global_scale_a = torch.tensor([activation_global_scale], dtype=torch.float32, device=x.device) + + # Run the CuTeDSL grouped GEMM (1 group) + out = run_nvfp4_grouped_gemm( + mat_a=x_fp4_padded, + mat_b=mat_b, + scale_a=scale_a, + scale_b=scale_b, + expert_offsets=expert_offsets, + global_scale_a=global_scale_a, + global_scale_b=global_scale_b, + ) + + return out[:num_tokens] + + @_cutedsl_nvfp4_gemm.register_fake + def _cutedsl_nvfp4_gemm_fake( + x: torch.Tensor, + mat_b: torch.Tensor, + scale_b: torch.Tensor, + global_scale_b: torch.Tensor, + in_features: int, + out_features: int, + activation_global_scale: float, + ) -> torch.Tensor: + return torch.empty((*x.shape[:-1], out_features), dtype=torch.bfloat16, + device=x.device) class CuTeDSLNvFp4LinearKernel(NvFp4LinearKernel):