Files
nvfp4-megamoe-kernel/tests/test_multilayer.py
2026-05-17 22:58:27 +00:00

159 lines
8.0 KiB
Python

"""Extended pipeline test: simulate multi-layer MoE to check for error accumulation.
Uses same config as vLLM: max_num_tokens=8192, max_chunks=8, 48 experts."""
import torch
import torch.nn.functional as F
import sys, os, glob
sys.path.insert(0, os.path.join(os.path.dirname(os.path.abspath(__file__)), '..'))
MODEL_PATH = "/root/nvidia-meeting/DeepSeek-V4-Pro-NVFP4"
LAYER_IDX = 0 # Use layer 0 weights for all layers (just testing accumulation)
NUM_EXPERTS = 48
HIDDEN_SIZE = 7168
INTERMEDIATE_SIZE = 3072
NUM_TOKENS = 5 # "The capital of France is"
TOP_K = 6
SWIGLU_LIMIT = 10.0
DEVICE = "cuda"
NUM_LAYERS = 3 # Test error accumulation over multiple layers
def load_layer_tensors(model_dir, layer_idx):
tensors = {}
for sf in glob.glob(os.path.join(model_dir, "*.safetensors")):
from safetensors.torch import load_file
data = load_file(sf)
for k, v in data.items():
if f"layers.{layer_idx}." in k and "mlp.experts" in k:
tensors[k.removeprefix("model.")] = v
return tensors
def dequantize_nvfp4_weight(packed_uint8, scale_e4m3, global_scale):
lut = torch.tensor([0.,0.5,1.,1.5,2.,3.,4.,6.,-0.,-0.5,-1.,-1.5,-2.,-3.,-4.,-6.],
dtype=torch.float32, device=packed_uint8.device)
lower = lut[(packed_uint8 & 0x0F).long()]
upper = lut[((packed_uint8 >> 4) & 0x0F).long()]
N, K = packed_uint8.shape[0], packed_uint8.shape[1] * 2
bf16 = torch.stack([lower, upper], dim=-1).reshape(N, K)
K_sf = scale_e4m3.shape[1]
scale_2d = scale_e4m3.float().repeat_interleave(K // K_sf, dim=1)
return (bf16 * scale_2d * global_scale).to(torch.bfloat16)
def main():
torch.cuda.set_device(0)
torch.manual_seed(42)
print(f"=== Multi-Layer Pipeline Test ({NUM_LAYERS} layers) ===")
nvfp4_tensors = load_layer_tensors(MODEL_PATH, LAYER_IDX)
expert_indices = list(range(NUM_EXPERTS))
# Start with random hidden states (like after embedding + first attention)
hidden = torch.randn(NUM_TOKENS, HIDDEN_SIZE, dtype=torch.bfloat16, device=DEVICE) * 2.0
topk_ids = torch.zeros(NUM_TOKENS, TOP_K, dtype=torch.int64, device=DEVICE)
for i in range(NUM_TOKENS):
topk_ids[i] = torch.randperm(NUM_EXPERTS)[:TOP_K]
topk_weights = torch.ones(NUM_TOKENS, TOP_K, dtype=torch.float32, device=DEVICE) / TOP_K
# Setup runner
from vllm.nvfp4_cutedsl import CuTeDSLMoERunner
from cutedsl.bridge import assemble_scales_3d_side, make_b_k_major
l1_fp4, l1_sf, l1_gs_list = [], [], []
l2_fp4, l2_sf, l2_gs_list = [], [], []
for e in expert_indices:
gw = nvfp4_tensors[f"layers.{LAYER_IDX}.mlp.experts.{e}.gate_proj.weight"].to(DEVICE)
uw = nvfp4_tensors[f"layers.{LAYER_IDX}.mlp.experts.{e}.up_proj.weight"].to(DEVICE)
gsf = nvfp4_tensors[f"layers.{LAYER_IDX}.mlp.experts.{e}.gate_proj.weight_scale"].to(DEVICE)
usf = nvfp4_tensors[f"layers.{LAYER_IDX}.mlp.experts.{e}.up_proj.weight_scale"].to(DEVICE)
ggs = nvfp4_tensors[f"layers.{LAYER_IDX}.mlp.experts.{e}.gate_proj.weight_scale_2"].item()
ugs = nvfp4_tensors[f"layers.{LAYER_IDX}.mlp.experts.{e}.up_proj.weight_scale_2"].item()
fw = torch.cat([gw, uw], dim=0).view(torch.float4_e2m1fn_x2).permute(1,0).contiguous()
fsf = torch.cat([gsf, usf], dim=0).permute(1,0).contiguous()
mgs = max(ggs, ugs)
if ggs != ugs:
sf32 = fsf.float()
sf32[:, :INTERMEDIATE_SIZE] *= (ggs / mgs)
sf32[:, INTERMEDIATE_SIZE:] *= (ugs / mgs)
fsf = sf32.to(torch.float8_e4m3fn)
l1_fp4.append(fw); l1_sf.append(fsf); l1_gs_list.append(mgs)
dk = f"layers.{LAYER_IDX}.mlp.experts.{e}.down_proj.weight"
if dk in nvfp4_tensors:
dw = nvfp4_tensors[dk].to(DEVICE)
dsf = nvfp4_tensors[f"layers.{LAYER_IDX}.mlp.experts.{e}.down_proj.weight_scale"].to(DEVICE)
dgs = nvfp4_tensors[f"layers.{LAYER_IDX}.mlp.experts.{e}.down_proj.weight_scale_2"].item()
l2_fp4.append(dw.view(torch.float4_e2m1fn_x2).permute(1,0).contiguous())
l2_sf.append(dsf.permute(1,0).contiguous()); l2_gs_list.append(dgs)
else:
l2_fp4.append(torch.zeros(INTERMEDIATE_SIZE//2, HIDDEN_SIZE, dtype=torch.float4_e2m1fn_x2, device=DEVICE))
l2_sf.append(torch.ones(INTERMEDIATE_SIZE//16, HIDDEN_SIZE, dtype=torch.float8_e4m3fn, device=DEVICE))
l2_gs_list.append(1.0)
runner = CuTeDSLMoERunner(
num_experts=NUM_EXPERTS, hidden_size=HIDDEN_SIZE,
intermediate_size=INTERMEDIATE_SIZE, max_num_tokens=NUM_TOKENS,
top_k=TOP_K, device=DEVICE,
)
runner.l1_fp4 = l1_fp4; runner.l1_sf = l1_sf; runner.l1_gs = l1_gs_list
runner.l2_fp4 = l2_fp4; runner.l2_sf = l2_sf; runner.l2_gs = l2_gs_list
runner.set_swiglu_limit(SWIGLU_LIMIT)
# Warmup
with torch.no_grad():
runner.compute_activation_global_scales(hidden, topk_weights, topk_ids)
# Run multiple layers (using same weights, but hidden evolves)
ref_hidden = hidden.clone()
run_hidden = hidden.clone()
for layer in range(NUM_LAYERS):
with torch.no_grad():
# Runner
run_hidden_saved = run_hidden.clone()
runner.compute_activation_global_scales(run_hidden, topk_weights, topk_ids)
run_out = runner.run(run_hidden, topk_weights, topk_ids)
run_hidden = run_hidden + run_hidden_saved # Residual connection
# BF16 reference
ref_hidden_saved = ref_hidden.clone()
ref_out = torch.zeros(NUM_TOKENS, HIDDEN_SIZE, dtype=torch.bfloat16, device=DEVICE)
for i, e in enumerate(expert_indices):
dk = f"layers.{LAYER_IDX}.mlp.experts.{e}.down_proj.weight"
gk = f"layers.{LAYER_IDX}.mlp.experts.{e}.gate_proj.weight"
uk = f"layers.{LAYER_IDX}.mlp.experts.{e}.up_proj.weight"
if dk not in nvfp4_tensors:
continue
gate_bf16 = dequantize_nvfp4_weight(nvfp4_tensors[gk].to(DEVICE), nvfp4_tensors[gk.replace('.weight', '.weight_scale')].to(DEVICE), nvfp4_tensors[gk.replace('.weight', '.weight_scale_2')].item())
up_bf16 = dequantize_nvfp4_weight(nvfp4_tensors[uk].to(DEVICE), nvfp4_tensors[uk.replace('.weight', '.weight_scale')].to(DEVICE), nvfp4_tensors[uk.replace('.weight', '.weight_scale_2')].item())
down_bf16 = dequantize_nvfp4_weight(nvfp4_tensors[dk].to(DEVICE), nvfp4_tensors[dk.replace('.weight', '.weight_scale')].to(DEVICE), nvfp4_tensors[dk.replace('.weight', '.weight_scale_2')].item())
for t in range(NUM_TOKENS):
for k in range(TOP_K):
if topk_ids[t, k].item() != i:
continue
w = topk_weights[t, k].item()
x = ref_hidden[t]
gate = x @ gate_bf16.T
up = x @ up_bf16.T
gate_silu = F.silu(gate).clamp(max=SWIGLU_LIMIT)
up = up.clamp(min=-SWIGLU_LIMIT, max=SWIGLU_LIMIT)
act = gate_silu * up
ref_out[t] += w * (act @ down_bf16.T)
ref_hidden = ref_out + ref_hidden_saved # Residual
cos_moe = F.cosine_similarity(ref_out.flatten().unsqueeze(0), run_out.flatten().unsqueeze(0)).item()
cos = F.cosine_similarity(ref_hidden.flatten().unsqueeze(0), run_hidden.flatten().unsqueeze(0)).item()
has_nan = torch.isnan(run_hidden).any().item()
has_inf = torch.isinf(run_hidden).any().item()
moe_scale = run_out.abs().mean().item() / max(ref_out.abs().mean().item(), 1e-8)
print(f"Layer {layer}: MoE_cosine={cos_moe:.6f} MoE_scale={moe_scale:.4f} ref_moe_amax={ref_out.amax().item():.4f} run_moe_amax={run_out.amax().item():.4f} NaN={has_nan}")
if has_nan:
print(f" ❌ NaN detected after layer {layer}! Stopping.")
break
if __name__ == "__main__":
main()