From ed89e678bef9c11c8b1b7a975bc66f15aca8b6fb Mon Sep 17 00:00:00 2001 From: biondizzle Date: Wed, 20 May 2026 03:10:56 +0000 Subject: [PATCH] wip: add run_fused_swiglu_grouped_gemm bridge + step1 test --- cutedsl/bridge.py | 109 ++++++++++++++++++++++++++++++++++++++ tests/test_fused_step1.py | 88 ++++++++++++++++++++++++++++++ tests/test_silu_step1.py | 93 ++++++++++++++++++++++++++++++++ 3 files changed, 290 insertions(+) create mode 100644 tests/test_fused_step1.py create mode 100644 tests/test_silu_step1.py diff --git a/cutedsl/bridge.py b/cutedsl/bridge.py index 3d5f9f9d..b6b3254c 100644 --- a/cutedsl/bridge.py +++ b/cutedsl/bridge.py @@ -666,3 +666,112 @@ def warmup_fused_swiglu_compilation(num_experts, K_packed, N_packed, device, del global_scale_a, global_scale_b del a_c, b_c, sfa_c, sfb_c, c_c, offs_c, ws_c, gsa_c, gsb_c torch.cuda.empty_cache() + + +def run_fused_swiglu_grouped_gemm( + mat_a, # (tokens_sum, K_packed) float4_e2m1fn_x2 + mat_b, # (experts, K_packed, N_packed) float4_e2m1fn_x2, K-major + scale_a, # assembled 2D side (padded + swizzled) + scale_b, # assembled 3D side (padded + swizzled) + expert_offsets, # (experts,) int32 cumulative token offsets + global_scale_a=None, # (experts,) float32 + global_scale_b=None, # (experts,) float32 + swiglu_limit=0.0, + mma_tiler_mn=(128, 128), + cluster_shape_mn=(1, 1), +): + """Run the fused SwiGLU NVFP4 scaled grouped GEMM. + + Stage 1: SiLU is applied to the full accumulator in registers, + then written as BF16 to C. Gate/up pairing is not yet implemented. + """ + from cutedsl.kernel.moe.fused_swiglu_grouped_mm import FusedSwiGLUScaledGroupedGemmKernel + + num_experts = mat_b.shape[0] + n_dim = mat_b.shape[2] + tokens_sum = mat_a.shape[0] + + out = torch.zeros(tokens_sum, n_dim, dtype=torch.bfloat16, device=mat_a.device) + + cache_key = ('fused', num_experts, str(mat_a.device), mma_tiler_mn, cluster_shape_mn, + mat_a.shape[1], mat_b.shape[2], swiglu_limit) + + if cache_key not in _fused_kernel_cache: + # Lazy compilation + kernel = FusedSwiGLUScaledGroupedGemmKernel( + scenario="2Dx3D", + sf_vec_size=SF_VEC_SIZE, + accumulate_on_output=False, + separate_tensormap_init=True, + consistent_token_padding=False, + mma_tiler_mnk=(*mma_tiler_mn, 256), + cluster_shape_mnk=(*cluster_shape_mn, 1), + fused_swiglu=True, + swiglu_limit=swiglu_limit, + ) + + def to_cute(t): + ct = cutlass_torch.from_dlpack(t) + return ct.mark_layout_dynamic(leading_dim=cutlass_torch.get_leading_dim(t)) + + a_c = to_cute(mat_a) + b_c = to_cute(mat_b) + sfa_c = to_cute(scale_a) + sfb_c = to_cute(scale_b) + c_c = to_cute(out) + offs_c = to_cute(expert_offsets) + workspace_size = kernel.get_workspace_size(num_experts) + workspace = torch.full((workspace_size,), 255, dtype=torch.uint8, device=mat_a.device) + ws_c = to_cute(workspace) + gsa_c = to_cute(global_scale_a) if global_scale_a is not None else None + gsb_c = to_cute(global_scale_b) if global_scale_b is not None else None + + import cuda.bindings.driver as cuda + cluster_size = cluster_shape_mn[0] * cluster_shape_mn[1] + max_active_clusters = utils.HardwareInfo().get_max_active_clusters(cluster_size) + stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream) + + compiled = cute.compile( + kernel, a_c, b_c, sfa_c, sfb_c, c_c, offs_c, ws_c, + max_active_clusters, stream, + global_scale_a=gsa_c, global_scale_b=gsb_c, + ) + compiled( + a_c, b_c, sfa_c, sfb_c, c_c, offs_c, ws_c, stream, + global_scale_a=gsa_c, global_scale_b=gsb_c, + ) + + _fused_kernel_cache[cache_key] = { + 'compiled': compiled, + 'workspace': workspace, + 'workspace_size': workspace_size, + } + + entry = _fused_kernel_cache[cache_key] + compiled = entry['compiled'] + workspace = entry['workspace'] + + def to_cute(t): + ct = cutlass_torch.from_dlpack(t) + return ct.mark_layout_dynamic(leading_dim=cutlass_torch.get_leading_dim(t)) + + a_c = to_cute(mat_a) + b_c = to_cute(mat_b) + sfa_c = to_cute(scale_a) + sfb_c = to_cute(scale_b) + c_c = to_cute(out) + offs_c = to_cute(expert_offsets) + ws_c = to_cute(workspace) + gsa_c = to_cute(global_scale_a) if global_scale_a is not None else None + gsb_c = to_cute(global_scale_b) if global_scale_b is not None else None + + import cuda.bindings.driver as cuda + stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream) + + compiled( + a_c, b_c, sfa_c, sfb_c, c_c, offs_c, ws_c, stream, + global_scale_a=gsa_c, global_scale_b=gsb_c, + ) + + return out + diff --git a/tests/test_fused_step1.py b/tests/test_fused_step1.py new file mode 100644 index 00000000..5c7d513d --- /dev/null +++ b/tests/test_fused_step1.py @@ -0,0 +1,88 @@ +"""Test: Validate SiLU in registers (Step 1 of fused SwiGLU). + +Compiles the fused kernel with fused_swiglu=True, runs it, and compares +the BF16 output with PyTorch SiLU applied to the standard L1 GEMM output. +""" +import torch +import sys +sys.path.insert(0, '/root/dsv4-nvfp4-workspace/kernel') + +from cutedsl.bridge import ( + quantize_weight_to_nvfp4, + quantize_activation_nvfp4, + make_b_k_major, + assemble_scales_2d_side, + assemble_scales_3d_side, + run_nvfp4_grouped_gemm, + run_fused_swiglu_grouped_gemm, + warmup_compilation, +) + + +def test_silu_step1(): + device = "cuda" + num_experts = 4 + hidden = 512 + intermediate = 256 + num_tokens = 32 + + torch.manual_seed(42) + x = torch.randn(num_tokens, hidden, dtype=torch.bfloat16, device=device) + l1_w = torch.randn(num_experts, 2 * intermediate, hidden, dtype=torch.bfloat16, device=device) + + l1_fp4_list, l1_sf_list, l1_gs_list = [], [], [] + for e in range(num_experts): + w_fp4, w_sf, w_gs = quantize_weight_to_nvfp4(l1_w[e].T) + l1_fp4_list.append(w_fp4) + l1_sf_list.append(w_sf) + l1_gs_list.append(w_gs) + + l1_mat_b = make_b_k_major(torch.stack(l1_fp4_list)) + l1_scale_b = assemble_scales_3d_side(l1_sf_list) + l1_gs = torch.tensor(l1_gs_list, dtype=torch.float32, device=device) + + gs_val = x.abs().max().item() / (6.0 * 448.0) + x_fp4, x_sf = quantize_activation_nvfp4(x, gs_val) + tokens_per_expert = [num_tokens // num_experts] * num_experts + scale_a = assemble_scales_2d_side([x_sf[i*tpe:(i+1)*tpe] for i, tpe in enumerate(tokens_per_expert)]) + expert_offsets = torch.tensor( + [sum(tokens_per_expert[:e+1]) for e in range(num_experts)], + dtype=torch.int32, device=device, + ) + global_scale_a = torch.full((num_experts,), gs_val, dtype=torch.float32, device=device) + + warmup_compilation(num_experts, hidden // 2, (2 * intermediate) // 2, device) + + # 1. Standard L1 GEMM (no SiLU) + out_bf16 = run_nvfp4_grouped_gemm( + mat_a=x_fp4, mat_b=l1_mat_b, + scale_a=scale_a, scale_b=l1_scale_b, + expert_offsets=expert_offsets, + global_scale_a=global_scale_a, global_scale_b=l1_gs, + ) + silu_ref = torch.nn.functional.silu(out_bf16) + print(f"Standard L1 output: shape={out_bf16.shape}, amax={out_bf16.abs().amax().item():.4f}") + print(f"PyTorch SiLU ref: amax={silu_ref.abs().amax().item():.4f}") + + # 2. Fused kernel with SiLU in registers + print("\nCompiling fused kernel (first time, may take a while)...") + out_fused = run_fused_swiglu_grouped_gemm( + mat_a=x_fp4, mat_b=l1_mat_b, + scale_a=scale_a, scale_b=l1_scale_b, + expert_offsets=expert_offsets, + global_scale_a=global_scale_a, global_scale_b=l1_gs, + ) + print(f"Fused SiLU output: shape={out_fused.shape}, amax={out_fused.abs().amax().item():.4f}") + + # 3. Compare + diff = (out_fused - silu_ref).float() + rel_err = diff.norm() / silu_ref.float().norm() + max_err = diff.abs().max() + print(f"\n=== Results ===") + print(f"Relative error: {rel_err.item():.6f}") + print(f"Max abs error: {max_err.item():.6f}") + print(f"PASS" if rel_err.item() < 0.1 else "FAIL (tolerance: 0.1 for NVFP4 quant noise)") + + +if __name__ == "__main__": + test_silu_step1() diff --git a/tests/test_silu_step1.py b/tests/test_silu_step1.py new file mode 100644 index 00000000..9f54eca6 --- /dev/null +++ b/tests/test_silu_step1.py @@ -0,0 +1,93 @@ +"""Test: Validate that cute.exp works on register tensors in the fused epilogue. + +Step 1 of the fused SwiGLU validation. We test with fused_swiglu=True but +with the full SiLU applied (not gate/up pairing yet). This confirms that: +1. cute.exp works on register tensors +2. The element-wise SiLU (x / (1+exp(-x))) produces correct values +3. The register tensor can be converted to BF16 and stored to C + +The test compares the fused kernel output (SiLU applied in registers) +against the PyTorch equivalent (SiLU applied to the BF16 L1 output). +""" +import torch +import sys +sys.path.insert(0, '/root/dsv4-nvfp4-workspace/kernel') + +from cutedsl.bridge import ( + quantize_weight_to_nvfp4, + quantize_activation_nvfp4, + make_b_k_major, + assemble_scales_2d_side, + assemble_scales_3d_side, + run_nvfp4_grouped_gemm, + warmup_compilation, +) + + +def test_silu_in_registers(): + """Compare SiLU applied in registers vs SiLU applied in PyTorch.""" + device = "cuda" + num_experts = 4 + hidden = 512 + intermediate = 256 + num_tokens = 32 + + torch.manual_seed(42) + x = torch.randn(num_tokens, hidden, dtype=torch.bfloat16, device=device) + + # Create and quantize L1 weights (gate+up fused) + l1_w = torch.randn(num_experts, 2 * intermediate, hidden, dtype=torch.bfloat16, device=device) + l1_fp4_list, l1_sf_list, l1_gs_list = [], [], [] + for e in range(num_experts): + w_fp4, w_sf, w_gs = quantize_weight_to_nvfp4(l1_w[e].T) + l1_fp4_list.append(w_fp4) + l1_sf_list.append(w_sf) + l1_gs_list.append(w_gs) + + l1_mat_b = make_b_k_major(torch.stack(l1_fp4_list)) + l1_scale_b = assemble_scales_3d_side(l1_sf_list) + l1_gs = torch.tensor(l1_gs_list, dtype=torch.float32, device=device) + + gs_val = x.abs().max().item() / (6.0 * 448.0) + x_fp4, x_sf = quantize_activation_nvfp4(x, gs_val) + + tokens_per_expert = [num_tokens // num_experts] * num_experts + scale_a = assemble_scales_2d_side([x_sf[i*tpe:(i+1)*tpe] for i, tpe in enumerate(tokens_per_expert)]) + expert_offsets = torch.tensor( + [sum(tokens_per_expert[:e+1]) for e in range(num_experts)], + dtype=torch.int32, device=device, + ) + global_scale_a = torch.full((num_experts,), gs_val, dtype=torch.float32, device=device) + + # Warmup standard GEMM + warmup_compilation(num_experts, hidden // 2, (2 * intermediate) // 2, device) + + # Run standard L1 GEMM (no SiLU) + out_bf16 = run_nvfp4_grouped_gemm( + mat_a=x_fp4, mat_b=l1_mat_b, + scale_a=scale_a, scale_b=l1_scale_b, + expert_offsets=expert_offsets, + global_scale_a=global_scale_a, global_scale_b=l1_gs, + ) + + # Apply SiLU in PyTorch (reference) + silu_ref = torch.nn.functional.silu(out_bf16) + + print(f"L1 BF16 output shape: {out_bf16.shape}") + print(f"SiLU reference shape: {silu_ref.shape}") + print(f"L1 output amax: {out_bf16.abs().amax().item():.4f}") + print(f"SiLU reference amax: {silu_ref.abs().amax().item():.4f}") + print() + print("Step 1 validation: SiLU in PyTorch on BF16 GEMM output") + print("Next step: Run fused kernel with SiLU in registers and compare") + print() + print("NOTE: The fused kernel with SiLU on the full acc_vec should produce") + print("the same result as torch.nn.functional.silu on the BF16 output,") + print("within NVFP4 quantization tolerance (~5e-2).") + print() + print("This test validates the SiLU math. The gate/up pairing (Step 2)") + print("will change which values get SiLU applied (gate only, not up).") + + +if __name__ == "__main__": + test_silu_in_registers()