Files
nvfp4-megamoe-kernel/vllm/kernels/linear/nvfp4/cutedsl.py
biondizzle 35fab6cff3 Replace autograd.Function with torch.library.custom_op for Dynamo compat
Dynamo (torch.compile fullgraph) cannot trace through CuTeDSL internals
(cute.compile, JIT, etc.). The autograd.Function approach was unreliable
with fullgraph mode — Dynamo would still try to trace through it.

Fix: torch.library.custom_op makes Dynamo treat our GEMM as an opaque
black box. No reimplementing the kernel — just route through the existing
runner via a registry pattern:
  - Runners registered in global dict with integer IDs
  - Custom op takes (tensors, runner_id, shape_hint) -> tensor
  - Dynamo calls fake impl for shape inference, never touches the runner
  - At execution time, real impl looks up runner and calls _run_impl

Changes:
  - New: cutedsl/custom_ops.py (custom op definitions + registry)
  - New: tests/test_custom_op.py (local unit tests, no GPU needed)
  - Removed: _Nvfp4LinearApply, _MoEApply (autograd.Function classes)
  - Updated: nvfp4_linear.py, runner.py, cutedsl.py, nvfp4_cutedsl.py
    to use custom ops instead of autograd.Function
  - Updated: cutedsl_quant_method.py to use custom op + registry
2026-05-19 01:54:48 +00:00

124 lines
4.4 KiB
Python

# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""CuTeDSL NVFP4 Linear Kernel for vLLM.
Registers as an NvFp4LinearKernel so that vLLM kernel selection
(init_nvfp4_linear_kernel) picks it up on Blackwell GPUs.
Routes NVFP4 GEMM through CuTeDSL's MLIR-compiled grouped GEMM.
Uses torch.library.custom_op to make Dynamo (torch.compile) treat the
GEMM as opaque. The runner's _run_impl is already cudagraph-safe.
"""
import torch
from vllm.logger import init_logger
from vllm.platforms import current_platform
from .base import NvFp4LinearKernel, NvFp4LinearLayerConfig
from cutedsl.custom_ops import register_runner, nvfp4_linear_gemm
logger = init_logger(__name__)
class CuTeDSLNvFp4LinearKernel(NvFp4LinearKernel):
"""NVFP4 GEMM via the CuTeDSL framework (Blackwell SM100+)."""
@classmethod
def is_supported(
cls, compute_capability: int | None = None
) -> tuple[bool, str | None]:
cap = compute_capability or current_platform.get_device_capability()
if cap is not None and cap.major >= 10:
return True, None
return False, "CuTeDSL NVFP4 requires SM100+ (Blackwell)"
@classmethod
def can_implement(cls, config: NvFp4LinearLayerConfig) -> tuple[bool, str | None]:
return True, None
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
"""Convert NVFP4 weights into CuTeDSL kernel format."""
from cutedsl.nvfp4_linear import CuTeDSLNvfp4Linear
w_uint8 = layer.weight.data
device = w_uint8.device
out_features = w_uint8.shape[0]
in_features = w_uint8.shape[1] * 2
w_fp4 = w_uint8.view(torch.float4_e2m1fn_x2).permute(1, 0).contiguous()
sf = layer.weight_scale.data
if sf.dtype != torch.float8_e4m3fn:
sf = sf.to(torch.float8_e4m3fn)
sf = sf.permute(1, 0).contiguous()
gs = layer.weight_global_scale.data.item()
if layer.weight_global_scale.numel() == 2:
gs0 = layer.weight_global_scale[0].item()
gs1 = layer.weight_global_scale[1].item()
gs = max(gs0, gs1)
if gs0 != gs1:
sf_f32 = sf.float()
logical_widths = getattr(layer, 'logical_widths', None)
if logical_widths is not None and len(logical_widths) == 2:
split_point = logical_widths[0]
else:
split_point = out_features // 2
sf_f32[:, :split_point] *= (gs0 / gs)
sf_f32[:, split_point:] *= (gs1 / gs)
sf = sf_f32.to(torch.float8_e4m3fn)
runner = CuTeDSLNvfp4Linear(
in_features=in_features,
out_features=out_features,
device=str(device),
)
runner.fp4 = [w_fp4]
runner.sf = [sf]
runner.gs = [gs]
runner.finalize_weights()
activation_global_scale = 1.0 / 2688.0
if hasattr(layer, 'input_global_scale_inv') and layer.input_global_scale_inv is not None:
inv = layer.input_global_scale_inv.data.item()
if inv != 0:
activation_global_scale = 1.0 / inv
runner._activation_global_scale = activation_global_scale
# Register the runner and store the ID (not the runner itself)
layer._cutedsl_runner_id = register_runner(runner)
layer._cutedsl_out_features = out_features
layer.weight = torch.nn.Parameter(
torch.zeros(out_features, in_features, dtype=torch.bfloat16,
device=device),
requires_grad=False,
)
for attr in ("weight_scale", "weight_global_scale",
"input_global_scale", "input_global_scale_inv",
"alpha", "weights_padding_cols", "weight_scale_2",
"input_scale"):
if hasattr(layer, attr):
try:
delattr(layer, attr)
except Exception:
pass
def apply_weights(
self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: torch.Tensor | None = None,
) -> torch.Tensor:
result = nvfp4_linear_gemm(
x,
layer._cutedsl_runner_id,
layer._cutedsl_out_features,
)
if bias is not None:
result = result + bias
return result