Files
nvfp4-megamoe-kernel/tests/test_pipeline_real_weights.py

162 lines
7.0 KiB
Python

"""Full pipeline test: Fixed runner vs BF16 reference."""
import torch
import torch.nn.functional as F
import sys, os, glob
sys.path.insert(0, os.path.join(os.path.dirname(os.path.abspath(__file__)), '..'))
MODEL_PATH = "/root/nvidia-meeting/DeepSeek-V4-Pro-NVFP4"
LAYER_IDX = 0
NUM_EXPERTS = 48
HIDDEN_SIZE = 7168
INTERMEDIATE_SIZE = 3072
NUM_TOKENS = 8
TOP_K = 6
SWIGLU_LIMIT = 10.0
DEVICE = "cuda"
MAX_NUM_TOKENS = 8192 # match vLLM config
def load_layer_tensors(model_dir, layer_idx):
tensors = {}
for sf in glob.glob(os.path.join(model_dir, "*.safetensors")):
from safetensors.torch import load_file
data = load_file(sf)
for k, v in data.items():
if f"layers.{layer_idx}." in k and "mlp.experts" in k:
tensors[k.removeprefix("model.")] = v
return tensors
def dequantize_nvfp4_weight(packed_uint8, scale_e4m3, global_scale):
lut = torch.tensor([0.,0.5,1.,1.5,2.,3.,4.,6.,-0.,-0.5,-1.,-1.5,-2.,-3.,-4.,-6.],
dtype=torch.float32, device=packed_uint8.device)
lower = lut[(packed_uint8 & 0x0F).long()]
upper = lut[((packed_uint8 >> 4) & 0x0F).long()]
N, K = packed_uint8.shape[0], packed_uint8.shape[1] * 2
bf16 = torch.stack([lower, upper], dim=-1).reshape(N, K)
K_sf = scale_e4m3.shape[1]
scale_2d = scale_e4m3.float().repeat_interleave(K // K_sf, dim=1)
return (bf16 * scale_2d * global_scale).to(torch.bfloat16)
def main():
torch.cuda.set_device(0)
torch.manual_seed(42)
print("=== Full Pipeline Test (Fixed Runner) ===")
nvfp4_tensors = load_layer_tensors(MODEL_PATH, LAYER_IDX)
expert_indices = list(range(NUM_EXPERTS))
hidden_states = torch.randn(NUM_TOKENS, HIDDEN_SIZE, dtype=torch.bfloat16, device=DEVICE) * 2.0
topk_ids = torch.zeros(NUM_TOKENS, TOP_K, dtype=torch.int64, device=DEVICE)
for i in range(NUM_TOKENS):
topk_ids[i] = torch.randperm(NUM_EXPERTS)[:TOP_K]
topk_weights = torch.ones(NUM_TOKENS, TOP_K, dtype=torch.float32, device=DEVICE) / TOP_K
# BF16 reference
ref_out = torch.zeros(NUM_TOKENS, HIDDEN_SIZE, dtype=torch.bfloat16, device=DEVICE)
for i, e in enumerate(expert_indices):
dk = f"layers.{LAYER_IDX}.mlp.experts.{e}.down_proj.weight"
gk = f"layers.{LAYER_IDX}.mlp.experts.{e}.gate_proj.weight"
uk = f"layers.{LAYER_IDX}.mlp.experts.{e}.up_proj.weight"
if dk not in nvfp4_tensors:
continue
gate_bf16 = dequantize_nvfp4_weight(
nvfp4_tensors[gk].to(DEVICE),
nvfp4_tensors[gk.replace('.weight', '.weight_scale')].to(DEVICE),
nvfp4_tensors[gk.replace('.weight', '.weight_scale_2')].item())
up_bf16 = dequantize_nvfp4_weight(
nvfp4_tensors[uk].to(DEVICE),
nvfp4_tensors[uk.replace('.weight', '.weight_scale')].to(DEVICE),
nvfp4_tensors[uk.replace('.weight', '.weight_scale_2')].item())
down_bf16 = dequantize_nvfp4_weight(
nvfp4_tensors[dk].to(DEVICE),
nvfp4_tensors[dk.replace('.weight', '.weight_scale')].to(DEVICE),
nvfp4_tensors[dk.replace('.weight', '.weight_scale_2')].item())
for t in range(NUM_TOKENS):
for k in range(TOP_K):
if topk_ids[t, k].item() != i:
continue
w = topk_weights[t, k].item()
x = hidden_states[t]
gate = x @ gate_bf16.T
up = x @ up_bf16.T
gate_silu = F.silu(gate).clamp(max=SWIGLU_LIMIT)
up = up.clamp(min=-SWIGLU_LIMIT, max=SWIGLU_LIMIT)
act = gate_silu * up
ref_out[t] += w * (act @ down_bf16.T)
print(f"BF16 ref: amax={ref_out.amax().item():.4f}")
# CuTeDSL runner
from vllm.nvfp4_cutedsl import CuTeDSLMoERunner
from cutedsl.bridge import assemble_scales_3d_side, make_b_k_major
l1_fp4, l1_sf, l1_gs = [], [], []
l2_fp4, l2_sf, l2_gs = [], [], []
for e in expert_indices:
gw = nvfp4_tensors[f"layers.{LAYER_IDX}.mlp.experts.{e}.gate_proj.weight"].to(DEVICE)
uw = nvfp4_tensors[f"layers.{LAYER_IDX}.mlp.experts.{e}.up_proj.weight"].to(DEVICE)
gsf = nvfp4_tensors[f"layers.{LAYER_IDX}.mlp.experts.{e}.gate_proj.weight_scale"].to(DEVICE)
usf = nvfp4_tensors[f"layers.{LAYER_IDX}.mlp.experts.{e}.up_proj.weight_scale"].to(DEVICE)
ggs = nvfp4_tensors[f"layers.{LAYER_IDX}.mlp.experts.{e}.gate_proj.weight_scale_2"].item()
ugs = nvfp4_tensors[f"layers.{LAYER_IDX}.mlp.experts.{e}.up_proj.weight_scale_2"].item()
fw = torch.cat([gw, uw], dim=0).view(torch.float4_e2m1fn_x2).permute(1,0).contiguous()
fsf = torch.cat([gsf, usf], dim=0).permute(1,0).contiguous()
mgs = max(ggs, ugs)
if ggs != ugs:
sf32 = fsf.float()
sf32[:, :INTERMEDIATE_SIZE] *= (ggs / mgs)
sf32[:, INTERMEDIATE_SIZE:] *= (ugs / mgs)
fsf = sf32.to(torch.float8_e4m3fn)
l1_fp4.append(fw); l1_sf.append(fsf); l1_gs.append(mgs)
dk = f"layers.{LAYER_IDX}.mlp.experts.{e}.down_proj.weight"
if dk in nvfp4_tensors:
dw = nvfp4_tensors[dk].to(DEVICE)
dsf = nvfp4_tensors[f"layers.{LAYER_IDX}.mlp.experts.{e}.down_proj.weight_scale"].to(DEVICE)
dgs = nvfp4_tensors[f"layers.{LAYER_IDX}.mlp.experts.{e}.down_proj.weight_scale_2"].item()
l2_fp4.append(dw.view(torch.float4_e2m1fn_x2).permute(1,0).contiguous())
l2_sf.append(dsf.permute(1,0).contiguous()); l2_gs.append(dgs)
else:
l2_fp4.append(torch.zeros(INTERMEDIATE_SIZE//2, HIDDEN_SIZE, dtype=torch.float4_e2m1fn_x2, device=DEVICE))
l2_sf.append(torch.ones(INTERMEDIATE_SIZE//16, HIDDEN_SIZE, dtype=torch.float8_e4m3fn, device=DEVICE))
l2_gs.append(1.0)
runner = CuTeDSLMoERunner(
num_experts=NUM_EXPERTS, hidden_size=HIDDEN_SIZE,
intermediate_size=INTERMEDIATE_SIZE, max_num_tokens=MAX_NUM_TOKENS,
top_k=TOP_K, device=DEVICE,
)
runner.l1_fp4 = l1_fp4; runner.l1_sf = l1_sf; runner.l1_gs = l1_gs
runner.l2_fp4 = l2_fp4; runner.l2_sf = l2_sf; runner.l2_gs = l2_gs
runner.set_swiglu_limit(SWIGLU_LIMIT)
with torch.no_grad():
runner.compute_activation_global_scales(hidden_states, topk_weights, topk_ids)
runner_out = runner.run(hidden_states, topk_weights, topk_ids)
print(f"Runner: amax={runner_out.amax().item():.4f}")
print(f"NaN: {torch.isnan(runner_out).any().item()}")
cos = F.cosine_similarity(ref_out.flatten().unsqueeze(0), runner_out.flatten().unsqueeze(0)).item()
mse = (ref_out - runner_out).pow(2).mean().item()
print(f"\nCosine: {cos:.6f} MSE: {mse:.6e}")
for t in range(NUM_TOKENS):
ct = F.cosine_similarity(ref_out[t].unsqueeze(0), runner_out[t].unsqueeze(0)).item()
print(f" Token {t}: cosine={ct:.4f}")
if cos >= 0.98:
print(f"\n✅ PASS")
elif cos >= 0.90:
print(f"\n⚠️ MARGINAL")
else:
print(f"\n❌ FAIL")
if __name__ == "__main__":
main()