Remove debug print statements from pipeline
This commit is contained in:
98
analyze_layout.py
Normal file
98
analyze_layout.py
Normal file
@@ -0,0 +1,98 @@
|
||||
"""Detailed register layout analysis for the fused SwiGLU epilogue.
|
||||
|
||||
Strategy: Use gate=1.0 and up=3.0 weights (distinct ratio) and a row-varying
|
||||
input (each token has a different scale). The fused output at each (M, N)
|
||||
position tells us the value. By checking multiple positions, we can determine
|
||||
which register positions map to which (M, N) addresses.
|
||||
|
||||
With epi_tile=(128, 8), each subtile covers 128 M-rows and 8 N-cols.
|
||||
The TMA store writes in (M, N) order, so the GMEM output is in row-major order.
|
||||
The register layout depends on the TiledCopy atom (SM100_TMEM_LOAD_16dp256b1x).
|
||||
|
||||
For 128 epilogue threads and (128, 8) subtiles:
|
||||
128 * 8 = 1024 values per subtile
|
||||
1024 / 128 = 8 values per thread per subtile
|
||||
|
||||
Possible layouts:
|
||||
a) 8 N-cols × 1 M-row per thread (contiguous along N)
|
||||
b) 1 N-col × 8 M-rows per thread (contiguous along M)
|
||||
c) 4 N-cols × 2 M-rows per thread
|
||||
d) 2 N-cols × 4 M-rows per thread
|
||||
"""
|
||||
import sys, os
|
||||
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
|
||||
import torch
|
||||
from cutedsl.bridge import (
|
||||
quantize_weight_to_nvfp4, quantize_activation_nvfp4,
|
||||
make_b_k_major, interleave_l1_weights, deinterleave_l1_weights,
|
||||
run_fused_swiglu_grouped_gemm, assemble_scales_2d_side,
|
||||
)
|
||||
from cutedsl.kernel.moe.torch_scaled_grouped_mm import (
|
||||
ceil_div, assemble_raw_scales_2d3d_3d_side,
|
||||
)
|
||||
|
||||
torch.manual_seed(42)
|
||||
device = "cuda"
|
||||
hidden = 7168
|
||||
intermediate = 3072
|
||||
K_packed = hidden // 2
|
||||
|
||||
# gate=1.0, up=3.0 — distinct from silu scaling
|
||||
gate_w = torch.ones(hidden, intermediate, dtype=torch.bfloat16, device=device)
|
||||
up_w = torch.ones(hidden, intermediate, dtype=torch.bfloat16, device=device) * 3.0
|
||||
l1_w = torch.cat([gate_w, up_w], dim=1)
|
||||
l1_fp4, l1_sf, l1_gs = quantize_weight_to_nvfp4(l1_w)
|
||||
|
||||
l1_ekn = interleave_l1_weights(l1_fp4.unsqueeze(0))
|
||||
l1_mat_b = make_b_k_major(l1_ekn)
|
||||
l1_sf_il = interleave_l1_weights(l1_sf.unsqueeze(0))
|
||||
l1_scale_b = assemble_raw_scales_2d3d_3d_side([l1_sf_il[0].T.contiguous()])
|
||||
l1_gsb = torch.tensor([l1_gs], dtype=torch.float32, device=device)
|
||||
|
||||
# Input: 128 tokens with VARYING scales (each row has a unique value)
|
||||
n_tokens = 128
|
||||
hidden_states = torch.randn(n_tokens, hidden, dtype=torch.bfloat16, device=device) * 0.01
|
||||
# But we want deterministic, so use a known pattern:
|
||||
# Row i has value i/128 * 0.1
|
||||
for i in range(n_tokens):
|
||||
hidden_states[i] = (i / 128.0) * 0.1
|
||||
|
||||
gs_a = 1.0 / 2688.0
|
||||
x_fp4, x_sf = quantize_activation_nvfp4(hidden_states, gs_a)
|
||||
expert_offsets = torch.tensor([128, 128, 128], dtype=torch.int32, device=device)
|
||||
l1_gsa = torch.tensor([gs_a] * 3, dtype=torch.float32, device=device)
|
||||
l1_scale_a = assemble_scales_2d_side([x_sf])
|
||||
|
||||
fused_out = run_fused_swiglu_grouped_gemm(
|
||||
mat_a=x_fp4, mat_b=l1_mat_b,
|
||||
scale_a=l1_scale_a, scale_b=l1_scale_b,
|
||||
expert_offsets=expert_offsets,
|
||||
global_scale_a=l1_gsa, global_scale_b=l1_gsb,
|
||||
)
|
||||
|
||||
print(f"Fused output shape: {fused_out.shape}")
|
||||
|
||||
# The output should be proportional to the input value.
|
||||
# Row i has input ≈ i/128 * 0.1, so the GEMM output is proportional to i.
|
||||
# Gate (cols 0-7, 16-23, ...): silu(gate) ≈ c * i
|
||||
# Up (cols 8-15, 24-31, ...): silu(gate)*up ≈ 3c * i (since up=3.0)
|
||||
|
||||
# Check the first subtile (cols 0-7, should be gate)
|
||||
# and second subtile (cols 8-15, should be up)
|
||||
# For M-rows 0, 1, 2, ...
|
||||
print("\nM-row | Gate (col 0) | Up (col 8) | Ratio")
|
||||
for m in [0, 1, 2, 4, 8, 16, 32, 64, 127]:
|
||||
g = fused_out[m, 0].item()
|
||||
u = fused_out[m, 8].item()
|
||||
ratio = u / g if abs(g) > 0.01 else float('inf')
|
||||
print(f" {m:3d} | {g:12.2f} | {u:12.2f} | {ratio:.2f}")
|
||||
|
||||
# Check if values within a subtile are uniform (same value for all 8 N-cols)
|
||||
print("\nRow 0, first 16 values (2 subtiles):")
|
||||
print(f" {[round(v, 2) for v in fused_out[0, :16].float().cpu().tolist()]}")
|
||||
print(f"Row 1, first 16 values:")
|
||||
print(f" {[round(v, 2) for v in fused_out[1, :16].float().cpu().tolist()]}")
|
||||
|
||||
# If values within a subtile are uniform (all 8 N-cols have the same value),
|
||||
# the register layout has 8 N-cols per thread (layout a).
|
||||
# If they differ across M-rows but same N-col, it's layout b.
|
||||
114
analyze_output.py
Normal file
114
analyze_output.py
Normal file
@@ -0,0 +1,114 @@
|
||||
"""Empirical register layout analysis for the fused SwiGLU epilogue.
|
||||
|
||||
With epi_tile=(128, 8), each subtile covers 8 BF16 N-columns and 128 M-rows.
|
||||
The TMA store writes the BF16 output to GMEM in a deterministic order.
|
||||
|
||||
By running the fused kernel with gate=1.0/up=2.0 interleaved weights,
|
||||
the output at odd 8-column groups should be silu(gate)*up ≈ 2*silu(1.0),
|
||||
and even groups should be silu(gate) ≈ silu(1.0).
|
||||
|
||||
This script analyzes the GMEM output to understand:
|
||||
1. Which 8-column groups are gate vs up (verify interleaving)
|
||||
2. The BF16 values at specific (M, N) positions
|
||||
3. Whether the subtile pairing is correct
|
||||
"""
|
||||
import sys, os
|
||||
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from cutedsl.bridge import (
|
||||
quantize_weight_to_nvfp4, quantize_activation_nvfp4,
|
||||
make_b_k_major, interleave_l1_weights, deinterleave_l1_weights,
|
||||
compute_expert_offsets, run_fused_swiglu_grouped_gemm, quantize_to_nvfp4,
|
||||
assemble_scales_2d_side, assemble_scales_3d_side,
|
||||
)
|
||||
from cutedsl.kernel.moe.torch_scaled_grouped_mm import assemble_raw_scales_2d3d_3d_side
|
||||
|
||||
torch.manual_seed(42)
|
||||
device = "cuda"
|
||||
hidden = 7168
|
||||
intermediate = 3072
|
||||
K_packed = hidden // 2
|
||||
|
||||
# gate=1.0, up=2.0 — clear signal for interleaving
|
||||
gate_w = torch.ones(hidden, intermediate, dtype=torch.bfloat16, device=device)
|
||||
up_w = torch.ones(hidden, intermediate, dtype=torch.bfloat16, device=device) * 2.0
|
||||
l1_w = torch.cat([gate_w, up_w], dim=1)
|
||||
l1_fp4, l1_sf, l1_gs = quantize_weight_to_nvfp4(l1_w)
|
||||
|
||||
# Interleave
|
||||
l1_ekn = interleave_l1_weights(l1_fp4.unsqueeze(0))
|
||||
l1_mat_b = make_b_k_major(l1_ekn)
|
||||
|
||||
# SF interleave
|
||||
l1_sf_ekn = l1_sf.unsqueeze(0)
|
||||
l1_sf_il = interleave_l1_weights(l1_sf_ekn)
|
||||
l1_sf_il_list = [l1_sf_il[0].T.contiguous()]
|
||||
l1_scale_b = assemble_raw_scales_2d3d_3d_side(l1_sf_il_list)
|
||||
l1_gsb = torch.tensor([l1_gs], dtype=torch.float32, device=device)
|
||||
|
||||
# Input: 128 tokens, all 0.1
|
||||
n_tokens = 128
|
||||
hidden_states = torch.ones(n_tokens, hidden, dtype=torch.bfloat16, device=device) * 0.1
|
||||
gs_a = 1.0 / 2688.0
|
||||
x_fp4, x_sf = quantize_activation_nvfp4(hidden_states, gs_a)
|
||||
expert_offsets = torch.tensor([128, 128, 128], dtype=torch.int32, device=device)
|
||||
l1_gsa = torch.tensor([gs_a] * 3, dtype=torch.float32, device=device)
|
||||
|
||||
from cutedsl.kernel.moe.torch_scaled_grouped_mm import ceil_div, pad_and_swizzle_single
|
||||
K_sf = ceil_div(K_packed, 8)
|
||||
x_sf_parts = [x_sf]
|
||||
l1_scale_a = assemble_scales_2d_side(x_sf_parts)
|
||||
|
||||
# Run fused kernel
|
||||
fused_out = run_fused_swiglu_grouped_gemm(
|
||||
mat_a=x_fp4, mat_b=l1_mat_b,
|
||||
scale_a=l1_scale_a, scale_b=l1_scale_b,
|
||||
expert_offsets=expert_offsets,
|
||||
global_scale_a=l1_gsa, global_scale_b=l1_gsb,
|
||||
)
|
||||
|
||||
print(f"Fused output shape: {fused_out.shape}")
|
||||
out0 = fused_out[0].float() # First token, as float32
|
||||
|
||||
# BF16 reference: silu(1.0) = 0.7311, silu(1.0)*2.0 = 1.4621
|
||||
import math
|
||||
silu_one = 1.0 / (1.0 + math.exp(-1.0)) # sigmoid(1.0) = 0.7311
|
||||
silu_two = 2.0 / (1.0 + math.exp(-2.0)) # sigmoid(2.0) = 1.7616
|
||||
|
||||
# Compute expected values
|
||||
# With input 0.1 and gate=1.0, the gate GEMM output ≈ 7168 * 0.1 * 1.0 * gs_a * gs_b
|
||||
# Let's check empirically
|
||||
gate_vals = []
|
||||
up_vals = []
|
||||
for i in range(0, 64, 8):
|
||||
chunk = out0[i:i+8].tolist()
|
||||
is_gate = all(abs(v - out0[0].item()) < 1.0 for v in chunk)
|
||||
label = "gate(silu)" if is_gate else "up(swiglu)"
|
||||
if is_gate:
|
||||
gate_vals.append(out0[i].item())
|
||||
else:
|
||||
up_vals.append(out0[i].item())
|
||||
print(f" Cols {i:3d}-{i+7:3d}: {[round(v,2) for v in chunk]} → {label}")
|
||||
|
||||
if gate_vals and up_vals:
|
||||
g = gate_vals[0]
|
||||
u = up_vals[0]
|
||||
print(f"\nGate ≈ {g:.4f}, Up ≈ {u:.4f}")
|
||||
print(f"Ratio up/gate ≈ {u/g:.4f}")
|
||||
print(f"Expected silu(1.0) ≈ {silu_one:.4f}, silu(1.0)*2.0 ≈ {2*silu_one:.4f}")
|
||||
print(f"Actual gate/expected_gate ≈ {g / (7168 * 0.1 * 1.0 * gs_a * l1_gs):.4f}")
|
||||
|
||||
# Now check the de-interleaved SwiGLU output
|
||||
l1_deil = deinterleave_l1_weights(fused_out.unsqueeze(0).contiguous())[0]
|
||||
swiglu_result = l1_deil[:, intermediate:]
|
||||
silu_gate = l1_deil[:, :intermediate]
|
||||
|
||||
print(f"\nDe-interleaved silu(gate) amax: {silu_gate.abs().amax():.4f}")
|
||||
print(f"De-interleaved SwiGLU amax: {swiglu_result.abs().amax():.4f}")
|
||||
print(f"SwiGLU/silu(gate) ratio: {(swiglu_result[0,0] / silu_gate[0,0]):.4f}")
|
||||
|
||||
# Verify: quantize the SwiGLU result and check it matches the Python quantize path
|
||||
x_fp4, x_sf, gs = quantize_to_nvfp4(swiglu_result)
|
||||
print(f"\nQuantized SwiGLU: FP4 shape={x_fp4.shape}, SF shape={x_sf.shape}, gs={gs:.8f}")
|
||||
print(f"FP4 amax (uint8): {x_fp4.view(torch.uint8).amax()}")
|
||||
54
bench_fused.py
Normal file
54
bench_fused.py
Normal file
@@ -0,0 +1,54 @@
|
||||
import sys, os, torch, time
|
||||
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
|
||||
from cutedsl.moe_pipeline import run_nvfp4_moe, run_nvfp4_moe_fused, quantize_weight
|
||||
|
||||
torch.manual_seed(42)
|
||||
device = "cuda"
|
||||
num_experts = 3
|
||||
hidden = 7168
|
||||
intermediate = 3072
|
||||
n_tokens = 128
|
||||
|
||||
l1_weights = [torch.randn(2*intermediate, hidden, dtype=torch.bfloat16, device=device) for _ in range(num_experts)]
|
||||
l2_weights = [torch.randn(hidden, intermediate, dtype=torch.bfloat16, device=device) for _ in range(num_experts)]
|
||||
|
||||
l1_fp4, l1_sf, l1_gs = [], [], []
|
||||
l2_fp4, l2_sf, l2_gs = [], [], []
|
||||
for l1_w, l2_w in zip(l1_weights, l2_weights):
|
||||
fp4, sf, gs = quantize_weight(l1_w)
|
||||
l1_fp4.append(fp4); l1_sf.append(sf); l1_gs.append(gs)
|
||||
fp4, sf, gs = quantize_weight(l2_w)
|
||||
l2_fp4.append(fp4); l2_sf.append(sf); l2_gs.append(gs)
|
||||
|
||||
weights = {
|
||||
"l1_fp4": l1_fp4, "l1_sf": l1_sf, "l1_gs": l1_gs,
|
||||
"l2_fp4": l2_fp4, "l2_sf": l2_sf, "l2_gs": l2_gs,
|
||||
}
|
||||
|
||||
hidden_states = torch.randn(n_tokens, hidden, dtype=torch.bfloat16, device=device) * 0.1
|
||||
expert_ids = torch.zeros(n_tokens, 1, dtype=torch.int32, device=device)
|
||||
expert_weights = torch.ones(n_tokens, 1, dtype=torch.float32, device=device)
|
||||
expert_indices = [0, 1, 2]
|
||||
|
||||
# Warmup
|
||||
_ = run_nvfp4_moe(hidden_states, expert_ids, expert_weights, weights, expert_indices)
|
||||
_ = run_nvfp4_moe_fused(hidden_states, expert_ids, expert_weights, weights, expert_indices)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
# Benchmark
|
||||
N = 50
|
||||
t0 = time.perf_counter()
|
||||
for _ in range(N):
|
||||
out1 = run_nvfp4_moe(hidden_states, expert_ids, expert_weights, weights, expert_indices)
|
||||
torch.cuda.synchronize()
|
||||
t1 = time.perf_counter()
|
||||
|
||||
for _ in range(N):
|
||||
out2 = run_nvfp4_moe_fused(hidden_states, expert_ids, expert_weights, weights, expert_indices)
|
||||
torch.cuda.synchronize()
|
||||
t2 = time.perf_counter()
|
||||
|
||||
print(f"Non-fused: {(t1-t0)/N*1000:.2f} ms/iter")
|
||||
print(f"Fused: {(t2-t1)/N*1000:.2f} ms/iter")
|
||||
print(f"Speedup: {(t1-t0)/(t2-t1):.2f}x")
|
||||
print(f"Output match: {torch.allclose(out1.float(), out2.float(), atol=1.0)}")
|
||||
@@ -204,9 +204,6 @@ def run_nvfp4_moe(
|
||||
# Global scales: alpha = igs * weight_gs for each expert
|
||||
l1_global_scale_a = torch.tensor([x_igs] * num_experts, dtype=torch.float32, device=device)
|
||||
l1_global_scale_b = torch.tensor(weights['l1_gs'], dtype=torch.float32, device=device)
|
||||
print(f" L1 global_scale_a: {l1_global_scale_a.tolist()}", flush=True)
|
||||
print(f" L1 global_scale_b: {l1_global_scale_b.tolist()}", flush=True)
|
||||
print(f" alpha (a*b): {(l1_global_scale_a * l1_global_scale_b).tolist()}", flush=True)
|
||||
|
||||
# Run L1 GEMM
|
||||
l1_out = run_nvfp4_grouped_gemm(
|
||||
@@ -215,7 +212,6 @@ def run_nvfp4_moe(
|
||||
expert_offsets=expert_offsets,
|
||||
global_scale_a=l1_global_scale_a, global_scale_b=l1_global_scale_b,
|
||||
) # (num_slots, 2*intermediate) BF16
|
||||
print(f" L1 GEMM output: shape={l1_out.shape}, amax={l1_out.abs().amax().item():.4f}", flush=True)
|
||||
|
||||
# ════════════════════════════════════════════════════════════════
|
||||
# SiLU(gate) * up (BF16 — nonlinear requires BF16)
|
||||
@@ -226,14 +222,11 @@ def run_nvfp4_moe(
|
||||
l1_deil = deinterleave_l1_weights(l1_out.unsqueeze(0).contiguous())[0]
|
||||
gate = l1_deil[:, :intermediate_size]
|
||||
up = l1_deil[:, intermediate_size:]
|
||||
print(f" gate: shape={gate.shape}, amax={gate.abs().amax().item():.4f}", flush=True)
|
||||
print(f" up: shape={up.shape}, amax={up.abs().amax().item():.4f}", flush=True)
|
||||
gate_silu = torch.nn.functional.silu(gate)
|
||||
if swiglu_limit is not None:
|
||||
gate_silu = gate_silu.clamp(max=swiglu_limit)
|
||||
up = up.clamp(min=-swiglu_limit, max=swiglu_limit)
|
||||
activated = gate_silu * up # (num_slots, intermediate) BF16
|
||||
print(f" After SiLU(gate)*up: shape={activated.shape}, amax={activated.abs().amax().item():.4f}", flush=True)
|
||||
|
||||
# ════════════════════════════════════════════════════════════════
|
||||
# L2: down projection (NVFP4 × NVFP4 → BF16)
|
||||
@@ -374,7 +367,6 @@ def run_nvfp4_moe_fused(
|
||||
intermediate_size = l1_fused_out.shape[1] // 2
|
||||
l1_deil = deinterleave_l1_weights(l1_fused_out.unsqueeze(0).contiguous())[0]
|
||||
activated = l1_deil[:, intermediate_size:] # up columns = SwiGLU result
|
||||
print(f" Fused SwiGLU: shape={activated.shape}, amax={activated.abs().amax().item():.4f}", flush=True)
|
||||
|
||||
# === L2: down projection (same as non-fused) ===
|
||||
|
||||
|
||||
Reference in New Issue
Block a user