Test: fused router kernel vs BF16 reference path

- BF16 GEMM + activation_topk as reference
- NVFP4 GEMM + fused router epilogue as test target
- Proper NVFP4 quantization and CuTe tensor creation
- Cosine similarity and topk_ids matching validation
This commit is contained in:
2026-06-01 08:54:24 +00:00
parent 2433700a69
commit b94f8d4ed8

View File

@@ -1,10 +1,7 @@
"""Test NVFP4 fused router kernel against the reference path.
The fused kernel does NVFP4 block-scaled GEMM + sqrt(softplus) + e_bias +
top-k + renormalization in a single kernel, with no intermediate GMEM buffer
for logits. This test verifies correctness against the 2-kernel reference:
1. NVFP4 GEMM via Nvfp4Linear → logits in GMEM
2. activation_topk CUDA kernel → topk_weights, topk_ids
Phase 1: Reference path (BF16 linear + activation_topk) to get ground truth.
Phase 2: Fused kernel (NVFP4 GEMM + router epilogue) to compare.
Test checks:
- topk_ids match (expert selection)
@@ -16,12 +13,12 @@ import sys
import os
import torch
# Add kernel to path
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", ".."))
from dsv4.layers.linear import Nvfp4Linear
from dsv4.ops.quantize import quantize_activation_nvfp4
from dsv4.ops.quantize import quantize_to_nvfp4, quantize_activation_nvfp4
from dsv4.kernels.router._activation_topk import run_fused_activation_topk
import cutlass.torch as cutlass_torch
import cutlass.cute as cute
def test_fused_router_correctness():
@@ -39,26 +36,15 @@ def test_fused_router_correctness():
print(f"=== NVFP4 Fused Router Kernel Test ===")
print(f" M={M}, K={K}, E={E}, top_k={top_k}")
print(f" sf_vec_size={sf_vec_size}")
# Create gate weight in BF16, then quantize to NVFP4
# Create gate weight and activation
W_gate_bf16 = torch.randn(E, K, dtype=torch.bfloat16, device=device) * 0.02
e_bias = torch.randn(E, dtype=torch.float32, device=device) * 0.1
hidden_states = torch.randn(M, K, dtype=torch.bfloat16, device=device) * 0.5
# Build Nvfp4Linear for the gate projection (reference path)
gate_lin = Nvfp4Linear(
in_features=K,
out_features=E,
sf_vec_size=sf_vec_size,
device=device,
)
gate_lin.load_weights(W_gate_bf16.T) # [K, E] layout
gate_lin.finalize_weights()
# ---- Reference path: Nvfp4Linear GEMM + activation_topk ----
print("\n[1] Running reference path (Nvfp4Linear + activation_topk)...")
logits_ref = gate_lin(hidden_states).float() # [M, E] FP32
# ---- Reference path: BF16 GEMM + activation_topk ----
print("\n[1] Running reference path (BF16 GEMM + activation_topk)...")
logits_ref = torch.nn.functional.linear(hidden_states.float(), W_gate_bf16.float()) # [M, E]
ref_weights = torch.zeros(M, top_k, dtype=torch.float32, device=device)
ref_ids = torch.zeros(M, top_k, dtype=torch.int32, device=device)
run_fused_activation_topk(
@@ -69,52 +55,87 @@ def test_fused_router_correctness():
print(f" Reference topk_weights: {ref_weights[0].tolist()}")
# ---- Fused kernel path ----
print("\n[2] Running fused kernel path (NVFP4 GEMM + router epilogue)...")
from dsv4.kernels.router.nvfp4_fused_router_kernel import run_nvfp4_fused_router
print("\n[2] Preparing NVFP4 tensors for fused kernel...")
# Quantize activation to NVFP4
act_nvfp4, act_sf, act_gs, _ = quantize_activation_nvfp4(hidden_states)
print(f" act_nvfp4 shape: {act_nvfp4.shape}, act_sf shape: {act_sf.shape}")
print(f" act_gs: {act_gs}")
# Quantize weight to NVFP4 (need K-major layout for B operand)
from dsv4.ops.layouts import make_b_k_major, assemble_raw_scales_2d3d_3d_side
w_nvfp4, w_sf, w_gs = quantize_to_nvfp4(W_gate_bf16.T) # [K, E]
print(f" w_nvfp4 shape: {w_nvfp4.shape}, w_sf shape: {w_sf.shape}")
print(f" w_gs: {w_gs}")
# Build CuTe tensors for B operand (K-major blockscaled layout)
# Stack to (1, K_packed, E_packed) for make_b_k_major
w_stacked = w_nvfp4.unsqueeze(0) # (1, K_packed, E_packed)
mat_b = make_b_k_major(w_stacked)
scale_b = assemble_raw_scales_2d3d_3d_side([w_sf])
# Build CuTe tensors for A operand (activation)
mat_a = cutlass_torch.from_dlpack(act_nvfp4)
mat_a = cute.mark_layout_dynamic(mat_a)
scale_a = cutlass_torch.from_dlpack(act_sf)
scale_a = cute.mark_layout_dynamic(scale_a)
# e_bias CuTe tensor
e_bias_cute = cutlass_torch.from_dlpack(e_bias)
e_bias_cute = cute.mark_layout_dynamic(e_bias_cute)
# Output buffers
out_weights = torch.zeros(M, top_k, dtype=torch.float32, device=device)
out_ids = torch.zeros(M, top_k, dtype=torch.int32, device=device)
out_w_cute = cutlass_torch.from_dlpack(out_weights)
out_w_cute = cute.mark_layout_dynamic(out_w_cute)
out_id_cute = cutlass_torch.from_dlpack(out_ids)
out_id_cute = cute.mark_layout_dynamic(out_id_cute)
print("\n[3] Running fused kernel (NVFP4 GEMM + router epilogue)...")
from dsv4.kernels.router.nvfp4_fused_router_kernel import Nvfp4FusedRouterKernel
import cuda.bindings.driver as cuda
mma_tiler_mnk = (128, 128, 64)
kernel = Nvfp4FusedRouterKernel(
sf_vec_size=sf_vec_size,
mma_tiler_mnk=mma_tiler_mnk,
cluster_shape_mnk=(1, 1, 1),
top_k=top_k,
)
try:
fused_weights, fused_ids = run_nvfp4_fused_router(
hidden_states=hidden_states,
mat_b=gate_lin._mat_b,
scale_b=gate_lin._scale_b,
gsa=gate_lin._gsa,
gsb_val=gate_lin._gsb_val,
e_bias=e_bias,
routed_scaling_factor=routed_scaling_factor,
top_k=top_k,
sf_vec_size=sf_vec_size,
kernel.run(
mat_a, mat_b, scale_a, scale_b,
e_bias_cute, out_w_cute, out_id_cute,
M, E, K, routed_scaling_factor, top_k,
)
print(" Fused kernel compilation and execution succeeded!")
except Exception as ex:
print(f" FUSED KERNEL FAILED: {ex}")
import traceback
traceback.print_exc()
print("\nFused kernel compilation/execution failed.")
print("This is expected if CuTeDSL math functions (absf, log, sqrt) are not available.")
print("The kernel structure is correct; CuTeDSL API coverage is the blocker.")
return
fused_weights = out_weights
fused_ids = out_ids
print(f" Fused topk_ids: {fused_ids[0].tolist()}")
print(f" Fused topk_weights: {fused_weights[0].tolist()}")
# ---- Validation ----
print("\n[3] Validation...")
print("\n[4] Validation...")
# Check for NaN
if torch.isnan(fused_weights).any():
print(" FAIL: NaN in fused weights!")
return
if torch.isnan(fused_ids.float()).any():
print(" FAIL: NaN in fused IDs!")
return
# Check IDs match
ids_match = torch.equal(ref_ids, fused_ids)
print(f" topk_ids match: {ids_match}")
if not ids_match:
print(f" Reference: {ref_ids[0].tolist()}")
print(f" Fused: {fused_ids[0].tolist()}")
# Check weights similarity
w_cos = torch.nn.functional.cosine_similarity(
ref_weights.flatten().unsqueeze(0),
fused_weights.flatten().unsqueeze(0),
@@ -123,7 +144,6 @@ def test_fused_router_correctness():
print(f" topk_weights cosine sim: {w_cos:.6f}")
print(f" topk_weights max diff: {w_max_diff:.6f}")
# Check non-negative weights
neg_count = (fused_weights < 0).sum().item()
print(f" Negative weights: {neg_count}")
@@ -131,9 +151,7 @@ def test_fused_router_correctness():
print("\n✅ FUSED ROUTER KERNEL PASSED!")
else:
print(f"\n❌ FUSED ROUTER KERNEL FAILED")
print(f" IDs match: {ids_match}")
print(f" Cosine: {w_cos:.6f} (need >= 0.999)")
print(f" Neg weights: {neg_count} (need 0)")
print(f" IDs match: {ids_match}, Cosine: {w_cos:.6f}, Neg: {neg_count}")
if __name__ == "__main__":