Files
nvfp4-megamoe-kernel/cutedsl/nvfp4_linear.py
biondizzle 35fab6cff3 Replace autograd.Function with torch.library.custom_op for Dynamo compat
Dynamo (torch.compile fullgraph) cannot trace through CuTeDSL internals
(cute.compile, JIT, etc.). The autograd.Function approach was unreliable
with fullgraph mode — Dynamo would still try to trace through it.

Fix: torch.library.custom_op makes Dynamo treat our GEMM as an opaque
black box. No reimplementing the kernel — just route through the existing
runner via a registry pattern:
  - Runners registered in global dict with integer IDs
  - Custom op takes (tensors, runner_id, shape_hint) -> tensor
  - Dynamo calls fake impl for shape inference, never touches the runner
  - At execution time, real impl looks up runner and calls _run_impl

Changes:
  - New: cutedsl/custom_ops.py (custom op definitions + registry)
  - New: tests/test_custom_op.py (local unit tests, no GPU needed)
  - Removed: _Nvfp4LinearApply, _MoEApply (autograd.Function classes)
  - Updated: nvfp4_linear.py, runner.py, cutedsl.py, nvfp4_cutedsl.py
    to use custom ops instead of autograd.Function
  - Updated: cutedsl_quant_method.py to use custom op + registry
2026-05-19 01:54:48 +00:00

168 lines
5.6 KiB
Python

"""CuTeDSL NVFP4 Linear (single GEMM)
Generic NVFP4 GEMM runner for attention projections and any single
linear layer. Uses ScaledGroupedGemmKernel with num_groups=1.
CUDA-graph-compatible: all buffers pre-allocated, no CPU-GPU syncs.
"""
import torch
from cutedsl.bridge import (
quantize_activation_nvfp4,
quantize_to_nvfp4,
make_b_k_major,
assemble_scales_3d_side,
run_nvfp4_grouped_gemm,
)
from cutedsl.kernel.moe.torch_scaled_grouped_mm import (
ceil_div as cutedsl_ceil_div,
pad_and_swizzle_single,
)
from cutedsl.custom_ops import register_runner, nvfp4_linear_gemm
class CuTeDSLNvfp4Linear:
"""Single NVFP4 GEMM using CuTeDSL (num_groups=1).
Handles any (K, N) weight matrix in NVFP4 format.
Simple: quantize activation → GEMM → BF16 output.
No SiLU, no fusion, no routing.
CUDA-graph-compatible: all buffers pre-allocated, no CPU-GPU syncs.
"""
def __init__(
self,
in_features: int,
out_features: int,
max_num_tokens: int = 8192,
device: str = "cuda",
):
self.in_features = in_features
self.out_features = out_features
self.max_num_tokens = max_num_tokens
self.device = device
# Weights (set after construction, then call finalize_weights)
self.fp4 = None # list of 1 tensor
self.sf = None # list of 1 tensor
self.gs = None # list of 1 float
# Processed weights
self._mat_b = None
self._scale_b = None
self._gsb = None
# Activation global scale
self._activation_global_scale = 1.0 / (6.0 * 448.0)
# Pre-allocated buffers
self._padded_x_fp4_buf = None
self._expert_offsets_buf = None
self._gsa_buf = None
self._buffers_allocated = False
def finalize_weights(self):
"""Process weights for CuTeDSL GEMM."""
self._mat_b = make_b_k_major(torch.stack(self.fp4)) # (1, K_packed, N_packed)
self._scale_b = assemble_scales_3d_side(self.sf)
self._gsb = torch.tensor(self.gs, dtype=torch.float32, device=self.device)
# Free raw weights
self.fp4 = None
self.sf = None
self.gs = None
def _allocate_buffers(self):
"""Pre-allocate buffers at max size for cudagraph compatibility."""
max_rows = cutedsl_ceil_div(self.max_num_tokens, 128) * 128
self._padded_x_fp4_buf = torch.zeros(
max_rows, self.in_features // 2, dtype=torch.uint8, device=self.device
).view(torch.float4_e2m1fn_x2)
self._expert_offsets_buf = torch.zeros(1, dtype=torch.int32, device=self.device)
self._gsa_buf = torch.zeros(1, dtype=torch.float32, device=self.device)
self._buffers_allocated = True
def _ensure_initialized(self):
if self._mat_b is None:
self.finalize_weights()
if not self._buffers_allocated:
self._allocate_buffers()
def _assemble_scales_single_group(self, x_sf):
"""Assemble 2D-side activation scales for num_groups=1."""
num_rows, num_cols = x_sf.shape
padded_rows = cutedsl_ceil_div(num_rows, 128) * 128
padded_cols = cutedsl_ceil_div(num_cols, 4) * 4
buf = torch.zeros(padded_rows, padded_cols, dtype=torch.float16, device=x_sf.device).to(torch.float8_e4m3fn)
buf[:num_rows, :num_cols] = x_sf
swizzled_flat = pad_and_swizzle_single(buf)
return swizzled_flat.reshape(padded_rows, padded_cols)
def compute_activation_global_scale(self, hidden_states_sample):
"""Compute activation global scale from a warmup forward."""
self._ensure_initialized()
with torch.no_grad():
_, _, gs = quantize_to_nvfp4(hidden_states_sample)
self._activation_global_scale = gs
def run(self, hidden_states: torch.Tensor) -> torch.Tensor:
"""Forward: BF16 input → NVFP4 GEMM → BF16 output.
Uses torch.library.custom_op (nvfp4::linear_gemm) so torch.compile
treats this as an opaque op. The custom op calls _run_impl internally.
"""
if not hasattr(self, '_runner_id'):
self._runner_id = register_runner(self)
return nvfp4_linear_gemm(
hidden_states, self._runner_id, self.out_features,
)
def _run_impl(self, hidden_states: torch.Tensor) -> torch.Tensor:
"""Actual implementation — called via custom autograd to be torch.compile-safe."""
self._ensure_initialized()
num_tokens = hidden_states.shape[0]
padded_rows = cutedsl_ceil_div(num_tokens, 128) * 128
# Quantize activation
x_fp4, x_sf = quantize_activation_nvfp4(
hidden_states, self._activation_global_scale
)
# Scatter x_fp4 into padded buffer
padded_x_fp4 = self._padded_x_fp4_buf
padded_x_fp4.view(torch.uint8).zero_()
padded_x_fp4.view(torch.uint8)[:num_tokens] = x_fp4.view(torch.uint8)
# Assemble A-side scales
scale_a = self._assemble_scales_single_group(x_sf)
# Expert offsets: [padded_rows] for 1 group
expert_offsets = self._expert_offsets_buf
expert_offsets.fill_(padded_rows)
# Global scales
gsa = self._gsa_buf.fill_(self._activation_global_scale)
# Run GEMM
out = run_nvfp4_grouped_gemm(
mat_a=padded_x_fp4,
mat_b=self._mat_b,
scale_a=scale_a,
scale_b=self._scale_b,
expert_offsets=expert_offsets,
global_scale_a=gsa,
global_scale_b=self._gsb,
)
return out[:num_tokens]
def __call__(self, hidden_states: torch.Tensor) -> torch.Tensor:
return self.run(hidden_states)