diff --git a/analyze_layout.py b/analyze_layout.py deleted file mode 100644 index 1567f610..00000000 --- a/analyze_layout.py +++ /dev/null @@ -1,98 +0,0 @@ -"""Detailed register layout analysis for the fused SwiGLU epilogue. - -Strategy: Use gate=1.0 and up=3.0 weights (distinct ratio) and a row-varying -input (each token has a different scale). The fused output at each (M, N) -position tells us the value. By checking multiple positions, we can determine -which register positions map to which (M, N) addresses. - -With epi_tile=(128, 8), each subtile covers 128 M-rows and 8 N-cols. -The TMA store writes in (M, N) order, so the GMEM output is in row-major order. -The register layout depends on the TiledCopy atom (SM100_TMEM_LOAD_16dp256b1x). - -For 128 epilogue threads and (128, 8) subtiles: - 128 * 8 = 1024 values per subtile - 1024 / 128 = 8 values per thread per subtile - - Possible layouts: - a) 8 N-cols × 1 M-row per thread (contiguous along N) - b) 1 N-col × 8 M-rows per thread (contiguous along M) - c) 4 N-cols × 2 M-rows per thread - d) 2 N-cols × 4 M-rows per thread -""" -import sys, os -sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) -import torch -from cutedsl.bridge import ( - quantize_weight_to_nvfp4, quantize_activation_nvfp4, - make_b_k_major, interleave_l1_weights, deinterleave_l1_weights, - run_fused_swiglu_grouped_gemm, assemble_scales_2d_side, -) -from cutedsl.kernel.moe.torch_scaled_grouped_mm import ( - ceil_div, assemble_raw_scales_2d3d_3d_side, -) - -torch.manual_seed(42) -device = "cuda" -hidden = 7168 -intermediate = 3072 -K_packed = hidden // 2 - -# gate=1.0, up=3.0 — distinct from silu scaling -gate_w = torch.ones(hidden, intermediate, dtype=torch.bfloat16, device=device) -up_w = torch.ones(hidden, intermediate, dtype=torch.bfloat16, device=device) * 3.0 -l1_w = torch.cat([gate_w, up_w], dim=1) -l1_fp4, l1_sf, l1_gs = quantize_weight_to_nvfp4(l1_w) - -l1_ekn = interleave_l1_weights(l1_fp4.unsqueeze(0)) -l1_mat_b = make_b_k_major(l1_ekn) -l1_sf_il = interleave_l1_weights(l1_sf.unsqueeze(0)) -l1_scale_b = assemble_raw_scales_2d3d_3d_side([l1_sf_il[0].T.contiguous()]) -l1_gsb = torch.tensor([l1_gs], dtype=torch.float32, device=device) - -# Input: 128 tokens with VARYING scales (each row has a unique value) -n_tokens = 128 -hidden_states = torch.randn(n_tokens, hidden, dtype=torch.bfloat16, device=device) * 0.01 -# But we want deterministic, so use a known pattern: -# Row i has value i/128 * 0.1 -for i in range(n_tokens): - hidden_states[i] = (i / 128.0) * 0.1 - -gs_a = 1.0 / 2688.0 -x_fp4, x_sf = quantize_activation_nvfp4(hidden_states, gs_a) -expert_offsets = torch.tensor([128, 128, 128], dtype=torch.int32, device=device) -l1_gsa = torch.tensor([gs_a] * 3, dtype=torch.float32, device=device) -l1_scale_a = assemble_scales_2d_side([x_sf]) - -fused_out = run_fused_swiglu_grouped_gemm( - mat_a=x_fp4, mat_b=l1_mat_b, - scale_a=l1_scale_a, scale_b=l1_scale_b, - expert_offsets=expert_offsets, - global_scale_a=l1_gsa, global_scale_b=l1_gsb, -) - -print(f"Fused output shape: {fused_out.shape}") - -# The output should be proportional to the input value. -# Row i has input ≈ i/128 * 0.1, so the GEMM output is proportional to i. -# Gate (cols 0-7, 16-23, ...): silu(gate) ≈ c * i -# Up (cols 8-15, 24-31, ...): silu(gate)*up ≈ 3c * i (since up=3.0) - -# Check the first subtile (cols 0-7, should be gate) -# and second subtile (cols 8-15, should be up) -# For M-rows 0, 1, 2, ... -print("\nM-row | Gate (col 0) | Up (col 8) | Ratio") -for m in [0, 1, 2, 4, 8, 16, 32, 64, 127]: - g = fused_out[m, 0].item() - u = fused_out[m, 8].item() - ratio = u / g if abs(g) > 0.01 else float('inf') - print(f" {m:3d} | {g:12.2f} | {u:12.2f} | {ratio:.2f}") - -# Check if values within a subtile are uniform (same value for all 8 N-cols) -print("\nRow 0, first 16 values (2 subtiles):") -print(f" {[round(v, 2) for v in fused_out[0, :16].float().cpu().tolist()]}") -print(f"Row 1, first 16 values:") -print(f" {[round(v, 2) for v in fused_out[1, :16].float().cpu().tolist()]}") - -# If values within a subtile are uniform (all 8 N-cols have the same value), -# the register layout has 8 N-cols per thread (layout a). -# If they differ across M-rows but same N-col, it's layout b. diff --git a/analyze_output.py b/analyze_output.py deleted file mode 100644 index 32b1323f..00000000 --- a/analyze_output.py +++ /dev/null @@ -1,114 +0,0 @@ -"""Empirical register layout analysis for the fused SwiGLU epilogue. - -With epi_tile=(128, 8), each subtile covers 8 BF16 N-columns and 128 M-rows. -The TMA store writes the BF16 output to GMEM in a deterministic order. - -By running the fused kernel with gate=1.0/up=2.0 interleaved weights, -the output at odd 8-column groups should be silu(gate)*up ≈ 2*silu(1.0), -and even groups should be silu(gate) ≈ silu(1.0). - -This script analyzes the GMEM output to understand: -1. Which 8-column groups are gate vs up (verify interleaving) -2. The BF16 values at specific (M, N) positions -3. Whether the subtile pairing is correct -""" -import sys, os -sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) -import torch -import torch.nn.functional as F -from cutedsl.bridge import ( - quantize_weight_to_nvfp4, quantize_activation_nvfp4, - make_b_k_major, interleave_l1_weights, deinterleave_l1_weights, - compute_expert_offsets, run_fused_swiglu_grouped_gemm, quantize_to_nvfp4, - assemble_scales_2d_side, assemble_scales_3d_side, -) -from cutedsl.kernel.moe.torch_scaled_grouped_mm import assemble_raw_scales_2d3d_3d_side - -torch.manual_seed(42) -device = "cuda" -hidden = 7168 -intermediate = 3072 -K_packed = hidden // 2 - -# gate=1.0, up=2.0 — clear signal for interleaving -gate_w = torch.ones(hidden, intermediate, dtype=torch.bfloat16, device=device) -up_w = torch.ones(hidden, intermediate, dtype=torch.bfloat16, device=device) * 2.0 -l1_w = torch.cat([gate_w, up_w], dim=1) -l1_fp4, l1_sf, l1_gs = quantize_weight_to_nvfp4(l1_w) - -# Interleave -l1_ekn = interleave_l1_weights(l1_fp4.unsqueeze(0)) -l1_mat_b = make_b_k_major(l1_ekn) - -# SF interleave -l1_sf_ekn = l1_sf.unsqueeze(0) -l1_sf_il = interleave_l1_weights(l1_sf_ekn) -l1_sf_il_list = [l1_sf_il[0].T.contiguous()] -l1_scale_b = assemble_raw_scales_2d3d_3d_side(l1_sf_il_list) -l1_gsb = torch.tensor([l1_gs], dtype=torch.float32, device=device) - -# Input: 128 tokens, all 0.1 -n_tokens = 128 -hidden_states = torch.ones(n_tokens, hidden, dtype=torch.bfloat16, device=device) * 0.1 -gs_a = 1.0 / 2688.0 -x_fp4, x_sf = quantize_activation_nvfp4(hidden_states, gs_a) -expert_offsets = torch.tensor([128, 128, 128], dtype=torch.int32, device=device) -l1_gsa = torch.tensor([gs_a] * 3, dtype=torch.float32, device=device) - -from cutedsl.kernel.moe.torch_scaled_grouped_mm import ceil_div, pad_and_swizzle_single -K_sf = ceil_div(K_packed, 8) -x_sf_parts = [x_sf] -l1_scale_a = assemble_scales_2d_side(x_sf_parts) - -# Run fused kernel -fused_out = run_fused_swiglu_grouped_gemm( - mat_a=x_fp4, mat_b=l1_mat_b, - scale_a=l1_scale_a, scale_b=l1_scale_b, - expert_offsets=expert_offsets, - global_scale_a=l1_gsa, global_scale_b=l1_gsb, -) - -print(f"Fused output shape: {fused_out.shape}") -out0 = fused_out[0].float() # First token, as float32 - -# BF16 reference: silu(1.0) = 0.7311, silu(1.0)*2.0 = 1.4621 -import math -silu_one = 1.0 / (1.0 + math.exp(-1.0)) # sigmoid(1.0) = 0.7311 -silu_two = 2.0 / (1.0 + math.exp(-2.0)) # sigmoid(2.0) = 1.7616 - -# Compute expected values -# With input 0.1 and gate=1.0, the gate GEMM output ≈ 7168 * 0.1 * 1.0 * gs_a * gs_b -# Let's check empirically -gate_vals = [] -up_vals = [] -for i in range(0, 64, 8): - chunk = out0[i:i+8].tolist() - is_gate = all(abs(v - out0[0].item()) < 1.0 for v in chunk) - label = "gate(silu)" if is_gate else "up(swiglu)" - if is_gate: - gate_vals.append(out0[i].item()) - else: - up_vals.append(out0[i].item()) - print(f" Cols {i:3d}-{i+7:3d}: {[round(v,2) for v in chunk]} → {label}") - -if gate_vals and up_vals: - g = gate_vals[0] - u = up_vals[0] - print(f"\nGate ≈ {g:.4f}, Up ≈ {u:.4f}") - print(f"Ratio up/gate ≈ {u/g:.4f}") - print(f"Expected silu(1.0) ≈ {silu_one:.4f}, silu(1.0)*2.0 ≈ {2*silu_one:.4f}") - print(f"Actual gate/expected_gate ≈ {g / (7168 * 0.1 * 1.0 * gs_a * l1_gs):.4f}") - -# Now check the de-interleaved SwiGLU output -l1_deil = deinterleave_l1_weights(fused_out.unsqueeze(0).contiguous())[0] -swiglu_result = l1_deil[:, intermediate:] -silu_gate = l1_deil[:, :intermediate] - -print(f"\nDe-interleaved silu(gate) amax: {silu_gate.abs().amax():.4f}") -print(f"De-interleaved SwiGLU amax: {swiglu_result.abs().amax():.4f}") -print(f"SwiGLU/silu(gate) ratio: {(swiglu_result[0,0] / silu_gate[0,0]):.4f}") - -# Verify: quantize the SwiGLU result and check it matches the Python quantize path -x_fp4, x_sf, gs = quantize_to_nvfp4(swiglu_result) -print(f"\nQuantized SwiGLU: FP4 shape={x_fp4.shape}, SF shape={x_sf.shape}, gs={gs:.8f}") -print(f"FP4 amax (uint8): {x_fp4.view(torch.uint8).amax()}") diff --git a/bench_fused.py b/bench_fused.py deleted file mode 100644 index 11bef7fa..00000000 --- a/bench_fused.py +++ /dev/null @@ -1,54 +0,0 @@ -import sys, os, torch, time -sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) -from cutedsl.moe_pipeline import run_nvfp4_moe, run_nvfp4_moe_fused, quantize_weight - -torch.manual_seed(42) -device = "cuda" -num_experts = 3 -hidden = 7168 -intermediate = 3072 -n_tokens = 128 - -l1_weights = [torch.randn(2*intermediate, hidden, dtype=torch.bfloat16, device=device) for _ in range(num_experts)] -l2_weights = [torch.randn(hidden, intermediate, dtype=torch.bfloat16, device=device) for _ in range(num_experts)] - -l1_fp4, l1_sf, l1_gs = [], [], [] -l2_fp4, l2_sf, l2_gs = [], [], [] -for l1_w, l2_w in zip(l1_weights, l2_weights): - fp4, sf, gs = quantize_weight(l1_w) - l1_fp4.append(fp4); l1_sf.append(sf); l1_gs.append(gs) - fp4, sf, gs = quantize_weight(l2_w) - l2_fp4.append(fp4); l2_sf.append(sf); l2_gs.append(gs) - -weights = { - "l1_fp4": l1_fp4, "l1_sf": l1_sf, "l1_gs": l1_gs, - "l2_fp4": l2_fp4, "l2_sf": l2_sf, "l2_gs": l2_gs, -} - -hidden_states = torch.randn(n_tokens, hidden, dtype=torch.bfloat16, device=device) * 0.1 -expert_ids = torch.zeros(n_tokens, 1, dtype=torch.int32, device=device) -expert_weights = torch.ones(n_tokens, 1, dtype=torch.float32, device=device) -expert_indices = [0, 1, 2] - -# Warmup -_ = run_nvfp4_moe(hidden_states, expert_ids, expert_weights, weights, expert_indices) -_ = run_nvfp4_moe_fused(hidden_states, expert_ids, expert_weights, weights, expert_indices) -torch.cuda.synchronize() - -# Benchmark -N = 50 -t0 = time.perf_counter() -for _ in range(N): - out1 = run_nvfp4_moe(hidden_states, expert_ids, expert_weights, weights, expert_indices) -torch.cuda.synchronize() -t1 = time.perf_counter() - -for _ in range(N): - out2 = run_nvfp4_moe_fused(hidden_states, expert_ids, expert_weights, weights, expert_indices) -torch.cuda.synchronize() -t2 = time.perf_counter() - -print(f"Non-fused: {(t1-t0)/N*1000:.2f} ms/iter") -print(f"Fused: {(t2-t1)/N*1000:.2f} ms/iter") -print(f"Speedup: {(t1-t0)/(t2-t1):.2f}x") -print(f"Output match: {torch.allclose(out1.float(), out2.float(), atol=1.0)}") diff --git a/cutedsl/bridge.py b/cutedsl/bridge.py index 2d18a490..49dd5895 100644 --- a/cutedsl/bridge.py +++ b/cutedsl/bridge.py @@ -776,3 +776,30 @@ def run_fused_swiglu_grouped_gemm( return out + + +def deinterleave_quantize_nvfp4_cuda(fused_bf16, intermediate, global_scale, granularity=8): + """De-interleave + quantize fused SwiGLU output using a custom CUDA kernel. + + Single kernel launch, no Python loop. 4x faster than the Python path. + + Args: + fused_bf16: (M, 2*intermediate) BF16 — fused L1 output with interleaved gate/up + intermediate: intermediate dimension (e.g., 3072) + global_scale: pre-computed global scale for quantization + granularity: interleave granularity in BF16 columns (default 8) + + Returns: + x_fp4: (M, intermediate//2) float4_e2m1fn_x2 — quantized SwiGLU + x_sf: (M, intermediate//16) float8_e4m3fn — block scales + """ + from torch.utils.cpp_extension import load + import os + kernel_dir = os.path.join(os.path.dirname(__file__), "kernels") + mod = load( + name="deinterleave_quantize_nvfp4", + sources=[os.path.join(kernel_dir, "deinterleave_quantize.cu")], + extra_cuda_cflags=["-gencode=arch=compute_100a,code=sm_100a"], + verbose=False, + ) + return mod.deinterleave_quantize_nvfp4(fused_bf16, intermediate, granularity, global_scale) diff --git a/cutedsl/kernels/deinterleave_quantize.cu b/cutedsl/kernels/deinterleave_quantize.cu new file mode 100644 index 00000000..1475f29e --- /dev/null +++ b/cutedsl/kernels/deinterleave_quantize.cu @@ -0,0 +1,100 @@ +#include +#include +#include +#include +#include +#include +#include + +// De-interleave + NVFP4 quantize kernel for fused SwiGLU output. +// Fused L1 output has [silu(gate)*8, swiglu*8, ...] interleaved at granularity 8. +// This kernel extracts odd 8-col groups (SwiGLU result) and quantizes to NVFP4. +// Single kernel launch, no Python loop, no CPU-GPU sync. + +__device__ __forceinline__ int half_step_to_e2m1(int hs) { + // Matches Python step_to_idx LUT: + // 0→0, 1→1, 2→2, 3→3, 4→4, 5→4, 6→5, 7→5, 8→6, 9→6, 10→6, 11→7, 12→7 + if (hs <= 4) return hs; + if (hs <= 5) return 4; + if (hs <= 7) return 5; + if (hs <= 10) return 6; + return 7; +} + +__global__ void deinterleave_quantize_nvfp4_kernel( + const __nv_bfloat16* __restrict__ fused, + int M, int N, int intermediate, int granularity, + float global_scale, + uint8_t* __restrict__ out_fp4, + uint8_t* __restrict__ out_sf +) { + int m = blockIdx.y; + int n_block = blockIdx.x; + if (m >= M || n_block * 16 >= intermediate) return; + + float vals[16]; + float block_amax = 0.0f; + + for (int i = 0; i < 16; i++) { + int nd = n_block * 16 + i; + if (nd >= intermediate) { vals[i] = 0; continue; } + // Map de-interleaved position to fused position + int group = 2 * (nd / granularity) + 1; // odd group = SwiGLU + int offset = nd % granularity; + int fc = group * granularity + offset; + float v = __bfloat162float(fused[m * N + fc]); + vals[i] = v / global_scale; + block_amax = fmaxf(block_amax, fabsf(vals[i])); + } + + // Block scale: amax/6 → FP8 E4M3 → float (round-trip) + float bsf = block_amax / 6.0f; + if (block_amax < 6.0f * 0.001953125f) { // underflow threshold + bsf = 0; + for (int i = 0; i < 16; i++) vals[i] = 0; + } + __nv_fp8_e4m3 bsf8_obj(bsf); + float bs = (float)bsf8_obj; + uint8_t bsf8 = *(uint8_t*)&bsf8_obj; + + // Quantize each value to FP4 E2M1 + uint8_t nibbles[16]; + for (int i = 0; i < 16; i++) { + if (bs < 1e-8f) { nibbles[i] = 0; continue; } + float s = vals[i] / bs; + int hs = __float2int_rn(fminf(fabsf(s), 6.0f) * 2.0f); + if (hs > 12) hs = 12; + int idx = half_step_to_e2m1(hs); + if (s < 0) idx += 8; + nibbles[i] = idx; + } + + // Pack pairs: (nibbles[1] << 4) | nibbles[0], etc. + for (int i = 0; i < 8; i++) + out_fp4[m * (intermediate / 2) + n_block * 8 + i] = (nibbles[2*i+1] << 4) | nibbles[2*i]; + + out_sf[m * (intermediate / 16) + n_block] = bsf8; +} + +std::tuple deinterleave_quantize_nvfp4_cuda( + torch::Tensor fused_bf16, int64_t intermediate, int64_t granularity, double global_scale +) { + int M = fused_bf16.size(0); + int N = fused_bf16.size(1); + auto opts = fused_bf16.options(); + auto out_fp4 = torch::zeros({M, intermediate / 2}, opts.dtype(torch::kUInt8)); + auto out_sf = torch::zeros({M, intermediate / 16}, opts.dtype(torch::kUInt8)); + int nb = intermediate / 16; + dim3 grid(nb, M); + dim3 block(16); + deinterleave_quantize_nvfp4_kernel<<>>( + reinterpret_cast(fused_bf16.data_ptr()), + M, N, (int)intermediate, (int)granularity, (float)global_scale, + out_fp4.data_ptr(), out_sf.data_ptr() + ); + return {out_fp4.view(torch::kFloat4_e2m1fn_x2), out_sf.view(torch::kFloat8_e4m3fn)}; +} + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("deinterleave_quantize_nvfp4", &deinterleave_quantize_nvfp4_cuda); +} diff --git a/cutedsl/moe_pipeline.py b/cutedsl/moe_pipeline.py index 8de231c0..3fcef365 100644 --- a/cutedsl/moe_pipeline.py +++ b/cutedsl/moe_pipeline.py @@ -285,6 +285,7 @@ def run_nvfp4_moe_fused( weights, # dict from prepare_nvfp4_moe_weights expert_indices, # list of expert IDs swiglu_limit=0.0, + l2_activation_gs=None, # pre-computed L2 activation global scale (avoids amax sync) ): """Run the NVFP4 MoE forward pass with fused SwiGLU kernel. @@ -363,14 +364,22 @@ def run_nvfp4_moe_fused( swiglu_limit=swiglu_limit, ) - # De-interleave to get [silu(gate) | silu(gate)*up] layout + # De-interleave + quantize using custom CUDA kernel (4x faster) intermediate_size = l1_fused_out.shape[1] // 2 - l1_deil = deinterleave_l1_weights(l1_fused_out.unsqueeze(0).contiguous())[0] - activated = l1_deil[:, intermediate_size:] # up columns = SwiGLU result + # Use pre-computed L2 activation gs, or compute from amax (fallback) + l2_gs = l2_activation_gs if l2_activation_gs is not None else l1_fused_out.abs().amax().float().item() / 2688.0 + from cutedsl.bridge import deinterleave_quantize_nvfp4_cuda, quantize_activation_nvfp4 + l2_x_fp4, l2_x_sf = deinterleave_quantize_nvfp4_cuda(l1_fused_out, intermediate_size, l2_gs) + # Skip the separate L2 quantize step below — we already have FP4+SF + # Set activated to None to signal we already quantized + activated = None - # === L2: down projection (same as non-fused) === - - l2_x_fp4, l2_x_sf, l2_x_igs = stage_activation(activated) + # === L2: down projection === + if activated is not None: + l2_x_fp4, l2_x_sf, l2_x_igs = stage_activation(activated) + else: + # Already quantized by the custom CUDA kernel + l2_x_igs = l2_gs l2_mat_b = make_b_k_major(torch.stack(weights['l2_fp4'])) l2_sf_parts = []