532 lines
22 KiB
Python
532 lines
22 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
Layer 0 comparison test: original checkpoint vs NVFP4 checkpoint + our kernel.
|
|
|
|
Loads layer 0 expert weights from both checkpoints, runs the same deterministic
|
|
MoE forward pass, and compares the results. No vLLM, no Docker, no tensor
|
|
parallelism — just raw weights + our GEMM kernel.
|
|
|
|
Usage:
|
|
python3 layertest.py
|
|
"""
|
|
|
|
import os
|
|
import sys
|
|
import json
|
|
import glob
|
|
import torch
|
|
from safetensors import safe_open
|
|
|
|
# ── Constants ──────────────────────────────────────────────────────────
|
|
|
|
ORIG_MODEL_DIR = "/root/nvidia-meeting/DeepSeek-V4-Pro"
|
|
NVFP4_MODEL_DIR = "/root/nvidia-meeting/DeepSeek-V4-Pro-NVFP4"
|
|
LAYER_IDX = 0
|
|
DEVICE = "cuda"
|
|
|
|
# E2M1 FP4 lookup table (shared by both formats)
|
|
E2M1_LUT = torch.tensor([
|
|
0.0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0,
|
|
-0.0, -0.5, -1.0, -1.5, -2.0, -3.0, -4.0, -6.0,
|
|
], dtype=torch.float32)
|
|
|
|
# ── Checkpoint loading ─────────────────────────────────────────────────
|
|
|
|
def find_shards(model_dir):
|
|
"""Find all safetensors shards and return {key: shard_path} mapping."""
|
|
index_path = os.path.join(model_dir, "model.safetensors.index.json")
|
|
key_to_shard = {}
|
|
|
|
if os.path.exists(index_path):
|
|
with open(index_path) as f:
|
|
index = json.load(f)
|
|
for key, shard in index["weight_map"].items():
|
|
key_to_shard[key] = os.path.join(model_dir, shard)
|
|
else:
|
|
# Single shard
|
|
for sf in glob.glob(os.path.join(model_dir, "*.safetensors")):
|
|
with safe_open(sf, framework="pt") as f:
|
|
for key in f.keys():
|
|
key_to_shard[key] = sf
|
|
|
|
return key_to_shard
|
|
|
|
|
|
def load_layer_tensors(model_dir, layer_idx, prefix_filter=None):
|
|
"""Load all tensors for a specific layer from the checkpoint.
|
|
|
|
Returns dict of {key: tensor} for all keys matching the layer.
|
|
"""
|
|
key_to_shard = find_shards(model_dir)
|
|
layer_prefix = f"model.layers.{layer_idx}."
|
|
|
|
# Group by shard to minimize file opens
|
|
shard_to_keys = {}
|
|
for key, shard in key_to_shard.items():
|
|
if not key.startswith(layer_prefix):
|
|
continue
|
|
if prefix_filter and prefix_filter not in key:
|
|
continue
|
|
shard_to_keys.setdefault(shard, []).append(key)
|
|
|
|
tensors = {}
|
|
for shard, keys in shard_to_keys.items():
|
|
with safe_open(shard, framework="pt") as f:
|
|
for key in keys:
|
|
tensors[key] = f.get_tensor(key)
|
|
|
|
return tensors
|
|
|
|
|
|
def print_layer_keys(tensors, label):
|
|
"""Print sorted tensor keys with shapes and dtypes."""
|
|
print(f"\n{'='*70}")
|
|
print(f" {label} — {len(tensors)} tensors")
|
|
print(f"{'='*70}")
|
|
for key in sorted(tensors.keys()):
|
|
t = tensors[key]
|
|
print(f" {key}: dtype={t.dtype} shape={tuple(t.shape)}")
|
|
|
|
|
|
# ── Dequantization: Original checkpoint (MXFP4) ───────────────────────
|
|
|
|
def dequantize_mxfp4_weight(packed_uint8, scale_e8m0):
|
|
"""Dequantize MXFP4 (E2M1 + E8M0, block_size=32) to BF16.
|
|
|
|
Original checkpoint format:
|
|
packed_uint8: (out_features, in_features//2) uint8
|
|
scale_e8m0: (out_features, in_features//32) float8_e8m0fnu
|
|
"""
|
|
device = packed_uint8.device
|
|
lut = E2M1_LUT.to(device)
|
|
|
|
lower = lut[(packed_uint8 & 0x0F).long()]
|
|
upper = lut[((packed_uint8 >> 4) & 0x0F).long()]
|
|
|
|
out_features = packed_uint8.shape[0]
|
|
in_features = packed_uint8.shape[1] * 2
|
|
|
|
unpacked = torch.empty(out_features, in_features, dtype=torch.float32, device=device)
|
|
unpacked[:, 0::2] = lower
|
|
unpacked[:, 1::2] = upper
|
|
|
|
# E8M0 → float32: exponent-only format, represents 2^(x - bias)
|
|
scale_f32 = scale_e8m0.float()
|
|
scale_expanded = scale_f32.repeat_interleave(32, dim=1)[:, :in_features]
|
|
|
|
return (unpacked * scale_expanded).to(torch.bfloat16)
|
|
|
|
|
|
def dequantize_mxfp4_experts(orig_tensors, layer_idx, expert_indices):
|
|
"""Dequantize expert weights from original MXFP4 checkpoint.
|
|
|
|
Returns dict: {expert_id: {gate_proj, up_proj, down_proj}} each as BF16.
|
|
"""
|
|
experts = {}
|
|
for e in expert_indices:
|
|
expert = {}
|
|
for proj in ["gate_proj", "up_proj", "down_proj"]:
|
|
weight_key = f"model.layers.{layer_idx}.mlp.experts.{e}.{proj}.weight"
|
|
scale_key = f"model.layers.{layer_idx}.mlp.experts.{e}.{proj}.scale"
|
|
|
|
if weight_key not in orig_tensors:
|
|
# Expert 211 has no down_proj
|
|
if proj == "down_proj" and e == 211:
|
|
continue
|
|
raise KeyError(f"Missing {weight_key}")
|
|
|
|
weight = orig_tensors[weight_key].to(DEVICE)
|
|
scale = orig_tensors[scale_key].to(DEVICE)
|
|
expert[proj] = dequantize_mxfp4_weight(weight, scale)
|
|
|
|
experts[e] = expert
|
|
return experts
|
|
|
|
|
|
# ── Dequantization: NVFP4 checkpoint ──────────────────────────────────
|
|
|
|
def dequantize_nvfp4_weight(packed_uint8, scale_e4m3, global_scale):
|
|
"""Dequantize NVFP4 (E2M1 + E4M3 block scale + float32 global) to BF16.
|
|
|
|
NVFP4 checkpoint format:
|
|
packed_uint8: (out_features, in_features//2) uint8
|
|
scale_e4m3: (out_features, in_features//16) float8_e4m3fn
|
|
global_scale: float32 scalar
|
|
"""
|
|
device = packed_uint8.device
|
|
lut = E2M1_LUT.to(device)
|
|
|
|
lower = lut[(packed_uint8 & 0x0F).long()]
|
|
upper = lut[((packed_uint8 >> 4) & 0x0F).long()]
|
|
|
|
out_features = packed_uint8.shape[0]
|
|
in_features = packed_uint8.shape[1] * 2
|
|
|
|
unpacked = torch.empty(out_features, in_features, dtype=torch.float32, device=device)
|
|
unpacked[:, 0::2] = lower
|
|
unpacked[:, 1::2] = upper
|
|
|
|
block_scale = scale_e4m3.float() # float8_e4m3fn → float32
|
|
block_expanded = block_scale.repeat_interleave(16, dim=1)[:, :in_features]
|
|
|
|
# Weight dequant = e2m1 * block_scale * global_scale
|
|
return (unpacked * block_expanded * global_scale).to(torch.bfloat16)
|
|
|
|
|
|
def dequantize_nvfp4_experts(nvfp4_tensors, layer_idx, expert_indices):
|
|
"""Dequantize expert weights from NVFP4 checkpoint.
|
|
|
|
Returns dict: {expert_id: {gate_proj, up_proj, down_proj}} each as BF16.
|
|
"""
|
|
experts = {}
|
|
for e in expert_indices:
|
|
expert = {}
|
|
for proj in ["gate_proj", "up_proj", "down_proj"]:
|
|
weight_key = f"model.layers.{layer_idx}.mlp.experts.{e}.{proj}.weight"
|
|
scale_key = f"model.layers.{layer_idx}.mlp.experts.{e}.{proj}.weight_scale"
|
|
gs_key = f"model.layers.{layer_idx}.mlp.experts.{e}.{proj}.weight_scale_2"
|
|
|
|
if weight_key not in nvfp4_tensors:
|
|
if proj == "down_proj" and e == 211:
|
|
continue
|
|
raise KeyError(f"Missing {weight_key}")
|
|
|
|
weight = nvfp4_tensors[weight_key].to(DEVICE)
|
|
scale = nvfp4_tensors[scale_key].to(DEVICE)
|
|
global_scale = nvfp4_tensors[gs_key].item()
|
|
expert[proj] = dequantize_nvfp4_weight(weight, scale, global_scale)
|
|
|
|
experts[e] = expert
|
|
return experts
|
|
|
|
|
|
# ── MoE Forward Pass (BF16 reference) ─────────────────────────────────
|
|
|
|
def moe_forward_bf16(hidden_states, experts, expert_ids, expert_weights):
|
|
"""Run MoE forward pass in pure BF16.
|
|
|
|
Args:
|
|
hidden_states: (num_tokens, hidden_size) BF16
|
|
experts: dict {expert_id: {gate_proj, up_proj, down_proj}} BF16
|
|
expert_ids: (num_tokens, top_k) int — which expert per token per slot
|
|
expert_weights: (num_tokens, top_k) float32 — routing weights
|
|
|
|
Returns:
|
|
output: (num_tokens, hidden_size) BF16
|
|
"""
|
|
num_tokens, hidden_size = hidden_states.shape
|
|
top_k = expert_ids.shape[1]
|
|
output = torch.zeros(num_tokens, hidden_size, dtype=torch.bfloat16, device=DEVICE)
|
|
|
|
for t in range(num_tokens):
|
|
for k in range(top_k):
|
|
e = expert_ids[t, k].item()
|
|
w = expert_weights[t, k].item()
|
|
|
|
if e not in experts:
|
|
continue
|
|
|
|
x = hidden_states[t] # (hidden_size,)
|
|
gate = x @ experts[e]["gate_proj"].T # (intermediate//2,)
|
|
up = x @ experts[e]["up_proj"].T # (intermediate//2,)
|
|
activated = torch.nn.functional.silu(gate) * up # (intermediate//2,)
|
|
|
|
if "down_proj" in experts[e]:
|
|
y = activated @ experts[e]["down_proj"].T # (hidden_size,)
|
|
else:
|
|
y = activated[:hidden_size] # shared expert, no down_proj
|
|
|
|
output[t] += w * y
|
|
|
|
return output
|
|
|
|
|
|
# ── MoE Forward Pass (NVFP4 kernel) ───────────────────────────────────
|
|
|
|
def moe_forward_nvfp4(hidden_states, nvfp4_tensors, layer_idx, expert_ids, expert_weights):
|
|
"""Run MoE forward pass using our NVFP4 kernel.
|
|
|
|
Loads weights directly from NVFP4 checkpoint (no vLLM), transforms them
|
|
for CUTLASS, and runs the grouped GEMM.
|
|
"""
|
|
from nvfp4_megamoe_kernel import (
|
|
stage_activation,
|
|
nvfp4_mega_moe_full,
|
|
transform_nvfp4_weights_for_mega_moe,
|
|
SymmBuffer,
|
|
get_symm_buffer_for_nvfp4_mega_moe,
|
|
)
|
|
|
|
num_tokens, hidden_size = hidden_states.shape
|
|
top_k = expert_ids.shape[1]
|
|
|
|
# Collect the experts we need
|
|
unique_experts = sorted(set(expert_ids.flatten().tolist()))
|
|
num_experts = len(unique_experts)
|
|
expert_map = {e: i for i, e in enumerate(unique_experts)}
|
|
|
|
# Load NVFP4 weights for these experts
|
|
# Shapes: gate_proj.weight = (3072, 3584) uint8, weight_scale = (3072, 448) float8_e4m3fn
|
|
intermediate_half = 3072 # intermediate_size // 2
|
|
hidden_half = hidden_size // 2
|
|
|
|
l1_weights = [] # gate + up fused
|
|
l1_scales = []
|
|
l1_global_scales = []
|
|
l2_weights = [] # down
|
|
l2_scales = []
|
|
l2_global_scales = []
|
|
|
|
for e in unique_experts:
|
|
# L1: gate_proj + up_proj fused
|
|
gate_w_key = f"model.layers.{layer_idx}.mlp.experts.{e}.gate_proj.weight"
|
|
gate_sf_key = f"model.layers.{layer_idx}.mlp.experts.{e}.gate_proj.weight_scale"
|
|
gate_gs_key = f"model.layers.{layer_idx}.mlp.experts.{e}.gate_proj.weight_scale_2"
|
|
up_w_key = f"model.layers.{layer_idx}.mlp.experts.{e}.up_proj.weight"
|
|
up_sf_key = f"model.layers.{layer_idx}.mlp.experts.{e}.up_proj.weight_scale"
|
|
up_gs_key = f"model.layers.{layer_idx}.mlp.experts.{e}.up_proj.weight_scale_2"
|
|
|
|
gate_w = nvfp4_tensors[gate_w_key].view(torch.int8).to(DEVICE)
|
|
gate_sf = nvfp4_tensors[gate_sf_key].to(DEVICE)
|
|
gate_gs = nvfp4_tensors[gate_gs_key].item()
|
|
up_w = nvfp4_tensors[up_w_key].view(torch.int8).to(DEVICE)
|
|
up_sf = nvfp4_tensors[up_sf_key].to(DEVICE)
|
|
up_gs = nvfp4_tensors[up_gs_key].item()
|
|
|
|
# Fuse gate + up: stack along dim 0 → (2*3072, 3584)
|
|
l1_w = torch.cat([gate_w, up_w], dim=0)
|
|
l1_sf = torch.cat([gate_sf, up_sf], dim=0)
|
|
l1_gs = torch.tensor([gate_gs, up_gs], dtype=torch.float32, device=DEVICE)
|
|
|
|
l1_weights.append(l1_w)
|
|
l1_scales.append(l1_sf)
|
|
l1_global_scales.append(l1_gs)
|
|
|
|
# L2: down_proj
|
|
down_w_key = f"model.layers.{layer_idx}.mlp.experts.{e}.down_proj.weight"
|
|
if down_w_key in nvfp4_tensors:
|
|
down_w = nvfp4_tensors[down_w_key].view(torch.int8).to(DEVICE)
|
|
down_sf_key = f"model.layers.{layer_idx}.mlp.experts.{e}.down_proj.weight_scale"
|
|
down_gs_key = f"model.layers.{layer_idx}.mlp.experts.{e}.down_proj.weight_scale_2"
|
|
down_sf = nvfp4_tensors[down_sf_key].to(DEVICE)
|
|
down_gs = nvfp4_tensors[down_gs_key].item()
|
|
else:
|
|
# Expert 211 has no down_proj — use zeros
|
|
down_w = torch.zeros(hidden_size, intermediate_half, dtype=torch.int8, device=DEVICE)
|
|
down_sf = torch.ones(hidden_size, intermediate_half // 16, dtype=torch.float8_e4m3fn, device=DEVICE)
|
|
down_gs = 1.0
|
|
|
|
l2_weights.append(down_w)
|
|
l2_scales.append(down_sf)
|
|
l2_global_scales.append(torch.tensor([down_gs], dtype=torch.float32, device=DEVICE))
|
|
|
|
# Stack into (num_experts, ...) tensors
|
|
l1_w = torch.stack(l1_weights) # (E, 2*3072, 3584) int8
|
|
l1_sf = torch.stack(l1_scales) # (E, 2*3072, 448) float8_e4m3fn
|
|
l1_gs = torch.stack(l1_global_scales) # (E, 2) float32
|
|
l2_w = torch.stack(l2_weights) # (E, hidden, intermediate_half) int8
|
|
l2_sf = torch.stack(l2_scales) # (E, hidden, intermediate_half//16) float8_e4m3fn
|
|
l2_gs = torch.stack(l2_global_scales) # (E, 1) float32
|
|
|
|
# Transform weights for CUTLASS
|
|
(l1_w, l1_sf, l1_global_sf), (l2_w, l2_sf, l2_global_sf) = \
|
|
transform_nvfp4_weights_for_mega_moe(
|
|
(l1_w, l1_sf),
|
|
(l2_w, l2_sf),
|
|
l1_weight_scale_2=l1_gs,
|
|
l2_weight_scale_2=l2_gs,
|
|
)
|
|
|
|
# Build slot mapping: each (token, top_k) pair → slot
|
|
num_slots = num_tokens * top_k
|
|
slot_expert = torch.zeros(num_slots, dtype=torch.int32, device=DEVICE)
|
|
slot_token = torch.zeros(num_slots, dtype=torch.int64, device=DEVICE)
|
|
slot_weight = torch.zeros(num_slots, dtype=torch.float32, device=DEVICE)
|
|
|
|
for t in range(num_tokens):
|
|
for k in range(top_k):
|
|
slot = t * top_k + k
|
|
e = expert_ids[t, k].item()
|
|
slot_expert[slot] = expert_map[e]
|
|
slot_token[slot] = t
|
|
slot_weight[slot] = expert_weights[t, k].item()
|
|
|
|
# SymmBuffer
|
|
symm_buffer = get_symm_buffer_for_nvfp4_mega_moe(
|
|
group=None, # no EP
|
|
num_experts=num_experts,
|
|
max_num_tokens=num_tokens,
|
|
top_k=top_k,
|
|
hidden_size=hidden_size,
|
|
intermediate_size=6144, # 2 * 3072
|
|
)
|
|
|
|
# Stage activation
|
|
x_fp4, x_sf, input_global_scale = stage_activation(hidden_states)
|
|
symm_buffer.x[:num_tokens].copy_(x_fp4)
|
|
symm_buffer.x_sf[:num_tokens].copy_(x_sf)
|
|
symm_buffer.input_global_scale = input_global_scale
|
|
symm_buffer.topk_idx[:num_tokens].copy_(expert_ids[:, 0:1].expand(-1, top_k))
|
|
symm_buffer.topk_weights[:num_tokens].copy_(expert_weights)
|
|
symm_buffer.experts_start_idx = 0
|
|
|
|
# Run
|
|
y = torch.zeros(num_tokens, hidden_size, dtype=torch.bfloat16, device=DEVICE)
|
|
nvfp4_mega_moe_full(
|
|
y,
|
|
(l1_w, l1_sf, l1_global_sf),
|
|
(l2_w, l2_sf, l2_global_sf),
|
|
symm_buffer,
|
|
)
|
|
|
|
return y
|
|
|
|
|
|
# ── Main ───────────────────────────────────────────────────────────────
|
|
|
|
def main():
|
|
torch.manual_seed(42)
|
|
|
|
expert_indices = [0, 1, 2] # Test with 3 experts
|
|
top_k = 2
|
|
num_tokens = 4
|
|
|
|
# ── Step 1: Load original checkpoint layer 0 ──
|
|
print("\n" + "="*70)
|
|
print(" STEP 1: Loading original MXFP4 checkpoint")
|
|
print("="*70)
|
|
|
|
orig_tensors = load_layer_tensors(ORIG_MODEL_DIR, LAYER_IDX)
|
|
print_layer_keys(orig_tensors, "Original checkpoint (MXFP4)")
|
|
|
|
# Dequantize to BF16
|
|
print("\nDequantizing MXFP4 → BF16...")
|
|
orig_experts_bf16 = dequantize_mxfp4_experts(orig_tensors, LAYER_IDX, expert_indices)
|
|
for e in expert_indices:
|
|
for proj, w in orig_experts_bf16[e].items():
|
|
print(f" Expert {e} {proj}: shape={tuple(w.shape)} amax={w.abs().max():.4f}")
|
|
|
|
# ── Step 2: Run BF16 reference forward pass ──
|
|
print("\n" + "="*70)
|
|
print(" STEP 2: BF16 reference forward pass")
|
|
print("="*70)
|
|
|
|
hidden_size = 7168
|
|
hidden_states = torch.randn(num_tokens, hidden_size, dtype=torch.bfloat16, device=DEVICE) * 2.0
|
|
|
|
# Deterministic routing: each token picks experts 0,1
|
|
expert_ids = torch.tensor([[0, 1]] * num_tokens, dtype=torch.int32, device=DEVICE)
|
|
expert_weights = torch.tensor([[0.6, 0.4]] * num_tokens, dtype=torch.float32, device=DEVICE)
|
|
|
|
ref_output = moe_forward_bf16(hidden_states, orig_experts_bf16, expert_ids, expert_weights)
|
|
print(f" Reference output: shape={tuple(ref_output.shape)} amax={ref_output.abs().max():.4f} mean={ref_output.float().mean():.6f}")
|
|
print(f" First token first 10: {ref_output[0, :10].tolist()}")
|
|
|
|
del orig_tensors, orig_experts_bf16 # Free memory
|
|
torch.cuda.empty_cache()
|
|
|
|
# ── Step 3: Load NVFP4 checkpoint layer 0 ──
|
|
print("\n" + "="*70)
|
|
print(" STEP 3: Loading NVFP4 checkpoint")
|
|
print("="*70)
|
|
|
|
nvfp4_tensors = load_layer_tensors(NVFP4_MODEL_DIR, LAYER_IDX)
|
|
print_layer_keys(nvfp4_tensors, "NVFP4 checkpoint")
|
|
|
|
# Verify dtype of weight_scale (should be float8_e4m3fn, NOT float8_e8m0fnu)
|
|
for e in expert_indices[:1]:
|
|
for proj in ["gate_proj", "up_proj", "down_proj"]:
|
|
key = f"model.layers.{LAYER_IDX}.mlp.experts.{e}.{proj}.weight_scale"
|
|
if key in nvfp4_tensors:
|
|
dt = nvfp4_tensors[key].dtype
|
|
print(f" {proj}.weight_scale dtype = {dt} {'✓ E4M3' if dt == torch.float8_e4m3fn else '✗ WRONG (expected float8_e4m3fn)'}")
|
|
|
|
# Dequantize NVFP4 → BF16 (for BF16 reference comparison)
|
|
print("\nDequantizing NVFP4 → BF16...")
|
|
nvfp4_experts_bf16 = dequantize_nvfp4_experts(nvfp4_tensors, LAYER_IDX, expert_indices)
|
|
for e in expert_indices:
|
|
for proj, w in nvfp4_experts_bf16[e].items():
|
|
print(f" Expert {e} {proj}: shape={tuple(w.shape)} amax={w.abs().max():.4f}")
|
|
|
|
# ── Step 4: Compare dequantized weights ──
|
|
print("\n" + "="*70)
|
|
print(" STEP 4: Weight comparison (original dequant vs NVFP4 dequant)")
|
|
print("="*70)
|
|
|
|
# Note: the original was MXFP4 (E8M0, block=32) and NVFP4 is (E4M3, block=16)
|
|
# They were quantized independently so weights will differ — this is expected.
|
|
# The comparison is to verify the NVFP4 dequant matches its own re-dequant.
|
|
print(" (MXFP4 and NVFP4 were independently quantized — weight values will differ)")
|
|
print(" (This is expected. The comparison is: NVFP4 dequant vs NVFP4 kernel)")
|
|
|
|
# ── Step 5: Run NVFP4 BF16 reference (using NVFP4-dequantized weights) ──
|
|
print("\n" + "="*70)
|
|
print(" STEP 5: NVFP4 BF16 reference forward pass")
|
|
print("="*70)
|
|
|
|
nvfp4_ref_output = moe_forward_bf16(hidden_states, nvfp4_experts_bf16, expert_ids, expert_weights)
|
|
print(f" NVFP4 BF16 ref: shape={tuple(nvfp4_ref_output.shape)} amax={nvfp4_ref_output.abs().max():.4f} mean={nvfp4_ref_output.float().mean():.6f}")
|
|
print(f" First token first 10: {nvfp4_ref_output[0, :10].tolist()}")
|
|
|
|
# Compare against original dequant
|
|
cos_orig_vs_nvfp4bf16 = torch.nn.functional.cosine_similarity(
|
|
ref_output.flatten().unsqueeze(0).float(),
|
|
nvfp4_ref_output.flatten().unsqueeze(0).float(),
|
|
).item()
|
|
print(f" Cosine (orig BF16 ref vs NVFP4 BF16 ref): {cos_orig_vs_nvfp4bf16:.6f}")
|
|
|
|
# ── Step 6: Run our NVFP4 kernel ──
|
|
print("\n" + "="*70)
|
|
print(" STEP 6: NVFP4 kernel forward pass")
|
|
print("="*70)
|
|
|
|
try:
|
|
kernel_output = moe_forward_nvfp4(hidden_states, nvfp4_tensors, LAYER_IDX, expert_ids, expert_weights)
|
|
print(f" Kernel output: shape={tuple(kernel_output.shape)} amax={kernel_output.abs().max():.4f} mean={kernel_output.float().mean():.6f}")
|
|
print(f" First token first 10: {kernel_output[0, :10].tolist()}")
|
|
|
|
# Compare kernel vs NVFP4 BF16 reference
|
|
cos_kernel_vs_nvfp4bf16 = torch.nn.functional.cosine_similarity(
|
|
kernel_output.flatten().unsqueeze(0).float(),
|
|
nvfp4_ref_output.flatten().unsqueeze(0).float(),
|
|
).item()
|
|
mse = (kernel_output.float() - nvfp4_ref_output.float()).pow(2).mean().item()
|
|
print(f" Cosine (kernel vs NVFP4 BF16 ref): {cos_kernel_vs_nvfp4bf16:.6f}")
|
|
print(f" MSE (kernel vs NVFP4 BF16 ref): {mse:.6e}")
|
|
|
|
# Compare kernel vs original BF16 reference
|
|
cos_kernel_vs_orig = torch.nn.functional.cosine_similarity(
|
|
kernel_output.flatten().unsqueeze(0).float(),
|
|
ref_output.flatten().unsqueeze(0).float(),
|
|
).item()
|
|
print(f" Cosine (kernel vs orig BF16 ref): {cos_kernel_vs_orig:.6f}")
|
|
|
|
except Exception as e:
|
|
print(f" KERNEL FAILED: {e}")
|
|
import traceback
|
|
traceback.print_exc()
|
|
|
|
# ── Summary ──
|
|
print("\n" + "="*70)
|
|
print(" SUMMARY")
|
|
print("="*70)
|
|
print(f" Original BF16 reference: amax={ref_output.abs().max():.4f} mean={ref_output.float().mean():.6f}")
|
|
print(f" NVFP4 BF16 reference: amax={nvfp4_ref_output.abs().max():.4f} mean={nvfp4_ref_output.float().mean():.6f}")
|
|
print(f" Cosine (orig vs NVFP4 BF16): {cos_orig_vs_nvfp4bf16:.6f}")
|
|
if 'kernel_output' in dir():
|
|
cos_k = torch.nn.functional.cosine_similarity(
|
|
kernel_output.flatten().unsqueeze(0).float(),
|
|
nvfp4_ref_output.flatten().unsqueeze(0).float(),
|
|
).item()
|
|
print(f" Cosine (kernel vs NVFP4 BF16): {cos_k:.6f}")
|
|
if cos_k > 0.99:
|
|
print(f" ✅ Kernel matches BF16 reference — bug is in vLLM integration")
|
|
elif cos_k > 0.9:
|
|
print(f" ⚠️ Kernel is close but not perfect — minor numerical issue")
|
|
else:
|
|
print(f" ❌ Kernel is far from BF16 reference — bug is in the kernel or weight pipeline")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|