Test: fused router kernel vs BF16 reference path
- BF16 GEMM + activation_topk as reference - NVFP4 GEMM + fused router epilogue as test target - Proper NVFP4 quantization and CuTe tensor creation - Cosine similarity and topk_ids matching validation
This commit is contained in:
@@ -1,10 +1,7 @@
|
||||
"""Test NVFP4 fused router kernel against the reference path.
|
||||
|
||||
The fused kernel does NVFP4 block-scaled GEMM + sqrt(softplus) + e_bias +
|
||||
top-k + renormalization in a single kernel, with no intermediate GMEM buffer
|
||||
for logits. This test verifies correctness against the 2-kernel reference:
|
||||
1. NVFP4 GEMM via Nvfp4Linear → logits in GMEM
|
||||
2. activation_topk CUDA kernel → topk_weights, topk_ids
|
||||
Phase 1: Reference path (BF16 linear + activation_topk) to get ground truth.
|
||||
Phase 2: Fused kernel (NVFP4 GEMM + router epilogue) to compare.
|
||||
|
||||
Test checks:
|
||||
- topk_ids match (expert selection)
|
||||
@@ -16,12 +13,12 @@ import sys
|
||||
import os
|
||||
import torch
|
||||
|
||||
# Add kernel to path
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", ".."))
|
||||
|
||||
from dsv4.layers.linear import Nvfp4Linear
|
||||
from dsv4.ops.quantize import quantize_activation_nvfp4
|
||||
from dsv4.ops.quantize import quantize_to_nvfp4, quantize_activation_nvfp4
|
||||
from dsv4.kernels.router._activation_topk import run_fused_activation_topk
|
||||
import cutlass.torch as cutlass_torch
|
||||
import cutlass.cute as cute
|
||||
|
||||
|
||||
def test_fused_router_correctness():
|
||||
@@ -39,26 +36,15 @@ def test_fused_router_correctness():
|
||||
|
||||
print(f"=== NVFP4 Fused Router Kernel Test ===")
|
||||
print(f" M={M}, K={K}, E={E}, top_k={top_k}")
|
||||
print(f" sf_vec_size={sf_vec_size}")
|
||||
|
||||
# Create gate weight in BF16, then quantize to NVFP4
|
||||
# Create gate weight and activation
|
||||
W_gate_bf16 = torch.randn(E, K, dtype=torch.bfloat16, device=device) * 0.02
|
||||
e_bias = torch.randn(E, dtype=torch.float32, device=device) * 0.1
|
||||
hidden_states = torch.randn(M, K, dtype=torch.bfloat16, device=device) * 0.5
|
||||
|
||||
# Build Nvfp4Linear for the gate projection (reference path)
|
||||
gate_lin = Nvfp4Linear(
|
||||
in_features=K,
|
||||
out_features=E,
|
||||
sf_vec_size=sf_vec_size,
|
||||
device=device,
|
||||
)
|
||||
gate_lin.load_weights(W_gate_bf16.T) # [K, E] layout
|
||||
gate_lin.finalize_weights()
|
||||
|
||||
# ---- Reference path: Nvfp4Linear GEMM + activation_topk ----
|
||||
print("\n[1] Running reference path (Nvfp4Linear + activation_topk)...")
|
||||
logits_ref = gate_lin(hidden_states).float() # [M, E] FP32
|
||||
# ---- Reference path: BF16 GEMM + activation_topk ----
|
||||
print("\n[1] Running reference path (BF16 GEMM + activation_topk)...")
|
||||
logits_ref = torch.nn.functional.linear(hidden_states.float(), W_gate_bf16.float()) # [M, E]
|
||||
ref_weights = torch.zeros(M, top_k, dtype=torch.float32, device=device)
|
||||
ref_ids = torch.zeros(M, top_k, dtype=torch.int32, device=device)
|
||||
run_fused_activation_topk(
|
||||
@@ -69,52 +55,87 @@ def test_fused_router_correctness():
|
||||
print(f" Reference topk_weights: {ref_weights[0].tolist()}")
|
||||
|
||||
# ---- Fused kernel path ----
|
||||
print("\n[2] Running fused kernel path (NVFP4 GEMM + router epilogue)...")
|
||||
from dsv4.kernels.router.nvfp4_fused_router_kernel import run_nvfp4_fused_router
|
||||
print("\n[2] Preparing NVFP4 tensors for fused kernel...")
|
||||
|
||||
# Quantize activation to NVFP4
|
||||
act_nvfp4, act_sf, act_gs, _ = quantize_activation_nvfp4(hidden_states)
|
||||
print(f" act_nvfp4 shape: {act_nvfp4.shape}, act_sf shape: {act_sf.shape}")
|
||||
print(f" act_gs: {act_gs}")
|
||||
|
||||
# Quantize weight to NVFP4 (need K-major layout for B operand)
|
||||
from dsv4.ops.layouts import make_b_k_major, assemble_raw_scales_2d3d_3d_side
|
||||
w_nvfp4, w_sf, w_gs = quantize_to_nvfp4(W_gate_bf16.T) # [K, E]
|
||||
print(f" w_nvfp4 shape: {w_nvfp4.shape}, w_sf shape: {w_sf.shape}")
|
||||
print(f" w_gs: {w_gs}")
|
||||
|
||||
# Build CuTe tensors for B operand (K-major blockscaled layout)
|
||||
# Stack to (1, K_packed, E_packed) for make_b_k_major
|
||||
w_stacked = w_nvfp4.unsqueeze(0) # (1, K_packed, E_packed)
|
||||
mat_b = make_b_k_major(w_stacked)
|
||||
scale_b = assemble_raw_scales_2d3d_3d_side([w_sf])
|
||||
|
||||
# Build CuTe tensors for A operand (activation)
|
||||
mat_a = cutlass_torch.from_dlpack(act_nvfp4)
|
||||
mat_a = cute.mark_layout_dynamic(mat_a)
|
||||
scale_a = cutlass_torch.from_dlpack(act_sf)
|
||||
scale_a = cute.mark_layout_dynamic(scale_a)
|
||||
|
||||
# e_bias CuTe tensor
|
||||
e_bias_cute = cutlass_torch.from_dlpack(e_bias)
|
||||
e_bias_cute = cute.mark_layout_dynamic(e_bias_cute)
|
||||
|
||||
# Output buffers
|
||||
out_weights = torch.zeros(M, top_k, dtype=torch.float32, device=device)
|
||||
out_ids = torch.zeros(M, top_k, dtype=torch.int32, device=device)
|
||||
out_w_cute = cutlass_torch.from_dlpack(out_weights)
|
||||
out_w_cute = cute.mark_layout_dynamic(out_w_cute)
|
||||
out_id_cute = cutlass_torch.from_dlpack(out_ids)
|
||||
out_id_cute = cute.mark_layout_dynamic(out_id_cute)
|
||||
|
||||
print("\n[3] Running fused kernel (NVFP4 GEMM + router epilogue)...")
|
||||
from dsv4.kernels.router.nvfp4_fused_router_kernel import Nvfp4FusedRouterKernel
|
||||
import cuda.bindings.driver as cuda
|
||||
|
||||
mma_tiler_mnk = (128, 128, 64)
|
||||
kernel = Nvfp4FusedRouterKernel(
|
||||
sf_vec_size=sf_vec_size,
|
||||
mma_tiler_mnk=mma_tiler_mnk,
|
||||
cluster_shape_mnk=(1, 1, 1),
|
||||
top_k=top_k,
|
||||
)
|
||||
|
||||
try:
|
||||
fused_weights, fused_ids = run_nvfp4_fused_router(
|
||||
hidden_states=hidden_states,
|
||||
mat_b=gate_lin._mat_b,
|
||||
scale_b=gate_lin._scale_b,
|
||||
gsa=gate_lin._gsa,
|
||||
gsb_val=gate_lin._gsb_val,
|
||||
e_bias=e_bias,
|
||||
routed_scaling_factor=routed_scaling_factor,
|
||||
top_k=top_k,
|
||||
sf_vec_size=sf_vec_size,
|
||||
kernel.run(
|
||||
mat_a, mat_b, scale_a, scale_b,
|
||||
e_bias_cute, out_w_cute, out_id_cute,
|
||||
M, E, K, routed_scaling_factor, top_k,
|
||||
)
|
||||
print(" Fused kernel compilation and execution succeeded!")
|
||||
except Exception as ex:
|
||||
print(f" FUSED KERNEL FAILED: {ex}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
print("\nFused kernel compilation/execution failed.")
|
||||
print("This is expected if CuTeDSL math functions (absf, log, sqrt) are not available.")
|
||||
print("The kernel structure is correct; CuTeDSL API coverage is the blocker.")
|
||||
return
|
||||
|
||||
fused_weights = out_weights
|
||||
fused_ids = out_ids
|
||||
|
||||
print(f" Fused topk_ids: {fused_ids[0].tolist()}")
|
||||
print(f" Fused topk_weights: {fused_weights[0].tolist()}")
|
||||
|
||||
# ---- Validation ----
|
||||
print("\n[3] Validation...")
|
||||
print("\n[4] Validation...")
|
||||
|
||||
# Check for NaN
|
||||
if torch.isnan(fused_weights).any():
|
||||
print(" FAIL: NaN in fused weights!")
|
||||
return
|
||||
if torch.isnan(fused_ids.float()).any():
|
||||
print(" FAIL: NaN in fused IDs!")
|
||||
return
|
||||
|
||||
# Check IDs match
|
||||
ids_match = torch.equal(ref_ids, fused_ids)
|
||||
print(f" topk_ids match: {ids_match}")
|
||||
if not ids_match:
|
||||
print(f" Reference: {ref_ids[0].tolist()}")
|
||||
print(f" Fused: {fused_ids[0].tolist()}")
|
||||
|
||||
# Check weights similarity
|
||||
w_cos = torch.nn.functional.cosine_similarity(
|
||||
ref_weights.flatten().unsqueeze(0),
|
||||
fused_weights.flatten().unsqueeze(0),
|
||||
@@ -123,7 +144,6 @@ def test_fused_router_correctness():
|
||||
print(f" topk_weights cosine sim: {w_cos:.6f}")
|
||||
print(f" topk_weights max diff: {w_max_diff:.6f}")
|
||||
|
||||
# Check non-negative weights
|
||||
neg_count = (fused_weights < 0).sum().item()
|
||||
print(f" Negative weights: {neg_count}")
|
||||
|
||||
@@ -131,9 +151,7 @@ def test_fused_router_correctness():
|
||||
print("\n✅ FUSED ROUTER KERNEL PASSED!")
|
||||
else:
|
||||
print(f"\n❌ FUSED ROUTER KERNEL FAILED")
|
||||
print(f" IDs match: {ids_match}")
|
||||
print(f" Cosine: {w_cos:.6f} (need >= 0.999)")
|
||||
print(f" Neg weights: {neg_count} (need 0)")
|
||||
print(f" IDs match: {ids_match}, Cosine: {w_cos:.6f}, Neg: {neg_count}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
Reference in New Issue
Block a user