fix: add shape assertions to fused router test
This commit is contained in:
@@ -50,7 +50,9 @@ def test_fused_router():
|
||||
e_bias = torch.randn(N, dtype=torch.float32, device=device)
|
||||
|
||||
# === Reference path: Nvfp4Linear + activation_topk ===
|
||||
logits_ref = gate_lin(hidden_states).float() # (M, N) FP32
|
||||
logits_ref = gate_lin(hidden_states).float() # (M, N) BF16 → FP32
|
||||
print(f" logits_ref shape: {logits_ref.shape}, e_bias shape: {e_bias.shape}")
|
||||
assert logits_ref.shape[1] == N, f"logits shape {logits_ref.shape} doesn't match N={N}"
|
||||
from dsv4.kernels.router._activation_topk import run_fused_activation_topk
|
||||
ref_w = torch.empty(M, top_k, dtype=torch.float32, device=device)
|
||||
ref_ids = torch.empty(M, top_k, dtype=torch.int32, device=device)
|
||||
|
||||
Reference in New Issue
Block a user