Removed original checkpoint loading (already verified 0.997 cosine). Test now: load NVFP4 → dequant BF16 ref → run kernel → compare. Exits with code 1 if cosine < 0.99.
338 lines
13 KiB
Python
338 lines
13 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
Layer 0 kernel comparison test: NVFP4 kernel vs BF16 reference.
|
|
|
|
No vLLM, no Docker, no tensor parallelism. Just raw weights + our kernel.
|
|
If cosine < 0.99, the test exits with error.
|
|
|
|
Usage:
|
|
python3 layertest.py
|
|
"""
|
|
|
|
import os
|
|
import sys
|
|
import json
|
|
import glob
|
|
import torch
|
|
from safetensors import safe_open
|
|
|
|
# ── Constants ──────────────────────────────────────────────────────────
|
|
|
|
NVFP4_MODEL_DIR = "/root/nvidia-meeting/DeepSeek-V4-Pro-NVFP4"
|
|
LAYER_IDX = 0
|
|
DEVICE = "cuda"
|
|
COSINE_THRESHOLD = 0.99
|
|
|
|
# E2M1 FP4 lookup table
|
|
E2M1_LUT = torch.tensor([
|
|
0.0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0,
|
|
-0.0, -0.5, -1.0, -1.5, -2.0, -3.0, -4.0, -6.0,
|
|
], dtype=torch.float32)
|
|
|
|
# ── Checkpoint loading ─────────────────────────────────────────────────
|
|
|
|
def find_shards(model_dir):
|
|
"""Find all safetensors shards and return {key: shard_path} mapping."""
|
|
index_path = os.path.join(model_dir, "model.safetensors.index.json")
|
|
key_to_shard = {}
|
|
|
|
if os.path.exists(index_path):
|
|
with open(index_path) as f:
|
|
index = json.load(f)
|
|
for key, shard in index["weight_map"].items():
|
|
key_to_shard[key] = os.path.join(model_dir, shard)
|
|
else:
|
|
for sf in glob.glob(os.path.join(model_dir, "*.safetensors")):
|
|
with safe_open(sf, framework="pt") as f:
|
|
for key in f.keys():
|
|
key_to_shard[key] = sf
|
|
|
|
return key_to_shard
|
|
|
|
|
|
def load_layer_tensors(model_dir, layer_idx):
|
|
"""Load all tensors for a specific layer. Keys normalized (no 'model.' prefix)."""
|
|
key_to_shard = find_shards(model_dir)
|
|
layer_prefix = f"layers.{layer_idx}."
|
|
|
|
shard_to_keys = {}
|
|
for key, shard in key_to_shard.items():
|
|
norm_key = key.removeprefix("model.")
|
|
if not norm_key.startswith(layer_prefix):
|
|
continue
|
|
shard_to_keys.setdefault(shard, []).append((key, norm_key))
|
|
|
|
tensors = {}
|
|
for shard, keys in shard_to_keys.items():
|
|
with safe_open(shard, framework="pt") as f:
|
|
for orig_key, norm_key in keys:
|
|
tensors[norm_key] = f.get_tensor(orig_key)
|
|
|
|
return tensors
|
|
|
|
|
|
def print_layer_keys(tensors, label, max_keys=20):
|
|
"""Print sorted tensor keys with shapes and dtypes (first N)."""
|
|
print(f"\n {label} — {len(tensors)} tensors")
|
|
sorted_keys = sorted(tensors.keys())
|
|
for key in sorted_keys[:max_keys]:
|
|
t = tensors[key]
|
|
print(f" {key}: dtype={t.dtype} shape={tuple(t.shape)}")
|
|
if len(sorted_keys) > max_keys:
|
|
print(f" ... ({len(sorted_keys) - max_keys} more)")
|
|
|
|
|
|
# ── NVFP4 Dequantization ──────────────────────────────────────────────
|
|
|
|
def dequantize_nvfp4_weight(packed_uint8, scale_e4m3, global_scale):
|
|
"""Dequantize NVFP4 (E2M1 + E4M3 + global) to BF16."""
|
|
device = packed_uint8.device
|
|
lut = E2M1_LUT.to(device)
|
|
|
|
lower = lut[(packed_uint8 & 0x0F).long()]
|
|
upper = lut[((packed_uint8 >> 4) & 0x0F).long()]
|
|
|
|
out_features = packed_uint8.shape[0]
|
|
in_features = packed_uint8.shape[1] * 2
|
|
|
|
unpacked = torch.empty(out_features, in_features, dtype=torch.float32, device=device)
|
|
unpacked[:, 0::2] = lower
|
|
unpacked[:, 1::2] = upper
|
|
|
|
block_scale = scale_e4m3.float()
|
|
block_expanded = block_scale.repeat_interleave(16, dim=1)[:, :in_features]
|
|
|
|
return (unpacked * block_expanded * global_scale).to(torch.bfloat16)
|
|
|
|
|
|
def dequantize_nvfp4_experts(nvfp4_tensors, layer_idx, expert_indices):
|
|
"""Dequantize expert weights from NVFP4 checkpoint → BF16."""
|
|
experts = {}
|
|
for e in expert_indices:
|
|
expert = {}
|
|
for proj in ["gate_proj", "up_proj", "down_proj"]:
|
|
weight_key = f"layers.{layer_idx}.mlp.experts.{e}.{proj}.weight"
|
|
scale_key = f"layers.{layer_idx}.mlp.experts.{e}.{proj}.weight_scale"
|
|
gs_key = f"layers.{layer_idx}.mlp.experts.{e}.{proj}.weight_scale_2"
|
|
|
|
if weight_key not in nvfp4_tensors:
|
|
if proj == "down_proj" and e == 211:
|
|
continue
|
|
raise KeyError(f"Missing {weight_key}")
|
|
|
|
weight = nvfp4_tensors[weight_key].to(DEVICE)
|
|
scale = nvfp4_tensors[scale_key].to(DEVICE)
|
|
global_scale = nvfp4_tensors[gs_key].item()
|
|
expert[proj] = dequantize_nvfp4_weight(weight, scale, global_scale)
|
|
|
|
experts[e] = expert
|
|
return experts
|
|
|
|
|
|
# ── BF16 MoE Forward ───────────────────────────────────────────────────
|
|
|
|
def moe_forward_bf16(hidden_states, experts, expert_ids, expert_weights):
|
|
"""Run MoE forward pass in pure BF16 (torch.matmul)."""
|
|
num_tokens, hidden_size = hidden_states.shape
|
|
top_k = expert_ids.shape[1]
|
|
output = torch.zeros(num_tokens, hidden_size, dtype=torch.bfloat16, device=DEVICE)
|
|
|
|
for t in range(num_tokens):
|
|
for k in range(top_k):
|
|
e = expert_ids[t, k].item()
|
|
w = expert_weights[t, k].item()
|
|
|
|
if e not in experts:
|
|
continue
|
|
|
|
x = hidden_states[t]
|
|
gate = x @ experts[e]["gate_proj"].T
|
|
up = x @ experts[e]["up_proj"].T
|
|
activated = torch.nn.functional.silu(gate) * up
|
|
|
|
if "down_proj" in experts[e]:
|
|
y = activated @ experts[e]["down_proj"].T
|
|
else:
|
|
y = activated[:hidden_size]
|
|
|
|
output[t] += w * y
|
|
|
|
return output
|
|
|
|
|
|
# ── NVFP4 Kernel MoE Forward ──────────────────────────────────────────
|
|
|
|
def moe_forward_nvfp4(hidden_states, nvfp4_tensors, layer_idx, expert_ids, expert_weights):
|
|
"""Run MoE forward pass using our NVFP4 kernel."""
|
|
from nvfp4_megamoe_kernel import (
|
|
stage_activation,
|
|
nvfp4_mega_moe_full,
|
|
transform_nvfp4_weights_for_mega_moe,
|
|
get_symm_buffer_for_nvfp4_mega_moe,
|
|
)
|
|
|
|
num_tokens, hidden_size = hidden_states.shape
|
|
top_k = expert_ids.shape[1]
|
|
|
|
unique_experts = sorted(set(expert_ids.flatten().tolist()))
|
|
num_experts = len(unique_experts)
|
|
expert_map = {e: i for i, e in enumerate(unique_experts)}
|
|
|
|
intermediate_half = 3072
|
|
hidden_half = hidden_size // 2
|
|
|
|
l1_weights, l1_scales, l1_global_scales = [], [], []
|
|
l2_weights, l2_scales, l2_global_scales = [], [], []
|
|
|
|
for e in unique_experts:
|
|
gate_w = nvfp4_tensors[f"layers.{layer_idx}.mlp.experts.{e}.gate_proj.weight"].view(torch.int8).to(DEVICE)
|
|
gate_sf = nvfp4_tensors[f"layers.{layer_idx}.mlp.experts.{e}.gate_proj.weight_scale"].to(DEVICE)
|
|
gate_gs = nvfp4_tensors[f"layers.{layer_idx}.mlp.experts.{e}.gate_proj.weight_scale_2"].item()
|
|
up_w = nvfp4_tensors[f"layers.{layer_idx}.mlp.experts.{e}.up_proj.weight"].view(torch.int8).to(DEVICE)
|
|
up_sf = nvfp4_tensors[f"layers.{layer_idx}.mlp.experts.{e}.up_proj.weight_scale"].to(DEVICE)
|
|
up_gs = nvfp4_tensors[f"layers.{layer_idx}.mlp.experts.{e}.up_proj.weight_scale_2"].item()
|
|
|
|
l1_w = torch.cat([gate_w, up_w], dim=0)
|
|
l1_sf = torch.cat([gate_sf, up_sf], dim=0)
|
|
l1_gs = torch.tensor([gate_gs, up_gs], dtype=torch.float32, device=DEVICE)
|
|
l1_weights.append(l1_w)
|
|
l1_scales.append(l1_sf)
|
|
l1_global_scales.append(l1_gs)
|
|
|
|
down_w_key = f"layers.{layer_idx}.mlp.experts.{e}.down_proj.weight"
|
|
if down_w_key in nvfp4_tensors:
|
|
down_w = nvfp4_tensors[down_w_key].view(torch.int8).to(DEVICE)
|
|
down_sf = nvfp4_tensors[f"layers.{layer_idx}.mlp.experts.{e}.down_proj.weight_scale"].to(DEVICE)
|
|
down_gs = nvfp4_tensors[f"layers.{layer_idx}.mlp.experts.{e}.down_proj.weight_scale_2"].item()
|
|
else:
|
|
down_w = torch.zeros(hidden_size, intermediate_half, dtype=torch.int8, device=DEVICE)
|
|
down_sf = torch.ones(hidden_size, intermediate_half // 16, dtype=torch.float8_e4m3fn, device=DEVICE)
|
|
down_gs = 1.0
|
|
|
|
l2_weights.append(down_w)
|
|
l2_scales.append(down_sf)
|
|
l2_global_scales.append(torch.tensor([down_gs], dtype=torch.float32, device=DEVICE))
|
|
|
|
l1_w = torch.stack(l1_weights)
|
|
l1_sf = torch.stack(l1_scales)
|
|
l1_gs = torch.stack(l1_global_scales)
|
|
l2_w = torch.stack(l2_weights)
|
|
l2_sf = torch.stack(l2_scales)
|
|
l2_gs = torch.stack(l2_global_scales)
|
|
|
|
(l1_w, l1_sf, l1_global_sf), (l2_w, l2_sf, l2_global_sf) = \
|
|
transform_nvfp4_weights_for_mega_moe(
|
|
(l1_w, l1_sf), (l2_w, l2_sf),
|
|
l1_weight_scale_2=l1_gs, l2_weight_scale_2=l2_gs,
|
|
)
|
|
|
|
num_slots = num_tokens * top_k
|
|
slot_expert = torch.zeros(num_slots, dtype=torch.int32, device=DEVICE)
|
|
slot_token = torch.zeros(num_slots, dtype=torch.int64, device=DEVICE)
|
|
slot_weight = torch.zeros(num_slots, dtype=torch.float32, device=DEVICE)
|
|
|
|
for t in range(num_tokens):
|
|
for k in range(top_k):
|
|
slot = t * top_k + k
|
|
slot_expert[slot] = expert_map[expert_ids[t, k].item()]
|
|
slot_token[slot] = t
|
|
slot_weight[slot] = expert_weights[t, k].item()
|
|
|
|
symm_buffer = get_symm_buffer_for_nvfp4_mega_moe(
|
|
group=None, num_experts=num_experts, max_num_tokens=num_tokens,
|
|
top_k=top_k, hidden_size=hidden_size, intermediate_size=6144,
|
|
)
|
|
|
|
x_fp4, x_sf, input_global_scale = stage_activation(hidden_states)
|
|
symm_buffer.x[:num_tokens].copy_(x_fp4)
|
|
symm_buffer.x_sf[:num_tokens].copy_(x_sf)
|
|
symm_buffer.input_global_scale = input_global_scale
|
|
symm_buffer.topk_idx[:num_tokens].copy_(expert_ids[:, 0:1].expand(-1, top_k))
|
|
symm_buffer.topk_weights[:num_tokens].copy_(expert_weights)
|
|
symm_buffer.experts_start_idx = 0
|
|
|
|
y = torch.zeros(num_tokens, hidden_size, dtype=torch.bfloat16, device=DEVICE)
|
|
nvfp4_mega_moe_full(
|
|
y,
|
|
(l1_w, l1_sf, l1_global_sf),
|
|
(l2_w, l2_sf, l2_global_sf),
|
|
symm_buffer,
|
|
)
|
|
|
|
return y
|
|
|
|
|
|
# ── Main ───────────────────────────────────────────────────────────────
|
|
|
|
def main():
|
|
torch.manual_seed(42)
|
|
expert_indices = [0, 1, 2]
|
|
top_k = 2
|
|
num_tokens = 4
|
|
hidden_size = 7168
|
|
|
|
# ── Load NVFP4 checkpoint ──
|
|
print("=" * 70)
|
|
print(" Loading NVFP4 checkpoint layer 0")
|
|
print("=" * 70)
|
|
|
|
nvfp4_tensors = load_layer_tensors(NVFP4_MODEL_DIR, LAYER_IDX)
|
|
print_layer_keys(nvfp4_tensors, "NVFP4 checkpoint", max_keys=5)
|
|
|
|
# Verify weight_scale dtype
|
|
for e in expert_indices[:1]:
|
|
for proj in ["gate_proj", "up_proj", "down_proj"]:
|
|
key = f"layers.{LAYER_IDX}.mlp.experts.{e}.{proj}.weight_scale"
|
|
if key in nvfp4_tensors:
|
|
dt = nvfp4_tensors[key].dtype
|
|
assert dt == torch.float8_e4m3fn, f"{proj}.weight_scale dtype={dt}, expected float8_e4m3fn"
|
|
print(f" {proj}.weight_scale dtype = {dt} ✓")
|
|
|
|
# ── Dequantize → BF16 reference ──
|
|
print("\n Dequantizing NVFP4 → BF16...")
|
|
nvfp4_experts_bf16 = dequantize_nvfp4_experts(nvfp4_tensors, LAYER_IDX, expert_indices)
|
|
for e in expert_indices[:2]:
|
|
for proj, w in nvfp4_experts_bf16[e].items():
|
|
print(f" Expert {e} {proj}: shape={tuple(w.shape)} amax={w.abs().max():.4f}")
|
|
|
|
# ── Create test input ──
|
|
hidden_states = torch.randn(num_tokens, hidden_size, dtype=torch.bfloat16, device=DEVICE) * 2.0
|
|
expert_ids = torch.tensor([[0, 1]] * num_tokens, dtype=torch.int32, device=DEVICE)
|
|
expert_weights = torch.tensor([[0.6, 0.4]] * num_tokens, dtype=torch.float32, device=DEVICE)
|
|
|
|
# ── BF16 reference forward pass ──
|
|
print("\n Running BF16 reference...")
|
|
ref_output = moe_forward_bf16(hidden_states, nvfp4_experts_bf16, expert_ids, expert_weights)
|
|
print(f" BF16 ref: amax={ref_output.abs().max():.4f} mean={ref_output.float().mean():.6f}")
|
|
print(f" First token first 8: {[f'{v:.4f}' for v in ref_output[0, :8].tolist()]}")
|
|
|
|
del nvfp4_experts_bf16
|
|
torch.cuda.empty_cache()
|
|
|
|
# ── NVFP4 kernel forward pass ──
|
|
print("\n Running NVFP4 kernel...")
|
|
kernel_output = moe_forward_nvfp4(hidden_states, nvfp4_tensors, LAYER_IDX, expert_ids, expert_weights)
|
|
print(f" Kernel: amax={kernel_output.abs().max():.4f} mean={kernel_output.float().mean():.6f}")
|
|
print(f" First token first 8: {[f'{v:.4f}' for v in kernel_output[0, :8].tolist()]}")
|
|
|
|
# ── Compare ──
|
|
cosine = torch.nn.functional.cosine_similarity(
|
|
kernel_output.flatten().unsqueeze(0).float(),
|
|
ref_output.flatten().unsqueeze(0).float(),
|
|
).item()
|
|
mse = (kernel_output.float() - ref_output.float()).pow(2).mean().item()
|
|
|
|
print(f"\n{'=' * 70}")
|
|
print(f" RESULT: cosine={cosine:.6f} MSE={mse:.6e}")
|
|
print(f"{'=' * 70}")
|
|
|
|
if cosine < COSINE_THRESHOLD:
|
|
print(f" ❌ FAIL: cosine {cosine:.6f} < {COSINE_THRESHOLD}")
|
|
sys.exit(1)
|
|
else:
|
|
print(f" ✅ PASS: cosine {cosine:.6f} >= {COSINE_THRESHOLD}")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|