Go back to torch.library.custom_op with correct GEMM impl

allow_in_graph doesn't work — Dynamo can't create proxies for Python
objects (the runner). The custom op approach requires only tensor args.

This time the GEMM impl correctly:
- Uses quantize_activation_nvfp4 for activation quantization
- Pads x_fp4 via uint8 + view(float4) for torch.zeros compat
- Assembles A-side scales with pad + swizzle
- Uses int32 expert_offsets (CuTeDSL requirement)
- Passes runner's pre-assembled mat_b, scale_b, gsb tensors
This commit is contained in:
2026-05-19 01:24:41 +00:00
parent 02c500bbb1
commit 98153002c0

View File

@@ -6,9 +6,14 @@ Registers as an NvFp4LinearKernel so that vLLM kernel selection
(init_nvfp4_linear_kernel) picks it up on Blackwell GPUs.
Routes NVFP4 GEMM through CuTeDSL's MLIR-compiled grouped GEMM.
Uses torch.autograd.Function + torch._dynamo.allow_in_graph to make
the GEMM opaque to torch.compile/Dynamo (CuTeDSL internals use
The GEMM is registered as a torch.library.custom_op so that
torch.compile/Dynamo treats it as opaque (CuTeDSL internals use
Path.cwd, JIT compilation, etc. which Dynamo cannot trace).
The custom op only takes tensor arguments. The runner's pre-assembled
weight tensors (mat_b, scale_b, global_scale_b) are stored on the
layer and passed directly. Activation quantization and scale assembly
are done inside the custom op.
"""
import torch
@@ -21,25 +26,84 @@ from .base import NvFp4LinearKernel, NvFp4LinearLayerConfig
logger = init_logger(__name__)
class _CuTeDSLNvfp4LinearFn(torch.autograd.Function):
"""Custom autograd function to make CuTeDSL NVFP4 GEMM opaque to Dynamo.
@torch.library.custom_op("cutedsl::nvfp4_linear", mutates_args=())
def _cutedsl_nvfp4_linear(
x: torch.Tensor,
mat_b: torch.Tensor,
scale_b: torch.Tensor,
global_scale_b: torch.Tensor,
activation_global_scale: float,
) -> torch.Tensor:
"""Run a single-group NVFP4 GEMM via CuTeDSL.
Without this, Dynamo in fullgraph mode tries to trace through
CuTeDSL internals (JIT compilation uses Path.cwd, etc.) and crashes.
allow_in_graph tells Dynamo this is a known, opaque kernel call.
All args are tensors (or scalars) — Dynamo-compatible.
The weight tensors come from the runner's finalize_weights:
mat_b: (1, K_padded, N_packed) float4_e2m1fn_x2
scale_b: (1, K_sf_padded, N_sf_packed) fp8
global_scale_b: (1,) float32
"""
@staticmethod
def forward(ctx, x, runner):
return runner._run_impl(x)
from cutedsl.bridge import (quantize_activation_nvfp4,
run_nvfp4_grouped_gemm)
from cutedsl.nvfp4_linear import CuTeDSLNvfp4Linear, cutedsl_ceil_div
from cutedsl.kernel.moe.torch_scaled_grouped_mm import pad_and_swizzle_single
@staticmethod
def backward(ctx, grad_output):
raise NotImplementedError(
"CuTeDSL NVFP4 linear does not support backward")
num_tokens = x.shape[0]
out_features = mat_b.shape[2] # packed N in float4 elements
# Quantize activation: x → (x_fp4, x_sf)
x_fp4, x_sf = quantize_activation_nvfp4(x, activation_global_scale)
# Pad activation to 128-row alignment for TMA
padded_rows = cutedsl_ceil_div(num_tokens, 128) * 128
if num_tokens < padded_rows:
# Can't torch.zeros with float4 dtype — allocate as uint8 then view
x_fp4_padded = torch.zeros(padded_rows, x_fp4.shape[1],
dtype=torch.uint8, device=x.device
).view(torch.float4_e2m1fn_x2)
x_fp4_padded[:num_tokens] = x_fp4
else:
x_fp4_padded = x_fp4
# Assemble A-side scales: pad + swizzle for CuTeDSL layout
num_rows_sf, num_cols_sf = x_sf.shape
padded_rows_sf = cutedsl_ceil_div(num_rows_sf, 128) * 128
padded_cols_sf = cutedsl_ceil_div(num_cols_sf, 4) * 4
sf_buf = torch.zeros(padded_rows_sf, padded_cols_sf,
dtype=torch.float8_e4m3fn, device=x_sf.device)
sf_buf[:num_rows_sf, :num_cols_sf] = x_sf
scale_a = pad_and_swizzle_single(sf_buf).unsqueeze(0) # (1, ...)
# Expert offsets for 1 group (int32 — CuTeDSL requires int32)
expert_offsets = torch.tensor([padded_rows], dtype=torch.int32, device=x.device)
# Global scale for activation
global_scale_a = torch.tensor([activation_global_scale], dtype=torch.float32, device=x.device)
# Run the CuTeDSL grouped GEMM (1 group)
out = run_nvfp4_grouped_gemm(
mat_a=x_fp4_padded,
mat_b=mat_b,
scale_a=scale_a,
scale_b=scale_b,
expert_offsets=expert_offsets,
global_scale_a=global_scale_a,
global_scale_b=global_scale_b,
)
return out[:num_tokens]
# Tell Dynamo: this autograd function is an opaque op, don't trace inside.
torch._dynamo.allow_in_graph(_CuTeDSLNvfp4LinearFn)
@_cutedsl_nvfp4_linear.register_fake
def _cutedsl_nvfp4_linear_fake(
x: torch.Tensor,
mat_b: torch.Tensor,
scale_b: torch.Tensor,
global_scale_b: torch.Tensor,
activation_global_scale: float,
) -> torch.Tensor:
out_features = mat_b.shape[2]
return torch.empty((*x.shape[:-1], out_features), dtype=torch.bfloat16,
device=x.device)
class CuTeDSLNvFp4LinearKernel(NvFp4LinearKernel):
@@ -119,10 +183,12 @@ class CuTeDSLNvFp4LinearKernel(NvFp4LinearKernel):
inv = layer.input_global_scale_inv.data.item()
if inv != 0:
activation_global_scale = 1.0 / inv
runner._activation_global_scale = activation_global_scale
# Store runner on the layer
layer._cutedsl_runner = runner
# Store pre-assembled weight tensors on the layer for the custom op.
layer._cutedsl_mat_b = runner._mat_b
layer._cutedsl_scale_b = runner._scale_b
layer._cutedsl_global_scale_b = runner._gsb
layer._cutedsl_activation_global_scale = activation_global_scale
# Replace weight with dummy BF16 (vLLM module introspection may need it)
layer.weight = torch.nn.Parameter(
@@ -131,7 +197,7 @@ class CuTeDSLNvFp4LinearKernel(NvFp4LinearKernel):
requires_grad=False,
)
# Clean up NVFP4 params that are now handled by the runner.
# Clean up NVFP4 params that are now handled by the custom op.
for attr in ("weight_scale", "weight_global_scale",
"input_global_scale", "input_global_scale_inv",
"alpha", "weights_padding_cols", "weight_scale_2",
@@ -148,7 +214,13 @@ class CuTeDSLNvFp4LinearKernel(NvFp4LinearKernel):
x: torch.Tensor,
bias: torch.Tensor | None = None,
) -> torch.Tensor:
result = _CuTeDSLNvfp4LinearFn.apply(x, layer._cutedsl_runner)
result = torch.ops.cutedsl.nvfp4_linear(
x,
layer._cutedsl_mat_b,
layer._cutedsl_scale_b,
layer._cutedsl_global_scale_b,
layer._cutedsl_activation_global_scale,
)
if bias is not None:
result = result + bias
return result