fix: add shape assertions to fused router test

This commit is contained in:
2026-06-01 06:51:47 +00:00
parent db07d17a62
commit 262cec262d

View File

@@ -50,7 +50,9 @@ def test_fused_router():
e_bias = torch.randn(N, dtype=torch.float32, device=device)
# === Reference path: Nvfp4Linear + activation_topk ===
logits_ref = gate_lin(hidden_states).float() # (M, N) FP32
logits_ref = gate_lin(hidden_states).float() # (M, N) BF16 → FP32
print(f" logits_ref shape: {logits_ref.shape}, e_bias shape: {e_bias.shape}")
assert logits_ref.shape[1] == N, f"logits shape {logits_ref.shape} doesn't match N={N}"
from dsv4.kernels.router._activation_topk import run_fused_activation_topk
ref_w = torch.empty(M, top_k, dtype=torch.float32, device=device)
ref_ids = torch.empty(M, top_k, dtype=torch.int32, device=device)