162 lines
7.0 KiB
Python
162 lines
7.0 KiB
Python
"""Full pipeline test: Fixed runner vs BF16 reference."""
|
|
import torch
|
|
import torch.nn.functional as F
|
|
import sys, os, glob
|
|
|
|
sys.path.insert(0, os.path.join(os.path.dirname(os.path.abspath(__file__)), '..'))
|
|
|
|
MODEL_PATH = "/root/nvidia-meeting/DeepSeek-V4-Pro-NVFP4"
|
|
LAYER_IDX = 0
|
|
NUM_EXPERTS = 48
|
|
HIDDEN_SIZE = 7168
|
|
INTERMEDIATE_SIZE = 3072
|
|
NUM_TOKENS = 8
|
|
TOP_K = 6
|
|
SWIGLU_LIMIT = 10.0
|
|
DEVICE = "cuda"
|
|
MAX_NUM_TOKENS = 8192 # match vLLM config
|
|
|
|
|
|
def load_layer_tensors(model_dir, layer_idx):
|
|
tensors = {}
|
|
for sf in glob.glob(os.path.join(model_dir, "*.safetensors")):
|
|
from safetensors.torch import load_file
|
|
data = load_file(sf)
|
|
for k, v in data.items():
|
|
if f"layers.{layer_idx}." in k and "mlp.experts" in k:
|
|
tensors[k.removeprefix("model.")] = v
|
|
return tensors
|
|
|
|
|
|
def dequantize_nvfp4_weight(packed_uint8, scale_e4m3, global_scale):
|
|
lut = torch.tensor([0.,0.5,1.,1.5,2.,3.,4.,6.,-0.,-0.5,-1.,-1.5,-2.,-3.,-4.,-6.],
|
|
dtype=torch.float32, device=packed_uint8.device)
|
|
lower = lut[(packed_uint8 & 0x0F).long()]
|
|
upper = lut[((packed_uint8 >> 4) & 0x0F).long()]
|
|
N, K = packed_uint8.shape[0], packed_uint8.shape[1] * 2
|
|
bf16 = torch.stack([lower, upper], dim=-1).reshape(N, K)
|
|
K_sf = scale_e4m3.shape[1]
|
|
scale_2d = scale_e4m3.float().repeat_interleave(K // K_sf, dim=1)
|
|
return (bf16 * scale_2d * global_scale).to(torch.bfloat16)
|
|
|
|
|
|
def main():
|
|
torch.cuda.set_device(0)
|
|
torch.manual_seed(42)
|
|
|
|
print("=== Full Pipeline Test (Fixed Runner) ===")
|
|
nvfp4_tensors = load_layer_tensors(MODEL_PATH, LAYER_IDX)
|
|
expert_indices = list(range(NUM_EXPERTS))
|
|
|
|
hidden_states = torch.randn(NUM_TOKENS, HIDDEN_SIZE, dtype=torch.bfloat16, device=DEVICE) * 2.0
|
|
topk_ids = torch.zeros(NUM_TOKENS, TOP_K, dtype=torch.int64, device=DEVICE)
|
|
for i in range(NUM_TOKENS):
|
|
topk_ids[i] = torch.randperm(NUM_EXPERTS)[:TOP_K]
|
|
topk_weights = torch.ones(NUM_TOKENS, TOP_K, dtype=torch.float32, device=DEVICE) / TOP_K
|
|
|
|
# BF16 reference
|
|
ref_out = torch.zeros(NUM_TOKENS, HIDDEN_SIZE, dtype=torch.bfloat16, device=DEVICE)
|
|
for i, e in enumerate(expert_indices):
|
|
dk = f"layers.{LAYER_IDX}.mlp.experts.{e}.down_proj.weight"
|
|
gk = f"layers.{LAYER_IDX}.mlp.experts.{e}.gate_proj.weight"
|
|
uk = f"layers.{LAYER_IDX}.mlp.experts.{e}.up_proj.weight"
|
|
if dk not in nvfp4_tensors:
|
|
continue
|
|
gate_bf16 = dequantize_nvfp4_weight(
|
|
nvfp4_tensors[gk].to(DEVICE),
|
|
nvfp4_tensors[gk.replace('.weight', '.weight_scale')].to(DEVICE),
|
|
nvfp4_tensors[gk.replace('.weight', '.weight_scale_2')].item())
|
|
up_bf16 = dequantize_nvfp4_weight(
|
|
nvfp4_tensors[uk].to(DEVICE),
|
|
nvfp4_tensors[uk.replace('.weight', '.weight_scale')].to(DEVICE),
|
|
nvfp4_tensors[uk.replace('.weight', '.weight_scale_2')].item())
|
|
down_bf16 = dequantize_nvfp4_weight(
|
|
nvfp4_tensors[dk].to(DEVICE),
|
|
nvfp4_tensors[dk.replace('.weight', '.weight_scale')].to(DEVICE),
|
|
nvfp4_tensors[dk.replace('.weight', '.weight_scale_2')].item())
|
|
|
|
for t in range(NUM_TOKENS):
|
|
for k in range(TOP_K):
|
|
if topk_ids[t, k].item() != i:
|
|
continue
|
|
w = topk_weights[t, k].item()
|
|
x = hidden_states[t]
|
|
gate = x @ gate_bf16.T
|
|
up = x @ up_bf16.T
|
|
gate_silu = F.silu(gate).clamp(max=SWIGLU_LIMIT)
|
|
up = up.clamp(min=-SWIGLU_LIMIT, max=SWIGLU_LIMIT)
|
|
act = gate_silu * up
|
|
ref_out[t] += w * (act @ down_bf16.T)
|
|
|
|
print(f"BF16 ref: amax={ref_out.amax().item():.4f}")
|
|
|
|
# CuTeDSL runner
|
|
from vllm.nvfp4_cutedsl import CuTeDSLMoERunner
|
|
from cutedsl.bridge import assemble_scales_3d_side, make_b_k_major
|
|
|
|
l1_fp4, l1_sf, l1_gs = [], [], []
|
|
l2_fp4, l2_sf, l2_gs = [], [], []
|
|
for e in expert_indices:
|
|
gw = nvfp4_tensors[f"layers.{LAYER_IDX}.mlp.experts.{e}.gate_proj.weight"].to(DEVICE)
|
|
uw = nvfp4_tensors[f"layers.{LAYER_IDX}.mlp.experts.{e}.up_proj.weight"].to(DEVICE)
|
|
gsf = nvfp4_tensors[f"layers.{LAYER_IDX}.mlp.experts.{e}.gate_proj.weight_scale"].to(DEVICE)
|
|
usf = nvfp4_tensors[f"layers.{LAYER_IDX}.mlp.experts.{e}.up_proj.weight_scale"].to(DEVICE)
|
|
ggs = nvfp4_tensors[f"layers.{LAYER_IDX}.mlp.experts.{e}.gate_proj.weight_scale_2"].item()
|
|
ugs = nvfp4_tensors[f"layers.{LAYER_IDX}.mlp.experts.{e}.up_proj.weight_scale_2"].item()
|
|
fw = torch.cat([gw, uw], dim=0).view(torch.float4_e2m1fn_x2).permute(1,0).contiguous()
|
|
fsf = torch.cat([gsf, usf], dim=0).permute(1,0).contiguous()
|
|
mgs = max(ggs, ugs)
|
|
if ggs != ugs:
|
|
sf32 = fsf.float()
|
|
sf32[:, :INTERMEDIATE_SIZE] *= (ggs / mgs)
|
|
sf32[:, INTERMEDIATE_SIZE:] *= (ugs / mgs)
|
|
fsf = sf32.to(torch.float8_e4m3fn)
|
|
l1_fp4.append(fw); l1_sf.append(fsf); l1_gs.append(mgs)
|
|
|
|
dk = f"layers.{LAYER_IDX}.mlp.experts.{e}.down_proj.weight"
|
|
if dk in nvfp4_tensors:
|
|
dw = nvfp4_tensors[dk].to(DEVICE)
|
|
dsf = nvfp4_tensors[f"layers.{LAYER_IDX}.mlp.experts.{e}.down_proj.weight_scale"].to(DEVICE)
|
|
dgs = nvfp4_tensors[f"layers.{LAYER_IDX}.mlp.experts.{e}.down_proj.weight_scale_2"].item()
|
|
l2_fp4.append(dw.view(torch.float4_e2m1fn_x2).permute(1,0).contiguous())
|
|
l2_sf.append(dsf.permute(1,0).contiguous()); l2_gs.append(dgs)
|
|
else:
|
|
l2_fp4.append(torch.zeros(INTERMEDIATE_SIZE//2, HIDDEN_SIZE, dtype=torch.float4_e2m1fn_x2, device=DEVICE))
|
|
l2_sf.append(torch.ones(INTERMEDIATE_SIZE//16, HIDDEN_SIZE, dtype=torch.float8_e4m3fn, device=DEVICE))
|
|
l2_gs.append(1.0)
|
|
|
|
runner = CuTeDSLMoERunner(
|
|
num_experts=NUM_EXPERTS, hidden_size=HIDDEN_SIZE,
|
|
intermediate_size=INTERMEDIATE_SIZE, max_num_tokens=MAX_NUM_TOKENS,
|
|
top_k=TOP_K, device=DEVICE,
|
|
)
|
|
runner.l1_fp4 = l1_fp4; runner.l1_sf = l1_sf; runner.l1_gs = l1_gs
|
|
runner.l2_fp4 = l2_fp4; runner.l2_sf = l2_sf; runner.l2_gs = l2_gs
|
|
runner.set_swiglu_limit(SWIGLU_LIMIT)
|
|
|
|
with torch.no_grad():
|
|
runner.compute_activation_global_scales(hidden_states, topk_weights, topk_ids)
|
|
runner_out = runner.run(hidden_states, topk_weights, topk_ids)
|
|
|
|
print(f"Runner: amax={runner_out.amax().item():.4f}")
|
|
print(f"NaN: {torch.isnan(runner_out).any().item()}")
|
|
|
|
cos = F.cosine_similarity(ref_out.flatten().unsqueeze(0), runner_out.flatten().unsqueeze(0)).item()
|
|
mse = (ref_out - runner_out).pow(2).mean().item()
|
|
print(f"\nCosine: {cos:.6f} MSE: {mse:.6e}")
|
|
|
|
for t in range(NUM_TOKENS):
|
|
ct = F.cosine_similarity(ref_out[t].unsqueeze(0), runner_out[t].unsqueeze(0)).item()
|
|
print(f" Token {t}: cosine={ct:.4f}")
|
|
|
|
if cos >= 0.98:
|
|
print(f"\n✅ PASS")
|
|
elif cos >= 0.90:
|
|
print(f"\n⚠️ MARGINAL")
|
|
else:
|
|
print(f"\n❌ FAIL")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|