test: streamline layertest — kernel vs BF16 ref only, exit on fail

Removed original checkpoint loading (already verified 0.997 cosine).
Test now: load NVFP4 → dequant BF16 ref → run kernel → compare.
Exits with code 1 if cosine < 0.99.
This commit is contained in:
2026-05-16 02:29:41 +00:00
parent de9b50cbe7
commit c4a262bd54

View File

@@ -1,10 +1,9 @@
#!/usr/bin/env python3
"""
Layer 0 comparison test: original checkpoint vs NVFP4 checkpoint + our kernel.
Layer 0 kernel comparison test: NVFP4 kernel vs BF16 reference.
Loads layer 0 expert weights from both checkpoints, runs the same deterministic
MoE forward pass, and compares the results. No vLLM, no Docker, no tensor
parallelism — just raw weights + our GEMM kernel.
No vLLM, no Docker, no tensor parallelism. Just raw weights + our kernel.
If cosine < 0.99, the test exits with error.
Usage:
python3 layertest.py
@@ -19,12 +18,12 @@ from safetensors import safe_open
# ── Constants ──────────────────────────────────────────────────────────
ORIG_MODEL_DIR = "/root/nvidia-meeting/DeepSeek-V4-Pro"
NVFP4_MODEL_DIR = "/root/nvidia-meeting/DeepSeek-V4-Pro-NVFP4"
LAYER_IDX = 0
DEVICE = "cuda"
COSINE_THRESHOLD = 0.99
# E2M1 FP4 lookup table (shared by both formats)
# E2M1 FP4 lookup table
E2M1_LUT = torch.tensor([
0.0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0,
-0.0, -0.5, -1.0, -1.5, -2.0, -3.0, -4.0, -6.0,
@@ -43,7 +42,6 @@ def find_shards(model_dir):
for key, shard in index["weight_map"].items():
key_to_shard[key] = os.path.join(model_dir, shard)
else:
# Single shard
for sf in glob.glob(os.path.join(model_dir, "*.safetensors")):
with safe_open(sf, framework="pt") as f:
for key in f.keys():
@@ -53,18 +51,12 @@ def find_shards(model_dir):
def load_layer_tensors(model_dir, layer_idx):
"""Load all tensors for a specific layer from the checkpoint.
Returns dict of {key: tensor} for all keys matching the layer.
Keys are normalized to NOT have 'model.' prefix.
"""
"""Load all tensors for a specific layer. Keys normalized (no 'model.' prefix)."""
key_to_shard = find_shards(model_dir)
layer_prefix = f"layers.{layer_idx}."
# Group by shard to minimize file opens
shard_to_keys = {}
for key, shard in key_to_shard.items():
# Normalize: strip 'model.' prefix if present
norm_key = key.removeprefix("model.")
if not norm_key.startswith(layer_prefix):
continue
@@ -79,83 +71,21 @@ def load_layer_tensors(model_dir, layer_idx):
return tensors
def print_layer_keys(tensors, label):
"""Print sorted tensor keys with shapes and dtypes."""
print(f"\n{'='*70}")
print(f" {label}{len(tensors)} tensors")
print(f"{'='*70}")
for key in sorted(tensors.keys()):
def print_layer_keys(tensors, label, max_keys=20):
"""Print sorted tensor keys with shapes and dtypes (first N)."""
print(f"\n {label}{len(tensors)} tensors")
sorted_keys = sorted(tensors.keys())
for key in sorted_keys[:max_keys]:
t = tensors[key]
print(f" {key}: dtype={t.dtype} shape={tuple(t.shape)}")
print(f" {key}: dtype={t.dtype} shape={tuple(t.shape)}")
if len(sorted_keys) > max_keys:
print(f" ... ({len(sorted_keys) - max_keys} more)")
# ── Dequantization: Original checkpoint (MXFP4) ───────────────────────
def dequantize_mxfp4_weight(packed_uint8, scale_e8m0):
"""Dequantize MXFP4 (E2M1 + E8M0, block_size=32) to BF16.
Original checkpoint format:
packed_uint8: (out_features, in_features//2) uint8
scale_e8m0: (out_features, in_features//32) float8_e8m0fnu
"""
device = packed_uint8.device
lut = E2M1_LUT.to(device)
lower = lut[(packed_uint8 & 0x0F).long()]
upper = lut[((packed_uint8 >> 4) & 0x0F).long()]
out_features = packed_uint8.shape[0]
in_features = packed_uint8.shape[1] * 2
unpacked = torch.empty(out_features, in_features, dtype=torch.float32, device=device)
unpacked[:, 0::2] = lower
unpacked[:, 1::2] = upper
# E8M0 → float32: exponent-only format, represents 2^(x - bias)
scale_f32 = scale_e8m0.float()
scale_expanded = scale_f32.repeat_interleave(32, dim=1)[:, :in_features]
return (unpacked * scale_expanded).to(torch.bfloat16)
def dequantize_mxfp4_experts(orig_tensors, layer_idx, expert_indices):
"""Dequantize expert weights from original MXFP4 checkpoint.
Original checkpoint key format: layers.{L}.ffn.experts.{E}.{w1/w2/w3}.{weight/scale}
w1 = gate_proj, w3 = up_proj, w2 = down_proj
Returns dict: {expert_id: {gate_proj, up_proj, down_proj}} each as BF16.
"""
experts = {}
for e in expert_indices:
expert = {}
for proj, shard in [("gate_proj", "w1"), ("up_proj", "w3"), ("down_proj", "w2")]:
weight_key = f"layers.{layer_idx}.ffn.experts.{e}.{shard}.weight"
scale_key = f"layers.{layer_idx}.ffn.experts.{e}.{shard}.scale"
if weight_key not in orig_tensors:
if proj == "down_proj" and e == 211:
continue
raise KeyError(f"Missing {weight_key}")
weight = orig_tensors[weight_key].to(DEVICE)
scale = orig_tensors[scale_key].to(DEVICE)
expert[proj] = dequantize_mxfp4_weight(weight, scale)
experts[e] = expert
return experts
# ── Dequantization: NVFP4 checkpoint ──────────────────────────────────
# ── NVFP4 Dequantization ──────────────────────────────────────────────
def dequantize_nvfp4_weight(packed_uint8, scale_e4m3, global_scale):
"""Dequantize NVFP4 (E2M1 + E4M3 block scale + float32 global) to BF16.
NVFP4 checkpoint format:
packed_uint8: (out_features, in_features//2) uint8
scale_e4m3: (out_features, in_features//16) float8_e4m3fn
global_scale: float32 scalar
"""
"""Dequantize NVFP4 (E2M1 + E4M3 + global) to BF16."""
device = packed_uint8.device
lut = E2M1_LUT.to(device)
@@ -169,20 +99,14 @@ def dequantize_nvfp4_weight(packed_uint8, scale_e4m3, global_scale):
unpacked[:, 0::2] = lower
unpacked[:, 1::2] = upper
block_scale = scale_e4m3.float() # float8_e4m3fn → float32
block_scale = scale_e4m3.float()
block_expanded = block_scale.repeat_interleave(16, dim=1)[:, :in_features]
# Weight dequant = e2m1 * block_scale * global_scale
return (unpacked * block_expanded * global_scale).to(torch.bfloat16)
def dequantize_nvfp4_experts(nvfp4_tensors, layer_idx, expert_indices):
"""Dequantize expert weights from NVFP4 checkpoint.
NVFP4 checkpoint key format: layers.{L}.mlp.experts.{E}.{gate_proj/up_proj/down_proj}.{weight/weight_scale/weight_scale_2}
Returns dict: {expert_id: {gate_proj, up_proj, down_proj}} each as BF16.
"""
"""Dequantize expert weights from NVFP4 checkpoint → BF16."""
experts = {}
for e in expert_indices:
expert = {}
@@ -205,20 +129,10 @@ def dequantize_nvfp4_experts(nvfp4_tensors, layer_idx, expert_indices):
return experts
# ── MoE Forward Pass (BF16 reference) ─────────────────────────────────
# ── BF16 MoE Forward ───────────────────────────────────────────────────
def moe_forward_bf16(hidden_states, experts, expert_ids, expert_weights):
"""Run MoE forward pass in pure BF16.
Args:
hidden_states: (num_tokens, hidden_size) BF16
experts: dict {expert_id: {gate_proj, up_proj, down_proj}} BF16
expert_ids: (num_tokens, top_k) int — which expert per token per slot
expert_weights: (num_tokens, top_k) float32 — routing weights
Returns:
output: (num_tokens, hidden_size) BF16
"""
"""Run MoE forward pass in pure BF16 (torch.matmul)."""
num_tokens, hidden_size = hidden_states.shape
top_k = expert_ids.shape[1]
output = torch.zeros(num_tokens, hidden_size, dtype=torch.bfloat16, device=DEVICE)
@@ -231,92 +145,66 @@ def moe_forward_bf16(hidden_states, experts, expert_ids, expert_weights):
if e not in experts:
continue
x = hidden_states[t] # (hidden_size,)
gate = x @ experts[e]["gate_proj"].T # (intermediate//2,)
up = x @ experts[e]["up_proj"].T # (intermediate//2,)
activated = torch.nn.functional.silu(gate) * up # (intermediate//2,)
x = hidden_states[t]
gate = x @ experts[e]["gate_proj"].T
up = x @ experts[e]["up_proj"].T
activated = torch.nn.functional.silu(gate) * up
if "down_proj" in experts[e]:
y = activated @ experts[e]["down_proj"].T # (hidden_size,)
y = activated @ experts[e]["down_proj"].T
else:
y = activated[:hidden_size] # shared expert, no down_proj
y = activated[:hidden_size]
output[t] += w * y
return output
# ── MoE Forward Pass (NVFP4 kernel) ───────────────────────────────────
# ── NVFP4 Kernel MoE Forward ──────────────────────────────────────────
def moe_forward_nvfp4(hidden_states, nvfp4_tensors, layer_idx, expert_ids, expert_weights):
"""Run MoE forward pass using our NVFP4 kernel.
Loads weights directly from NVFP4 checkpoint (no vLLM), transforms them
for CUTLASS, and runs the grouped GEMM.
"""
"""Run MoE forward pass using our NVFP4 kernel."""
from nvfp4_megamoe_kernel import (
stage_activation,
nvfp4_mega_moe_full,
transform_nvfp4_weights_for_mega_moe,
SymmBuffer,
get_symm_buffer_for_nvfp4_mega_moe,
)
num_tokens, hidden_size = hidden_states.shape
top_k = expert_ids.shape[1]
# Collect the experts we need
unique_experts = sorted(set(expert_ids.flatten().tolist()))
num_experts = len(unique_experts)
expert_map = {e: i for i, e in enumerate(unique_experts)}
# Load NVFP4 weights for these experts
# Shapes: gate_proj.weight = (3072, 3584) uint8, weight_scale = (3072, 448) float8_e4m3fn
intermediate_half = 3072 # intermediate_size // 2
intermediate_half = 3072
hidden_half = hidden_size // 2
l1_weights = [] # gate + up fused
l1_scales = []
l1_global_scales = []
l2_weights = [] # down
l2_scales = []
l2_global_scales = []
l1_weights, l1_scales, l1_global_scales = [], [], []
l2_weights, l2_scales, l2_global_scales = [], [], []
for e in unique_experts:
# L1: gate_proj + up_proj fused
gate_w_key = f"layers.{layer_idx}.mlp.experts.{e}.gate_proj.weight"
gate_sf_key = f"layers.{layer_idx}.mlp.experts.{e}.gate_proj.weight_scale"
gate_gs_key = f"layers.{layer_idx}.mlp.experts.{e}.gate_proj.weight_scale_2"
up_w_key = f"layers.{layer_idx}.mlp.experts.{e}.up_proj.weight"
up_sf_key = f"layers.{layer_idx}.mlp.experts.{e}.up_proj.weight_scale"
up_gs_key = f"layers.{layer_idx}.mlp.experts.{e}.up_proj.weight_scale_2"
gate_w = nvfp4_tensors[f"layers.{layer_idx}.mlp.experts.{e}.gate_proj.weight"].view(torch.int8).to(DEVICE)
gate_sf = nvfp4_tensors[f"layers.{layer_idx}.mlp.experts.{e}.gate_proj.weight_scale"].to(DEVICE)
gate_gs = nvfp4_tensors[f"layers.{layer_idx}.mlp.experts.{e}.gate_proj.weight_scale_2"].item()
up_w = nvfp4_tensors[f"layers.{layer_idx}.mlp.experts.{e}.up_proj.weight"].view(torch.int8).to(DEVICE)
up_sf = nvfp4_tensors[f"layers.{layer_idx}.mlp.experts.{e}.up_proj.weight_scale"].to(DEVICE)
up_gs = nvfp4_tensors[f"layers.{layer_idx}.mlp.experts.{e}.up_proj.weight_scale_2"].item()
gate_w = nvfp4_tensors[gate_w_key].view(torch.int8).to(DEVICE)
gate_sf = nvfp4_tensors[gate_sf_key].to(DEVICE)
gate_gs = nvfp4_tensors[gate_gs_key].item()
up_w = nvfp4_tensors[up_w_key].view(torch.int8).to(DEVICE)
up_sf = nvfp4_tensors[up_sf_key].to(DEVICE)
up_gs = nvfp4_tensors[up_gs_key].item()
# Fuse gate + up: stack along dim 0 → (2*3072, 3584)
l1_w = torch.cat([gate_w, up_w], dim=0)
l1_sf = torch.cat([gate_sf, up_sf], dim=0)
l1_gs = torch.tensor([gate_gs, up_gs], dtype=torch.float32, device=DEVICE)
l1_weights.append(l1_w)
l1_scales.append(l1_sf)
l1_global_scales.append(l1_gs)
# L2: down_proj
down_w_key = f"layers.{layer_idx}.mlp.experts.{e}.down_proj.weight"
if down_w_key in nvfp4_tensors:
down_w = nvfp4_tensors[down_w_key].view(torch.int8).to(DEVICE)
down_sf_key = f"layers.{layer_idx}.mlp.experts.{e}.down_proj.weight_scale"
down_gs_key = f"layers.{layer_idx}.mlp.experts.{e}.down_proj.weight_scale_2"
down_sf = nvfp4_tensors[down_sf_key].to(DEVICE)
down_gs = nvfp4_tensors[down_gs_key].item()
down_sf = nvfp4_tensors[f"layers.{layer_idx}.mlp.experts.{e}.down_proj.weight_scale"].to(DEVICE)
down_gs = nvfp4_tensors[f"layers.{layer_idx}.mlp.experts.{e}.down_proj.weight_scale_2"].item()
else:
# Expert 211 has no down_proj — use zeros
down_w = torch.zeros(hidden_size, intermediate_half, dtype=torch.int8, device=DEVICE)
down_sf = torch.ones(hidden_size, intermediate_half // 16, dtype=torch.float8_e4m3fn, device=DEVICE)
down_gs = 1.0
@@ -325,24 +213,19 @@ def moe_forward_nvfp4(hidden_states, nvfp4_tensors, layer_idx, expert_ids, exper
l2_scales.append(down_sf)
l2_global_scales.append(torch.tensor([down_gs], dtype=torch.float32, device=DEVICE))
# Stack into (num_experts, ...) tensors
l1_w = torch.stack(l1_weights) # (E, 2*3072, 3584) int8
l1_sf = torch.stack(l1_scales) # (E, 2*3072, 448) float8_e4m3fn
l1_gs = torch.stack(l1_global_scales) # (E, 2) float32
l2_w = torch.stack(l2_weights) # (E, hidden, intermediate_half) int8
l2_sf = torch.stack(l2_scales) # (E, hidden, intermediate_half//16) float8_e4m3fn
l2_gs = torch.stack(l2_global_scales) # (E, 1) float32
l1_w = torch.stack(l1_weights)
l1_sf = torch.stack(l1_scales)
l1_gs = torch.stack(l1_global_scales)
l2_w = torch.stack(l2_weights)
l2_sf = torch.stack(l2_scales)
l2_gs = torch.stack(l2_global_scales)
# Transform weights for CUTLASS
(l1_w, l1_sf, l1_global_sf), (l2_w, l2_sf, l2_global_sf) = \
transform_nvfp4_weights_for_mega_moe(
(l1_w, l1_sf),
(l2_w, l2_sf),
l1_weight_scale_2=l1_gs,
l2_weight_scale_2=l2_gs,
(l1_w, l1_sf), (l2_w, l2_sf),
l1_weight_scale_2=l1_gs, l2_weight_scale_2=l2_gs,
)
# Build slot mapping: each (token, top_k) pair → slot
num_slots = num_tokens * top_k
slot_expert = torch.zeros(num_slots, dtype=torch.int32, device=DEVICE)
slot_token = torch.zeros(num_slots, dtype=torch.int64, device=DEVICE)
@@ -351,22 +234,15 @@ def moe_forward_nvfp4(hidden_states, nvfp4_tensors, layer_idx, expert_ids, exper
for t in range(num_tokens):
for k in range(top_k):
slot = t * top_k + k
e = expert_ids[t, k].item()
slot_expert[slot] = expert_map[e]
slot_expert[slot] = expert_map[expert_ids[t, k].item()]
slot_token[slot] = t
slot_weight[slot] = expert_weights[t, k].item()
# SymmBuffer
symm_buffer = get_symm_buffer_for_nvfp4_mega_moe(
group=None, # no EP
num_experts=num_experts,
max_num_tokens=num_tokens,
top_k=top_k,
hidden_size=hidden_size,
intermediate_size=6144, # 2 * 3072
group=None, num_experts=num_experts, max_num_tokens=num_tokens,
top_k=top_k, hidden_size=hidden_size, intermediate_size=6144,
)
# Stage activation
x_fp4, x_sf, input_global_scale = stage_activation(hidden_states)
symm_buffer.x[:num_tokens].copy_(x_fp4)
symm_buffer.x_sf[:num_tokens].copy_(x_sf)
@@ -375,7 +251,6 @@ def moe_forward_nvfp4(hidden_states, nvfp4_tensors, layer_idx, expert_ids, exper
symm_buffer.topk_weights[:num_tokens].copy_(expert_weights)
symm_buffer.experts_start_idx = 0
# Run
y = torch.zeros(num_tokens, hidden_size, dtype=torch.bfloat16, device=DEVICE)
nvfp4_mega_moe_full(
y,
@@ -391,145 +266,71 @@ def moe_forward_nvfp4(hidden_states, nvfp4_tensors, layer_idx, expert_ids, exper
def main():
torch.manual_seed(42)
expert_indices = [0, 1, 2] # Test with 3 experts
expert_indices = [0, 1, 2]
top_k = 2
num_tokens = 4
# ── Step 1: Load original checkpoint layer 0 ──
print("\n" + "="*70)
print(" STEP 1: Loading original MXFP4 checkpoint")
print("="*70)
orig_tensors = load_layer_tensors(ORIG_MODEL_DIR, LAYER_IDX)
print_layer_keys(orig_tensors, "Original checkpoint (MXFP4)")
# Dequantize to BF16
print("\nDequantizing MXFP4 → BF16...")
orig_experts_bf16 = dequantize_mxfp4_experts(orig_tensors, LAYER_IDX, expert_indices)
for e in expert_indices:
for proj, w in orig_experts_bf16[e].items():
print(f" Expert {e} {proj}: shape={tuple(w.shape)} amax={w.abs().max():.4f}")
# ── Step 2: Run BF16 reference forward pass ──
print("\n" + "="*70)
print(" STEP 2: BF16 reference forward pass")
print("="*70)
hidden_size = 7168
hidden_states = torch.randn(num_tokens, hidden_size, dtype=torch.bfloat16, device=DEVICE) * 2.0
# Deterministic routing: each token picks experts 0,1
expert_ids = torch.tensor([[0, 1]] * num_tokens, dtype=torch.int32, device=DEVICE)
expert_weights = torch.tensor([[0.6, 0.4]] * num_tokens, dtype=torch.float32, device=DEVICE)
ref_output = moe_forward_bf16(hidden_states, orig_experts_bf16, expert_ids, expert_weights)
print(f" Reference output: shape={tuple(ref_output.shape)} amax={ref_output.abs().max():.4f} mean={ref_output.float().mean():.6f}")
print(f" First token first 10: {ref_output[0, :10].tolist()}")
del orig_tensors, orig_experts_bf16 # Free memory
torch.cuda.empty_cache()
# ── Step 3: Load NVFP4 checkpoint layer 0 ──
print("\n" + "="*70)
print(" STEP 3: Loading NVFP4 checkpoint")
print("="*70)
# ── Load NVFP4 checkpoint ──
print("=" * 70)
print(" Loading NVFP4 checkpoint layer 0")
print("=" * 70)
nvfp4_tensors = load_layer_tensors(NVFP4_MODEL_DIR, LAYER_IDX)
print_layer_keys(nvfp4_tensors, "NVFP4 checkpoint")
print_layer_keys(nvfp4_tensors, "NVFP4 checkpoint", max_keys=5)
# Verify dtype of weight_scale (should be float8_e4m3fn, NOT float8_e8m0fnu)
# Verify weight_scale dtype
for e in expert_indices[:1]:
for proj in ["gate_proj", "up_proj", "down_proj"]:
key = f"layers.{LAYER_IDX}.mlp.experts.{e}.{proj}.weight_scale"
if key in nvfp4_tensors:
dt = nvfp4_tensors[key].dtype
print(f" {proj}.weight_scale dtype = {dt} {'✓ E4M3' if dt == torch.float8_e4m3fn else '✗ WRONG (expected float8_e4m3fn)'}")
assert dt == torch.float8_e4m3fn, f"{proj}.weight_scale dtype={dt}, expected float8_e4m3fn"
print(f" {proj}.weight_scale dtype = {dt}")
# Dequantize NVFP4 → BF16 (for BF16 reference comparison)
print("\nDequantizing NVFP4 → BF16...")
# ── Dequantize → BF16 reference ──
print("\n Dequantizing NVFP4 → BF16...")
nvfp4_experts_bf16 = dequantize_nvfp4_experts(nvfp4_tensors, LAYER_IDX, expert_indices)
for e in expert_indices:
for e in expert_indices[:2]:
for proj, w in nvfp4_experts_bf16[e].items():
print(f" Expert {e} {proj}: shape={tuple(w.shape)} amax={w.abs().max():.4f}")
print(f" Expert {e} {proj}: shape={tuple(w.shape)} amax={w.abs().max():.4f}")
# ── Step 4: Compare dequantized weights ──
print("\n" + "="*70)
print(" STEP 4: Weight comparison (original dequant vs NVFP4 dequant)")
print("="*70)
# ── Create test input ──
hidden_states = torch.randn(num_tokens, hidden_size, dtype=torch.bfloat16, device=DEVICE) * 2.0
expert_ids = torch.tensor([[0, 1]] * num_tokens, dtype=torch.int32, device=DEVICE)
expert_weights = torch.tensor([[0.6, 0.4]] * num_tokens, dtype=torch.float32, device=DEVICE)
# Note: the original was MXFP4 (E8M0, block=32) and NVFP4 is (E4M3, block=16)
# They were quantized independently so weights will differ — this is expected.
# The comparison is to verify the NVFP4 dequant matches its own re-dequant.
print(" (MXFP4 and NVFP4 were independently quantized — weight values will differ)")
print(" (This is expected. The comparison is: NVFP4 dequant vs NVFP4 kernel)")
# ── BF16 reference forward pass ──
print("\n Running BF16 reference...")
ref_output = moe_forward_bf16(hidden_states, nvfp4_experts_bf16, expert_ids, expert_weights)
print(f" BF16 ref: amax={ref_output.abs().max():.4f} mean={ref_output.float().mean():.6f}")
print(f" First token first 8: {[f'{v:.4f}' for v in ref_output[0, :8].tolist()]}")
# ── Step 5: Run NVFP4 BF16 reference (using NVFP4-dequantized weights) ──
print("\n" + "="*70)
print(" STEP 5: NVFP4 BF16 reference forward pass")
print("="*70)
del nvfp4_experts_bf16
torch.cuda.empty_cache()
nvfp4_ref_output = moe_forward_bf16(hidden_states, nvfp4_experts_bf16, expert_ids, expert_weights)
print(f" NVFP4 BF16 ref: shape={tuple(nvfp4_ref_output.shape)} amax={nvfp4_ref_output.abs().max():.4f} mean={nvfp4_ref_output.float().mean():.6f}")
print(f" First token first 10: {nvfp4_ref_output[0, :10].tolist()}")
# ── NVFP4 kernel forward pass ──
print("\n Running NVFP4 kernel...")
kernel_output = moe_forward_nvfp4(hidden_states, nvfp4_tensors, LAYER_IDX, expert_ids, expert_weights)
print(f" Kernel: amax={kernel_output.abs().max():.4f} mean={kernel_output.float().mean():.6f}")
print(f" First token first 8: {[f'{v:.4f}' for v in kernel_output[0, :8].tolist()]}")
# Compare against original dequant
cos_orig_vs_nvfp4bf16 = torch.nn.functional.cosine_similarity(
# ── Compare ──
cosine = torch.nn.functional.cosine_similarity(
kernel_output.flatten().unsqueeze(0).float(),
ref_output.flatten().unsqueeze(0).float(),
nvfp4_ref_output.flatten().unsqueeze(0).float(),
).item()
print(f" Cosine (orig BF16 ref vs NVFP4 BF16 ref): {cos_orig_vs_nvfp4bf16:.6f}")
mse = (kernel_output.float() - ref_output.float()).pow(2).mean().item()
# ── Step 6: Run our NVFP4 kernel ──
print("\n" + "="*70)
print(" STEP 6: NVFP4 kernel forward pass")
print("="*70)
print(f"\n{'=' * 70}")
print(f" RESULT: cosine={cosine:.6f} MSE={mse:.6e}")
print(f"{'=' * 70}")
try:
kernel_output = moe_forward_nvfp4(hidden_states, nvfp4_tensors, LAYER_IDX, expert_ids, expert_weights)
print(f" Kernel output: shape={tuple(kernel_output.shape)} amax={kernel_output.abs().max():.4f} mean={kernel_output.float().mean():.6f}")
print(f" First token first 10: {kernel_output[0, :10].tolist()}")
# Compare kernel vs NVFP4 BF16 reference
cos_kernel_vs_nvfp4bf16 = torch.nn.functional.cosine_similarity(
kernel_output.flatten().unsqueeze(0).float(),
nvfp4_ref_output.flatten().unsqueeze(0).float(),
).item()
mse = (kernel_output.float() - nvfp4_ref_output.float()).pow(2).mean().item()
print(f" Cosine (kernel vs NVFP4 BF16 ref): {cos_kernel_vs_nvfp4bf16:.6f}")
print(f" MSE (kernel vs NVFP4 BF16 ref): {mse:.6e}")
# Compare kernel vs original BF16 reference
cos_kernel_vs_orig = torch.nn.functional.cosine_similarity(
kernel_output.flatten().unsqueeze(0).float(),
ref_output.flatten().unsqueeze(0).float(),
).item()
print(f" Cosine (kernel vs orig BF16 ref): {cos_kernel_vs_orig:.6f}")
except Exception as e:
print(f" KERNEL FAILED: {e}")
import traceback
traceback.print_exc()
# ── Summary ──
print("\n" + "="*70)
print(" SUMMARY")
print("="*70)
print(f" Original BF16 reference: amax={ref_output.abs().max():.4f} mean={ref_output.float().mean():.6f}")
print(f" NVFP4 BF16 reference: amax={nvfp4_ref_output.abs().max():.4f} mean={nvfp4_ref_output.float().mean():.6f}")
print(f" Cosine (orig vs NVFP4 BF16): {cos_orig_vs_nvfp4bf16:.6f}")
if 'kernel_output' in dir():
cos_k = torch.nn.functional.cosine_similarity(
kernel_output.flatten().unsqueeze(0).float(),
nvfp4_ref_output.flatten().unsqueeze(0).float(),
).item()
print(f" Cosine (kernel vs NVFP4 BF16): {cos_k:.6f}")
if cos_k > 0.99:
print(f" ✅ Kernel matches BF16 reference — bug is in vLLM integration")
elif cos_k > 0.9:
print(f" ⚠️ Kernel is close but not perfect — minor numerical issue")
else:
print(f" ❌ Kernel is far from BF16 reference — bug is in the kernel or weight pipeline")
if cosine < COSINE_THRESHOLD:
print(f" ❌ FAIL: cosine {cosine:.6f} < {COSINE_THRESHOLD}")
sys.exit(1)
else:
print(f" ✅ PASS: cosine {cosine:.6f} >= {COSINE_THRESHOLD}")
if __name__ == "__main__":