test: streamline layertest — kernel vs BF16 ref only, exit on fail
Removed original checkpoint loading (already verified 0.997 cosine). Test now: load NVFP4 → dequant BF16 ref → run kernel → compare. Exits with code 1 if cosine < 0.99.
This commit is contained in:
@@ -1,10 +1,9 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Layer 0 comparison test: original checkpoint vs NVFP4 checkpoint + our kernel.
|
||||
Layer 0 kernel comparison test: NVFP4 kernel vs BF16 reference.
|
||||
|
||||
Loads layer 0 expert weights from both checkpoints, runs the same deterministic
|
||||
MoE forward pass, and compares the results. No vLLM, no Docker, no tensor
|
||||
parallelism — just raw weights + our GEMM kernel.
|
||||
No vLLM, no Docker, no tensor parallelism. Just raw weights + our kernel.
|
||||
If cosine < 0.99, the test exits with error.
|
||||
|
||||
Usage:
|
||||
python3 layertest.py
|
||||
@@ -19,12 +18,12 @@ from safetensors import safe_open
|
||||
|
||||
# ── Constants ──────────────────────────────────────────────────────────
|
||||
|
||||
ORIG_MODEL_DIR = "/root/nvidia-meeting/DeepSeek-V4-Pro"
|
||||
NVFP4_MODEL_DIR = "/root/nvidia-meeting/DeepSeek-V4-Pro-NVFP4"
|
||||
LAYER_IDX = 0
|
||||
DEVICE = "cuda"
|
||||
COSINE_THRESHOLD = 0.99
|
||||
|
||||
# E2M1 FP4 lookup table (shared by both formats)
|
||||
# E2M1 FP4 lookup table
|
||||
E2M1_LUT = torch.tensor([
|
||||
0.0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0,
|
||||
-0.0, -0.5, -1.0, -1.5, -2.0, -3.0, -4.0, -6.0,
|
||||
@@ -43,7 +42,6 @@ def find_shards(model_dir):
|
||||
for key, shard in index["weight_map"].items():
|
||||
key_to_shard[key] = os.path.join(model_dir, shard)
|
||||
else:
|
||||
# Single shard
|
||||
for sf in glob.glob(os.path.join(model_dir, "*.safetensors")):
|
||||
with safe_open(sf, framework="pt") as f:
|
||||
for key in f.keys():
|
||||
@@ -53,18 +51,12 @@ def find_shards(model_dir):
|
||||
|
||||
|
||||
def load_layer_tensors(model_dir, layer_idx):
|
||||
"""Load all tensors for a specific layer from the checkpoint.
|
||||
|
||||
Returns dict of {key: tensor} for all keys matching the layer.
|
||||
Keys are normalized to NOT have 'model.' prefix.
|
||||
"""
|
||||
"""Load all tensors for a specific layer. Keys normalized (no 'model.' prefix)."""
|
||||
key_to_shard = find_shards(model_dir)
|
||||
layer_prefix = f"layers.{layer_idx}."
|
||||
|
||||
# Group by shard to minimize file opens
|
||||
shard_to_keys = {}
|
||||
for key, shard in key_to_shard.items():
|
||||
# Normalize: strip 'model.' prefix if present
|
||||
norm_key = key.removeprefix("model.")
|
||||
if not norm_key.startswith(layer_prefix):
|
||||
continue
|
||||
@@ -79,83 +71,21 @@ def load_layer_tensors(model_dir, layer_idx):
|
||||
return tensors
|
||||
|
||||
|
||||
def print_layer_keys(tensors, label):
|
||||
"""Print sorted tensor keys with shapes and dtypes."""
|
||||
print(f"\n{'='*70}")
|
||||
print(f" {label} — {len(tensors)} tensors")
|
||||
print(f"{'='*70}")
|
||||
for key in sorted(tensors.keys()):
|
||||
def print_layer_keys(tensors, label, max_keys=20):
|
||||
"""Print sorted tensor keys with shapes and dtypes (first N)."""
|
||||
print(f"\n {label} — {len(tensors)} tensors")
|
||||
sorted_keys = sorted(tensors.keys())
|
||||
for key in sorted_keys[:max_keys]:
|
||||
t = tensors[key]
|
||||
print(f" {key}: dtype={t.dtype} shape={tuple(t.shape)}")
|
||||
print(f" {key}: dtype={t.dtype} shape={tuple(t.shape)}")
|
||||
if len(sorted_keys) > max_keys:
|
||||
print(f" ... ({len(sorted_keys) - max_keys} more)")
|
||||
|
||||
|
||||
# ── Dequantization: Original checkpoint (MXFP4) ───────────────────────
|
||||
|
||||
def dequantize_mxfp4_weight(packed_uint8, scale_e8m0):
|
||||
"""Dequantize MXFP4 (E2M1 + E8M0, block_size=32) to BF16.
|
||||
|
||||
Original checkpoint format:
|
||||
packed_uint8: (out_features, in_features//2) uint8
|
||||
scale_e8m0: (out_features, in_features//32) float8_e8m0fnu
|
||||
"""
|
||||
device = packed_uint8.device
|
||||
lut = E2M1_LUT.to(device)
|
||||
|
||||
lower = lut[(packed_uint8 & 0x0F).long()]
|
||||
upper = lut[((packed_uint8 >> 4) & 0x0F).long()]
|
||||
|
||||
out_features = packed_uint8.shape[0]
|
||||
in_features = packed_uint8.shape[1] * 2
|
||||
|
||||
unpacked = torch.empty(out_features, in_features, dtype=torch.float32, device=device)
|
||||
unpacked[:, 0::2] = lower
|
||||
unpacked[:, 1::2] = upper
|
||||
|
||||
# E8M0 → float32: exponent-only format, represents 2^(x - bias)
|
||||
scale_f32 = scale_e8m0.float()
|
||||
scale_expanded = scale_f32.repeat_interleave(32, dim=1)[:, :in_features]
|
||||
|
||||
return (unpacked * scale_expanded).to(torch.bfloat16)
|
||||
|
||||
|
||||
def dequantize_mxfp4_experts(orig_tensors, layer_idx, expert_indices):
|
||||
"""Dequantize expert weights from original MXFP4 checkpoint.
|
||||
|
||||
Original checkpoint key format: layers.{L}.ffn.experts.{E}.{w1/w2/w3}.{weight/scale}
|
||||
w1 = gate_proj, w3 = up_proj, w2 = down_proj
|
||||
|
||||
Returns dict: {expert_id: {gate_proj, up_proj, down_proj}} each as BF16.
|
||||
"""
|
||||
experts = {}
|
||||
for e in expert_indices:
|
||||
expert = {}
|
||||
for proj, shard in [("gate_proj", "w1"), ("up_proj", "w3"), ("down_proj", "w2")]:
|
||||
weight_key = f"layers.{layer_idx}.ffn.experts.{e}.{shard}.weight"
|
||||
scale_key = f"layers.{layer_idx}.ffn.experts.{e}.{shard}.scale"
|
||||
|
||||
if weight_key not in orig_tensors:
|
||||
if proj == "down_proj" and e == 211:
|
||||
continue
|
||||
raise KeyError(f"Missing {weight_key}")
|
||||
|
||||
weight = orig_tensors[weight_key].to(DEVICE)
|
||||
scale = orig_tensors[scale_key].to(DEVICE)
|
||||
expert[proj] = dequantize_mxfp4_weight(weight, scale)
|
||||
|
||||
experts[e] = expert
|
||||
return experts
|
||||
|
||||
|
||||
# ── Dequantization: NVFP4 checkpoint ──────────────────────────────────
|
||||
# ── NVFP4 Dequantization ──────────────────────────────────────────────
|
||||
|
||||
def dequantize_nvfp4_weight(packed_uint8, scale_e4m3, global_scale):
|
||||
"""Dequantize NVFP4 (E2M1 + E4M3 block scale + float32 global) to BF16.
|
||||
|
||||
NVFP4 checkpoint format:
|
||||
packed_uint8: (out_features, in_features//2) uint8
|
||||
scale_e4m3: (out_features, in_features//16) float8_e4m3fn
|
||||
global_scale: float32 scalar
|
||||
"""
|
||||
"""Dequantize NVFP4 (E2M1 + E4M3 + global) to BF16."""
|
||||
device = packed_uint8.device
|
||||
lut = E2M1_LUT.to(device)
|
||||
|
||||
@@ -169,20 +99,14 @@ def dequantize_nvfp4_weight(packed_uint8, scale_e4m3, global_scale):
|
||||
unpacked[:, 0::2] = lower
|
||||
unpacked[:, 1::2] = upper
|
||||
|
||||
block_scale = scale_e4m3.float() # float8_e4m3fn → float32
|
||||
block_scale = scale_e4m3.float()
|
||||
block_expanded = block_scale.repeat_interleave(16, dim=1)[:, :in_features]
|
||||
|
||||
# Weight dequant = e2m1 * block_scale * global_scale
|
||||
return (unpacked * block_expanded * global_scale).to(torch.bfloat16)
|
||||
|
||||
|
||||
def dequantize_nvfp4_experts(nvfp4_tensors, layer_idx, expert_indices):
|
||||
"""Dequantize expert weights from NVFP4 checkpoint.
|
||||
|
||||
NVFP4 checkpoint key format: layers.{L}.mlp.experts.{E}.{gate_proj/up_proj/down_proj}.{weight/weight_scale/weight_scale_2}
|
||||
|
||||
Returns dict: {expert_id: {gate_proj, up_proj, down_proj}} each as BF16.
|
||||
"""
|
||||
"""Dequantize expert weights from NVFP4 checkpoint → BF16."""
|
||||
experts = {}
|
||||
for e in expert_indices:
|
||||
expert = {}
|
||||
@@ -205,20 +129,10 @@ def dequantize_nvfp4_experts(nvfp4_tensors, layer_idx, expert_indices):
|
||||
return experts
|
||||
|
||||
|
||||
# ── MoE Forward Pass (BF16 reference) ─────────────────────────────────
|
||||
# ── BF16 MoE Forward ───────────────────────────────────────────────────
|
||||
|
||||
def moe_forward_bf16(hidden_states, experts, expert_ids, expert_weights):
|
||||
"""Run MoE forward pass in pure BF16.
|
||||
|
||||
Args:
|
||||
hidden_states: (num_tokens, hidden_size) BF16
|
||||
experts: dict {expert_id: {gate_proj, up_proj, down_proj}} BF16
|
||||
expert_ids: (num_tokens, top_k) int — which expert per token per slot
|
||||
expert_weights: (num_tokens, top_k) float32 — routing weights
|
||||
|
||||
Returns:
|
||||
output: (num_tokens, hidden_size) BF16
|
||||
"""
|
||||
"""Run MoE forward pass in pure BF16 (torch.matmul)."""
|
||||
num_tokens, hidden_size = hidden_states.shape
|
||||
top_k = expert_ids.shape[1]
|
||||
output = torch.zeros(num_tokens, hidden_size, dtype=torch.bfloat16, device=DEVICE)
|
||||
@@ -231,92 +145,66 @@ def moe_forward_bf16(hidden_states, experts, expert_ids, expert_weights):
|
||||
if e not in experts:
|
||||
continue
|
||||
|
||||
x = hidden_states[t] # (hidden_size,)
|
||||
gate = x @ experts[e]["gate_proj"].T # (intermediate//2,)
|
||||
up = x @ experts[e]["up_proj"].T # (intermediate//2,)
|
||||
activated = torch.nn.functional.silu(gate) * up # (intermediate//2,)
|
||||
x = hidden_states[t]
|
||||
gate = x @ experts[e]["gate_proj"].T
|
||||
up = x @ experts[e]["up_proj"].T
|
||||
activated = torch.nn.functional.silu(gate) * up
|
||||
|
||||
if "down_proj" in experts[e]:
|
||||
y = activated @ experts[e]["down_proj"].T # (hidden_size,)
|
||||
y = activated @ experts[e]["down_proj"].T
|
||||
else:
|
||||
y = activated[:hidden_size] # shared expert, no down_proj
|
||||
y = activated[:hidden_size]
|
||||
|
||||
output[t] += w * y
|
||||
|
||||
return output
|
||||
|
||||
|
||||
# ── MoE Forward Pass (NVFP4 kernel) ───────────────────────────────────
|
||||
# ── NVFP4 Kernel MoE Forward ──────────────────────────────────────────
|
||||
|
||||
def moe_forward_nvfp4(hidden_states, nvfp4_tensors, layer_idx, expert_ids, expert_weights):
|
||||
"""Run MoE forward pass using our NVFP4 kernel.
|
||||
|
||||
Loads weights directly from NVFP4 checkpoint (no vLLM), transforms them
|
||||
for CUTLASS, and runs the grouped GEMM.
|
||||
"""
|
||||
"""Run MoE forward pass using our NVFP4 kernel."""
|
||||
from nvfp4_megamoe_kernel import (
|
||||
stage_activation,
|
||||
nvfp4_mega_moe_full,
|
||||
transform_nvfp4_weights_for_mega_moe,
|
||||
SymmBuffer,
|
||||
get_symm_buffer_for_nvfp4_mega_moe,
|
||||
)
|
||||
|
||||
num_tokens, hidden_size = hidden_states.shape
|
||||
top_k = expert_ids.shape[1]
|
||||
|
||||
# Collect the experts we need
|
||||
unique_experts = sorted(set(expert_ids.flatten().tolist()))
|
||||
num_experts = len(unique_experts)
|
||||
expert_map = {e: i for i, e in enumerate(unique_experts)}
|
||||
|
||||
# Load NVFP4 weights for these experts
|
||||
# Shapes: gate_proj.weight = (3072, 3584) uint8, weight_scale = (3072, 448) float8_e4m3fn
|
||||
intermediate_half = 3072 # intermediate_size // 2
|
||||
intermediate_half = 3072
|
||||
hidden_half = hidden_size // 2
|
||||
|
||||
l1_weights = [] # gate + up fused
|
||||
l1_scales = []
|
||||
l1_global_scales = []
|
||||
l2_weights = [] # down
|
||||
l2_scales = []
|
||||
l2_global_scales = []
|
||||
l1_weights, l1_scales, l1_global_scales = [], [], []
|
||||
l2_weights, l2_scales, l2_global_scales = [], [], []
|
||||
|
||||
for e in unique_experts:
|
||||
# L1: gate_proj + up_proj fused
|
||||
gate_w_key = f"layers.{layer_idx}.mlp.experts.{e}.gate_proj.weight"
|
||||
gate_sf_key = f"layers.{layer_idx}.mlp.experts.{e}.gate_proj.weight_scale"
|
||||
gate_gs_key = f"layers.{layer_idx}.mlp.experts.{e}.gate_proj.weight_scale_2"
|
||||
up_w_key = f"layers.{layer_idx}.mlp.experts.{e}.up_proj.weight"
|
||||
up_sf_key = f"layers.{layer_idx}.mlp.experts.{e}.up_proj.weight_scale"
|
||||
up_gs_key = f"layers.{layer_idx}.mlp.experts.{e}.up_proj.weight_scale_2"
|
||||
gate_w = nvfp4_tensors[f"layers.{layer_idx}.mlp.experts.{e}.gate_proj.weight"].view(torch.int8).to(DEVICE)
|
||||
gate_sf = nvfp4_tensors[f"layers.{layer_idx}.mlp.experts.{e}.gate_proj.weight_scale"].to(DEVICE)
|
||||
gate_gs = nvfp4_tensors[f"layers.{layer_idx}.mlp.experts.{e}.gate_proj.weight_scale_2"].item()
|
||||
up_w = nvfp4_tensors[f"layers.{layer_idx}.mlp.experts.{e}.up_proj.weight"].view(torch.int8).to(DEVICE)
|
||||
up_sf = nvfp4_tensors[f"layers.{layer_idx}.mlp.experts.{e}.up_proj.weight_scale"].to(DEVICE)
|
||||
up_gs = nvfp4_tensors[f"layers.{layer_idx}.mlp.experts.{e}.up_proj.weight_scale_2"].item()
|
||||
|
||||
gate_w = nvfp4_tensors[gate_w_key].view(torch.int8).to(DEVICE)
|
||||
gate_sf = nvfp4_tensors[gate_sf_key].to(DEVICE)
|
||||
gate_gs = nvfp4_tensors[gate_gs_key].item()
|
||||
up_w = nvfp4_tensors[up_w_key].view(torch.int8).to(DEVICE)
|
||||
up_sf = nvfp4_tensors[up_sf_key].to(DEVICE)
|
||||
up_gs = nvfp4_tensors[up_gs_key].item()
|
||||
|
||||
# Fuse gate + up: stack along dim 0 → (2*3072, 3584)
|
||||
l1_w = torch.cat([gate_w, up_w], dim=0)
|
||||
l1_sf = torch.cat([gate_sf, up_sf], dim=0)
|
||||
l1_gs = torch.tensor([gate_gs, up_gs], dtype=torch.float32, device=DEVICE)
|
||||
|
||||
l1_weights.append(l1_w)
|
||||
l1_scales.append(l1_sf)
|
||||
l1_global_scales.append(l1_gs)
|
||||
|
||||
# L2: down_proj
|
||||
down_w_key = f"layers.{layer_idx}.mlp.experts.{e}.down_proj.weight"
|
||||
if down_w_key in nvfp4_tensors:
|
||||
down_w = nvfp4_tensors[down_w_key].view(torch.int8).to(DEVICE)
|
||||
down_sf_key = f"layers.{layer_idx}.mlp.experts.{e}.down_proj.weight_scale"
|
||||
down_gs_key = f"layers.{layer_idx}.mlp.experts.{e}.down_proj.weight_scale_2"
|
||||
down_sf = nvfp4_tensors[down_sf_key].to(DEVICE)
|
||||
down_gs = nvfp4_tensors[down_gs_key].item()
|
||||
down_sf = nvfp4_tensors[f"layers.{layer_idx}.mlp.experts.{e}.down_proj.weight_scale"].to(DEVICE)
|
||||
down_gs = nvfp4_tensors[f"layers.{layer_idx}.mlp.experts.{e}.down_proj.weight_scale_2"].item()
|
||||
else:
|
||||
# Expert 211 has no down_proj — use zeros
|
||||
down_w = torch.zeros(hidden_size, intermediate_half, dtype=torch.int8, device=DEVICE)
|
||||
down_sf = torch.ones(hidden_size, intermediate_half // 16, dtype=torch.float8_e4m3fn, device=DEVICE)
|
||||
down_gs = 1.0
|
||||
@@ -325,24 +213,19 @@ def moe_forward_nvfp4(hidden_states, nvfp4_tensors, layer_idx, expert_ids, exper
|
||||
l2_scales.append(down_sf)
|
||||
l2_global_scales.append(torch.tensor([down_gs], dtype=torch.float32, device=DEVICE))
|
||||
|
||||
# Stack into (num_experts, ...) tensors
|
||||
l1_w = torch.stack(l1_weights) # (E, 2*3072, 3584) int8
|
||||
l1_sf = torch.stack(l1_scales) # (E, 2*3072, 448) float8_e4m3fn
|
||||
l1_gs = torch.stack(l1_global_scales) # (E, 2) float32
|
||||
l2_w = torch.stack(l2_weights) # (E, hidden, intermediate_half) int8
|
||||
l2_sf = torch.stack(l2_scales) # (E, hidden, intermediate_half//16) float8_e4m3fn
|
||||
l2_gs = torch.stack(l2_global_scales) # (E, 1) float32
|
||||
l1_w = torch.stack(l1_weights)
|
||||
l1_sf = torch.stack(l1_scales)
|
||||
l1_gs = torch.stack(l1_global_scales)
|
||||
l2_w = torch.stack(l2_weights)
|
||||
l2_sf = torch.stack(l2_scales)
|
||||
l2_gs = torch.stack(l2_global_scales)
|
||||
|
||||
# Transform weights for CUTLASS
|
||||
(l1_w, l1_sf, l1_global_sf), (l2_w, l2_sf, l2_global_sf) = \
|
||||
transform_nvfp4_weights_for_mega_moe(
|
||||
(l1_w, l1_sf),
|
||||
(l2_w, l2_sf),
|
||||
l1_weight_scale_2=l1_gs,
|
||||
l2_weight_scale_2=l2_gs,
|
||||
(l1_w, l1_sf), (l2_w, l2_sf),
|
||||
l1_weight_scale_2=l1_gs, l2_weight_scale_2=l2_gs,
|
||||
)
|
||||
|
||||
# Build slot mapping: each (token, top_k) pair → slot
|
||||
num_slots = num_tokens * top_k
|
||||
slot_expert = torch.zeros(num_slots, dtype=torch.int32, device=DEVICE)
|
||||
slot_token = torch.zeros(num_slots, dtype=torch.int64, device=DEVICE)
|
||||
@@ -351,22 +234,15 @@ def moe_forward_nvfp4(hidden_states, nvfp4_tensors, layer_idx, expert_ids, exper
|
||||
for t in range(num_tokens):
|
||||
for k in range(top_k):
|
||||
slot = t * top_k + k
|
||||
e = expert_ids[t, k].item()
|
||||
slot_expert[slot] = expert_map[e]
|
||||
slot_expert[slot] = expert_map[expert_ids[t, k].item()]
|
||||
slot_token[slot] = t
|
||||
slot_weight[slot] = expert_weights[t, k].item()
|
||||
|
||||
# SymmBuffer
|
||||
symm_buffer = get_symm_buffer_for_nvfp4_mega_moe(
|
||||
group=None, # no EP
|
||||
num_experts=num_experts,
|
||||
max_num_tokens=num_tokens,
|
||||
top_k=top_k,
|
||||
hidden_size=hidden_size,
|
||||
intermediate_size=6144, # 2 * 3072
|
||||
group=None, num_experts=num_experts, max_num_tokens=num_tokens,
|
||||
top_k=top_k, hidden_size=hidden_size, intermediate_size=6144,
|
||||
)
|
||||
|
||||
# Stage activation
|
||||
x_fp4, x_sf, input_global_scale = stage_activation(hidden_states)
|
||||
symm_buffer.x[:num_tokens].copy_(x_fp4)
|
||||
symm_buffer.x_sf[:num_tokens].copy_(x_sf)
|
||||
@@ -375,7 +251,6 @@ def moe_forward_nvfp4(hidden_states, nvfp4_tensors, layer_idx, expert_ids, exper
|
||||
symm_buffer.topk_weights[:num_tokens].copy_(expert_weights)
|
||||
symm_buffer.experts_start_idx = 0
|
||||
|
||||
# Run
|
||||
y = torch.zeros(num_tokens, hidden_size, dtype=torch.bfloat16, device=DEVICE)
|
||||
nvfp4_mega_moe_full(
|
||||
y,
|
||||
@@ -391,145 +266,71 @@ def moe_forward_nvfp4(hidden_states, nvfp4_tensors, layer_idx, expert_ids, exper
|
||||
|
||||
def main():
|
||||
torch.manual_seed(42)
|
||||
|
||||
expert_indices = [0, 1, 2] # Test with 3 experts
|
||||
expert_indices = [0, 1, 2]
|
||||
top_k = 2
|
||||
num_tokens = 4
|
||||
|
||||
# ── Step 1: Load original checkpoint layer 0 ──
|
||||
print("\n" + "="*70)
|
||||
print(" STEP 1: Loading original MXFP4 checkpoint")
|
||||
print("="*70)
|
||||
|
||||
orig_tensors = load_layer_tensors(ORIG_MODEL_DIR, LAYER_IDX)
|
||||
print_layer_keys(orig_tensors, "Original checkpoint (MXFP4)")
|
||||
|
||||
# Dequantize to BF16
|
||||
print("\nDequantizing MXFP4 → BF16...")
|
||||
orig_experts_bf16 = dequantize_mxfp4_experts(orig_tensors, LAYER_IDX, expert_indices)
|
||||
for e in expert_indices:
|
||||
for proj, w in orig_experts_bf16[e].items():
|
||||
print(f" Expert {e} {proj}: shape={tuple(w.shape)} amax={w.abs().max():.4f}")
|
||||
|
||||
# ── Step 2: Run BF16 reference forward pass ──
|
||||
print("\n" + "="*70)
|
||||
print(" STEP 2: BF16 reference forward pass")
|
||||
print("="*70)
|
||||
|
||||
hidden_size = 7168
|
||||
hidden_states = torch.randn(num_tokens, hidden_size, dtype=torch.bfloat16, device=DEVICE) * 2.0
|
||||
|
||||
# Deterministic routing: each token picks experts 0,1
|
||||
expert_ids = torch.tensor([[0, 1]] * num_tokens, dtype=torch.int32, device=DEVICE)
|
||||
expert_weights = torch.tensor([[0.6, 0.4]] * num_tokens, dtype=torch.float32, device=DEVICE)
|
||||
|
||||
ref_output = moe_forward_bf16(hidden_states, orig_experts_bf16, expert_ids, expert_weights)
|
||||
print(f" Reference output: shape={tuple(ref_output.shape)} amax={ref_output.abs().max():.4f} mean={ref_output.float().mean():.6f}")
|
||||
print(f" First token first 10: {ref_output[0, :10].tolist()}")
|
||||
|
||||
del orig_tensors, orig_experts_bf16 # Free memory
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
# ── Step 3: Load NVFP4 checkpoint layer 0 ──
|
||||
print("\n" + "="*70)
|
||||
print(" STEP 3: Loading NVFP4 checkpoint")
|
||||
print("="*70)
|
||||
# ── Load NVFP4 checkpoint ──
|
||||
print("=" * 70)
|
||||
print(" Loading NVFP4 checkpoint layer 0")
|
||||
print("=" * 70)
|
||||
|
||||
nvfp4_tensors = load_layer_tensors(NVFP4_MODEL_DIR, LAYER_IDX)
|
||||
print_layer_keys(nvfp4_tensors, "NVFP4 checkpoint")
|
||||
print_layer_keys(nvfp4_tensors, "NVFP4 checkpoint", max_keys=5)
|
||||
|
||||
# Verify dtype of weight_scale (should be float8_e4m3fn, NOT float8_e8m0fnu)
|
||||
# Verify weight_scale dtype
|
||||
for e in expert_indices[:1]:
|
||||
for proj in ["gate_proj", "up_proj", "down_proj"]:
|
||||
key = f"layers.{LAYER_IDX}.mlp.experts.{e}.{proj}.weight_scale"
|
||||
if key in nvfp4_tensors:
|
||||
dt = nvfp4_tensors[key].dtype
|
||||
print(f" {proj}.weight_scale dtype = {dt} {'✓ E4M3' if dt == torch.float8_e4m3fn else '✗ WRONG (expected float8_e4m3fn)'}")
|
||||
assert dt == torch.float8_e4m3fn, f"{proj}.weight_scale dtype={dt}, expected float8_e4m3fn"
|
||||
print(f" {proj}.weight_scale dtype = {dt} ✓")
|
||||
|
||||
# Dequantize NVFP4 → BF16 (for BF16 reference comparison)
|
||||
print("\nDequantizing NVFP4 → BF16...")
|
||||
# ── Dequantize → BF16 reference ──
|
||||
print("\n Dequantizing NVFP4 → BF16...")
|
||||
nvfp4_experts_bf16 = dequantize_nvfp4_experts(nvfp4_tensors, LAYER_IDX, expert_indices)
|
||||
for e in expert_indices:
|
||||
for e in expert_indices[:2]:
|
||||
for proj, w in nvfp4_experts_bf16[e].items():
|
||||
print(f" Expert {e} {proj}: shape={tuple(w.shape)} amax={w.abs().max():.4f}")
|
||||
print(f" Expert {e} {proj}: shape={tuple(w.shape)} amax={w.abs().max():.4f}")
|
||||
|
||||
# ── Step 4: Compare dequantized weights ──
|
||||
print("\n" + "="*70)
|
||||
print(" STEP 4: Weight comparison (original dequant vs NVFP4 dequant)")
|
||||
print("="*70)
|
||||
# ── Create test input ──
|
||||
hidden_states = torch.randn(num_tokens, hidden_size, dtype=torch.bfloat16, device=DEVICE) * 2.0
|
||||
expert_ids = torch.tensor([[0, 1]] * num_tokens, dtype=torch.int32, device=DEVICE)
|
||||
expert_weights = torch.tensor([[0.6, 0.4]] * num_tokens, dtype=torch.float32, device=DEVICE)
|
||||
|
||||
# Note: the original was MXFP4 (E8M0, block=32) and NVFP4 is (E4M3, block=16)
|
||||
# They were quantized independently so weights will differ — this is expected.
|
||||
# The comparison is to verify the NVFP4 dequant matches its own re-dequant.
|
||||
print(" (MXFP4 and NVFP4 were independently quantized — weight values will differ)")
|
||||
print(" (This is expected. The comparison is: NVFP4 dequant vs NVFP4 kernel)")
|
||||
# ── BF16 reference forward pass ──
|
||||
print("\n Running BF16 reference...")
|
||||
ref_output = moe_forward_bf16(hidden_states, nvfp4_experts_bf16, expert_ids, expert_weights)
|
||||
print(f" BF16 ref: amax={ref_output.abs().max():.4f} mean={ref_output.float().mean():.6f}")
|
||||
print(f" First token first 8: {[f'{v:.4f}' for v in ref_output[0, :8].tolist()]}")
|
||||
|
||||
# ── Step 5: Run NVFP4 BF16 reference (using NVFP4-dequantized weights) ──
|
||||
print("\n" + "="*70)
|
||||
print(" STEP 5: NVFP4 BF16 reference forward pass")
|
||||
print("="*70)
|
||||
del nvfp4_experts_bf16
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
nvfp4_ref_output = moe_forward_bf16(hidden_states, nvfp4_experts_bf16, expert_ids, expert_weights)
|
||||
print(f" NVFP4 BF16 ref: shape={tuple(nvfp4_ref_output.shape)} amax={nvfp4_ref_output.abs().max():.4f} mean={nvfp4_ref_output.float().mean():.6f}")
|
||||
print(f" First token first 10: {nvfp4_ref_output[0, :10].tolist()}")
|
||||
# ── NVFP4 kernel forward pass ──
|
||||
print("\n Running NVFP4 kernel...")
|
||||
kernel_output = moe_forward_nvfp4(hidden_states, nvfp4_tensors, LAYER_IDX, expert_ids, expert_weights)
|
||||
print(f" Kernel: amax={kernel_output.abs().max():.4f} mean={kernel_output.float().mean():.6f}")
|
||||
print(f" First token first 8: {[f'{v:.4f}' for v in kernel_output[0, :8].tolist()]}")
|
||||
|
||||
# Compare against original dequant
|
||||
cos_orig_vs_nvfp4bf16 = torch.nn.functional.cosine_similarity(
|
||||
# ── Compare ──
|
||||
cosine = torch.nn.functional.cosine_similarity(
|
||||
kernel_output.flatten().unsqueeze(0).float(),
|
||||
ref_output.flatten().unsqueeze(0).float(),
|
||||
nvfp4_ref_output.flatten().unsqueeze(0).float(),
|
||||
).item()
|
||||
print(f" Cosine (orig BF16 ref vs NVFP4 BF16 ref): {cos_orig_vs_nvfp4bf16:.6f}")
|
||||
mse = (kernel_output.float() - ref_output.float()).pow(2).mean().item()
|
||||
|
||||
# ── Step 6: Run our NVFP4 kernel ──
|
||||
print("\n" + "="*70)
|
||||
print(" STEP 6: NVFP4 kernel forward pass")
|
||||
print("="*70)
|
||||
print(f"\n{'=' * 70}")
|
||||
print(f" RESULT: cosine={cosine:.6f} MSE={mse:.6e}")
|
||||
print(f"{'=' * 70}")
|
||||
|
||||
try:
|
||||
kernel_output = moe_forward_nvfp4(hidden_states, nvfp4_tensors, LAYER_IDX, expert_ids, expert_weights)
|
||||
print(f" Kernel output: shape={tuple(kernel_output.shape)} amax={kernel_output.abs().max():.4f} mean={kernel_output.float().mean():.6f}")
|
||||
print(f" First token first 10: {kernel_output[0, :10].tolist()}")
|
||||
|
||||
# Compare kernel vs NVFP4 BF16 reference
|
||||
cos_kernel_vs_nvfp4bf16 = torch.nn.functional.cosine_similarity(
|
||||
kernel_output.flatten().unsqueeze(0).float(),
|
||||
nvfp4_ref_output.flatten().unsqueeze(0).float(),
|
||||
).item()
|
||||
mse = (kernel_output.float() - nvfp4_ref_output.float()).pow(2).mean().item()
|
||||
print(f" Cosine (kernel vs NVFP4 BF16 ref): {cos_kernel_vs_nvfp4bf16:.6f}")
|
||||
print(f" MSE (kernel vs NVFP4 BF16 ref): {mse:.6e}")
|
||||
|
||||
# Compare kernel vs original BF16 reference
|
||||
cos_kernel_vs_orig = torch.nn.functional.cosine_similarity(
|
||||
kernel_output.flatten().unsqueeze(0).float(),
|
||||
ref_output.flatten().unsqueeze(0).float(),
|
||||
).item()
|
||||
print(f" Cosine (kernel vs orig BF16 ref): {cos_kernel_vs_orig:.6f}")
|
||||
|
||||
except Exception as e:
|
||||
print(f" KERNEL FAILED: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
|
||||
# ── Summary ──
|
||||
print("\n" + "="*70)
|
||||
print(" SUMMARY")
|
||||
print("="*70)
|
||||
print(f" Original BF16 reference: amax={ref_output.abs().max():.4f} mean={ref_output.float().mean():.6f}")
|
||||
print(f" NVFP4 BF16 reference: amax={nvfp4_ref_output.abs().max():.4f} mean={nvfp4_ref_output.float().mean():.6f}")
|
||||
print(f" Cosine (orig vs NVFP4 BF16): {cos_orig_vs_nvfp4bf16:.6f}")
|
||||
if 'kernel_output' in dir():
|
||||
cos_k = torch.nn.functional.cosine_similarity(
|
||||
kernel_output.flatten().unsqueeze(0).float(),
|
||||
nvfp4_ref_output.flatten().unsqueeze(0).float(),
|
||||
).item()
|
||||
print(f" Cosine (kernel vs NVFP4 BF16): {cos_k:.6f}")
|
||||
if cos_k > 0.99:
|
||||
print(f" ✅ Kernel matches BF16 reference — bug is in vLLM integration")
|
||||
elif cos_k > 0.9:
|
||||
print(f" ⚠️ Kernel is close but not perfect — minor numerical issue")
|
||||
else:
|
||||
print(f" ❌ Kernel is far from BF16 reference — bug is in the kernel or weight pipeline")
|
||||
if cosine < COSINE_THRESHOLD:
|
||||
print(f" ❌ FAIL: cosine {cosine:.6f} < {COSINE_THRESHOLD}")
|
||||
sys.exit(1)
|
||||
else:
|
||||
print(f" ✅ PASS: cosine {cosine:.6f} >= {COSINE_THRESHOLD}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
Reference in New Issue
Block a user