From 262cec262d78947d3d8fb2f7d8030a14eeff1fda Mon Sep 17 00:00:00 2001 From: biondizzle Date: Mon, 1 Jun 2026 06:51:47 +0000 Subject: [PATCH] fix: add shape assertions to fused router test --- tests/unit/test_fused_router.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/unit/test_fused_router.py b/tests/unit/test_fused_router.py index 5b5906a7..7862e351 100644 --- a/tests/unit/test_fused_router.py +++ b/tests/unit/test_fused_router.py @@ -50,7 +50,9 @@ def test_fused_router(): e_bias = torch.randn(N, dtype=torch.float32, device=device) # === Reference path: Nvfp4Linear + activation_topk === - logits_ref = gate_lin(hidden_states).float() # (M, N) FP32 + logits_ref = gate_lin(hidden_states).float() # (M, N) BF16 → FP32 + print(f" logits_ref shape: {logits_ref.shape}, e_bias shape: {e_bias.shape}") + assert logits_ref.shape[1] == N, f"logits shape {logits_ref.shape} doesn't match N={N}" from dsv4.kernels.router._activation_topk import run_fused_activation_topk ref_w = torch.empty(M, top_k, dtype=torch.float32, device=device) ref_ids = torch.empty(M, top_k, dtype=torch.int32, device=device)