Remove debug print statements from pipeline

This commit is contained in:
2026-05-20 04:20:46 +00:00
parent aa8563c626
commit 061d5692a9
4 changed files with 266 additions and 8 deletions

98
analyze_layout.py Normal file
View File

@@ -0,0 +1,98 @@
"""Detailed register layout analysis for the fused SwiGLU epilogue.
Strategy: Use gate=1.0 and up=3.0 weights (distinct ratio) and a row-varying
input (each token has a different scale). The fused output at each (M, N)
position tells us the value. By checking multiple positions, we can determine
which register positions map to which (M, N) addresses.
With epi_tile=(128, 8), each subtile covers 128 M-rows and 8 N-cols.
The TMA store writes in (M, N) order, so the GMEM output is in row-major order.
The register layout depends on the TiledCopy atom (SM100_TMEM_LOAD_16dp256b1x).
For 128 epilogue threads and (128, 8) subtiles:
128 * 8 = 1024 values per subtile
1024 / 128 = 8 values per thread per subtile
Possible layouts:
a) 8 N-cols × 1 M-row per thread (contiguous along N)
b) 1 N-col × 8 M-rows per thread (contiguous along M)
c) 4 N-cols × 2 M-rows per thread
d) 2 N-cols × 4 M-rows per thread
"""
import sys, os
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
import torch
from cutedsl.bridge import (
quantize_weight_to_nvfp4, quantize_activation_nvfp4,
make_b_k_major, interleave_l1_weights, deinterleave_l1_weights,
run_fused_swiglu_grouped_gemm, assemble_scales_2d_side,
)
from cutedsl.kernel.moe.torch_scaled_grouped_mm import (
ceil_div, assemble_raw_scales_2d3d_3d_side,
)
torch.manual_seed(42)
device = "cuda"
hidden = 7168
intermediate = 3072
K_packed = hidden // 2
# gate=1.0, up=3.0 — distinct from silu scaling
gate_w = torch.ones(hidden, intermediate, dtype=torch.bfloat16, device=device)
up_w = torch.ones(hidden, intermediate, dtype=torch.bfloat16, device=device) * 3.0
l1_w = torch.cat([gate_w, up_w], dim=1)
l1_fp4, l1_sf, l1_gs = quantize_weight_to_nvfp4(l1_w)
l1_ekn = interleave_l1_weights(l1_fp4.unsqueeze(0))
l1_mat_b = make_b_k_major(l1_ekn)
l1_sf_il = interleave_l1_weights(l1_sf.unsqueeze(0))
l1_scale_b = assemble_raw_scales_2d3d_3d_side([l1_sf_il[0].T.contiguous()])
l1_gsb = torch.tensor([l1_gs], dtype=torch.float32, device=device)
# Input: 128 tokens with VARYING scales (each row has a unique value)
n_tokens = 128
hidden_states = torch.randn(n_tokens, hidden, dtype=torch.bfloat16, device=device) * 0.01
# But we want deterministic, so use a known pattern:
# Row i has value i/128 * 0.1
for i in range(n_tokens):
hidden_states[i] = (i / 128.0) * 0.1
gs_a = 1.0 / 2688.0
x_fp4, x_sf = quantize_activation_nvfp4(hidden_states, gs_a)
expert_offsets = torch.tensor([128, 128, 128], dtype=torch.int32, device=device)
l1_gsa = torch.tensor([gs_a] * 3, dtype=torch.float32, device=device)
l1_scale_a = assemble_scales_2d_side([x_sf])
fused_out = run_fused_swiglu_grouped_gemm(
mat_a=x_fp4, mat_b=l1_mat_b,
scale_a=l1_scale_a, scale_b=l1_scale_b,
expert_offsets=expert_offsets,
global_scale_a=l1_gsa, global_scale_b=l1_gsb,
)
print(f"Fused output shape: {fused_out.shape}")
# The output should be proportional to the input value.
# Row i has input ≈ i/128 * 0.1, so the GEMM output is proportional to i.
# Gate (cols 0-7, 16-23, ...): silu(gate) ≈ c * i
# Up (cols 8-15, 24-31, ...): silu(gate)*up ≈ 3c * i (since up=3.0)
# Check the first subtile (cols 0-7, should be gate)
# and second subtile (cols 8-15, should be up)
# For M-rows 0, 1, 2, ...
print("\nM-row | Gate (col 0) | Up (col 8) | Ratio")
for m in [0, 1, 2, 4, 8, 16, 32, 64, 127]:
g = fused_out[m, 0].item()
u = fused_out[m, 8].item()
ratio = u / g if abs(g) > 0.01 else float('inf')
print(f" {m:3d} | {g:12.2f} | {u:12.2f} | {ratio:.2f}")
# Check if values within a subtile are uniform (same value for all 8 N-cols)
print("\nRow 0, first 16 values (2 subtiles):")
print(f" {[round(v, 2) for v in fused_out[0, :16].float().cpu().tolist()]}")
print(f"Row 1, first 16 values:")
print(f" {[round(v, 2) for v in fused_out[1, :16].float().cpu().tolist()]}")
# If values within a subtile are uniform (all 8 N-cols have the same value),
# the register layout has 8 N-cols per thread (layout a).
# If they differ across M-rows but same N-col, it's layout b.

114
analyze_output.py Normal file
View File

@@ -0,0 +1,114 @@
"""Empirical register layout analysis for the fused SwiGLU epilogue.
With epi_tile=(128, 8), each subtile covers 8 BF16 N-columns and 128 M-rows.
The TMA store writes the BF16 output to GMEM in a deterministic order.
By running the fused kernel with gate=1.0/up=2.0 interleaved weights,
the output at odd 8-column groups should be silu(gate)*up ≈ 2*silu(1.0),
and even groups should be silu(gate) ≈ silu(1.0).
This script analyzes the GMEM output to understand:
1. Which 8-column groups are gate vs up (verify interleaving)
2. The BF16 values at specific (M, N) positions
3. Whether the subtile pairing is correct
"""
import sys, os
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
import torch
import torch.nn.functional as F
from cutedsl.bridge import (
quantize_weight_to_nvfp4, quantize_activation_nvfp4,
make_b_k_major, interleave_l1_weights, deinterleave_l1_weights,
compute_expert_offsets, run_fused_swiglu_grouped_gemm, quantize_to_nvfp4,
assemble_scales_2d_side, assemble_scales_3d_side,
)
from cutedsl.kernel.moe.torch_scaled_grouped_mm import assemble_raw_scales_2d3d_3d_side
torch.manual_seed(42)
device = "cuda"
hidden = 7168
intermediate = 3072
K_packed = hidden // 2
# gate=1.0, up=2.0 — clear signal for interleaving
gate_w = torch.ones(hidden, intermediate, dtype=torch.bfloat16, device=device)
up_w = torch.ones(hidden, intermediate, dtype=torch.bfloat16, device=device) * 2.0
l1_w = torch.cat([gate_w, up_w], dim=1)
l1_fp4, l1_sf, l1_gs = quantize_weight_to_nvfp4(l1_w)
# Interleave
l1_ekn = interleave_l1_weights(l1_fp4.unsqueeze(0))
l1_mat_b = make_b_k_major(l1_ekn)
# SF interleave
l1_sf_ekn = l1_sf.unsqueeze(0)
l1_sf_il = interleave_l1_weights(l1_sf_ekn)
l1_sf_il_list = [l1_sf_il[0].T.contiguous()]
l1_scale_b = assemble_raw_scales_2d3d_3d_side(l1_sf_il_list)
l1_gsb = torch.tensor([l1_gs], dtype=torch.float32, device=device)
# Input: 128 tokens, all 0.1
n_tokens = 128
hidden_states = torch.ones(n_tokens, hidden, dtype=torch.bfloat16, device=device) * 0.1
gs_a = 1.0 / 2688.0
x_fp4, x_sf = quantize_activation_nvfp4(hidden_states, gs_a)
expert_offsets = torch.tensor([128, 128, 128], dtype=torch.int32, device=device)
l1_gsa = torch.tensor([gs_a] * 3, dtype=torch.float32, device=device)
from cutedsl.kernel.moe.torch_scaled_grouped_mm import ceil_div, pad_and_swizzle_single
K_sf = ceil_div(K_packed, 8)
x_sf_parts = [x_sf]
l1_scale_a = assemble_scales_2d_side(x_sf_parts)
# Run fused kernel
fused_out = run_fused_swiglu_grouped_gemm(
mat_a=x_fp4, mat_b=l1_mat_b,
scale_a=l1_scale_a, scale_b=l1_scale_b,
expert_offsets=expert_offsets,
global_scale_a=l1_gsa, global_scale_b=l1_gsb,
)
print(f"Fused output shape: {fused_out.shape}")
out0 = fused_out[0].float() # First token, as float32
# BF16 reference: silu(1.0) = 0.7311, silu(1.0)*2.0 = 1.4621
import math
silu_one = 1.0 / (1.0 + math.exp(-1.0)) # sigmoid(1.0) = 0.7311
silu_two = 2.0 / (1.0 + math.exp(-2.0)) # sigmoid(2.0) = 1.7616
# Compute expected values
# With input 0.1 and gate=1.0, the gate GEMM output ≈ 7168 * 0.1 * 1.0 * gs_a * gs_b
# Let's check empirically
gate_vals = []
up_vals = []
for i in range(0, 64, 8):
chunk = out0[i:i+8].tolist()
is_gate = all(abs(v - out0[0].item()) < 1.0 for v in chunk)
label = "gate(silu)" if is_gate else "up(swiglu)"
if is_gate:
gate_vals.append(out0[i].item())
else:
up_vals.append(out0[i].item())
print(f" Cols {i:3d}-{i+7:3d}: {[round(v,2) for v in chunk]}{label}")
if gate_vals and up_vals:
g = gate_vals[0]
u = up_vals[0]
print(f"\nGate ≈ {g:.4f}, Up ≈ {u:.4f}")
print(f"Ratio up/gate ≈ {u/g:.4f}")
print(f"Expected silu(1.0) ≈ {silu_one:.4f}, silu(1.0)*2.0 ≈ {2*silu_one:.4f}")
print(f"Actual gate/expected_gate ≈ {g / (7168 * 0.1 * 1.0 * gs_a * l1_gs):.4f}")
# Now check the de-interleaved SwiGLU output
l1_deil = deinterleave_l1_weights(fused_out.unsqueeze(0).contiguous())[0]
swiglu_result = l1_deil[:, intermediate:]
silu_gate = l1_deil[:, :intermediate]
print(f"\nDe-interleaved silu(gate) amax: {silu_gate.abs().amax():.4f}")
print(f"De-interleaved SwiGLU amax: {swiglu_result.abs().amax():.4f}")
print(f"SwiGLU/silu(gate) ratio: {(swiglu_result[0,0] / silu_gate[0,0]):.4f}")
# Verify: quantize the SwiGLU result and check it matches the Python quantize path
x_fp4, x_sf, gs = quantize_to_nvfp4(swiglu_result)
print(f"\nQuantized SwiGLU: FP4 shape={x_fp4.shape}, SF shape={x_sf.shape}, gs={gs:.8f}")
print(f"FP4 amax (uint8): {x_fp4.view(torch.uint8).amax()}")

54
bench_fused.py Normal file
View File

@@ -0,0 +1,54 @@
import sys, os, torch, time
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
from cutedsl.moe_pipeline import run_nvfp4_moe, run_nvfp4_moe_fused, quantize_weight
torch.manual_seed(42)
device = "cuda"
num_experts = 3
hidden = 7168
intermediate = 3072
n_tokens = 128
l1_weights = [torch.randn(2*intermediate, hidden, dtype=torch.bfloat16, device=device) for _ in range(num_experts)]
l2_weights = [torch.randn(hidden, intermediate, dtype=torch.bfloat16, device=device) for _ in range(num_experts)]
l1_fp4, l1_sf, l1_gs = [], [], []
l2_fp4, l2_sf, l2_gs = [], [], []
for l1_w, l2_w in zip(l1_weights, l2_weights):
fp4, sf, gs = quantize_weight(l1_w)
l1_fp4.append(fp4); l1_sf.append(sf); l1_gs.append(gs)
fp4, sf, gs = quantize_weight(l2_w)
l2_fp4.append(fp4); l2_sf.append(sf); l2_gs.append(gs)
weights = {
"l1_fp4": l1_fp4, "l1_sf": l1_sf, "l1_gs": l1_gs,
"l2_fp4": l2_fp4, "l2_sf": l2_sf, "l2_gs": l2_gs,
}
hidden_states = torch.randn(n_tokens, hidden, dtype=torch.bfloat16, device=device) * 0.1
expert_ids = torch.zeros(n_tokens, 1, dtype=torch.int32, device=device)
expert_weights = torch.ones(n_tokens, 1, dtype=torch.float32, device=device)
expert_indices = [0, 1, 2]
# Warmup
_ = run_nvfp4_moe(hidden_states, expert_ids, expert_weights, weights, expert_indices)
_ = run_nvfp4_moe_fused(hidden_states, expert_ids, expert_weights, weights, expert_indices)
torch.cuda.synchronize()
# Benchmark
N = 50
t0 = time.perf_counter()
for _ in range(N):
out1 = run_nvfp4_moe(hidden_states, expert_ids, expert_weights, weights, expert_indices)
torch.cuda.synchronize()
t1 = time.perf_counter()
for _ in range(N):
out2 = run_nvfp4_moe_fused(hidden_states, expert_ids, expert_weights, weights, expert_indices)
torch.cuda.synchronize()
t2 = time.perf_counter()
print(f"Non-fused: {(t1-t0)/N*1000:.2f} ms/iter")
print(f"Fused: {(t2-t1)/N*1000:.2f} ms/iter")
print(f"Speedup: {(t1-t0)/(t2-t1):.2f}x")
print(f"Output match: {torch.allclose(out1.float(), out2.float(), atol=1.0)}")

View File

@@ -204,9 +204,6 @@ def run_nvfp4_moe(
# Global scales: alpha = igs * weight_gs for each expert
l1_global_scale_a = torch.tensor([x_igs] * num_experts, dtype=torch.float32, device=device)
l1_global_scale_b = torch.tensor(weights['l1_gs'], dtype=torch.float32, device=device)
print(f" L1 global_scale_a: {l1_global_scale_a.tolist()}", flush=True)
print(f" L1 global_scale_b: {l1_global_scale_b.tolist()}", flush=True)
print(f" alpha (a*b): {(l1_global_scale_a * l1_global_scale_b).tolist()}", flush=True)
# Run L1 GEMM
l1_out = run_nvfp4_grouped_gemm(
@@ -215,7 +212,6 @@ def run_nvfp4_moe(
expert_offsets=expert_offsets,
global_scale_a=l1_global_scale_a, global_scale_b=l1_global_scale_b,
) # (num_slots, 2*intermediate) BF16
print(f" L1 GEMM output: shape={l1_out.shape}, amax={l1_out.abs().amax().item():.4f}", flush=True)
# ════════════════════════════════════════════════════════════════
# SiLU(gate) * up (BF16 — nonlinear requires BF16)
@@ -226,14 +222,11 @@ def run_nvfp4_moe(
l1_deil = deinterleave_l1_weights(l1_out.unsqueeze(0).contiguous())[0]
gate = l1_deil[:, :intermediate_size]
up = l1_deil[:, intermediate_size:]
print(f" gate: shape={gate.shape}, amax={gate.abs().amax().item():.4f}", flush=True)
print(f" up: shape={up.shape}, amax={up.abs().amax().item():.4f}", flush=True)
gate_silu = torch.nn.functional.silu(gate)
if swiglu_limit is not None:
gate_silu = gate_silu.clamp(max=swiglu_limit)
up = up.clamp(min=-swiglu_limit, max=swiglu_limit)
activated = gate_silu * up # (num_slots, intermediate) BF16
print(f" After SiLU(gate)*up: shape={activated.shape}, amax={activated.abs().amax().item():.4f}", flush=True)
# ════════════════════════════════════════════════════════════════
# L2: down projection (NVFP4 × NVFP4 → BF16)
@@ -374,7 +367,6 @@ def run_nvfp4_moe_fused(
intermediate_size = l1_fused_out.shape[1] // 2
l1_deil = deinterleave_l1_weights(l1_fused_out.unsqueeze(0).contiguous())[0]
activated = l1_deil[:, intermediate_size:] # up columns = SwiGLU result
print(f" Fused SwiGLU: shape={activated.shape}, amax={activated.abs().amax().item():.4f}", flush=True)
# === L2: down projection (same as non-fused) ===