Files
nvfp4-megamoe-kernel/analyze_layout.py

99 lines
4.0 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""Detailed register layout analysis for the fused SwiGLU epilogue.
Strategy: Use gate=1.0 and up=3.0 weights (distinct ratio) and a row-varying
input (each token has a different scale). The fused output at each (M, N)
position tells us the value. By checking multiple positions, we can determine
which register positions map to which (M, N) addresses.
With epi_tile=(128, 8), each subtile covers 128 M-rows and 8 N-cols.
The TMA store writes in (M, N) order, so the GMEM output is in row-major order.
The register layout depends on the TiledCopy atom (SM100_TMEM_LOAD_16dp256b1x).
For 128 epilogue threads and (128, 8) subtiles:
128 * 8 = 1024 values per subtile
1024 / 128 = 8 values per thread per subtile
Possible layouts:
a) 8 N-cols × 1 M-row per thread (contiguous along N)
b) 1 N-col × 8 M-rows per thread (contiguous along M)
c) 4 N-cols × 2 M-rows per thread
d) 2 N-cols × 4 M-rows per thread
"""
import sys, os
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
import torch
from cutedsl.bridge import (
quantize_weight_to_nvfp4, quantize_activation_nvfp4,
make_b_k_major, interleave_l1_weights, deinterleave_l1_weights,
run_fused_swiglu_grouped_gemm, assemble_scales_2d_side,
)
from cutedsl.kernel.moe.torch_scaled_grouped_mm import (
ceil_div, assemble_raw_scales_2d3d_3d_side,
)
torch.manual_seed(42)
device = "cuda"
hidden = 7168
intermediate = 3072
K_packed = hidden // 2
# gate=1.0, up=3.0 — distinct from silu scaling
gate_w = torch.ones(hidden, intermediate, dtype=torch.bfloat16, device=device)
up_w = torch.ones(hidden, intermediate, dtype=torch.bfloat16, device=device) * 3.0
l1_w = torch.cat([gate_w, up_w], dim=1)
l1_fp4, l1_sf, l1_gs = quantize_weight_to_nvfp4(l1_w)
l1_ekn = interleave_l1_weights(l1_fp4.unsqueeze(0))
l1_mat_b = make_b_k_major(l1_ekn)
l1_sf_il = interleave_l1_weights(l1_sf.unsqueeze(0))
l1_scale_b = assemble_raw_scales_2d3d_3d_side([l1_sf_il[0].T.contiguous()])
l1_gsb = torch.tensor([l1_gs], dtype=torch.float32, device=device)
# Input: 128 tokens with VARYING scales (each row has a unique value)
n_tokens = 128
hidden_states = torch.randn(n_tokens, hidden, dtype=torch.bfloat16, device=device) * 0.01
# But we want deterministic, so use a known pattern:
# Row i has value i/128 * 0.1
for i in range(n_tokens):
hidden_states[i] = (i / 128.0) * 0.1
gs_a = 1.0 / 2688.0
x_fp4, x_sf = quantize_activation_nvfp4(hidden_states, gs_a)
expert_offsets = torch.tensor([128, 128, 128], dtype=torch.int32, device=device)
l1_gsa = torch.tensor([gs_a] * 3, dtype=torch.float32, device=device)
l1_scale_a = assemble_scales_2d_side([x_sf])
fused_out = run_fused_swiglu_grouped_gemm(
mat_a=x_fp4, mat_b=l1_mat_b,
scale_a=l1_scale_a, scale_b=l1_scale_b,
expert_offsets=expert_offsets,
global_scale_a=l1_gsa, global_scale_b=l1_gsb,
)
print(f"Fused output shape: {fused_out.shape}")
# The output should be proportional to the input value.
# Row i has input ≈ i/128 * 0.1, so the GEMM output is proportional to i.
# Gate (cols 0-7, 16-23, ...): silu(gate) ≈ c * i
# Up (cols 8-15, 24-31, ...): silu(gate)*up ≈ 3c * i (since up=3.0)
# Check the first subtile (cols 0-7, should be gate)
# and second subtile (cols 8-15, should be up)
# For M-rows 0, 1, 2, ...
print("\nM-row | Gate (col 0) | Up (col 8) | Ratio")
for m in [0, 1, 2, 4, 8, 16, 32, 64, 127]:
g = fused_out[m, 0].item()
u = fused_out[m, 8].item()
ratio = u / g if abs(g) > 0.01 else float('inf')
print(f" {m:3d} | {g:12.2f} | {u:12.2f} | {ratio:.2f}")
# Check if values within a subtile are uniform (same value for all 8 N-cols)
print("\nRow 0, first 16 values (2 subtiles):")
print(f" {[round(v, 2) for v in fused_out[0, :16].float().cpu().tolist()]}")
print(f"Row 1, first 16 values:")
print(f" {[round(v, 2) for v in fused_out[1, :16].float().cpu().tolist()]}")
# If values within a subtile are uniform (all 8 N-cols have the same value),
# the register layout has 8 N-cols per thread (layout a).
# If they differ across M-rows but same N-col, it's layout b.