Files
nvfp4-megamoe-kernel/tests/archive/test_dense_router.py
biondizzle 524f0bdfb4 Clean up: archive diagnostics and superseded tests
Kept:
- example10 (CUTLASS LLM, O rescale + final normalize)
- example9 (SSA kv_coord version)
- working_softmax_maybe.py (working softmax snapshot from before the nuke)
- test_fmha_v3_stage_c.py (identity softmax baseline, n=128 cos 0.999998)
- test_fmha_v3.py (Stage A+B baseline)
- layertest.py, cudagraph_test.py (required)
- test_cutedsl.py, test_fp4_roundtrip.py (NVFP4 tests)

Archived: diag_tma_*, example8, test_diag_multitile, test_reference_fmha,
test_ref_minimal, test_tma_coord, test_fmha_v3_diag*, test_fmha_v3_12w,
test_dense_router, test_interleave*, test_fused_step1, test_router,
test_cache, test_compile_custom_op, test_custom_op, test_layer_schedule
2026-05-23 00:17:07 +00:00

28 lines
1.0 KiB
Python

# tests/unit/test_dense_router.py
import torch
from dsv4.layers.router import Router
def test_dense_router_matches_spec(N=64, H=4096, E=256, k=6):
X = torch.randn(N, H, dtype=torch.bfloat16, device='cuda')
W = torch.randn(H, E, dtype=torch.bfloat16, device='cuda')
bias = torch.randn(E, dtype=torch.float32, device='cuda') * 0.01
scaling = 2.5
# Oracle: directly compute the spec, in one expression, in FP32.
# This is not "a PyTorch reference implementation" — it's the math.
logits = (X.float() @ W.float())
act = torch.sqrt(torch.nn.functional.softplus(logits))
score = act + bias
ids = score.topk(k, dim=-1).indices
w = act.gather(-1, ids)
w = w / w.sum(-1, keepdim=True) * scaling
# Kernel under test:
router = Router(H, E, k, scaling, mode='dense')
router.W_gate.copy_(W)
router.e_bias.copy_(bias)
out_w, out_ids = router(X, layer_idx=5)
assert (out_ids == ids).all() # ids must be exact match
torch.testing.assert_close(out_w, w, atol=1e-4, rtol=1e-3)