- Split bridge.py -> ops/quantize.py, ops/layouts.py, ops/gemm_runner.py - Renamed classes: CuTeDSLNvfp4Linear -> Nvfp4Linear, etc. - Moved kernel code to dsv4/kernels/ (gemm, attention, compressor, decode, cuda) - Moved PyTorch bridges to dsv4/ops/ - Moved nn.Module layers to dsv4layers/ - Moved reference implementations to dsv4/reference/ - Moved vendored CUTLASS code to vendored/ - Archived ~190 debug tests to tests/archive/ - Kept ~15 canonical tests in tests/unit/ - Updated all import paths - Added stubs for future components (model/, cache/, loader/) - Updated pyproject.toml: dsv4-inference package name
164 lines
5.9 KiB
Python
164 lines
5.9 KiB
Python
"""Standalone test: Shared expert using CuTeDSL dedicated runner.
|
|
|
|
Tests the Nvfp4SharedExpert for the shared expert path.
|
|
Compares against BF16 dequantized reference.
|
|
|
|
Usage: python3 test_shared_expert.py
|
|
"""
|
|
|
|
import torch
|
|
import torch.nn.functional as F
|
|
import sys, os, json
|
|
from safetensors import safe_open
|
|
|
|
MODEL_PATH = "/root/nvidia-meeting/DeepSeek-V4-Pro-NVFP4"
|
|
DEVICE = "cuda:0"
|
|
LAYER_IDX = 0
|
|
HIDDEN_SIZE = 7168 # shared expert input dim (from checkpoint weight shapes)
|
|
INTERMEDIATE_SIZE = 3072
|
|
SWIGLU_LIMIT = 10.0
|
|
NUM_TOKENS = 4
|
|
|
|
E2M1_LUT = torch.tensor([0., 0.5, 1., 1.5, 2., 3., 4., 6., -0., -0.5, -1., -1.5, -2., -3., -4., -6.],
|
|
dtype=torch.float32)
|
|
|
|
_cache = {}
|
|
|
|
def load_tensor(key, wm, model_dir):
|
|
if key in _cache:
|
|
return _cache[key]
|
|
shard_path = os.path.join(model_dir, wm[key])
|
|
with safe_open(shard_path, framework="pt") as f:
|
|
t = f.get_tensor(key)
|
|
_cache[key] = t
|
|
return t
|
|
|
|
|
|
def dequant_nvfp4(packed_uint8, scale_e4m3, global_scale):
|
|
"""Dequantize NVFP4 weight to BF16 for reference."""
|
|
device = packed_uint8.device
|
|
lut = E2M1_LUT.to(device)
|
|
lower = lut[(packed_uint8 & 0x0F).long()]
|
|
upper = lut[((packed_uint8 >> 4) & 0x0F).long()]
|
|
out_features = packed_uint8.shape[0]
|
|
in_features = packed_uint8.shape[1] * 2
|
|
unpacked = torch.empty(out_features, in_features, dtype=torch.float32, device=device)
|
|
unpacked[:, 0::2] = lower
|
|
unpacked[:, 1::2] = upper
|
|
block_scale = scale_e4m3.float()
|
|
block_expanded = block_scale.repeat_interleave(16, dim=1)[:out_features, :in_features]
|
|
return (unpacked * block_expanded * global_scale).to(torch.bfloat16)
|
|
|
|
|
|
def main():
|
|
torch.cuda.set_device(0)
|
|
torch.manual_seed(42)
|
|
|
|
sys.path.insert(0, "/root/nvfp4-megamoe-kernel")
|
|
from dsv4.layers.shared_expert import Nvfp4SharedExpert
|
|
|
|
with open(os.path.join(MODEL_PATH, "model.safetensors.index.json")) as f:
|
|
wm = json.load(f)["weight_map"]
|
|
P = lambda key: load_tensor(key, wm, MODEL_PATH).to(DEVICE)
|
|
|
|
print("=== Shared Expert Test (CuTeDSL SharedExpertRunner) ===\n")
|
|
|
|
# Load shared expert weights
|
|
prefix = f"model.layers.{LAYER_IDX}.mlp.shared_experts"
|
|
|
|
gate_w = P(f"{prefix}.gate_proj.weight")
|
|
gate_sf = P(f"{prefix}.gate_proj.weight_scale")
|
|
gate_gs = P(f"{prefix}.gate_proj.weight_scale_2").item()
|
|
up_w = P(f"{prefix}.up_proj.weight")
|
|
up_sf = P(f"{prefix}.up_proj.weight_scale")
|
|
up_gs = P(f"{prefix}.up_proj.weight_scale_2").item()
|
|
down_w = P(f"{prefix}.down_proj.weight")
|
|
down_sf = P(f"{prefix}.down_proj.weight_scale")
|
|
down_gs = P(f"{prefix}.down_proj.weight_scale_2").item()
|
|
|
|
print(f"gate_proj: shape={gate_w.shape} gs={gate_gs:.8f} sf_shape={gate_sf.shape}")
|
|
print(f"up_proj: shape={up_w.shape} gs={up_gs:.8f} sf_shape={up_sf.shape}")
|
|
print(f"down_proj: shape={down_w.shape} gs={down_gs:.8f} sf_shape={down_sf.shape}")
|
|
|
|
# Stack gate + up into gate_up_proj (same format as MoE L1)
|
|
# gate/up weights are (intermediate, hidden) uint8 packed
|
|
gate_up_w = torch.cat([gate_w, up_w], dim=0)
|
|
gate_up_sf = torch.cat([gate_sf, up_sf], dim=0)
|
|
mgs = max(gate_gs, up_gs)
|
|
if gate_gs != up_gs:
|
|
sf32 = gate_up_sf.float()
|
|
sf32[:INTERMEDIATE_SIZE] *= (gate_gs / mgs)
|
|
sf32[INTERMEDIATE_SIZE:] *= (up_gs / mgs)
|
|
gate_up_sf = sf32.to(torch.float8_e4m3fn)
|
|
|
|
# Convert to CuTeDSL format:
|
|
# Checkpoint weights are (out_features, in_features) uint8 packed
|
|
# We need float4_e2m1fn_x2 with (out_features, in_features // 2) after view
|
|
# Then permute to (in_features // 2, out_features) for K-major (K=in_features)
|
|
l1_fp4 = [gate_up_w.view(torch.float4_e2m1fn_x2).permute(1, 0).contiguous()]
|
|
l1_sf = [gate_up_sf.permute(1, 0).contiguous()]
|
|
l2_fp4 = [down_w.view(torch.float4_e2m1fn_x2).permute(1, 0).contiguous()]
|
|
l2_sf = [down_sf.permute(1, 0).contiguous()]
|
|
|
|
# Create runner
|
|
runner = Nvfp4SharedExpert(
|
|
hidden_size=HIDDEN_SIZE,
|
|
intermediate_size=INTERMEDIATE_SIZE,
|
|
max_num_tokens=8192,
|
|
device=DEVICE,
|
|
swiglu_limit=SWIGLU_LIMIT,
|
|
)
|
|
runner.l1_fp4 = l1_fp4
|
|
runner.l1_sf = l1_sf
|
|
runner.l1_gs = [mgs]
|
|
runner.l2_fp4 = l2_fp4
|
|
runner.l2_sf = l2_sf
|
|
runner.l2_gs = [down_gs]
|
|
runner.finalize_weights()
|
|
|
|
# Warmup to compute activation global scales
|
|
dummy = torch.randn(NUM_TOKENS, HIDDEN_SIZE, dtype=torch.bfloat16, device=DEVICE) * 2.0
|
|
runner._ensure_initialized()
|
|
runner.compute_activation_global_scales(dummy)
|
|
print(f"Warmup gs: L1={runner._l1_activation_global_scale:.6f} "
|
|
f"L2={runner._l2_activation_global_scale:.6f}")
|
|
|
|
# Run CuTeDSL
|
|
print("\n--- CuTeDSL Forward ---")
|
|
hidden = torch.randn(NUM_TOKENS, HIDDEN_SIZE, dtype=torch.bfloat16, device=DEVICE) * 2.0
|
|
|
|
with torch.no_grad():
|
|
output = runner.run(hidden)
|
|
print(f"CuTeDSL output: shape={output.shape} amax={output.amax():.4f} "
|
|
f"NaN={torch.isnan(output).any()}")
|
|
|
|
# BF16 reference
|
|
print("\n--- BF16 Reference ---")
|
|
gate_bf16 = dequant_nvfp4(gate_w, gate_sf, gate_gs)
|
|
up_bf16 = dequant_nvfp4(up_w, up_sf, up_gs)
|
|
down_bf16 = dequant_nvfp4(down_w, down_sf, down_gs)
|
|
|
|
with torch.no_grad():
|
|
gate = hidden @ gate_bf16.T
|
|
up = hidden @ up_bf16.T
|
|
gate_silu = F.silu(gate).clamp(max=SWIGLU_LIMIT)
|
|
up = up.clamp(min=-SWIGLU_LIMIT, max=SWIGLU_LIMIT)
|
|
intermediate = gate_silu * up
|
|
ref_output = intermediate @ down_bf16.T
|
|
|
|
print(f"BF16 ref: shape={ref_output.shape} amax={ref_output.amax():.4f}")
|
|
|
|
# Compare
|
|
cos = F.cosine_similarity(ref_output.flatten().unsqueeze(0),
|
|
output.flatten().unsqueeze(0)).item()
|
|
mse = (ref_output - output).pow(2).mean().item()
|
|
print(f"\n=== RESULT: cosine={cos:.6f} MSE={mse:.6e} ===")
|
|
if cos >= 0.98:
|
|
print("✅ PASS")
|
|
else:
|
|
print("❌ FAIL")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|