Add NVFP4 CuTeDSL compilation test (verify MmaMXF4NVF4Op compiles)

This commit is contained in:
2026-06-01 07:53:43 +00:00
parent fa6dbd4aa2
commit 5ea71ebd78

View File

@@ -0,0 +1,165 @@
"""Test: Verify NVFP4 CuTeDSL compilation with MmaMXF4NVF4Op (sf_vec_size=16).
This test does NOT run the kernel — it only verifies that the CuTeDSL JIT
compiler can handle the NVF4 block-scaled GEMM with proper pipeline abstractions.
If this compiles, we can add the custom epilogue.
"""
import torch
import cutlass
import cutlass.cute as cute
from cutlass.cute.nvgpu import cpasync, tcgen05
import cutlass.utils as utils
import cutlass.pipeline as pipeline
import cutlass.utils.blackwell_helpers as sm100_utils
import cutlass.utils.blockscaled_layout as blockscaled_utils
import cutlass.torch as cutlass_torch
from dsv4.ops.quantize import quantize_weight_to_nvfp4, quantize_activation_nvfp4
from dsv4.ops.layouts import make_b_k_major, assemble_raw_scales_2d3d_3d_side
def test_nvfp4_cutedsl_compilation():
"""Test that NVFP4 block-scaled GEMM compiles with CuTeDSL."""
device = "cuda:0"
M, N, K = 1, 384, 7168
top_k = 6
# Quantize
gsa = 1.0 / (6.0 * 448.0)
hs = torch.randn(M, K, dtype=torch.bfloat16, device=device)
x_fp4, x_sf = quantize_activation_nvfp4(hs, gsa)
W = torch.randn(K, N, dtype=torch.bfloat16, device=device)
w_fp4, w_sf, w_gs = quantize_weight_to_nvfp4(W)
stacked = torch.stack([w_fp4]).permute(0, 2, 1).contiguous()
mat_b = make_b_k_major(stacked)
scale_b = assemble_raw_scales_2d3d_3d_side([w_sf.T.contiguous()])
print(f"x_fp4: {x_fp4.shape}, dtype={x_fp4.dtype}")
print(f"x_sf: {x_sf.shape}, dtype={x_sf.dtype}")
print(f"mat_b: {mat_b.shape}, dtype={mat_b.dtype}")
print(f"scale_b: {scale_b.shape}, dtype={scale_b.dtype}")
# Convert to CuTe tensors
a_tensor = cutlass_torch.from_dlpack(x_fp4)
a_tensor = a_tensor.mark_layout_dynamic(leading_dim=cutlass_torch.get_leading_dim(x_fp4))
b_tensor = cutlass_torch.from_dlpack(mat_b)
b_tensor = b_tensor.mark_layout_dynamic(leading_dim=cutlass_torch.get_leading_dim(mat_b))
sfa_tensor = cutlass_torch.from_dlpack(x_sf)
sfa_tensor = sfa_tensor.mark_layout_dynamic(leading_dim=cutlass_torch.get_leading_dim(x_sf))
sfb_tensor = cutlass_torch.from_dlpack(scale_b)
sfb_tensor = sfb_tensor.mark_layout_dynamic(leading_dim=cutlass_torch.get_leading_dim(scale_b))
c_tensor = cutlass_torch.make_tensor(
torch.empty(M, N, dtype=torch.bfloat16, device=device))
print("CuTe tensors created OK")
# ---- Setup exactly like dense.py ----
sf_vec_size = 16 # NVF4
a_dtype = cutlass.Float4E2M1FN
b_dtype = cutlass.Float4E2M1FN
sf_dtype = cutlass.Float8E4M3FN
c_dtype = cutlass.BFloat16
mma_tiler_mn = (128, 128)
cluster_shape_mn = (1, 1)
use_2cta = False
cta_group = tcgen05.CtaGroup.ONE
a_major = utils.LayoutEnum.from_tensor(a_tensor).mma_major_mode()
b_major = utils.LayoutEnum.from_tensor(b_tensor).mma_major_mode()
mma_inst_shape_mn_sfb = (
mma_tiler_mn[0] // (2 if use_2cta else 1),
cute.round_up(mma_tiler_mn[1], 128),
)
print(f"Creating tiled_mma with sf_vec_size={sf_vec_size}...", flush=True)
tiled_mma = sm100_utils.make_blockscaled_trivial_tiled_mma(
a_dtype, a_major, b_major, sf_dtype, sf_vec_size,
cta_group, mma_tiler_mn)
print(f"tiled_mma OK: shape_mnk={tiled_mma.shape_mnk}", flush=True)
tiled_mma_sfb = sm100_utils.make_blockscaled_trivial_tiled_mma(
a_dtype, a_major, b_major, sf_dtype, sf_vec_size,
tcgen05.CtaGroup.ONE, mma_inst_shape_mn_sfb)
print(f"tiled_mma_sfb OK", flush=True)
# MMA tiler
inst_shape_k = cute.size(tiled_mma.shape_mnk, mode=[2])
inst_tile_k = 4
k_tile = inst_shape_k * inst_tile_k
mma_tiler = (cutlass.Int32(mma_tiler_mn[0]),
cutlass.Int32(mma_tiler_mn[1]),
cutlass.Int32(k_tile))
cta_tile_shape_mnk = (
mma_tiler[0] // cute.size(tiled_mma.thr_id.shape),
mma_tiler[1],
mma_tiler[2],
)
cluster_layout_vmnk = cute.tiled_divide(
cute.make_layout((*cluster_shape_mn, 1)),
(tiled_mma.thr_id.shape,))
# SMEM layouts
num_ab_stages = 2
print("Creating SMEM layouts...", flush=True)
a_smem_staged = sm100_utils.make_smem_layout_a(tiled_mma, mma_tiler, a_dtype, num_ab_stages)
b_smem_staged = sm100_utils.make_smem_layout_b(tiled_mma, mma_tiler, b_dtype, num_ab_stages)
sfa_smem_staged = blockscaled_utils.make_smem_layout_sfa(tiled_mma, mma_tiler, sf_vec_size, num_ab_stages)
sfb_smem_staged = blockscaled_utils.make_smem_layout_sfb(tiled_mma, mma_tiler, sf_vec_size, num_ab_stages)
print("SMEM layouts OK", flush=True)
# TMA
a_smem0 = cute.slice_(a_smem_staged, (None, None, None, 0))
b_smem0 = cute.slice_(b_smem_staged, (None, None, None, 0))
sfa_smem0 = cute.slice_(sfa_smem_staged, (None, None, None, 0))
sfb_smem0 = cute.slice_(sfb_smem_staged, (None, None, None, 0))
print("Creating TMA atoms...", flush=True)
a_op = sm100_utils.cluster_shape_to_tma_atom_A(cluster_shape_mn, tiled_mma.thr_id)
tma_a, gA = cute.nvgpu.make_tiled_tma_atom_A(a_op, a_tensor, a_smem0, mma_tiler, tiled_mma, cluster_layout_vmnk.shape)
print("TMA A OK", flush=True)
b_op = sm100_utils.cluster_shape_to_tma_atom_B(cluster_shape_mn, tiled_mma.thr_id)
tma_b, gB = cute.nvgpu.make_tiled_tma_atom_B(b_op, b_tensor, b_smem0, mma_tiler, tiled_mma, cluster_layout_vmnk.shape)
print("TMA B OK", flush=True)
tma_sfa, gSFA = cute.nvgpu.make_tiled_tma_atom_A(
a_op, sfa_tensor, sfa_smem0, mma_tiler, tiled_mma,
cluster_layout_vmnk.shape, internal_type=cutlass.Int16)
print("TMA SFA OK", flush=True)
mma_tiler_sfb = (cutlass.Int32(mma_inst_shape_mn_sfb[0]),
cutlass.Int32(mma_inst_shape_mn_sfb[1]),
cutlass.Int32(k_tile))
cluster_layout_sfb_vmnk = cute.tiled_divide(
cute.make_layout((*cluster_shape_mn, 1)),
(tiled_mma_sfb.thr_id.shape,))
sfb_op = sm100_utils.cluster_shape_to_tma_atom_SFB(cluster_shape_mn, tiled_mma.thr_id)
tma_sfb, gSFB = cute.nvgpu.make_tiled_tma_atom_B(
sfb_op, sfb_tensor, sfb_smem0, mma_tiler_sfb, tiled_mma_sfb,
cluster_layout_sfb_vmnk.shape, internal_type=cutlass.Int16)
print("TMA SFB OK", flush=True)
# Now try compiling the dense GEMM kernel (no custom epilogue)
print("Compiling dense_blockscaled GEMM with NVF4...", flush=True)
kernel = sm100_utils.Sm100BlockScaledPersistentDenseGemmKernel(
a_tensor, b_tensor, c_tensor, sfa_tensor, sfb_tensor,
acc_dtype=cutlass.Float32,
mma_tiler_mn=mma_tiler_mn,
cluster_shape_mn=cluster_shape_mn,
sf_vec_size=sf_vec_size,
)
print("COMPILATION SUCCEEDED! NVF4 CuTeDSL path works.", flush=True)
if __name__ == "__main__":
test_nvfp4_cutedsl_compilation()