Rewrite pipeline test: compare runner vs reference with real weights, step-by-step
This commit is contained in:
@@ -1,26 +1,17 @@
|
||||
"""Test #2: End-to-end single-layer test with real model weights.
|
||||
"""Pipeline Test: Compare CuTeDSL runner vs reference with real model weights.
|
||||
|
||||
Loads layer 0 from the DeepSeek-V4-Pro-NVFP4 checkpoint, runs one MoE layer
|
||||
through our CuTeDSL runner, and compares against the reference moe_pipeline
|
||||
(which uses the same NVFP4 weights but with dynamic gs).
|
||||
|
||||
This catches issues that the small layertest (3 experts, 8 tokens) misses:
|
||||
- Scale assembly with 48 experts × 8 chunks
|
||||
- Uneven expert assignment
|
||||
- Real activation magnitudes
|
||||
- swiglu_limit clamping
|
||||
- Variable padded expert offsets at scale
|
||||
Loads layer 0 from DeepSeek-V4-Pro-NVFP4, runs both the reference
|
||||
moe_pipeline and our CuTeDSLMoERunner, compares output step by step.
|
||||
"""
|
||||
import torch
|
||||
import sys
|
||||
import os
|
||||
import glob
|
||||
import math
|
||||
|
||||
# Add paths
|
||||
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)) + '/../cutedsl')
|
||||
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)) + '/..')
|
||||
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)) + '/../vllm')
|
||||
|
||||
from cutedsl.moe_pipeline import moe_pipeline
|
||||
from vllm.nvfp4_cutedsl import CuTeDSLMoERunner
|
||||
|
||||
# ============================================================
|
||||
@@ -28,7 +19,7 @@ from vllm.nvfp4_cutedsl import CuTeDSLMoERunner
|
||||
# ============================================================
|
||||
MODEL_PATH = "/root/nvidia-meeting/DeepSeek-V4-Pro-NVFP4"
|
||||
LAYER_IDX = 0
|
||||
NUM_EXPERTS = 48 # local experts per EP rank
|
||||
NUM_EXPERTS = 48
|
||||
HIDDEN_SIZE = 7168
|
||||
INTERMEDIATE_SIZE = 18432
|
||||
NUM_TOKENS = 64
|
||||
@@ -40,10 +31,6 @@ def load_expert_weights(layer_idx, num_experts):
|
||||
"""Load NVFP4 weights for one layer from the checkpoint."""
|
||||
from safetensors import safe_open
|
||||
|
||||
# Find the layer shard file
|
||||
shard_dir = os.path.join(MODEL_PATH, f"model-0000{layer_idx+1:02d}-of-00010.safetensors")
|
||||
# Try to find the right shard
|
||||
import glob
|
||||
shards = sorted(glob.glob(os.path.join(MODEL_PATH, "*.safetensors")))
|
||||
|
||||
l1_fp4 = []
|
||||
@@ -56,19 +43,14 @@ def load_expert_weights(layer_idx, num_experts):
|
||||
for shard_path in shards:
|
||||
with safe_open(shard_path, framework="pt", device="cpu") as f:
|
||||
for e in range(num_experts):
|
||||
global_e = e # For rank 0, local = global
|
||||
w13_key = f"model.layers.{layer_idx}.mlp.experts.{e}.w13_weight"
|
||||
sf13_key = f"model.layers.{layer_idx}.mlp.experts.{e}.w13_weight_scale"
|
||||
gs13_key = f"model.layers.{layer_idx}.mlp.experts.{e}.w13_weight_scale_2"
|
||||
w2_key = f"model.layers.{layer_idx}.mlp.experts.{e}.w2_weight"
|
||||
sf2_key = f"model.layers.{layer_idx}.mlp.experts.{e}.w2_weight_scale"
|
||||
gs2_key = f"model.layers.{layer_idx}.mlp.experts.{e}.w2_weight_scale_2"
|
||||
|
||||
# L1 (gate+up)
|
||||
w13_key = f"model.layers.{layer_idx}.mlp.experts.{global_e}.w13_weight"
|
||||
sf13_key = f"model.layers.{layer_idx}.mlp.experts.{global_e}.w13_weight_scale"
|
||||
gs13_key = f"model.layers.{layer_idx}.mlp.experts.{global_e}.w13_weight_scale_2"
|
||||
|
||||
# L2 (down)
|
||||
w2_key = f"model.layers.{layer_idx}.mlp.experts.{global_e}.w2_weight"
|
||||
sf2_key = f"model.layers.{layer_idx}.mlp.experts.{global_e}.w2_weight_scale"
|
||||
gs2_key = f"model.layers.{layer_idx}.mlp.experts.{global_e}.w2_weight_scale_2"
|
||||
|
||||
if w13_key in f.keys():
|
||||
if w13_key in f.keys() and len(l1_fp4) <= e:
|
||||
l1_fp4.append(f.get_tensor(w13_key).to(DEVICE))
|
||||
l1_sf.append(f.get_tensor(sf13_key).to(DEVICE))
|
||||
l1_gs.append(f.get_tensor(gs13_key).to(DEVICE))
|
||||
@@ -76,11 +58,11 @@ def load_expert_weights(layer_idx, num_experts):
|
||||
l2_sf.append(f.get_tensor(sf2_key).to(DEVICE))
|
||||
l2_gs.append(f.get_tensor(gs2_key).to(DEVICE))
|
||||
|
||||
if len(l1_fp4) == num_experts:
|
||||
if len(l1_fp4) >= num_experts:
|
||||
break
|
||||
|
||||
if len(l1_fp4) != num_experts:
|
||||
raise RuntimeError(f"Only loaded {len(l1_fp4)}/{num_experts} experts from checkpoint")
|
||||
raise RuntimeError(f"Only loaded {len(l1_fp4)}/{num_experts} experts")
|
||||
|
||||
return {
|
||||
'l1_fp4': l1_fp4, 'l1_sf': l1_sf, 'l1_gs': l1_gs,
|
||||
@@ -88,28 +70,133 @@ def load_expert_weights(layer_idx, num_experts):
|
||||
}
|
||||
|
||||
|
||||
def run_reference(hidden_states, topk_weights, topk_ids, weights, swiglu_limit=None):
|
||||
"""Reference MoE: per-expert processing with dynamic gs (quantize_to_nvfp4)."""
|
||||
from cutedsl.quantize import quantize_to_nvfp4
|
||||
from cutedsl.gemm import run_nvfp4_grouped_gemm
|
||||
|
||||
num_tokens = hidden_states.shape[0]
|
||||
top_k = topk_ids.shape[1]
|
||||
num_experts = len(weights['l1_fp4'])
|
||||
|
||||
# Sort tokens by expert
|
||||
flat_ids = topk_ids.reshape(-1).cpu().numpy()
|
||||
flat_weights = topk_weights.reshape(-1)
|
||||
token_indices = torch.arange(num_tokens).unsqueeze(1).expand(-1, top_k).reshape(-1)
|
||||
|
||||
sort_idx = torch.argsort(topk_ids.reshape(-1), stable=True)
|
||||
sorted_ids = topk_ids.reshape(-1)[sort_idx]
|
||||
sorted_weights = topk_weights.reshape(-1)[sort_idx]
|
||||
sorted_token_ids = token_indices[sort_idx]
|
||||
|
||||
# Compute expert offsets
|
||||
expert_id_range = torch.arange(num_experts)
|
||||
tokens_per_expert = (sorted_ids.unsqueeze(1).cpu() == expert_id_range.unsqueeze(0)).sum(dim=0)
|
||||
expert_offsets = torch.zeros(num_experts + 1, dtype=torch.int32)
|
||||
for e in range(num_experts):
|
||||
expert_offsets[e + 1] = expert_offsets[e] + tokens_per_expert[e].item()
|
||||
|
||||
num_slots = num_tokens * top_k
|
||||
slot_hidden = hidden_states[sorted_token_ids]
|
||||
|
||||
# Stack weights for GEMM
|
||||
l1_mat_b, l1_scale_b, l1_gsb = _stack_weights(weights['l1_fp4'], weights['l1_sf'], weights['l1_gs'])
|
||||
l2_mat_b, l2_scale_b, l2_gsb = _stack_weights(weights['l2_fp4'], weights['l2_sf'], weights['l2_gs'])
|
||||
|
||||
# L1 with dynamic gs
|
||||
x_fp4, x_sf, gs_val = quantize_to_nvfp4(slot_hidden)
|
||||
l1_gsa = torch.full((num_experts,), gs_val, dtype=torch.float32, device=DEVICE)
|
||||
|
||||
l1_scale_a = _assemble_scales_ref(x_sf, expert_offsets, num_experts)
|
||||
|
||||
l1_out = run_nvfp4_grouped_gemm(
|
||||
mat_a=x_fp4, mat_b=l1_mat_b,
|
||||
scale_a=l1_scale_a, scale_b=l1_scale_b,
|
||||
expert_offsets=expert_offsets[1:],
|
||||
global_scale_a=l1_gsa, global_scale_b=l1_gsb,
|
||||
)
|
||||
|
||||
# SiLU(gate) * up
|
||||
gate = l1_out[:, :INTERMEDIATE_SIZE]
|
||||
up = l1_out[:, INTERMEDIATE_SIZE:]
|
||||
gate_silu = torch.nn.functional.silu(gate)
|
||||
if swiglu_limit is not None:
|
||||
gate_silu = gate_silu.clamp(max=swiglu_limit)
|
||||
up = up.clamp(min=-swiglu_limit, max=swiglu_limit)
|
||||
activated = gate_silu * up
|
||||
|
||||
# L2 with dynamic gs
|
||||
l2_x_fp4, l2_x_sf, l2_gs_val = quantize_to_nvfp4(activated)
|
||||
l2_gsa = torch.full((num_experts,), l2_gs_val, dtype=torch.float32, device=DEVICE)
|
||||
l2_scale_a = _assemble_scales_ref(l2_x_sf, expert_offsets, num_experts)
|
||||
|
||||
l2_out = run_nvfp4_grouped_gemm(
|
||||
mat_a=l2_x_fp4, mat_b=l2_mat_b,
|
||||
scale_a=l2_scale_a, scale_b=l2_scale_b,
|
||||
expert_offsets=expert_offsets[1:],
|
||||
global_scale_a=l2_gsa, global_scale_b=l2_gsb,
|
||||
)
|
||||
|
||||
# Scatter-add
|
||||
y = torch.zeros(num_tokens, HIDDEN_SIZE, dtype=torch.bfloat16, device=DEVICE)
|
||||
weighted_out = l2_out * sorted_weights.unsqueeze(1).to(l2_out.dtype)
|
||||
y.scatter_add_(0, sorted_token_ids.unsqueeze(1).expand(-1, HIDDEN_SIZE), weighted_out)
|
||||
|
||||
return y
|
||||
|
||||
|
||||
def _stack_weights(fp4_list, sf_list, gs_list):
|
||||
"""Stack expert weights into GEMM format."""
|
||||
mat_b = torch.stack(fp4_list)
|
||||
scale_b = torch.stack(sf_list)
|
||||
gsb = torch.stack(gs_list)
|
||||
return mat_b, scale_b, gsb
|
||||
|
||||
|
||||
def _assemble_scales_ref(x_sf, expert_offsets, num_experts):
|
||||
"""Reference scale assembly using assemble_scales_3d_side from bridge."""
|
||||
from cutedsl.bridge import assemble_scales_3d_side
|
||||
return assemble_scales_3d_side(x_sf, expert_offsets, num_experts)
|
||||
|
||||
|
||||
def main():
|
||||
torch.cuda.set_device(0)
|
||||
torch.manual_seed(42)
|
||||
|
||||
print(f"=== Pipeline Test: Layer {LAYER_IDX}, {NUM_EXPERTS} experts, {NUM_TOKENS} tokens, top_k={TOP_K} ===")
|
||||
print(f" swiglu_limit={SWIGLU_LIMIT}")
|
||||
|
||||
# Load real weights
|
||||
print("Loading weights from checkpoint...")
|
||||
print("\nLoading weights from checkpoint...")
|
||||
weights = load_expert_weights(LAYER_IDX, NUM_EXPERTS)
|
||||
print(f"Loaded {NUM_EXPERTS} experts")
|
||||
for e in range(min(3, NUM_EXPERTS)):
|
||||
print(f" Expert {e}: l1_fp4={weights['l1_fp4'][e].shape} l1_gs={weights['l1_gs'][e].item():.6f} "
|
||||
f"l2_fp4={weights['l2_fp4'][e].shape} l2_gs={weights['l2_gs'][e].item():.6f}")
|
||||
|
||||
# Create runner
|
||||
# Create input
|
||||
hidden_states = torch.randn(NUM_TOKENS, HIDDEN_SIZE, dtype=torch.bfloat16, device=DEVICE)
|
||||
|
||||
# Realistic top-k: uneven distribution
|
||||
topk_ids = torch.zeros(NUM_TOKENS, TOP_K, dtype=torch.int64, device=DEVICE)
|
||||
for i in range(NUM_TOKENS):
|
||||
experts = torch.randperm(NUM_EXPERTS)[:TOP_K]
|
||||
topk_ids[i] = experts
|
||||
topk_weights = torch.ones(NUM_TOKENS, TOP_K, dtype=torch.float32, device=DEVICE) / TOP_K
|
||||
|
||||
# ---- Reference ----
|
||||
print("\n--- Reference (dynamic gs, per-expert scale assembly) ---")
|
||||
with torch.no_grad():
|
||||
ref_out = run_reference(hidden_states, topk_weights, topk_ids, weights, swiglu_limit=SWIGLU_LIMIT)
|
||||
print(f"Reference: amax={ref_out.amax().item():.4f} mean={ref_out.mean().item():.4f}")
|
||||
print(f" NaN: {torch.isnan(ref_out).any().item()} Inf: {torch.isinf(ref_out).any().item()}")
|
||||
|
||||
# ---- Runner ----
|
||||
print("\n--- CuTeDSL Runner (warmup gs, full-buffer swizzle) ---")
|
||||
runner = CuTeDSLMoERunner(
|
||||
num_experts=NUM_EXPERTS,
|
||||
hidden_size=HIDDEN_SIZE,
|
||||
intermediate_size=INTERMEDIATE_SIZE,
|
||||
max_num_tokens=NUM_TOKENS,
|
||||
top_k=TOP_K,
|
||||
device=DEVICE,
|
||||
num_experts=NUM_EXPERTS, hidden_size=HIDDEN_SIZE,
|
||||
intermediate_size=INTERMEDIATE_SIZE, max_num_tokens=NUM_TOKENS,
|
||||
top_k=TOP_K, device=DEVICE,
|
||||
)
|
||||
|
||||
# Set weights
|
||||
runner.l1_fp4 = weights['l1_fp4']
|
||||
runner.l1_sf = weights['l1_sf']
|
||||
runner.l1_gs = weights['l1_gs']
|
||||
@@ -118,73 +205,40 @@ def main():
|
||||
runner.l2_gs = weights['l2_gs']
|
||||
runner.set_swiglu_limit(SWIGLU_LIMIT)
|
||||
|
||||
# Create input
|
||||
hidden_states = torch.randn(NUM_TOKENS, HIDDEN_SIZE, dtype=torch.bfloat16, device=DEVICE)
|
||||
|
||||
# Create top-k assignments (realistic: uneven distribution)
|
||||
topk_ids = torch.zeros(NUM_TOKENS, TOP_K, dtype=torch.int64, device=DEVICE)
|
||||
for i in range(NUM_TOKENS):
|
||||
# Each token picks TOP_K random experts
|
||||
experts = torch.randperm(NUM_EXPERTS)[:TOP_K]
|
||||
topk_ids[i] = experts
|
||||
topk_weights = torch.ones(NUM_TOKENS, TOP_K, dtype=torch.float32, device=DEVICE) / TOP_K
|
||||
|
||||
# ---- Stage 1: Reference pipeline (dynamic gs) ----
|
||||
print("\n--- Reference pipeline (dynamic gs) ---")
|
||||
with torch.no_grad():
|
||||
ref_out = moe_pipeline(
|
||||
hidden_states=hidden_states,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
l1_fp4=weights['l1_fp4'],
|
||||
l1_sf=weights['l1_sf'],
|
||||
l1_gs=weights['l1_gs'],
|
||||
l2_fp4=weights['l2_fp4'],
|
||||
l2_sf=weights['l2_sf'],
|
||||
l2_gs=weights['l2_gs'],
|
||||
num_experts=NUM_EXPERTS,
|
||||
hidden_size=HIDDEN_SIZE,
|
||||
intermediate_size=INTERMEDIATE_SIZE,
|
||||
swiglu_limit=SWIGLU_LIMIT,
|
||||
)
|
||||
print(f"Reference: shape={ref_out.shape} amax={ref_out.amax().item():.4f} mean={ref_out.mean().item():.4f}")
|
||||
print(f" NaN: {torch.isnan(ref_out).any().item()} Inf: {torch.isinf(ref_out).any().item()}")
|
||||
|
||||
# ---- Stage 2: Runner with warmup gs ----
|
||||
print("\n--- Runner (warmup gs) ---")
|
||||
with torch.no_grad():
|
||||
# Compute warmup gs
|
||||
runner.compute_activation_global_scales(hidden_states, topk_weights, topk_ids)
|
||||
print(f"Warmup gs: L1={runner._l1_activation_global_scale:.6f} L2={runner._l2_activation_global_scale:.6f}")
|
||||
|
||||
# Run
|
||||
runner_out = runner.run(hidden_states, topk_weights, topk_ids)
|
||||
print(f"Runner: shape={runner_out.shape} amax={runner_out.amax().item():.4f} mean={runner_out.mean().item():.4f}")
|
||||
print(f"Runner: amax={runner_out.amax().item():.4f} mean={runner_out.mean().item():.4f}")
|
||||
print(f" NaN: {torch.isnan(runner_out).any().item()} Inf: {torch.isinf(runner_out).any().item()}")
|
||||
|
||||
# ---- Comparison ----
|
||||
print("\n--- Comparison ---")
|
||||
# Overall cosine
|
||||
cos = torch.nn.functional.cosine_similarity(
|
||||
ref_out.flatten().unsqueeze(0), runner_out.flatten().unsqueeze(0)
|
||||
).item()
|
||||
mse = (ref_out - runner_out).pow(2).mean().item()
|
||||
print(f"Cosine: {cos:.6f} MSE: {mse:.4f}")
|
||||
|
||||
if cos < 0.90:
|
||||
print("\n⚠️ LOW COSINE — investigating per-token differences...")
|
||||
for i in range(min(NUM_TOKENS, 8)):
|
||||
cos_i = torch.nn.functional.cosine_similarity(
|
||||
ref_out[i].unsqueeze(0), runner_out[i].unsqueeze(0)
|
||||
).item()
|
||||
print(f" Token {i}: cosine={cos_i:.4f} ref_max={ref_out[i].amax().item():.4f} run_max={runner_out[i].amax().item():.4f}")
|
||||
# Per-token
|
||||
low_cos_tokens = 0
|
||||
for i in range(NUM_TOKENS):
|
||||
cos_i = torch.nn.functional.cosine_similarity(
|
||||
ref_out[i].unsqueeze(0), runner_out[i].unsqueeze(0)
|
||||
).item()
|
||||
if cos_i < 0.95:
|
||||
low_cos_tokens += 1
|
||||
if low_cos_tokens <= 5:
|
||||
print(f" Token {i}: cosine={cos_i:.4f} ref_max={ref_out[i].amax().item():.4f} run_max={runner_out[i].amax().item():.4f}")
|
||||
if low_cos_tokens > 5:
|
||||
print(f" ... {low_cos_tokens - 5} more tokens with cosine < 0.95")
|
||||
|
||||
if cos >= 0.98:
|
||||
print(f"\n✅ PASS: cosine {cos:.6f} >= 0.98")
|
||||
elif cos >= 0.90:
|
||||
print(f"\n⚠️ MARGINAL: cosine {cos:.6f} — close but degraded")
|
||||
print(f"\n⚠️ MARGINAL: cosine {cos:.6f}")
|
||||
else:
|
||||
print(f"\n❌ FAIL: cosine {cos:.6f} < 0.90 — significant quality loss")
|
||||
print(f"\n❌ FAIL: cosine {cos:.6f}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
Reference in New Issue
Block a user