Switch to allow_in_graph for Dynamo opacity instead of custom op
The custom op approach required reimplementing the GEMM (wrong scale assembly, wrong tensor formats, cudaErrorIllegalAddress). Instead, use torch.autograd.Function + torch._dynamo.allow_in_graph which tells Dynamo to treat the function as an opaque kernel call, while still using the runner's battle-tested _run_impl for the actual GEMM. allow_in_graph is the proper way to register opaque ops for Dynamo without reimplementing the computation.
This commit is contained in:
@@ -6,8 +6,8 @@ Registers as an NvFp4LinearKernel so that vLLM kernel selection
|
||||
(init_nvfp4_linear_kernel) picks it up on Blackwell GPUs.
|
||||
Routes NVFP4 GEMM through CuTeDSL's MLIR-compiled grouped GEMM.
|
||||
|
||||
The GEMM is registered as a torch.library.custom_op so that
|
||||
torch.compile/Dynamo treats it as opaque (CuTeDSL internals use
|
||||
Uses torch.autograd.Function + torch._dynamo.allow_in_graph to make
|
||||
the GEMM opaque to torch.compile/Dynamo (CuTeDSL internals use
|
||||
Path.cwd, JIT compilation, etc. which Dynamo cannot trace).
|
||||
"""
|
||||
|
||||
@@ -21,88 +21,25 @@ from .base import NvFp4LinearKernel, NvFp4LinearLayerConfig
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@torch.library.custom_op("cutedsl::nvfp4_linear", mutates_args=())
|
||||
def _cutedsl_nvfp4_linear(
|
||||
x: torch.Tensor,
|
||||
mat_b: torch.Tensor,
|
||||
scale_b: torch.Tensor,
|
||||
global_scale_b: torch.Tensor,
|
||||
activation_global_scale: float,
|
||||
) -> torch.Tensor:
|
||||
"""Run a single-group NVFP4 GEMM via CuTeDSL.
|
||||
class _CuTeDSLNvfp4LinearFn(torch.autograd.Function):
|
||||
"""Custom autograd function to make CuTeDSL NVFP4 GEMM opaque to Dynamo.
|
||||
|
||||
This is registered as a torch.library.custom_op so Dynamo treats
|
||||
it as opaque. The real implementation calls into CuTeDSL's
|
||||
pre-assembled weight tensors and grouped GEMM kernel.
|
||||
|
||||
Args:
|
||||
x: BF16 input (M, K)
|
||||
mat_b: Pre-assembled FP4 weight (1, K_padded, N_packed) float4_e2m1fn_x2
|
||||
scale_b: Pre-assembled block scales (1, K_sf_padded, N_sf_packed) fp8
|
||||
global_scale_b: Weight global scale (1,) float32
|
||||
activation_global_scale: Activation global scale (float)
|
||||
Without this, Dynamo in fullgraph mode tries to trace through
|
||||
CuTeDSL internals (JIT compilation uses Path.cwd, etc.) and crashes.
|
||||
allow_in_graph tells Dynamo this is a known, opaque kernel call.
|
||||
"""
|
||||
from cutedsl.bridge import (quantize_activation_nvfp4,
|
||||
run_nvfp4_grouped_gemm)
|
||||
from cutedsl.kernel.moe.torch_scaled_grouped_mm import pad_and_swizzle_single
|
||||
from cutedsl.nvfp4_linear import cutedsl_ceil_div
|
||||
@staticmethod
|
||||
def forward(ctx, x, runner):
|
||||
return runner._run_impl(x)
|
||||
|
||||
num_tokens = x.shape[0]
|
||||
in_features = x.shape[1]
|
||||
out_features = mat_b.shape[2] # packed N in float4 elements
|
||||
|
||||
# Quantize activation
|
||||
x_fp4, x_sf = quantize_activation_nvfp4(x, activation_global_scale)
|
||||
|
||||
# Pad activation to 128-row alignment for TMA
|
||||
padded_rows = cutedsl_ceil_div(num_tokens, 128) * 128
|
||||
if num_tokens < padded_rows:
|
||||
x_fp4_padded = torch.zeros(padded_rows, x_fp4.shape[1],
|
||||
dtype=torch.uint8, device=x.device
|
||||
).view(torch.float4_e2m1fn_x2)
|
||||
x_fp4_padded[:num_tokens] = x_fp4
|
||||
else:
|
||||
x_fp4_padded = x_fp4
|
||||
|
||||
# Assemble A-side scales: pad + swizzle for single group
|
||||
num_rows, num_cols = x_sf.shape
|
||||
padded_rows_sf = cutedsl_ceil_div(num_rows, 128) * 128
|
||||
padded_cols_sf = cutedsl_ceil_div(num_cols, 4) * 4
|
||||
buf = torch.zeros(padded_rows_sf, padded_cols_sf, dtype=torch.float8_e4m3fn, device=x_sf.device)
|
||||
buf[:num_rows, :num_cols] = x_sf
|
||||
scale_a = pad_and_swizzle_single(buf).unsqueeze(0) # (1, padded_rows, padded_cols)
|
||||
|
||||
# Expert offsets for 1 group (int32 — CuTeDSL expects int32, not int64)
|
||||
expert_offsets = torch.tensor([padded_rows], dtype=torch.int32, device=x.device)
|
||||
|
||||
# Global scale for activation
|
||||
global_scale_a = torch.tensor([activation_global_scale], dtype=torch.float32, device=x.device)
|
||||
|
||||
# Run the CuTeDSL grouped GEMM (1 group)
|
||||
out = run_nvfp4_grouped_gemm(
|
||||
mat_a=x_fp4_padded,
|
||||
mat_b=mat_b,
|
||||
scale_a=scale_a,
|
||||
scale_b=scale_b,
|
||||
expert_offsets=expert_offsets,
|
||||
global_scale_a=global_scale_a,
|
||||
global_scale_b=global_scale_b,
|
||||
)
|
||||
|
||||
return out[:num_tokens]
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
raise NotImplementedError(
|
||||
"CuTeDSL NVFP4 linear does not support backward")
|
||||
|
||||
|
||||
@_cutedsl_nvfp4_linear.register_fake
|
||||
def _cutedsl_nvfp4_linear_fake(
|
||||
x: torch.Tensor,
|
||||
mat_b: torch.Tensor,
|
||||
scale_b: torch.Tensor,
|
||||
global_scale_b: torch.Tensor,
|
||||
activation_global_scale: float,
|
||||
) -> torch.Tensor:
|
||||
out_features = mat_b.shape[2]
|
||||
return torch.empty((*x.shape[:-1], out_features), dtype=torch.bfloat16,
|
||||
device=x.device)
|
||||
# Tell Dynamo: this autograd function is an opaque op, don't trace inside.
|
||||
torch._dynamo.allow_in_graph(_CuTeDSLNvfp4LinearFn)
|
||||
|
||||
|
||||
class CuTeDSLNvFp4LinearKernel(NvFp4LinearKernel):
|
||||
@@ -122,7 +59,7 @@ class CuTeDSLNvFp4LinearKernel(NvFp4LinearKernel):
|
||||
return False, "CuTeDSL NVFP4 requires SM100+ (Blackwell)"
|
||||
|
||||
@classmethod
|
||||
def can_implement(cls, config: NvFp4LinearLayerConfig) -> tuple[ bool, str | None]:
|
||||
def can_implement(cls, config: NvFp4LinearLayerConfig) -> tuple[bool, str | None]:
|
||||
return True, None
|
||||
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
@@ -174,20 +111,18 @@ class CuTeDSLNvFp4LinearKernel(NvFp4LinearKernel):
|
||||
runner.finalize_weights()
|
||||
|
||||
# Compute activation global scale from input_global_scale_inv.
|
||||
# quantize_activation_nvfp4(x, global_scale) normalizes:
|
||||
# x_norm = x / global_scale
|
||||
# global_scale = amax/448 = input_global_scale = 1/inv.
|
||||
activation_global_scale = 1.0 / 2688.0 # default fallback
|
||||
if hasattr(layer, 'input_global_scale_inv') and layer.input_global_scale_inv is not None:
|
||||
inv = layer.input_global_scale_inv.data.item()
|
||||
if inv != 0:
|
||||
activation_global_scale = 1.0 / inv
|
||||
runner._activation_global_scale = activation_global_scale
|
||||
|
||||
# Store pre-assembled weight tensors on the layer for the custom op.
|
||||
# mat_b shape: (1, K_padded, N_packed) float4_e2m1fn_x2
|
||||
# scale_b shape: (1, K_sf_padded, N_sf_packed) fp8
|
||||
# gsb shape: (1,) float32
|
||||
layer._cutedsl_mat_b = runner._mat_b
|
||||
layer._cutedsl_scale_b = runner._scale_b
|
||||
layer._cutedsl_global_scale_b = runner._gsb
|
||||
layer._cutedsl_activation_global_scale = activation_global_scale
|
||||
# Store runner on the layer
|
||||
layer._cutedsl_runner = runner
|
||||
|
||||
# Replace weight with dummy BF16 (vLLM module introspection may need it)
|
||||
layer.weight = torch.nn.Parameter(
|
||||
@@ -196,7 +131,7 @@ class CuTeDSLNvFp4LinearKernel(NvFp4LinearKernel):
|
||||
requires_grad=False,
|
||||
)
|
||||
|
||||
# Clean up NVFP4 params that are now handled by our custom op.
|
||||
# Clean up NVFP4 params that are now handled by the runner.
|
||||
for attr in ("weight_scale", "weight_global_scale",
|
||||
"input_global_scale", "input_global_scale_inv",
|
||||
"alpha", "weights_padding_cols", "weight_scale_2",
|
||||
@@ -213,13 +148,7 @@ class CuTeDSLNvFp4LinearKernel(NvFp4LinearKernel):
|
||||
x: torch.Tensor,
|
||||
bias: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
result = torch.ops.cutedsl.nvfp4_linear(
|
||||
x,
|
||||
layer._cutedsl_mat_b,
|
||||
layer._cutedsl_scale_b,
|
||||
layer._cutedsl_global_scale_b,
|
||||
layer._cutedsl_activation_global_scale,
|
||||
)
|
||||
result = _CuTeDSLNvfp4LinearFn.apply(x, layer._cutedsl_runner)
|
||||
if bias is not None:
|
||||
result = result + bias
|
||||
return result
|
||||
|
||||
Reference in New Issue
Block a user