diff --git a/analyze_layout.py b/analyze_layout.py new file mode 100644 index 00000000..1567f610 --- /dev/null +++ b/analyze_layout.py @@ -0,0 +1,98 @@ +"""Detailed register layout analysis for the fused SwiGLU epilogue. + +Strategy: Use gate=1.0 and up=3.0 weights (distinct ratio) and a row-varying +input (each token has a different scale). The fused output at each (M, N) +position tells us the value. By checking multiple positions, we can determine +which register positions map to which (M, N) addresses. + +With epi_tile=(128, 8), each subtile covers 128 M-rows and 8 N-cols. +The TMA store writes in (M, N) order, so the GMEM output is in row-major order. +The register layout depends on the TiledCopy atom (SM100_TMEM_LOAD_16dp256b1x). + +For 128 epilogue threads and (128, 8) subtiles: + 128 * 8 = 1024 values per subtile + 1024 / 128 = 8 values per thread per subtile + + Possible layouts: + a) 8 N-cols × 1 M-row per thread (contiguous along N) + b) 1 N-col × 8 M-rows per thread (contiguous along M) + c) 4 N-cols × 2 M-rows per thread + d) 2 N-cols × 4 M-rows per thread +""" +import sys, os +sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) +import torch +from cutedsl.bridge import ( + quantize_weight_to_nvfp4, quantize_activation_nvfp4, + make_b_k_major, interleave_l1_weights, deinterleave_l1_weights, + run_fused_swiglu_grouped_gemm, assemble_scales_2d_side, +) +from cutedsl.kernel.moe.torch_scaled_grouped_mm import ( + ceil_div, assemble_raw_scales_2d3d_3d_side, +) + +torch.manual_seed(42) +device = "cuda" +hidden = 7168 +intermediate = 3072 +K_packed = hidden // 2 + +# gate=1.0, up=3.0 — distinct from silu scaling +gate_w = torch.ones(hidden, intermediate, dtype=torch.bfloat16, device=device) +up_w = torch.ones(hidden, intermediate, dtype=torch.bfloat16, device=device) * 3.0 +l1_w = torch.cat([gate_w, up_w], dim=1) +l1_fp4, l1_sf, l1_gs = quantize_weight_to_nvfp4(l1_w) + +l1_ekn = interleave_l1_weights(l1_fp4.unsqueeze(0)) +l1_mat_b = make_b_k_major(l1_ekn) +l1_sf_il = interleave_l1_weights(l1_sf.unsqueeze(0)) +l1_scale_b = assemble_raw_scales_2d3d_3d_side([l1_sf_il[0].T.contiguous()]) +l1_gsb = torch.tensor([l1_gs], dtype=torch.float32, device=device) + +# Input: 128 tokens with VARYING scales (each row has a unique value) +n_tokens = 128 +hidden_states = torch.randn(n_tokens, hidden, dtype=torch.bfloat16, device=device) * 0.01 +# But we want deterministic, so use a known pattern: +# Row i has value i/128 * 0.1 +for i in range(n_tokens): + hidden_states[i] = (i / 128.0) * 0.1 + +gs_a = 1.0 / 2688.0 +x_fp4, x_sf = quantize_activation_nvfp4(hidden_states, gs_a) +expert_offsets = torch.tensor([128, 128, 128], dtype=torch.int32, device=device) +l1_gsa = torch.tensor([gs_a] * 3, dtype=torch.float32, device=device) +l1_scale_a = assemble_scales_2d_side([x_sf]) + +fused_out = run_fused_swiglu_grouped_gemm( + mat_a=x_fp4, mat_b=l1_mat_b, + scale_a=l1_scale_a, scale_b=l1_scale_b, + expert_offsets=expert_offsets, + global_scale_a=l1_gsa, global_scale_b=l1_gsb, +) + +print(f"Fused output shape: {fused_out.shape}") + +# The output should be proportional to the input value. +# Row i has input ≈ i/128 * 0.1, so the GEMM output is proportional to i. +# Gate (cols 0-7, 16-23, ...): silu(gate) ≈ c * i +# Up (cols 8-15, 24-31, ...): silu(gate)*up ≈ 3c * i (since up=3.0) + +# Check the first subtile (cols 0-7, should be gate) +# and second subtile (cols 8-15, should be up) +# For M-rows 0, 1, 2, ... +print("\nM-row | Gate (col 0) | Up (col 8) | Ratio") +for m in [0, 1, 2, 4, 8, 16, 32, 64, 127]: + g = fused_out[m, 0].item() + u = fused_out[m, 8].item() + ratio = u / g if abs(g) > 0.01 else float('inf') + print(f" {m:3d} | {g:12.2f} | {u:12.2f} | {ratio:.2f}") + +# Check if values within a subtile are uniform (same value for all 8 N-cols) +print("\nRow 0, first 16 values (2 subtiles):") +print(f" {[round(v, 2) for v in fused_out[0, :16].float().cpu().tolist()]}") +print(f"Row 1, first 16 values:") +print(f" {[round(v, 2) for v in fused_out[1, :16].float().cpu().tolist()]}") + +# If values within a subtile are uniform (all 8 N-cols have the same value), +# the register layout has 8 N-cols per thread (layout a). +# If they differ across M-rows but same N-col, it's layout b. diff --git a/analyze_output.py b/analyze_output.py new file mode 100644 index 00000000..32b1323f --- /dev/null +++ b/analyze_output.py @@ -0,0 +1,114 @@ +"""Empirical register layout analysis for the fused SwiGLU epilogue. + +With epi_tile=(128, 8), each subtile covers 8 BF16 N-columns and 128 M-rows. +The TMA store writes the BF16 output to GMEM in a deterministic order. + +By running the fused kernel with gate=1.0/up=2.0 interleaved weights, +the output at odd 8-column groups should be silu(gate)*up ≈ 2*silu(1.0), +and even groups should be silu(gate) ≈ silu(1.0). + +This script analyzes the GMEM output to understand: +1. Which 8-column groups are gate vs up (verify interleaving) +2. The BF16 values at specific (M, N) positions +3. Whether the subtile pairing is correct +""" +import sys, os +sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) +import torch +import torch.nn.functional as F +from cutedsl.bridge import ( + quantize_weight_to_nvfp4, quantize_activation_nvfp4, + make_b_k_major, interleave_l1_weights, deinterleave_l1_weights, + compute_expert_offsets, run_fused_swiglu_grouped_gemm, quantize_to_nvfp4, + assemble_scales_2d_side, assemble_scales_3d_side, +) +from cutedsl.kernel.moe.torch_scaled_grouped_mm import assemble_raw_scales_2d3d_3d_side + +torch.manual_seed(42) +device = "cuda" +hidden = 7168 +intermediate = 3072 +K_packed = hidden // 2 + +# gate=1.0, up=2.0 — clear signal for interleaving +gate_w = torch.ones(hidden, intermediate, dtype=torch.bfloat16, device=device) +up_w = torch.ones(hidden, intermediate, dtype=torch.bfloat16, device=device) * 2.0 +l1_w = torch.cat([gate_w, up_w], dim=1) +l1_fp4, l1_sf, l1_gs = quantize_weight_to_nvfp4(l1_w) + +# Interleave +l1_ekn = interleave_l1_weights(l1_fp4.unsqueeze(0)) +l1_mat_b = make_b_k_major(l1_ekn) + +# SF interleave +l1_sf_ekn = l1_sf.unsqueeze(0) +l1_sf_il = interleave_l1_weights(l1_sf_ekn) +l1_sf_il_list = [l1_sf_il[0].T.contiguous()] +l1_scale_b = assemble_raw_scales_2d3d_3d_side(l1_sf_il_list) +l1_gsb = torch.tensor([l1_gs], dtype=torch.float32, device=device) + +# Input: 128 tokens, all 0.1 +n_tokens = 128 +hidden_states = torch.ones(n_tokens, hidden, dtype=torch.bfloat16, device=device) * 0.1 +gs_a = 1.0 / 2688.0 +x_fp4, x_sf = quantize_activation_nvfp4(hidden_states, gs_a) +expert_offsets = torch.tensor([128, 128, 128], dtype=torch.int32, device=device) +l1_gsa = torch.tensor([gs_a] * 3, dtype=torch.float32, device=device) + +from cutedsl.kernel.moe.torch_scaled_grouped_mm import ceil_div, pad_and_swizzle_single +K_sf = ceil_div(K_packed, 8) +x_sf_parts = [x_sf] +l1_scale_a = assemble_scales_2d_side(x_sf_parts) + +# Run fused kernel +fused_out = run_fused_swiglu_grouped_gemm( + mat_a=x_fp4, mat_b=l1_mat_b, + scale_a=l1_scale_a, scale_b=l1_scale_b, + expert_offsets=expert_offsets, + global_scale_a=l1_gsa, global_scale_b=l1_gsb, +) + +print(f"Fused output shape: {fused_out.shape}") +out0 = fused_out[0].float() # First token, as float32 + +# BF16 reference: silu(1.0) = 0.7311, silu(1.0)*2.0 = 1.4621 +import math +silu_one = 1.0 / (1.0 + math.exp(-1.0)) # sigmoid(1.0) = 0.7311 +silu_two = 2.0 / (1.0 + math.exp(-2.0)) # sigmoid(2.0) = 1.7616 + +# Compute expected values +# With input 0.1 and gate=1.0, the gate GEMM output ≈ 7168 * 0.1 * 1.0 * gs_a * gs_b +# Let's check empirically +gate_vals = [] +up_vals = [] +for i in range(0, 64, 8): + chunk = out0[i:i+8].tolist() + is_gate = all(abs(v - out0[0].item()) < 1.0 for v in chunk) + label = "gate(silu)" if is_gate else "up(swiglu)" + if is_gate: + gate_vals.append(out0[i].item()) + else: + up_vals.append(out0[i].item()) + print(f" Cols {i:3d}-{i+7:3d}: {[round(v,2) for v in chunk]} → {label}") + +if gate_vals and up_vals: + g = gate_vals[0] + u = up_vals[0] + print(f"\nGate ≈ {g:.4f}, Up ≈ {u:.4f}") + print(f"Ratio up/gate ≈ {u/g:.4f}") + print(f"Expected silu(1.0) ≈ {silu_one:.4f}, silu(1.0)*2.0 ≈ {2*silu_one:.4f}") + print(f"Actual gate/expected_gate ≈ {g / (7168 * 0.1 * 1.0 * gs_a * l1_gs):.4f}") + +# Now check the de-interleaved SwiGLU output +l1_deil = deinterleave_l1_weights(fused_out.unsqueeze(0).contiguous())[0] +swiglu_result = l1_deil[:, intermediate:] +silu_gate = l1_deil[:, :intermediate] + +print(f"\nDe-interleaved silu(gate) amax: {silu_gate.abs().amax():.4f}") +print(f"De-interleaved SwiGLU amax: {swiglu_result.abs().amax():.4f}") +print(f"SwiGLU/silu(gate) ratio: {(swiglu_result[0,0] / silu_gate[0,0]):.4f}") + +# Verify: quantize the SwiGLU result and check it matches the Python quantize path +x_fp4, x_sf, gs = quantize_to_nvfp4(swiglu_result) +print(f"\nQuantized SwiGLU: FP4 shape={x_fp4.shape}, SF shape={x_sf.shape}, gs={gs:.8f}") +print(f"FP4 amax (uint8): {x_fp4.view(torch.uint8).amax()}") diff --git a/bench_fused.py b/bench_fused.py new file mode 100644 index 00000000..11bef7fa --- /dev/null +++ b/bench_fused.py @@ -0,0 +1,54 @@ +import sys, os, torch, time +sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) +from cutedsl.moe_pipeline import run_nvfp4_moe, run_nvfp4_moe_fused, quantize_weight + +torch.manual_seed(42) +device = "cuda" +num_experts = 3 +hidden = 7168 +intermediate = 3072 +n_tokens = 128 + +l1_weights = [torch.randn(2*intermediate, hidden, dtype=torch.bfloat16, device=device) for _ in range(num_experts)] +l2_weights = [torch.randn(hidden, intermediate, dtype=torch.bfloat16, device=device) for _ in range(num_experts)] + +l1_fp4, l1_sf, l1_gs = [], [], [] +l2_fp4, l2_sf, l2_gs = [], [], [] +for l1_w, l2_w in zip(l1_weights, l2_weights): + fp4, sf, gs = quantize_weight(l1_w) + l1_fp4.append(fp4); l1_sf.append(sf); l1_gs.append(gs) + fp4, sf, gs = quantize_weight(l2_w) + l2_fp4.append(fp4); l2_sf.append(sf); l2_gs.append(gs) + +weights = { + "l1_fp4": l1_fp4, "l1_sf": l1_sf, "l1_gs": l1_gs, + "l2_fp4": l2_fp4, "l2_sf": l2_sf, "l2_gs": l2_gs, +} + +hidden_states = torch.randn(n_tokens, hidden, dtype=torch.bfloat16, device=device) * 0.1 +expert_ids = torch.zeros(n_tokens, 1, dtype=torch.int32, device=device) +expert_weights = torch.ones(n_tokens, 1, dtype=torch.float32, device=device) +expert_indices = [0, 1, 2] + +# Warmup +_ = run_nvfp4_moe(hidden_states, expert_ids, expert_weights, weights, expert_indices) +_ = run_nvfp4_moe_fused(hidden_states, expert_ids, expert_weights, weights, expert_indices) +torch.cuda.synchronize() + +# Benchmark +N = 50 +t0 = time.perf_counter() +for _ in range(N): + out1 = run_nvfp4_moe(hidden_states, expert_ids, expert_weights, weights, expert_indices) +torch.cuda.synchronize() +t1 = time.perf_counter() + +for _ in range(N): + out2 = run_nvfp4_moe_fused(hidden_states, expert_ids, expert_weights, weights, expert_indices) +torch.cuda.synchronize() +t2 = time.perf_counter() + +print(f"Non-fused: {(t1-t0)/N*1000:.2f} ms/iter") +print(f"Fused: {(t2-t1)/N*1000:.2f} ms/iter") +print(f"Speedup: {(t1-t0)/(t2-t1):.2f}x") +print(f"Output match: {torch.allclose(out1.float(), out2.float(), atol=1.0)}") diff --git a/cutedsl/moe_pipeline.py b/cutedsl/moe_pipeline.py index de8a685a..8de231c0 100644 --- a/cutedsl/moe_pipeline.py +++ b/cutedsl/moe_pipeline.py @@ -204,9 +204,6 @@ def run_nvfp4_moe( # Global scales: alpha = igs * weight_gs for each expert l1_global_scale_a = torch.tensor([x_igs] * num_experts, dtype=torch.float32, device=device) l1_global_scale_b = torch.tensor(weights['l1_gs'], dtype=torch.float32, device=device) - print(f" L1 global_scale_a: {l1_global_scale_a.tolist()}", flush=True) - print(f" L1 global_scale_b: {l1_global_scale_b.tolist()}", flush=True) - print(f" alpha (a*b): {(l1_global_scale_a * l1_global_scale_b).tolist()}", flush=True) # Run L1 GEMM l1_out = run_nvfp4_grouped_gemm( @@ -215,7 +212,6 @@ def run_nvfp4_moe( expert_offsets=expert_offsets, global_scale_a=l1_global_scale_a, global_scale_b=l1_global_scale_b, ) # (num_slots, 2*intermediate) BF16 - print(f" L1 GEMM output: shape={l1_out.shape}, amax={l1_out.abs().amax().item():.4f}", flush=True) # ════════════════════════════════════════════════════════════════ # SiLU(gate) * up (BF16 — nonlinear requires BF16) @@ -226,14 +222,11 @@ def run_nvfp4_moe( l1_deil = deinterleave_l1_weights(l1_out.unsqueeze(0).contiguous())[0] gate = l1_deil[:, :intermediate_size] up = l1_deil[:, intermediate_size:] - print(f" gate: shape={gate.shape}, amax={gate.abs().amax().item():.4f}", flush=True) - print(f" up: shape={up.shape}, amax={up.abs().amax().item():.4f}", flush=True) gate_silu = torch.nn.functional.silu(gate) if swiglu_limit is not None: gate_silu = gate_silu.clamp(max=swiglu_limit) up = up.clamp(min=-swiglu_limit, max=swiglu_limit) activated = gate_silu * up # (num_slots, intermediate) BF16 - print(f" After SiLU(gate)*up: shape={activated.shape}, amax={activated.abs().amax().item():.4f}", flush=True) # ════════════════════════════════════════════════════════════════ # L2: down projection (NVFP4 × NVFP4 → BF16) @@ -374,7 +367,6 @@ def run_nvfp4_moe_fused( intermediate_size = l1_fused_out.shape[1] // 2 l1_deil = deinterleave_l1_weights(l1_fused_out.unsqueeze(0).contiguous())[0] activated = l1_deil[:, intermediate_size:] # up columns = SwiGLU result - print(f" Fused SwiGLU: shape={activated.shape}, amax={activated.abs().amax().item():.4f}", flush=True) # === L2: down projection (same as non-fused) ===