2026-05-18 20:14:03 +00:00
|
|
|
"""CuTeDSL NVFP4 Linear (single GEMM)
|
|
|
|
|
|
|
|
|
|
Generic NVFP4 GEMM runner for attention projections and any single
|
|
|
|
|
linear layer. Uses ScaledGroupedGemmKernel with num_groups=1.
|
|
|
|
|
|
|
|
|
|
CUDA-graph-compatible: all buffers pre-allocated, no CPU-GPU syncs.
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
import torch
|
|
|
|
|
|
|
|
|
|
from cutedsl.bridge import (
|
|
|
|
|
quantize_activation_nvfp4,
|
|
|
|
|
quantize_to_nvfp4,
|
|
|
|
|
make_b_k_major,
|
|
|
|
|
assemble_scales_3d_side,
|
|
|
|
|
run_nvfp4_grouped_gemm,
|
|
|
|
|
)
|
|
|
|
|
from cutedsl.kernel.moe.torch_scaled_grouped_mm import (
|
|
|
|
|
ceil_div as cutedsl_ceil_div,
|
|
|
|
|
pad_and_swizzle_single,
|
|
|
|
|
)
|
Replace autograd.Function with torch.library.custom_op for Dynamo compat
Dynamo (torch.compile fullgraph) cannot trace through CuTeDSL internals
(cute.compile, JIT, etc.). The autograd.Function approach was unreliable
with fullgraph mode — Dynamo would still try to trace through it.
Fix: torch.library.custom_op makes Dynamo treat our GEMM as an opaque
black box. No reimplementing the kernel — just route through the existing
runner via a registry pattern:
- Runners registered in global dict with integer IDs
- Custom op takes (tensors, runner_id, shape_hint) -> tensor
- Dynamo calls fake impl for shape inference, never touches the runner
- At execution time, real impl looks up runner and calls _run_impl
Changes:
- New: cutedsl/custom_ops.py (custom op definitions + registry)
- New: tests/test_custom_op.py (local unit tests, no GPU needed)
- Removed: _Nvfp4LinearApply, _MoEApply (autograd.Function classes)
- Updated: nvfp4_linear.py, runner.py, cutedsl.py, nvfp4_cutedsl.py
to use custom ops instead of autograd.Function
- Updated: cutedsl_quant_method.py to use custom op + registry
2026-05-19 01:54:48 +00:00
|
|
|
from cutedsl.custom_ops import register_runner, nvfp4_linear_gemm
|
2026-05-18 21:38:28 +00:00
|
|
|
|
|
|
|
|
|
2026-05-18 20:14:03 +00:00
|
|
|
class CuTeDSLNvfp4Linear:
|
|
|
|
|
"""Single NVFP4 GEMM using CuTeDSL (num_groups=1).
|
|
|
|
|
|
|
|
|
|
Handles any (K, N) weight matrix in NVFP4 format.
|
|
|
|
|
Simple: quantize activation → GEMM → BF16 output.
|
|
|
|
|
No SiLU, no fusion, no routing.
|
|
|
|
|
|
|
|
|
|
CUDA-graph-compatible: all buffers pre-allocated, no CPU-GPU syncs.
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
def __init__(
|
|
|
|
|
self,
|
|
|
|
|
in_features: int,
|
|
|
|
|
out_features: int,
|
|
|
|
|
max_num_tokens: int = 8192,
|
|
|
|
|
device: str = "cuda",
|
|
|
|
|
):
|
|
|
|
|
self.in_features = in_features
|
|
|
|
|
self.out_features = out_features
|
|
|
|
|
self.max_num_tokens = max_num_tokens
|
|
|
|
|
self.device = device
|
|
|
|
|
|
|
|
|
|
# Weights (set after construction, then call finalize_weights)
|
|
|
|
|
self.fp4 = None # list of 1 tensor
|
|
|
|
|
self.sf = None # list of 1 tensor
|
|
|
|
|
self.gs = None # list of 1 float
|
|
|
|
|
|
|
|
|
|
# Processed weights
|
|
|
|
|
self._mat_b = None
|
|
|
|
|
self._scale_b = None
|
|
|
|
|
self._gsb = None
|
|
|
|
|
|
|
|
|
|
# Activation global scale
|
|
|
|
|
self._activation_global_scale = 1.0 / (6.0 * 448.0)
|
|
|
|
|
|
|
|
|
|
# Pre-allocated buffers
|
|
|
|
|
self._padded_x_fp4_buf = None
|
|
|
|
|
self._expert_offsets_buf = None
|
|
|
|
|
self._gsa_buf = None
|
|
|
|
|
self._buffers_allocated = False
|
|
|
|
|
|
|
|
|
|
def finalize_weights(self):
|
|
|
|
|
"""Process weights for CuTeDSL GEMM."""
|
|
|
|
|
self._mat_b = make_b_k_major(torch.stack(self.fp4)) # (1, K_packed, N_packed)
|
|
|
|
|
self._scale_b = assemble_scales_3d_side(self.sf)
|
|
|
|
|
self._gsb = torch.tensor(self.gs, dtype=torch.float32, device=self.device)
|
|
|
|
|
|
|
|
|
|
# Free raw weights
|
|
|
|
|
self.fp4 = None
|
|
|
|
|
self.sf = None
|
|
|
|
|
self.gs = None
|
|
|
|
|
|
|
|
|
|
def _allocate_buffers(self):
|
|
|
|
|
"""Pre-allocate buffers at max size for cudagraph compatibility."""
|
|
|
|
|
max_rows = cutedsl_ceil_div(self.max_num_tokens, 128) * 128
|
|
|
|
|
|
|
|
|
|
self._padded_x_fp4_buf = torch.zeros(
|
|
|
|
|
max_rows, self.in_features // 2, dtype=torch.uint8, device=self.device
|
|
|
|
|
).view(torch.float4_e2m1fn_x2)
|
|
|
|
|
|
|
|
|
|
self._expert_offsets_buf = torch.zeros(1, dtype=torch.int32, device=self.device)
|
|
|
|
|
self._gsa_buf = torch.zeros(1, dtype=torch.float32, device=self.device)
|
|
|
|
|
self._buffers_allocated = True
|
|
|
|
|
|
|
|
|
|
def _ensure_initialized(self):
|
|
|
|
|
if self._mat_b is None:
|
|
|
|
|
self.finalize_weights()
|
|
|
|
|
if not self._buffers_allocated:
|
|
|
|
|
self._allocate_buffers()
|
|
|
|
|
|
|
|
|
|
def _assemble_scales_single_group(self, x_sf):
|
|
|
|
|
"""Assemble 2D-side activation scales for num_groups=1."""
|
|
|
|
|
num_rows, num_cols = x_sf.shape
|
|
|
|
|
padded_rows = cutedsl_ceil_div(num_rows, 128) * 128
|
|
|
|
|
padded_cols = cutedsl_ceil_div(num_cols, 4) * 4
|
|
|
|
|
|
|
|
|
|
buf = torch.zeros(padded_rows, padded_cols, dtype=torch.float16, device=x_sf.device).to(torch.float8_e4m3fn)
|
|
|
|
|
buf[:num_rows, :num_cols] = x_sf
|
|
|
|
|
swizzled_flat = pad_and_swizzle_single(buf)
|
|
|
|
|
return swizzled_flat.reshape(padded_rows, padded_cols)
|
|
|
|
|
|
|
|
|
|
def compute_activation_global_scale(self, hidden_states_sample):
|
|
|
|
|
"""Compute activation global scale from a warmup forward."""
|
|
|
|
|
self._ensure_initialized()
|
|
|
|
|
with torch.no_grad():
|
|
|
|
|
_, _, gs = quantize_to_nvfp4(hidden_states_sample)
|
|
|
|
|
self._activation_global_scale = gs
|
2026-05-18 20:54:55 +00:00
|
|
|
|
2026-05-18 20:14:03 +00:00
|
|
|
|
|
|
|
|
def run(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
Replace autograd.Function with torch.library.custom_op for Dynamo compat
Dynamo (torch.compile fullgraph) cannot trace through CuTeDSL internals
(cute.compile, JIT, etc.). The autograd.Function approach was unreliable
with fullgraph mode — Dynamo would still try to trace through it.
Fix: torch.library.custom_op makes Dynamo treat our GEMM as an opaque
black box. No reimplementing the kernel — just route through the existing
runner via a registry pattern:
- Runners registered in global dict with integer IDs
- Custom op takes (tensors, runner_id, shape_hint) -> tensor
- Dynamo calls fake impl for shape inference, never touches the runner
- At execution time, real impl looks up runner and calls _run_impl
Changes:
- New: cutedsl/custom_ops.py (custom op definitions + registry)
- New: tests/test_custom_op.py (local unit tests, no GPU needed)
- Removed: _Nvfp4LinearApply, _MoEApply (autograd.Function classes)
- Updated: nvfp4_linear.py, runner.py, cutedsl.py, nvfp4_cutedsl.py
to use custom ops instead of autograd.Function
- Updated: cutedsl_quant_method.py to use custom op + registry
2026-05-19 01:54:48 +00:00
|
|
|
"""Forward: BF16 input → NVFP4 GEMM → BF16 output.
|
|
|
|
|
|
|
|
|
|
Uses torch.library.custom_op (nvfp4::linear_gemm) so torch.compile
|
|
|
|
|
treats this as an opaque op. The custom op calls _run_impl internally.
|
|
|
|
|
"""
|
|
|
|
|
if not hasattr(self, '_runner_id'):
|
|
|
|
|
self._runner_id = register_runner(self)
|
|
|
|
|
return nvfp4_linear_gemm(
|
|
|
|
|
hidden_states, self._runner_id, self.out_features,
|
|
|
|
|
)
|
2026-05-18 21:38:28 +00:00
|
|
|
|
|
|
|
|
def _run_impl(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
|
|
|
|
"""Actual implementation — called via custom autograd to be torch.compile-safe."""
|
2026-05-18 20:14:03 +00:00
|
|
|
self._ensure_initialized()
|
|
|
|
|
|
|
|
|
|
num_tokens = hidden_states.shape[0]
|
|
|
|
|
padded_rows = cutedsl_ceil_div(num_tokens, 128) * 128
|
|
|
|
|
|
|
|
|
|
# Quantize activation
|
|
|
|
|
x_fp4, x_sf = quantize_activation_nvfp4(
|
|
|
|
|
hidden_states, self._activation_global_scale
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# Scatter x_fp4 into padded buffer
|
|
|
|
|
padded_x_fp4 = self._padded_x_fp4_buf
|
|
|
|
|
padded_x_fp4.view(torch.uint8).zero_()
|
|
|
|
|
padded_x_fp4.view(torch.uint8)[:num_tokens] = x_fp4.view(torch.uint8)
|
|
|
|
|
|
|
|
|
|
# Assemble A-side scales
|
|
|
|
|
scale_a = self._assemble_scales_single_group(x_sf)
|
|
|
|
|
|
|
|
|
|
# Expert offsets: [padded_rows] for 1 group
|
|
|
|
|
expert_offsets = self._expert_offsets_buf
|
|
|
|
|
expert_offsets.fill_(padded_rows)
|
|
|
|
|
|
|
|
|
|
# Global scales
|
|
|
|
|
gsa = self._gsa_buf.fill_(self._activation_global_scale)
|
|
|
|
|
|
|
|
|
|
# Run GEMM
|
|
|
|
|
out = run_nvfp4_grouped_gemm(
|
|
|
|
|
mat_a=padded_x_fp4,
|
|
|
|
|
mat_b=self._mat_b,
|
|
|
|
|
scale_a=scale_a,
|
|
|
|
|
scale_b=self._scale_b,
|
|
|
|
|
expert_offsets=expert_offsets,
|
|
|
|
|
global_scale_a=gsa,
|
|
|
|
|
global_scale_b=self._gsb,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
return out[:num_tokens]
|
|
|
|
|
|
|
|
|
|
def __call__(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
|
|
|
|
return self.run(hidden_states)
|