test: fused SwiGLU kernel compilation + correctness (P0/P1 gate)
This commit is contained in:
161
tests/unit/test_fused_swiglu_kernel.py
Normal file
161
tests/unit/test_fused_swiglu_kernel.py
Normal file
@@ -0,0 +1,161 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Test fused SwiGLU NVFP4 GEMM kernel compilation and correctness.
|
||||
|
||||
Validates P0/P1 from PERFORMANCE_AUDIT.md:
|
||||
- Fused SwiGLU kernel compiles via cute.compile
|
||||
- Output cosine similarity vs unfused path >= 0.9995
|
||||
- Tests both multi-expert (MoE) and single-expert (SharedExpert) modes
|
||||
"""
|
||||
import torch
|
||||
import sys
|
||||
|
||||
def test_fused_swiglu_compilation():
|
||||
"""Test that the fused SwiGLU kernel compiles and runs."""
|
||||
from dsv4.ops.gemm_runner import (
|
||||
warmup_fused_swiglu_compilation,
|
||||
warmup_compilation,
|
||||
run_nvfp4_grouped_gemm,
|
||||
run_fused_swiglu_grouped_gemm,
|
||||
)
|
||||
from dsv4.ops.quantize import quantize_to_nvfp4, SF_VEC_SIZE
|
||||
from dsv4.ops.layouts import make_b_k_major, interleave_l1_weights
|
||||
|
||||
device = "cuda:0"
|
||||
# Production MoE shapes (DeepSeek-V4 Pro L1 GEMM)
|
||||
# L1: K=7168, N=6144 (gate+up combined) → K_packed=3584, N_packed=3072
|
||||
K_packed = 3584
|
||||
N_packed = 3072
|
||||
num_experts = 4 # Small for testing, but >1 for MoE path
|
||||
swiglu_limit = 10.0
|
||||
|
||||
print(f"Testing fused SwiGLU kernel compilation...")
|
||||
print(f" K_packed={K_packed}, N_packed={N_packed}, E={num_experts}, limit={swiglu_limit}")
|
||||
|
||||
# Warmup standard GEMM first
|
||||
print(" Warming up standard GEMM...")
|
||||
warmup_compilation(num_experts, K_packed, N_packed, device)
|
||||
|
||||
# Warmup fused GEMM
|
||||
print(" Warming up fused SwiGLU GEMM...")
|
||||
try:
|
||||
warmup_fused_swiglu_compilation(
|
||||
num_experts, K_packed, N_packed, device,
|
||||
swiglu_limit=swiglu_limit,
|
||||
)
|
||||
print(" ✅ Fused SwiGLU kernel compiled successfully!")
|
||||
except TypeError as e:
|
||||
print(f" ❌ Fused SwiGLU compilation FAILED with TypeError: {e}")
|
||||
print(f" This is the arg-binding bug from the previous session.")
|
||||
raise
|
||||
except Exception as e:
|
||||
print(f" ❌ Fused SwiGLU compilation FAILED: {type(e).__name__}: {e}")
|
||||
raise
|
||||
|
||||
# Now test correctness: run both fused and unfused, compare
|
||||
print("\n Testing fused vs unfused output correctness...")
|
||||
tokens = 6 # top-k=6
|
||||
K = K_packed * 2 # 7168
|
||||
N = N_packed * 2 # 6144
|
||||
|
||||
# Create random input
|
||||
x_bf16 = torch.randn(tokens, K, dtype=torch.bfloat16, device=device) * 0.5
|
||||
|
||||
# Create random weight (same for both paths)
|
||||
w_bf16 = torch.randn(num_experts, K, N, dtype=torch.bfloat16, device=device) * 0.1
|
||||
|
||||
# Quantize activation
|
||||
x_fp4, x_sf, x_gs = quantize_to_nvfp4(x_bf16)
|
||||
|
||||
# Quantize weight (interleaved for L1 gate+up)
|
||||
w_bf16_t = w_bf16.permute(0, 2, 1).contiguous() # (E, N, K) for make_b_k_major
|
||||
w_fp4, w_sf, w_gs = quantize_to_nvfp4(w_bf16_t)
|
||||
w_fp4_il = interleave_l1_weights(w_fp4.unsqueeze(0)).squeeze(0) # interleave for SwiGLU
|
||||
mat_b = make_b_k_major(w_fp4_il)
|
||||
|
||||
# Expert offsets (all tokens go to expert 0 for simplicity)
|
||||
expert_offsets = torch.tensor([0, tokens], dtype=torch.int32, device=device)
|
||||
padded_offsets = torch.tensor([128], dtype=torch.int32, device=device) # padded to 128
|
||||
|
||||
# Pad activation to 128 rows
|
||||
x_padded = torch.zeros(128, K_packed, dtype=x_fp4.dtype, device=device)
|
||||
x_padded[:tokens] = x_fp4
|
||||
|
||||
# Assemble scales (simplified — just pad + swizzle)
|
||||
from dsv4.ops.layouts import pad_and_swizzle_single, ceil_div as cutedsl_ceil_div
|
||||
K_sf = cutedsl_ceil_div(K, 16)
|
||||
padded_cols = cutedsl_ceil_div(K_sf, 4) * 4
|
||||
scale_a_buf = torch.zeros(128, padded_cols, dtype=torch.float16, device=device).to(torch.float8_e4m3fn)
|
||||
scale_a_buf[:tokens, :x_sf.shape[1]] = x_sf
|
||||
scale_a = pad_and_swizzle_single(scale_a_buf).reshape(128, padded_cols)
|
||||
|
||||
from dsv4.ops.layouts import assemble_scales_3d_side
|
||||
scale_b = assemble_scales_3d_side(w_sf)
|
||||
|
||||
global_scale_a = torch.full((num_experts,), x_gs, dtype=torch.float32, device=device)
|
||||
global_scale_b = torch.tensor(w_gs, dtype=torch.float32, device=device)
|
||||
|
||||
# Run UNFUSED path
|
||||
print(" Running unfused GEMM...")
|
||||
l1_unfused = run_nvfp4_grouped_gemm(
|
||||
mat_a=x_padded, mat_b=mat_b,
|
||||
scale_a=scale_a, scale_b=scale_b,
|
||||
expert_offsets=padded_offsets,
|
||||
global_scale_a=global_scale_a, global_scale_b=global_scale_b,
|
||||
)[:tokens] # (6, 6144) BF16
|
||||
|
||||
# Manual SwiGLU on unfused output
|
||||
intermediate = N // 2 # 3072
|
||||
l1_deil = interleave_l1_weights(l1_unfused.unsqueeze(0).contiguous())[0]
|
||||
gate = l1_deil[:, :intermediate]
|
||||
up = l1_deil[:, intermediate:]
|
||||
gate_silu = torch.nn.functional.silu(gate)
|
||||
gate_silu = gate_silu.clamp(max=swiglu_limit)
|
||||
up = up.clamp(min=-swiglu_limit, max=swiglu_limit)
|
||||
activated_unfused = gate_silu * up
|
||||
|
||||
# Run FUSED path
|
||||
print(" Running fused SwiGLU GEMM...")
|
||||
try:
|
||||
l1_fused = run_fused_swiglu_grouped_gemm(
|
||||
mat_a=x_padded, mat_b=mat_b,
|
||||
scale_a=scale_a, scale_b=scale_b,
|
||||
expert_offsets=padded_offsets,
|
||||
global_scale_a=global_scale_a, global_scale_b=global_scale_b,
|
||||
swiglu_limit=swiglu_limit,
|
||||
)[:tokens] # (6, 3072) BF16 — SwiGLU already applied
|
||||
print(" ✅ Fused SwiGLU GEMM ran successfully!")
|
||||
except Exception as e:
|
||||
print(f" ❌ Fused SwiGLU GEMM FAILED: {type(e).__name__}: {e}")
|
||||
raise
|
||||
|
||||
# Compare
|
||||
# The fused kernel outputs only the silu(gate)*up result (N/2 = 3072)
|
||||
# The unfused path's activated_unfused is the same computation in Python
|
||||
cos = torch.nn.functional.cosine_similarity(
|
||||
l1_fused.flatten().float(), activated_unfused.flatten().float(), dim=0
|
||||
).item()
|
||||
max_diff = (l1_fused.float() - activated_unfused.float()).abs().max().item()
|
||||
print(f"\n Fused vs Unfused SwiGLU output:")
|
||||
print(f" Cosine similarity: {cos:.6f}")
|
||||
print(f" Max abs diff: {max_diff:.6f}")
|
||||
print(f" |fused|: {l1_fused.abs().max().item():.4f}")
|
||||
print(f" |unfused|: {activated_unfused.abs().max().item():.4f}")
|
||||
|
||||
if cos >= 0.9995:
|
||||
print(f" ✅ PASS: cosine >= 0.9995")
|
||||
else:
|
||||
print(f" ❌ FAIL: cosine < 0.9995")
|
||||
|
||||
# Test single-expert mode (for SharedExpert P1)
|
||||
print("\n--- Testing single-expert (SharedExpert) mode ---")
|
||||
warmup_fused_swiglu_compilation(
|
||||
1, K_packed, N_packed, device, swiglu_limit=swiglu_limit
|
||||
)
|
||||
print(" ✅ Single-expert fused SwiGLU compiled!")
|
||||
|
||||
return cos >= 0.9995
|
||||
|
||||
if __name__ == "__main__":
|
||||
torch.manual_seed(42)
|
||||
success = test_fused_swiglu_compilation()
|
||||
sys.exit(0 if success else 1)
|
||||
Reference in New Issue
Block a user