Add NVFP4 CuTeDSL compilation test (verify MmaMXF4NVF4Op compiles)
This commit is contained in:
165
tests/unit/test_nvfp4_cutedsl_compile.py
Normal file
165
tests/unit/test_nvfp4_cutedsl_compile.py
Normal file
@@ -0,0 +1,165 @@
|
||||
"""Test: Verify NVFP4 CuTeDSL compilation with MmaMXF4NVF4Op (sf_vec_size=16).
|
||||
|
||||
This test does NOT run the kernel — it only verifies that the CuTeDSL JIT
|
||||
compiler can handle the NVF4 block-scaled GEMM with proper pipeline abstractions.
|
||||
If this compiles, we can add the custom epilogue.
|
||||
"""
|
||||
|
||||
import torch
|
||||
import cutlass
|
||||
import cutlass.cute as cute
|
||||
from cutlass.cute.nvgpu import cpasync, tcgen05
|
||||
import cutlass.utils as utils
|
||||
import cutlass.pipeline as pipeline
|
||||
import cutlass.utils.blackwell_helpers as sm100_utils
|
||||
import cutlass.utils.blockscaled_layout as blockscaled_utils
|
||||
import cutlass.torch as cutlass_torch
|
||||
|
||||
from dsv4.ops.quantize import quantize_weight_to_nvfp4, quantize_activation_nvfp4
|
||||
from dsv4.ops.layouts import make_b_k_major, assemble_raw_scales_2d3d_3d_side
|
||||
|
||||
|
||||
def test_nvfp4_cutedsl_compilation():
|
||||
"""Test that NVFP4 block-scaled GEMM compiles with CuTeDSL."""
|
||||
device = "cuda:0"
|
||||
M, N, K = 1, 384, 7168
|
||||
top_k = 6
|
||||
|
||||
# Quantize
|
||||
gsa = 1.0 / (6.0 * 448.0)
|
||||
hs = torch.randn(M, K, dtype=torch.bfloat16, device=device)
|
||||
x_fp4, x_sf = quantize_activation_nvfp4(hs, gsa)
|
||||
|
||||
W = torch.randn(K, N, dtype=torch.bfloat16, device=device)
|
||||
w_fp4, w_sf, w_gs = quantize_weight_to_nvfp4(W)
|
||||
stacked = torch.stack([w_fp4]).permute(0, 2, 1).contiguous()
|
||||
mat_b = make_b_k_major(stacked)
|
||||
scale_b = assemble_raw_scales_2d3d_3d_side([w_sf.T.contiguous()])
|
||||
|
||||
print(f"x_fp4: {x_fp4.shape}, dtype={x_fp4.dtype}")
|
||||
print(f"x_sf: {x_sf.shape}, dtype={x_sf.dtype}")
|
||||
print(f"mat_b: {mat_b.shape}, dtype={mat_b.dtype}")
|
||||
print(f"scale_b: {scale_b.shape}, dtype={scale_b.dtype}")
|
||||
|
||||
# Convert to CuTe tensors
|
||||
a_tensor = cutlass_torch.from_dlpack(x_fp4)
|
||||
a_tensor = a_tensor.mark_layout_dynamic(leading_dim=cutlass_torch.get_leading_dim(x_fp4))
|
||||
|
||||
b_tensor = cutlass_torch.from_dlpack(mat_b)
|
||||
b_tensor = b_tensor.mark_layout_dynamic(leading_dim=cutlass_torch.get_leading_dim(mat_b))
|
||||
|
||||
sfa_tensor = cutlass_torch.from_dlpack(x_sf)
|
||||
sfa_tensor = sfa_tensor.mark_layout_dynamic(leading_dim=cutlass_torch.get_leading_dim(x_sf))
|
||||
|
||||
sfb_tensor = cutlass_torch.from_dlpack(scale_b)
|
||||
sfb_tensor = sfb_tensor.mark_layout_dynamic(leading_dim=cutlass_torch.get_leading_dim(scale_b))
|
||||
|
||||
c_tensor = cutlass_torch.make_tensor(
|
||||
torch.empty(M, N, dtype=torch.bfloat16, device=device))
|
||||
|
||||
print("CuTe tensors created OK")
|
||||
|
||||
# ---- Setup exactly like dense.py ----
|
||||
sf_vec_size = 16 # NVF4
|
||||
a_dtype = cutlass.Float4E2M1FN
|
||||
b_dtype = cutlass.Float4E2M1FN
|
||||
sf_dtype = cutlass.Float8E4M3FN
|
||||
c_dtype = cutlass.BFloat16
|
||||
|
||||
mma_tiler_mn = (128, 128)
|
||||
cluster_shape_mn = (1, 1)
|
||||
use_2cta = False
|
||||
cta_group = tcgen05.CtaGroup.ONE
|
||||
|
||||
a_major = utils.LayoutEnum.from_tensor(a_tensor).mma_major_mode()
|
||||
b_major = utils.LayoutEnum.from_tensor(b_tensor).mma_major_mode()
|
||||
|
||||
mma_inst_shape_mn_sfb = (
|
||||
mma_tiler_mn[0] // (2 if use_2cta else 1),
|
||||
cute.round_up(mma_tiler_mn[1], 128),
|
||||
)
|
||||
|
||||
print(f"Creating tiled_mma with sf_vec_size={sf_vec_size}...", flush=True)
|
||||
tiled_mma = sm100_utils.make_blockscaled_trivial_tiled_mma(
|
||||
a_dtype, a_major, b_major, sf_dtype, sf_vec_size,
|
||||
cta_group, mma_tiler_mn)
|
||||
print(f"tiled_mma OK: shape_mnk={tiled_mma.shape_mnk}", flush=True)
|
||||
|
||||
tiled_mma_sfb = sm100_utils.make_blockscaled_trivial_tiled_mma(
|
||||
a_dtype, a_major, b_major, sf_dtype, sf_vec_size,
|
||||
tcgen05.CtaGroup.ONE, mma_inst_shape_mn_sfb)
|
||||
print(f"tiled_mma_sfb OK", flush=True)
|
||||
|
||||
# MMA tiler
|
||||
inst_shape_k = cute.size(tiled_mma.shape_mnk, mode=[2])
|
||||
inst_tile_k = 4
|
||||
k_tile = inst_shape_k * inst_tile_k
|
||||
mma_tiler = (cutlass.Int32(mma_tiler_mn[0]),
|
||||
cutlass.Int32(mma_tiler_mn[1]),
|
||||
cutlass.Int32(k_tile))
|
||||
|
||||
cta_tile_shape_mnk = (
|
||||
mma_tiler[0] // cute.size(tiled_mma.thr_id.shape),
|
||||
mma_tiler[1],
|
||||
mma_tiler[2],
|
||||
)
|
||||
|
||||
cluster_layout_vmnk = cute.tiled_divide(
|
||||
cute.make_layout((*cluster_shape_mn, 1)),
|
||||
(tiled_mma.thr_id.shape,))
|
||||
|
||||
# SMEM layouts
|
||||
num_ab_stages = 2
|
||||
print("Creating SMEM layouts...", flush=True)
|
||||
a_smem_staged = sm100_utils.make_smem_layout_a(tiled_mma, mma_tiler, a_dtype, num_ab_stages)
|
||||
b_smem_staged = sm100_utils.make_smem_layout_b(tiled_mma, mma_tiler, b_dtype, num_ab_stages)
|
||||
sfa_smem_staged = blockscaled_utils.make_smem_layout_sfa(tiled_mma, mma_tiler, sf_vec_size, num_ab_stages)
|
||||
sfb_smem_staged = blockscaled_utils.make_smem_layout_sfb(tiled_mma, mma_tiler, sf_vec_size, num_ab_stages)
|
||||
print("SMEM layouts OK", flush=True)
|
||||
|
||||
# TMA
|
||||
a_smem0 = cute.slice_(a_smem_staged, (None, None, None, 0))
|
||||
b_smem0 = cute.slice_(b_smem_staged, (None, None, None, 0))
|
||||
sfa_smem0 = cute.slice_(sfa_smem_staged, (None, None, None, 0))
|
||||
sfb_smem0 = cute.slice_(sfb_smem_staged, (None, None, None, 0))
|
||||
|
||||
print("Creating TMA atoms...", flush=True)
|
||||
a_op = sm100_utils.cluster_shape_to_tma_atom_A(cluster_shape_mn, tiled_mma.thr_id)
|
||||
tma_a, gA = cute.nvgpu.make_tiled_tma_atom_A(a_op, a_tensor, a_smem0, mma_tiler, tiled_mma, cluster_layout_vmnk.shape)
|
||||
print("TMA A OK", flush=True)
|
||||
|
||||
b_op = sm100_utils.cluster_shape_to_tma_atom_B(cluster_shape_mn, tiled_mma.thr_id)
|
||||
tma_b, gB = cute.nvgpu.make_tiled_tma_atom_B(b_op, b_tensor, b_smem0, mma_tiler, tiled_mma, cluster_layout_vmnk.shape)
|
||||
print("TMA B OK", flush=True)
|
||||
|
||||
tma_sfa, gSFA = cute.nvgpu.make_tiled_tma_atom_A(
|
||||
a_op, sfa_tensor, sfa_smem0, mma_tiler, tiled_mma,
|
||||
cluster_layout_vmnk.shape, internal_type=cutlass.Int16)
|
||||
print("TMA SFA OK", flush=True)
|
||||
|
||||
mma_tiler_sfb = (cutlass.Int32(mma_inst_shape_mn_sfb[0]),
|
||||
cutlass.Int32(mma_inst_shape_mn_sfb[1]),
|
||||
cutlass.Int32(k_tile))
|
||||
cluster_layout_sfb_vmnk = cute.tiled_divide(
|
||||
cute.make_layout((*cluster_shape_mn, 1)),
|
||||
(tiled_mma_sfb.thr_id.shape,))
|
||||
sfb_op = sm100_utils.cluster_shape_to_tma_atom_SFB(cluster_shape_mn, tiled_mma.thr_id)
|
||||
tma_sfb, gSFB = cute.nvgpu.make_tiled_tma_atom_B(
|
||||
sfb_op, sfb_tensor, sfb_smem0, mma_tiler_sfb, tiled_mma_sfb,
|
||||
cluster_layout_sfb_vmnk.shape, internal_type=cutlass.Int16)
|
||||
print("TMA SFB OK", flush=True)
|
||||
|
||||
# Now try compiling the dense GEMM kernel (no custom epilogue)
|
||||
print("Compiling dense_blockscaled GEMM with NVF4...", flush=True)
|
||||
kernel = sm100_utils.Sm100BlockScaledPersistentDenseGemmKernel(
|
||||
a_tensor, b_tensor, c_tensor, sfa_tensor, sfb_tensor,
|
||||
acc_dtype=cutlass.Float32,
|
||||
mma_tiler_mn=mma_tiler_mn,
|
||||
cluster_shape_mn=cluster_shape_mn,
|
||||
sf_vec_size=sf_vec_size,
|
||||
)
|
||||
print("COMPILATION SUCCEEDED! NVF4 CuTeDSL path works.", flush=True)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_nvfp4_cutedsl_compilation()
|
||||
Reference in New Issue
Block a user