Files
nvfp4-megamoe-kernel/tests/test_step2_subtile_v2.py
biondizzle 6c04155167 wip: Step 2 gate/up pairing — SiLU validated, runtime conditionals blocked by CuTeDSL
SiLU in registers: PASS (0.034% error, Step 1 stable)
Gate/up subtile detection: blocked by CuTeDSL type system

CuTeDSL compiles the kernel for ALL subtile iterations at once.
Runtime conditionals (if is_gate_subtile) that affect:
- Register tensor assignment → DSLRuntimeError (type structure mismatch)
- TMA store skipping → corrupted output
- Mask blending → wrong results

Path forward: use const_expr debug flag for the BF16 side output,
or process gate/up in a separate post-GEMM kernel.
2026-05-20 03:26:20 +00:00

111 lines
4.0 KiB
Python

"""Test: Validate gate/up subtile detection (Step 2).
The fused kernel writes:
- Gate subtiles (0,1): SiLU applied, stored to C tensor at positions 0,1
- Up subtiles (2,3): raw values, stored to C tensor at positions 0,1 (overwriting gate)
(because TMA store uses gate_subtile_idx for up subtiles)
For now, the output is still (M, 2*intermediate). We compare the
gate half of the output against SiLU(gate_ref) and the up half against up_ref.
"""
import torch
import sys
sys.path.insert(0, '/root/dsv4-nvfp4-workspace/kernel')
from cutedsl.bridge import (
quantize_weight_to_nvfp4,
quantize_activation_nvfp4,
make_b_k_major,
assemble_scales_2d_side,
assemble_scales_3d_side,
run_nvfp4_grouped_gemm,
run_fused_swiglu_grouped_gemm,
warmup_compilation,
)
def test_gate_up_subtile():
device = "cuda"
num_experts = 4
hidden = 512
intermediate = 256
num_tokens = 32
torch.manual_seed(42)
x = torch.randn(num_tokens, hidden, dtype=torch.bfloat16, device=device)
l1_w = torch.randn(num_experts, 2 * intermediate, hidden, dtype=torch.bfloat16, device=device)
l1_fp4_list, l1_sf_list, l1_gs_list = [], [], []
for e in range(num_experts):
w_fp4, w_sf, w_gs = quantize_weight_to_nvfp4(l1_w[e].T)
l1_fp4_list.append(w_fp4)
l1_sf_list.append(w_sf)
l1_gs_list.append(w_gs)
l1_mat_b = make_b_k_major(torch.stack(l1_fp4_list))
l1_scale_b = assemble_scales_3d_side(l1_sf_list)
l1_gs = torch.tensor(l1_gs_list, dtype=torch.float32, device=device)
gs_val = x.abs().max().item() / (6.0 * 448.0)
x_fp4, x_sf = quantize_activation_nvfp4(x, gs_val)
tokens_per_expert = [num_tokens // num_experts] * num_experts
scale_a = assemble_scales_2d_side([x_sf[i*tpe:(i+1)*tpe] for i, tpe in enumerate(tokens_per_expert)])
expert_offsets = torch.tensor(
[sum(tokens_per_expert[:e+1]) for e in range(num_experts)],
dtype=torch.int32, device=device,
)
global_scale_a = torch.full((num_experts,), gs_val, dtype=torch.float32, device=device)
warmup_compilation(num_experts, hidden // 2, (2 * intermediate) // 2, device)
# Standard L1 GEMM
out_bf16 = run_nvfp4_grouped_gemm(
mat_a=x_fp4, mat_b=l1_mat_b,
scale_a=scale_a, scale_b=l1_scale_b,
expert_offsets=expert_offsets,
global_scale_a=global_scale_a, global_scale_b=l1_gs,
)
gate_ref = out_bf16[:, :intermediate]
up_ref = out_bf16[:, intermediate:]
silu_gate_ref = torch.nn.functional.silu(gate_ref)
# Fused kernel
print("Running fused kernel...")
out_fused = run_fused_swiglu_grouped_gemm(
mat_a=x_fp4, mat_b=l1_mat_b,
scale_a=scale_a, scale_b=l1_scale_b,
expert_offsets=expert_offsets,
global_scale_a=global_scale_a, global_scale_b=l1_gs,
)
print(f"Fused output: shape={out_fused.shape}, amax={out_fused.abs().amax().item():.4f}")
# The output has both gate (with SiLU) and up (raw) subtiles
# Gate is in the first half, up in the second half
fused_gate = out_fused[:, :intermediate]
fused_up = out_fused[:, intermediate:]
# Compare gate: fused should have SiLU applied
gate_diff = (fused_gate - silu_gate_ref).float()
gate_rel_err = gate_diff.norm() / silu_gate_ref.float().norm()
gate_max_err = gate_diff.abs().max()
# Compare up: fused should have raw values (no SiLU)
up_diff = (fused_up - up_ref).float()
up_rel_err = up_diff.norm() / up_ref.float().norm()
up_max_err = up_diff.abs().max()
print(f"\n=== Gate Comparison (SiLU applied) ===")
print(f"Rel error: {gate_rel_err.item():.6f}")
print(f"Max abs error: {gate_max_err.item():.6f}")
print(f"Gate PASS" if gate_rel_err.item() < 0.05 else "Gate FAIL")
print(f"\n=== Up Comparison (raw values) ===")
print(f"Rel error: {up_rel_err.item():.6f}")
print(f"Max abs error: {up_max_err.item():.6f}")
print(f"Up PASS" if up_rel_err.item() < 0.05 else "Up FAIL")
if __name__ == "__main__":
test_gate_up_subtile()