Go back to torch.library.custom_op with correct GEMM impl
allow_in_graph doesn't work — Dynamo can't create proxies for Python objects (the runner). The custom op approach requires only tensor args. This time the GEMM impl correctly: - Uses quantize_activation_nvfp4 for activation quantization - Pads x_fp4 via uint8 + view(float4) for torch.zeros compat - Assembles A-side scales with pad + swizzle - Uses int32 expert_offsets (CuTeDSL requirement) - Passes runner's pre-assembled mat_b, scale_b, gsb tensors
This commit is contained in:
@@ -6,9 +6,14 @@ Registers as an NvFp4LinearKernel so that vLLM kernel selection
|
||||
(init_nvfp4_linear_kernel) picks it up on Blackwell GPUs.
|
||||
Routes NVFP4 GEMM through CuTeDSL's MLIR-compiled grouped GEMM.
|
||||
|
||||
Uses torch.autograd.Function + torch._dynamo.allow_in_graph to make
|
||||
the GEMM opaque to torch.compile/Dynamo (CuTeDSL internals use
|
||||
The GEMM is registered as a torch.library.custom_op so that
|
||||
torch.compile/Dynamo treats it as opaque (CuTeDSL internals use
|
||||
Path.cwd, JIT compilation, etc. which Dynamo cannot trace).
|
||||
|
||||
The custom op only takes tensor arguments. The runner's pre-assembled
|
||||
weight tensors (mat_b, scale_b, global_scale_b) are stored on the
|
||||
layer and passed directly. Activation quantization and scale assembly
|
||||
are done inside the custom op.
|
||||
"""
|
||||
|
||||
import torch
|
||||
@@ -21,25 +26,84 @@ from .base import NvFp4LinearKernel, NvFp4LinearLayerConfig
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class _CuTeDSLNvfp4LinearFn(torch.autograd.Function):
|
||||
"""Custom autograd function to make CuTeDSL NVFP4 GEMM opaque to Dynamo.
|
||||
@torch.library.custom_op("cutedsl::nvfp4_linear", mutates_args=())
|
||||
def _cutedsl_nvfp4_linear(
|
||||
x: torch.Tensor,
|
||||
mat_b: torch.Tensor,
|
||||
scale_b: torch.Tensor,
|
||||
global_scale_b: torch.Tensor,
|
||||
activation_global_scale: float,
|
||||
) -> torch.Tensor:
|
||||
"""Run a single-group NVFP4 GEMM via CuTeDSL.
|
||||
|
||||
Without this, Dynamo in fullgraph mode tries to trace through
|
||||
CuTeDSL internals (JIT compilation uses Path.cwd, etc.) and crashes.
|
||||
allow_in_graph tells Dynamo this is a known, opaque kernel call.
|
||||
All args are tensors (or scalars) — Dynamo-compatible.
|
||||
The weight tensors come from the runner's finalize_weights:
|
||||
mat_b: (1, K_padded, N_packed) float4_e2m1fn_x2
|
||||
scale_b: (1, K_sf_padded, N_sf_packed) fp8
|
||||
global_scale_b: (1,) float32
|
||||
"""
|
||||
@staticmethod
|
||||
def forward(ctx, x, runner):
|
||||
return runner._run_impl(x)
|
||||
from cutedsl.bridge import (quantize_activation_nvfp4,
|
||||
run_nvfp4_grouped_gemm)
|
||||
from cutedsl.nvfp4_linear import CuTeDSLNvfp4Linear, cutedsl_ceil_div
|
||||
from cutedsl.kernel.moe.torch_scaled_grouped_mm import pad_and_swizzle_single
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
raise NotImplementedError(
|
||||
"CuTeDSL NVFP4 linear does not support backward")
|
||||
num_tokens = x.shape[0]
|
||||
out_features = mat_b.shape[2] # packed N in float4 elements
|
||||
|
||||
# Quantize activation: x → (x_fp4, x_sf)
|
||||
x_fp4, x_sf = quantize_activation_nvfp4(x, activation_global_scale)
|
||||
|
||||
# Pad activation to 128-row alignment for TMA
|
||||
padded_rows = cutedsl_ceil_div(num_tokens, 128) * 128
|
||||
if num_tokens < padded_rows:
|
||||
# Can't torch.zeros with float4 dtype — allocate as uint8 then view
|
||||
x_fp4_padded = torch.zeros(padded_rows, x_fp4.shape[1],
|
||||
dtype=torch.uint8, device=x.device
|
||||
).view(torch.float4_e2m1fn_x2)
|
||||
x_fp4_padded[:num_tokens] = x_fp4
|
||||
else:
|
||||
x_fp4_padded = x_fp4
|
||||
|
||||
# Assemble A-side scales: pad + swizzle for CuTeDSL layout
|
||||
num_rows_sf, num_cols_sf = x_sf.shape
|
||||
padded_rows_sf = cutedsl_ceil_div(num_rows_sf, 128) * 128
|
||||
padded_cols_sf = cutedsl_ceil_div(num_cols_sf, 4) * 4
|
||||
sf_buf = torch.zeros(padded_rows_sf, padded_cols_sf,
|
||||
dtype=torch.float8_e4m3fn, device=x_sf.device)
|
||||
sf_buf[:num_rows_sf, :num_cols_sf] = x_sf
|
||||
scale_a = pad_and_swizzle_single(sf_buf).unsqueeze(0) # (1, ...)
|
||||
|
||||
# Expert offsets for 1 group (int32 — CuTeDSL requires int32)
|
||||
expert_offsets = torch.tensor([padded_rows], dtype=torch.int32, device=x.device)
|
||||
|
||||
# Global scale for activation
|
||||
global_scale_a = torch.tensor([activation_global_scale], dtype=torch.float32, device=x.device)
|
||||
|
||||
# Run the CuTeDSL grouped GEMM (1 group)
|
||||
out = run_nvfp4_grouped_gemm(
|
||||
mat_a=x_fp4_padded,
|
||||
mat_b=mat_b,
|
||||
scale_a=scale_a,
|
||||
scale_b=scale_b,
|
||||
expert_offsets=expert_offsets,
|
||||
global_scale_a=global_scale_a,
|
||||
global_scale_b=global_scale_b,
|
||||
)
|
||||
|
||||
return out[:num_tokens]
|
||||
|
||||
|
||||
# Tell Dynamo: this autograd function is an opaque op, don't trace inside.
|
||||
torch._dynamo.allow_in_graph(_CuTeDSLNvfp4LinearFn)
|
||||
@_cutedsl_nvfp4_linear.register_fake
|
||||
def _cutedsl_nvfp4_linear_fake(
|
||||
x: torch.Tensor,
|
||||
mat_b: torch.Tensor,
|
||||
scale_b: torch.Tensor,
|
||||
global_scale_b: torch.Tensor,
|
||||
activation_global_scale: float,
|
||||
) -> torch.Tensor:
|
||||
out_features = mat_b.shape[2]
|
||||
return torch.empty((*x.shape[:-1], out_features), dtype=torch.bfloat16,
|
||||
device=x.device)
|
||||
|
||||
|
||||
class CuTeDSLNvFp4LinearKernel(NvFp4LinearKernel):
|
||||
@@ -119,10 +183,12 @@ class CuTeDSLNvFp4LinearKernel(NvFp4LinearKernel):
|
||||
inv = layer.input_global_scale_inv.data.item()
|
||||
if inv != 0:
|
||||
activation_global_scale = 1.0 / inv
|
||||
runner._activation_global_scale = activation_global_scale
|
||||
|
||||
# Store runner on the layer
|
||||
layer._cutedsl_runner = runner
|
||||
# Store pre-assembled weight tensors on the layer for the custom op.
|
||||
layer._cutedsl_mat_b = runner._mat_b
|
||||
layer._cutedsl_scale_b = runner._scale_b
|
||||
layer._cutedsl_global_scale_b = runner._gsb
|
||||
layer._cutedsl_activation_global_scale = activation_global_scale
|
||||
|
||||
# Replace weight with dummy BF16 (vLLM module introspection may need it)
|
||||
layer.weight = torch.nn.Parameter(
|
||||
@@ -131,7 +197,7 @@ class CuTeDSLNvFp4LinearKernel(NvFp4LinearKernel):
|
||||
requires_grad=False,
|
||||
)
|
||||
|
||||
# Clean up NVFP4 params that are now handled by the runner.
|
||||
# Clean up NVFP4 params that are now handled by the custom op.
|
||||
for attr in ("weight_scale", "weight_global_scale",
|
||||
"input_global_scale", "input_global_scale_inv",
|
||||
"alpha", "weights_padding_cols", "weight_scale_2",
|
||||
@@ -148,7 +214,13 @@ class CuTeDSLNvFp4LinearKernel(NvFp4LinearKernel):
|
||||
x: torch.Tensor,
|
||||
bias: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
result = _CuTeDSLNvfp4LinearFn.apply(x, layer._cutedsl_runner)
|
||||
result = torch.ops.cutedsl.nvfp4_linear(
|
||||
x,
|
||||
layer._cutedsl_mat_b,
|
||||
layer._cutedsl_scale_b,
|
||||
layer._cutedsl_global_scale_b,
|
||||
layer._cutedsl_activation_global_scale,
|
||||
)
|
||||
if bias is not None:
|
||||
result = result + bias
|
||||
return result
|
||||
|
||||
Reference in New Issue
Block a user