99 lines
4.0 KiB
Python
99 lines
4.0 KiB
Python
"""Detailed register layout analysis for the fused SwiGLU epilogue.
|
||
|
||
Strategy: Use gate=1.0 and up=3.0 weights (distinct ratio) and a row-varying
|
||
input (each token has a different scale). The fused output at each (M, N)
|
||
position tells us the value. By checking multiple positions, we can determine
|
||
which register positions map to which (M, N) addresses.
|
||
|
||
With epi_tile=(128, 8), each subtile covers 128 M-rows and 8 N-cols.
|
||
The TMA store writes in (M, N) order, so the GMEM output is in row-major order.
|
||
The register layout depends on the TiledCopy atom (SM100_TMEM_LOAD_16dp256b1x).
|
||
|
||
For 128 epilogue threads and (128, 8) subtiles:
|
||
128 * 8 = 1024 values per subtile
|
||
1024 / 128 = 8 values per thread per subtile
|
||
|
||
Possible layouts:
|
||
a) 8 N-cols × 1 M-row per thread (contiguous along N)
|
||
b) 1 N-col × 8 M-rows per thread (contiguous along M)
|
||
c) 4 N-cols × 2 M-rows per thread
|
||
d) 2 N-cols × 4 M-rows per thread
|
||
"""
|
||
import sys, os
|
||
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
|
||
import torch
|
||
from cutedsl.bridge import (
|
||
quantize_weight_to_nvfp4, quantize_activation_nvfp4,
|
||
make_b_k_major, interleave_l1_weights, deinterleave_l1_weights,
|
||
run_fused_swiglu_grouped_gemm, assemble_scales_2d_side,
|
||
)
|
||
from cutedsl.kernel.moe.torch_scaled_grouped_mm import (
|
||
ceil_div, assemble_raw_scales_2d3d_3d_side,
|
||
)
|
||
|
||
torch.manual_seed(42)
|
||
device = "cuda"
|
||
hidden = 7168
|
||
intermediate = 3072
|
||
K_packed = hidden // 2
|
||
|
||
# gate=1.0, up=3.0 — distinct from silu scaling
|
||
gate_w = torch.ones(hidden, intermediate, dtype=torch.bfloat16, device=device)
|
||
up_w = torch.ones(hidden, intermediate, dtype=torch.bfloat16, device=device) * 3.0
|
||
l1_w = torch.cat([gate_w, up_w], dim=1)
|
||
l1_fp4, l1_sf, l1_gs = quantize_weight_to_nvfp4(l1_w)
|
||
|
||
l1_ekn = interleave_l1_weights(l1_fp4.unsqueeze(0))
|
||
l1_mat_b = make_b_k_major(l1_ekn)
|
||
l1_sf_il = interleave_l1_weights(l1_sf.unsqueeze(0))
|
||
l1_scale_b = assemble_raw_scales_2d3d_3d_side([l1_sf_il[0].T.contiguous()])
|
||
l1_gsb = torch.tensor([l1_gs], dtype=torch.float32, device=device)
|
||
|
||
# Input: 128 tokens with VARYING scales (each row has a unique value)
|
||
n_tokens = 128
|
||
hidden_states = torch.randn(n_tokens, hidden, dtype=torch.bfloat16, device=device) * 0.01
|
||
# But we want deterministic, so use a known pattern:
|
||
# Row i has value i/128 * 0.1
|
||
for i in range(n_tokens):
|
||
hidden_states[i] = (i / 128.0) * 0.1
|
||
|
||
gs_a = 1.0 / 2688.0
|
||
x_fp4, x_sf = quantize_activation_nvfp4(hidden_states, gs_a)
|
||
expert_offsets = torch.tensor([128, 128, 128], dtype=torch.int32, device=device)
|
||
l1_gsa = torch.tensor([gs_a] * 3, dtype=torch.float32, device=device)
|
||
l1_scale_a = assemble_scales_2d_side([x_sf])
|
||
|
||
fused_out = run_fused_swiglu_grouped_gemm(
|
||
mat_a=x_fp4, mat_b=l1_mat_b,
|
||
scale_a=l1_scale_a, scale_b=l1_scale_b,
|
||
expert_offsets=expert_offsets,
|
||
global_scale_a=l1_gsa, global_scale_b=l1_gsb,
|
||
)
|
||
|
||
print(f"Fused output shape: {fused_out.shape}")
|
||
|
||
# The output should be proportional to the input value.
|
||
# Row i has input ≈ i/128 * 0.1, so the GEMM output is proportional to i.
|
||
# Gate (cols 0-7, 16-23, ...): silu(gate) ≈ c * i
|
||
# Up (cols 8-15, 24-31, ...): silu(gate)*up ≈ 3c * i (since up=3.0)
|
||
|
||
# Check the first subtile (cols 0-7, should be gate)
|
||
# and second subtile (cols 8-15, should be up)
|
||
# For M-rows 0, 1, 2, ...
|
||
print("\nM-row | Gate (col 0) | Up (col 8) | Ratio")
|
||
for m in [0, 1, 2, 4, 8, 16, 32, 64, 127]:
|
||
g = fused_out[m, 0].item()
|
||
u = fused_out[m, 8].item()
|
||
ratio = u / g if abs(g) > 0.01 else float('inf')
|
||
print(f" {m:3d} | {g:12.2f} | {u:12.2f} | {ratio:.2f}")
|
||
|
||
# Check if values within a subtile are uniform (same value for all 8 N-cols)
|
||
print("\nRow 0, first 16 values (2 subtiles):")
|
||
print(f" {[round(v, 2) for v in fused_out[0, :16].float().cpu().tolist()]}")
|
||
print(f"Row 1, first 16 values:")
|
||
print(f" {[round(v, 2) for v in fused_out[1, :16].float().cpu().tolist()]}")
|
||
|
||
# If values within a subtile are uniform (all 8 N-cols have the same value),
|
||
# the register layout has 8 N-cols per thread (layout a).
|
||
# If they differ across M-rows but same N-col, it's layout b.
|