Files
nvfp4-megamoe-kernel/tests/test_shared_expert.py
biondizzle e8b289e30d WIP: CuTeDSL shared expert kernel
Dedicated runner (shared_expert_pipeline.py) and test (test_shared_expert.py).
Tried reusing MoE runner with 1 expert — fails because MoE runner assumes
hidden_size != HC_DIM for scatter. Need dedicated runner with correct
scale assembly. Will continue tomorrow.
2026-05-18 20:02:19 +00:00

159 lines
5.7 KiB
Python

"""Standalone test: Shared expert using CuTeDSL MoE runner with 1 expert.
The shared expert is just "1 expert, no routing, top_k=1".
We reuse the existing CuTeDSLMoERunner with num_experts=1.
Usage: python3 test_shared_expert.py
"""
import torch
import torch.nn.functional as F
import sys, os, json
from safetensors import safe_open
MODEL_PATH = "/root/nvidia-meeting/DeepSeek-V4-Pro-NVFP4"
DEVICE = "cuda:0"
LAYER_IDX = 0
HIDDEN_SIZE = 7168
HC_MULT = 4
HC_DIM = HC_MULT * HIDDEN_SIZE
INTERMEDIATE_SIZE = 3072
SWIGLU_LIMIT = 10.0
NUM_TOKENS = 4
E2M1_LUT = torch.tensor([0., 0.5, 1., 1.5, 2., 3., 4., 6., -0., -0.5, -1., -1.5, -2., -3., -4., -6.],
dtype=torch.float32)
_cache = {}
def load_tensor(key, wm, model_dir):
if key in _cache:
return _cache[key]
shard_path = os.path.join(model_dir, wm[key])
with safe_open(shard_path, framework="pt") as f:
t = f.get_tensor(key)
_cache[key] = t
return t
def dequant_nvfp4(packed_uint8, scale_e4m3, global_scale):
device = packed_uint8.device
lut = E2M1_LUT.to(device)
lower = lut[(packed_uint8 & 0x0F).long()]
upper = lut[((packed_uint8 >> 4) & 0x0F).long()]
out_features = packed_uint8.shape[0]
in_features = packed_uint8.shape[1] * 2
unpacked = torch.empty(out_features, in_features, dtype=torch.float32, device=device)
unpacked[:, 0::2] = lower
unpacked[:, 1::2] = upper
block_scale = scale_e4m3.float()
block_expanded = block_scale.repeat_interleave(16, dim=1)[:out_features, :in_features]
return (unpacked * block_expanded * global_scale).to(torch.bfloat16)
def main():
torch.cuda.set_device(0)
torch.manual_seed(42)
sys.path.insert(0, "/root/nvfp4-megamoe-kernel")
from vllm.nvfp4_cutedsl import CuTeDSLMoERunner
with open(os.path.join(MODEL_PATH, "model.safetensors.index.json")) as f:
wm = json.load(f)["weight_map"]
P = lambda key: load_tensor(key, wm, MODEL_PATH).to(DEVICE)
print("=== Shared Expert Test (CuTeDSL MoE runner, 1 expert) ===\n")
# Load shared expert weights
prefix = f"model.layers.{LAYER_IDX}.mlp.shared_experts"
gate_w = P(f"{prefix}.gate_proj.weight")
gate_sf = P(f"{prefix}.gate_proj.weight_scale")
gate_gs = P(f"{prefix}.gate_proj.weight_scale_2").item()
up_w = P(f"{prefix}.up_proj.weight")
up_sf = P(f"{prefix}.up_proj.weight_scale")
up_gs = P(f"{prefix}.up_proj.weight_scale_2").item()
down_w = P(f"{prefix}.down_proj.weight")
down_sf = P(f"{prefix}.down_proj.weight_scale")
down_gs = P(f"{prefix}.down_proj.weight_scale_2").item()
print(f"gate_proj: shape={gate_w.shape} gs={gate_gs:.8f}")
print(f"up_proj: shape={up_w.shape} gs={up_gs:.8f}")
print(f"down_proj: shape={down_w.shape} gs={down_gs:.8f}")
# Stack gate + up into gate_up_proj (same format as MoE L1)
gate_up_w = torch.cat([gate_w, up_w], dim=0)
gate_up_sf = torch.cat([gate_sf, up_sf], dim=0)
mgs = max(gate_gs, up_gs)
if gate_gs != up_gs:
sf32 = gate_up_sf.float()
sf32[:, :INTERMEDIATE_SIZE] *= (gate_gs / mgs)
sf32[:, INTERMEDIATE_SIZE:] *= (up_gs / mgs)
gate_up_sf = sf32.to(torch.float8_e4m3fn)
# Convert to CuTeDSL format
l1_fp4 = gate_up_w.view(torch.float4_e2m1fn_x2).permute(1, 0).contiguous()
l1_sf = gate_up_sf.permute(1, 0).contiguous()
l2_fp4 = down_w.view(torch.float4_e2m1fn_x2).permute(1, 0).contiguous()
l2_sf = down_sf.permute(1, 0).contiguous()
# Create MoE runner with 1 expert
runner = CuTeDSLMoERunner(
num_experts=1, hidden_size=HC_DIM,
intermediate_size=INTERMEDIATE_SIZE, max_num_tokens=8192,
top_k=1, device=DEVICE,
)
runner.l1_fp4 = [l1_fp4]
runner.l1_sf = [l1_sf]
runner.l1_gs = [mgs]
runner.l2_fp4 = [l2_fp4]
runner.l2_sf = [l2_sf]
runner.l2_gs = [down_gs]
runner.set_swiglu_limit(SWIGLU_LIMIT)
# Warmup
dummy = torch.randn(NUM_TOKENS, HC_DIM, dtype=torch.bfloat16, device=DEVICE) * 2.0
dummy_topk_ids = torch.zeros(NUM_TOKENS, 1, dtype=torch.int64, device=DEVICE)
dummy_topk_weights = torch.ones(NUM_TOKENS, 1, dtype=torch.float32, device=DEVICE)
runner.compute_activation_global_scales(dummy, dummy_topk_weights, dummy_topk_ids)
print(f"Warmup gs: L1={runner._l1_activation_global_scale:.6f} L2={runner._l2_activation_global_scale:.6f}")
# Run CuTeDSL
print("\n--- CuTeDSL Forward ---")
hidden = torch.randn(NUM_TOKENS, HC_DIM, dtype=torch.bfloat16, device=DEVICE) * 2.0
topk_ids = torch.zeros(NUM_TOKENS, 1, dtype=torch.int64, device=DEVICE)
topk_weights = torch.ones(NUM_TOKENS, 1, dtype=torch.float32, device=DEVICE)
with torch.no_grad():
output = runner.run(hidden, topk_weights, topk_ids)
print(f"CuTeDSL output: amax={output.amax():.4f} NaN={torch.isnan(output).any()}")
# BF16 reference
print("\n--- BF16 Reference ---")
gate_bf16 = dequant_nvfp4(gate_w, gate_sf, gate_gs)
up_bf16 = dequant_nvfp4(up_w, up_sf, up_gs)
down_bf16 = dequant_nvfp4(down_w, down_sf, down_gs)
with torch.no_grad():
gate = hidden @ gate_bf16.T
up = hidden @ up_bf16.T
gate_silu = F.silu(gate).clamp(max=SWIGLU_LIMIT)
up = up.clamp(min=-SWIGLU_LIMIT, max=SWIGLU_LIMIT)
intermediate = gate_silu * up
ref_output = intermediate @ down_bf16.T
print(f"BF16 ref: amax={ref_output.amax():.4f}")
# Compare
cos = F.cosine_similarity(ref_output.flatten().unsqueeze(0), output.flatten().unsqueeze(0)).item()
mse = (ref_output - output).pow(2).mean().item()
print(f"\n=== RESULT: cosine={cos:.6f} MSE={mse:.6e} ===")
if cos >= 0.98:
print("✅ PASS")
else:
print("❌ FAIL")
if __name__ == "__main__":
main()