diff --git a/tests/unit/test_nvfp4_cutedsl_compile.py b/tests/unit/test_nvfp4_cutedsl_compile.py new file mode 100644 index 00000000..823eb5bf --- /dev/null +++ b/tests/unit/test_nvfp4_cutedsl_compile.py @@ -0,0 +1,165 @@ +"""Test: Verify NVFP4 CuTeDSL compilation with MmaMXF4NVF4Op (sf_vec_size=16). + +This test does NOT run the kernel — it only verifies that the CuTeDSL JIT +compiler can handle the NVF4 block-scaled GEMM with proper pipeline abstractions. +If this compiles, we can add the custom epilogue. +""" + +import torch +import cutlass +import cutlass.cute as cute +from cutlass.cute.nvgpu import cpasync, tcgen05 +import cutlass.utils as utils +import cutlass.pipeline as pipeline +import cutlass.utils.blackwell_helpers as sm100_utils +import cutlass.utils.blockscaled_layout as blockscaled_utils +import cutlass.torch as cutlass_torch + +from dsv4.ops.quantize import quantize_weight_to_nvfp4, quantize_activation_nvfp4 +from dsv4.ops.layouts import make_b_k_major, assemble_raw_scales_2d3d_3d_side + + +def test_nvfp4_cutedsl_compilation(): + """Test that NVFP4 block-scaled GEMM compiles with CuTeDSL.""" + device = "cuda:0" + M, N, K = 1, 384, 7168 + top_k = 6 + + # Quantize + gsa = 1.0 / (6.0 * 448.0) + hs = torch.randn(M, K, dtype=torch.bfloat16, device=device) + x_fp4, x_sf = quantize_activation_nvfp4(hs, gsa) + + W = torch.randn(K, N, dtype=torch.bfloat16, device=device) + w_fp4, w_sf, w_gs = quantize_weight_to_nvfp4(W) + stacked = torch.stack([w_fp4]).permute(0, 2, 1).contiguous() + mat_b = make_b_k_major(stacked) + scale_b = assemble_raw_scales_2d3d_3d_side([w_sf.T.contiguous()]) + + print(f"x_fp4: {x_fp4.shape}, dtype={x_fp4.dtype}") + print(f"x_sf: {x_sf.shape}, dtype={x_sf.dtype}") + print(f"mat_b: {mat_b.shape}, dtype={mat_b.dtype}") + print(f"scale_b: {scale_b.shape}, dtype={scale_b.dtype}") + + # Convert to CuTe tensors + a_tensor = cutlass_torch.from_dlpack(x_fp4) + a_tensor = a_tensor.mark_layout_dynamic(leading_dim=cutlass_torch.get_leading_dim(x_fp4)) + + b_tensor = cutlass_torch.from_dlpack(mat_b) + b_tensor = b_tensor.mark_layout_dynamic(leading_dim=cutlass_torch.get_leading_dim(mat_b)) + + sfa_tensor = cutlass_torch.from_dlpack(x_sf) + sfa_tensor = sfa_tensor.mark_layout_dynamic(leading_dim=cutlass_torch.get_leading_dim(x_sf)) + + sfb_tensor = cutlass_torch.from_dlpack(scale_b) + sfb_tensor = sfb_tensor.mark_layout_dynamic(leading_dim=cutlass_torch.get_leading_dim(scale_b)) + + c_tensor = cutlass_torch.make_tensor( + torch.empty(M, N, dtype=torch.bfloat16, device=device)) + + print("CuTe tensors created OK") + + # ---- Setup exactly like dense.py ---- + sf_vec_size = 16 # NVF4 + a_dtype = cutlass.Float4E2M1FN + b_dtype = cutlass.Float4E2M1FN + sf_dtype = cutlass.Float8E4M3FN + c_dtype = cutlass.BFloat16 + + mma_tiler_mn = (128, 128) + cluster_shape_mn = (1, 1) + use_2cta = False + cta_group = tcgen05.CtaGroup.ONE + + a_major = utils.LayoutEnum.from_tensor(a_tensor).mma_major_mode() + b_major = utils.LayoutEnum.from_tensor(b_tensor).mma_major_mode() + + mma_inst_shape_mn_sfb = ( + mma_tiler_mn[0] // (2 if use_2cta else 1), + cute.round_up(mma_tiler_mn[1], 128), + ) + + print(f"Creating tiled_mma with sf_vec_size={sf_vec_size}...", flush=True) + tiled_mma = sm100_utils.make_blockscaled_trivial_tiled_mma( + a_dtype, a_major, b_major, sf_dtype, sf_vec_size, + cta_group, mma_tiler_mn) + print(f"tiled_mma OK: shape_mnk={tiled_mma.shape_mnk}", flush=True) + + tiled_mma_sfb = sm100_utils.make_blockscaled_trivial_tiled_mma( + a_dtype, a_major, b_major, sf_dtype, sf_vec_size, + tcgen05.CtaGroup.ONE, mma_inst_shape_mn_sfb) + print(f"tiled_mma_sfb OK", flush=True) + + # MMA tiler + inst_shape_k = cute.size(tiled_mma.shape_mnk, mode=[2]) + inst_tile_k = 4 + k_tile = inst_shape_k * inst_tile_k + mma_tiler = (cutlass.Int32(mma_tiler_mn[0]), + cutlass.Int32(mma_tiler_mn[1]), + cutlass.Int32(k_tile)) + + cta_tile_shape_mnk = ( + mma_tiler[0] // cute.size(tiled_mma.thr_id.shape), + mma_tiler[1], + mma_tiler[2], + ) + + cluster_layout_vmnk = cute.tiled_divide( + cute.make_layout((*cluster_shape_mn, 1)), + (tiled_mma.thr_id.shape,)) + + # SMEM layouts + num_ab_stages = 2 + print("Creating SMEM layouts...", flush=True) + a_smem_staged = sm100_utils.make_smem_layout_a(tiled_mma, mma_tiler, a_dtype, num_ab_stages) + b_smem_staged = sm100_utils.make_smem_layout_b(tiled_mma, mma_tiler, b_dtype, num_ab_stages) + sfa_smem_staged = blockscaled_utils.make_smem_layout_sfa(tiled_mma, mma_tiler, sf_vec_size, num_ab_stages) + sfb_smem_staged = blockscaled_utils.make_smem_layout_sfb(tiled_mma, mma_tiler, sf_vec_size, num_ab_stages) + print("SMEM layouts OK", flush=True) + + # TMA + a_smem0 = cute.slice_(a_smem_staged, (None, None, None, 0)) + b_smem0 = cute.slice_(b_smem_staged, (None, None, None, 0)) + sfa_smem0 = cute.slice_(sfa_smem_staged, (None, None, None, 0)) + sfb_smem0 = cute.slice_(sfb_smem_staged, (None, None, None, 0)) + + print("Creating TMA atoms...", flush=True) + a_op = sm100_utils.cluster_shape_to_tma_atom_A(cluster_shape_mn, tiled_mma.thr_id) + tma_a, gA = cute.nvgpu.make_tiled_tma_atom_A(a_op, a_tensor, a_smem0, mma_tiler, tiled_mma, cluster_layout_vmnk.shape) + print("TMA A OK", flush=True) + + b_op = sm100_utils.cluster_shape_to_tma_atom_B(cluster_shape_mn, tiled_mma.thr_id) + tma_b, gB = cute.nvgpu.make_tiled_tma_atom_B(b_op, b_tensor, b_smem0, mma_tiler, tiled_mma, cluster_layout_vmnk.shape) + print("TMA B OK", flush=True) + + tma_sfa, gSFA = cute.nvgpu.make_tiled_tma_atom_A( + a_op, sfa_tensor, sfa_smem0, mma_tiler, tiled_mma, + cluster_layout_vmnk.shape, internal_type=cutlass.Int16) + print("TMA SFA OK", flush=True) + + mma_tiler_sfb = (cutlass.Int32(mma_inst_shape_mn_sfb[0]), + cutlass.Int32(mma_inst_shape_mn_sfb[1]), + cutlass.Int32(k_tile)) + cluster_layout_sfb_vmnk = cute.tiled_divide( + cute.make_layout((*cluster_shape_mn, 1)), + (tiled_mma_sfb.thr_id.shape,)) + sfb_op = sm100_utils.cluster_shape_to_tma_atom_SFB(cluster_shape_mn, tiled_mma.thr_id) + tma_sfb, gSFB = cute.nvgpu.make_tiled_tma_atom_B( + sfb_op, sfb_tensor, sfb_smem0, mma_tiler_sfb, tiled_mma_sfb, + cluster_layout_sfb_vmnk.shape, internal_type=cutlass.Int16) + print("TMA SFB OK", flush=True) + + # Now try compiling the dense GEMM kernel (no custom epilogue) + print("Compiling dense_blockscaled GEMM with NVF4...", flush=True) + kernel = sm100_utils.Sm100BlockScaledPersistentDenseGemmKernel( + a_tensor, b_tensor, c_tensor, sfa_tensor, sfb_tensor, + acc_dtype=cutlass.Float32, + mma_tiler_mn=mma_tiler_mn, + cluster_shape_mn=cluster_shape_mn, + sf_vec_size=sf_vec_size, + ) + print("COMPILATION SUCCEEDED! NVF4 CuTeDSL path works.", flush=True) + + +if __name__ == "__main__": + test_nvfp4_cutedsl_compilation()