Files
nvfp4-megamoe-kernel/tests/archive/test_fused_step1.py
biondizzle 524f0bdfb4 Clean up: archive diagnostics and superseded tests
Kept:
- example10 (CUTLASS LLM, O rescale + final normalize)
- example9 (SSA kv_coord version)
- working_softmax_maybe.py (working softmax snapshot from before the nuke)
- test_fmha_v3_stage_c.py (identity softmax baseline, n=128 cos 0.999998)
- test_fmha_v3.py (Stage A+B baseline)
- layertest.py, cudagraph_test.py (required)
- test_cutedsl.py, test_fp4_roundtrip.py (NVFP4 tests)

Archived: diag_tma_*, example8, test_diag_multitile, test_reference_fmha,
test_ref_minimal, test_tma_coord, test_fmha_v3_diag*, test_fmha_v3_12w,
test_dense_router, test_interleave*, test_fused_step1, test_router,
test_cache, test_compile_custom_op, test_custom_op, test_layer_schedule
2026-05-23 00:17:07 +00:00

93 lines
3.3 KiB
Python

"""Test: Validate SiLU in registers (Step 1 of fused SwiGLU).
Compiles the fused kernel with fused_swiglu=True, runs it, and compares
the BF16 output with PyTorch SiLU applied to the standard L1 GEMM output.
"""
import torch
import sys
sys.path.insert(0, '/root/dsv4-nvfp4-workspace/kernel')
from dsv4.ops.quantize import (
quantize_weight_to_nvfp4,
quantize_activation_nvfp4,
)
from dsv4.ops.layouts import (
make_b_k_major,
assemble_scales_2d_side,
assemble_scales_3d_side,
)
from dsv4.ops.gemm_runner import (
run_nvfp4_grouped_gemm,
run_fused_swiglu_grouped_gemm,
warmup_compilation,
)
def test_silu_step1():
device = "cuda"
num_experts = 4
hidden = 512
intermediate = 256
num_tokens = 32
torch.manual_seed(42)
x = torch.randn(num_tokens, hidden, dtype=torch.bfloat16, device=device)
l1_w = torch.randn(num_experts, 2 * intermediate, hidden, dtype=torch.bfloat16, device=device)
l1_fp4_list, l1_sf_list, l1_gs_list = [], [], []
for e in range(num_experts):
w_fp4, w_sf, w_gs = quantize_weight_to_nvfp4(l1_w[e].T)
l1_fp4_list.append(w_fp4)
l1_sf_list.append(w_sf)
l1_gs_list.append(w_gs)
l1_mat_b = make_b_k_major(torch.stack(l1_fp4_list))
l1_scale_b = assemble_scales_3d_side(l1_sf_list)
l1_gs = torch.tensor(l1_gs_list, dtype=torch.float32, device=device)
gs_val = x.abs().max().item() / (6.0 * 448.0)
x_fp4, x_sf = quantize_activation_nvfp4(x, gs_val)
tokens_per_expert = [num_tokens // num_experts] * num_experts
scale_a = assemble_scales_2d_side([x_sf[i*tpe:(i+1)*tpe] for i, tpe in enumerate(tokens_per_expert)])
expert_offsets = torch.tensor(
[sum(tokens_per_expert[:e+1]) for e in range(num_experts)],
dtype=torch.int32, device=device,
)
global_scale_a = torch.full((num_experts,), gs_val, dtype=torch.float32, device=device)
warmup_compilation(num_experts, hidden // 2, (2 * intermediate) // 2, device)
# 1. Standard L1 GEMM (no SiLU)
out_bf16 = run_nvfp4_grouped_gemm(
mat_a=x_fp4, mat_b=l1_mat_b,
scale_a=scale_a, scale_b=l1_scale_b,
expert_offsets=expert_offsets,
global_scale_a=global_scale_a, global_scale_b=l1_gs,
)
silu_ref = torch.nn.functional.silu(out_bf16)
print(f"Standard L1 output: shape={out_bf16.shape}, amax={out_bf16.abs().amax().item():.4f}")
print(f"PyTorch SiLU ref: amax={silu_ref.abs().amax().item():.4f}")
# 2. Fused kernel with SiLU in registers
print("\nCompiling fused kernel (first time, may take a while)...")
out_fused = run_fused_swiglu_grouped_gemm(
mat_a=x_fp4, mat_b=l1_mat_b,
scale_a=scale_a, scale_b=l1_scale_b,
expert_offsets=expert_offsets,
global_scale_a=global_scale_a, global_scale_b=l1_gs,
)
print(f"Fused SiLU output: shape={out_fused.shape}, amax={out_fused.abs().amax().item():.4f}")
# 3. Compare
diff = (out_fused - silu_ref).float()
rel_err = diff.norm() / silu_ref.float().norm()
max_err = diff.abs().max()
print(f"\n=== Results ===")
print(f"Relative error: {rel_err.item():.6f}")
print(f"Max abs error: {max_err.item():.6f}")
print(f"PASS" if rel_err.item() < 0.1 else "FAIL (tolerance: 0.1 for NVFP4 quant noise)")
if __name__ == "__main__":
test_silu_step1()