Kept: - example10 (CUTLASS LLM, O rescale + final normalize) - example9 (SSA kv_coord version) - working_softmax_maybe.py (working softmax snapshot from before the nuke) - test_fmha_v3_stage_c.py (identity softmax baseline, n=128 cos 0.999998) - test_fmha_v3.py (Stage A+B baseline) - layertest.py, cudagraph_test.py (required) - test_cutedsl.py, test_fp4_roundtrip.py (NVFP4 tests) Archived: diag_tma_*, example8, test_diag_multitile, test_reference_fmha, test_ref_minimal, test_tma_coord, test_fmha_v3_diag*, test_fmha_v3_12w, test_dense_router, test_interleave*, test_fused_step1, test_router, test_cache, test_compile_custom_op, test_custom_op, test_layer_schedule
28 lines
1.0 KiB
Python
28 lines
1.0 KiB
Python
# tests/unit/test_dense_router.py
|
|
import torch
|
|
from dsv4.layers.router import Router
|
|
|
|
def test_dense_router_matches_spec(N=64, H=4096, E=256, k=6):
|
|
X = torch.randn(N, H, dtype=torch.bfloat16, device='cuda')
|
|
W = torch.randn(H, E, dtype=torch.bfloat16, device='cuda')
|
|
bias = torch.randn(E, dtype=torch.float32, device='cuda') * 0.01
|
|
scaling = 2.5
|
|
|
|
# Oracle: directly compute the spec, in one expression, in FP32.
|
|
# This is not "a PyTorch reference implementation" — it's the math.
|
|
logits = (X.float() @ W.float())
|
|
act = torch.sqrt(torch.nn.functional.softplus(logits))
|
|
score = act + bias
|
|
ids = score.topk(k, dim=-1).indices
|
|
w = act.gather(-1, ids)
|
|
w = w / w.sum(-1, keepdim=True) * scaling
|
|
|
|
# Kernel under test:
|
|
router = Router(H, E, k, scaling, mode='dense')
|
|
router.W_gate.copy_(W)
|
|
router.e_bias.copy_(bias)
|
|
out_w, out_ids = router(X, layer_idx=5)
|
|
|
|
assert (out_ids == ids).all() # ids must be exact match
|
|
torch.testing.assert_close(out_w, w, atol=1e-4, rtol=1e-3)
|