Files
nvfp4-megamoe-kernel/tests/test_warmup_gs.py

199 lines
8.7 KiB
Python

#!/usr/bin/env python3
"""
Test C: Warmup-based gs computation — verify that exact warmup gs values
produce good cosine when used with quantize_activation_nvfp4.
The warmup runs quantize_to_nvfp4 (dynamic gs) on representative input,
captures the exact gs for both L1 and L2, then feeds those values to
quantize_activation_nvfp4 (fixed gs, cudagraph-safe).
Usage (on B200):
source /root/nvfp4-megamoe-kernel/tests/.venv/bin/activate
python3 tests/test_warmup_gs.py
"""
import torch, sys, os, json
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from cutedsl.bridge import (
quantize_to_nvfp4, quantize_activation_nvfp4,
make_b_k_major, assemble_scales_2d_side, assemble_scales_3d_side,
run_nvfp4_grouped_gemm, compute_expert_offsets,
)
from cutedsl.moe_pipeline import run_nvfp4_moe
from vllm.nvfp4_cutedsl import CuTeDSLMoERunner
from safetensors import safe_open
MODEL_DIR = "/root/nvidia-meeting/DeepSeek-V4-Pro-NVFP4"
DEVICE = "cuda"
E2M1_LUT = torch.tensor([0.,0.5,1.,1.5,2.,3.,4.,6.,-0.,-.5,-1.,-1.5,-2.,-3.,-4.,-6.], dtype=torch.float32)
def dequant(w, sf, gs):
dev = w.device
lo = E2M1_LUT.to(dev)[(w & 0xF).long()]
up = E2M1_LUT.to(dev)[((w >> 4) & 0xF).long()]
o = torch.empty(w.shape[0], w.shape[1]*2, dtype=torch.float32, device=dev)
o[:, 0::2] = lo; o[:, 1::2] = up
return (o * sf.float().repeat_interleave(16, dim=1)[:, :o.shape[1]] * gs).to(torch.bfloat16)
def load_tensor(key):
with open(os.path.join(MODEL_DIR, "model.safetensors.index.json")) as f:
wm = json.load(f)["weight_map"]
shard = os.path.join(MODEL_DIR, wm.get(key, ""))
if not os.path.exists(shard): return None
with safe_open(shard, framework="pt") as f:
if key in f.keys(): return f.get_tensor(key).to(DEVICE)
return None
def load_layer0_experts(expert_indices):
l1_fp4, l1_sf, l1_gs = [], [], []
l2_fp4, l2_sf, l2_gs = [], [], []
for e in expert_indices:
gw = load_tensor(f"model.layers.0.mlp.experts.{e}.gate_proj.weight")
uw = load_tensor(f"model.layers.0.mlp.experts.{e}.up_proj.weight")
gsf = load_tensor(f"model.layers.0.mlp.experts.{e}.gate_proj.weight_scale")
usf = load_tensor(f"model.layers.0.mlp.experts.{e}.up_proj.weight_scale")
ggs = load_tensor(f"model.layers.0.mlp.experts.{e}.gate_proj.weight_scale_2").item()
ugs = load_tensor(f"model.layers.0.mlp.experts.{e}.up_proj.weight_scale_2").item()
fw = torch.cat([gw, uw], dim=0)
fw4 = fw.view(torch.float4_e2m1fn_x2).permute(1, 0).contiguous()
fs = torch.cat([gsf, usf], dim=0).permute(1, 0).contiguous()
mgs = max(ggs, ugs)
if ggs != ugs:
f32 = fs.float()
f32[:, :3072] *= (ggs / mgs)
f32[:, 3072:] *= (ugs / mgs)
fs = f32.to(torch.float8_e4m3fn)
l1_fp4.append(fw4); l1_sf.append(fs); l1_gs.append(mgs)
dw = load_tensor(f"model.layers.0.mlp.experts.{e}.down_proj.weight")
dsf = load_tensor(f"model.layers.0.mlp.experts.{e}.down_proj.weight_scale")
dgs = load_tensor(f"model.layers.0.mlp.experts.{e}.down_proj.weight_scale_2").item()
l2_fp4.append(dw.view(torch.float4_e2m1fn_x2).permute(1, 0).contiguous())
l2_sf.append(dsf.permute(1, 0).contiguous()); l2_gs.append(dgs)
return l1_fp4, l1_sf, l1_gs, l2_fp4, l2_sf, l2_gs
def warmup_compute_gs(runner, hidden_states, topk_weights, topk_ids):
"""Run a full forward pass with quantize_to_nvfp4 (dynamic gs)
to capture the exact gs values for L1 and L2."""
device = hidden_states.device
num_tokens = hidden_states.shape[0]
top_k = topk_ids.shape[1]
# Build slot mapping (same as runner.run())
runner._ensure_stacked()
flat_ids = topk_ids.reshape(-1)
num_slots = num_tokens * top_k
token_indices = runner._token_indices[:num_slots]
sort_idx = flat_ids.argsort(stable=True)
sorted_ids = flat_ids[sort_idx]
sorted_token_ids = token_indices[sort_idx]
slot_hidden = hidden_states[sorted_token_ids]
# L1: dynamic gs
_, _, l1_gs = quantize_to_nvfp4(slot_hidden)
# Run L1 GEMM with dynamic gs to get L1 output
x_fp4, x_sf = quantize_activation_nvfp4(slot_hidden, l1_gs)
expert_id_range = runner._expert_id_range
tokens_per_expert = (sorted_ids.unsqueeze(1) == expert_id_range.unsqueeze(0)).sum(dim=0).int()
expert_offsets = runner._expert_offsets_buf
expert_offsets.zero_()
expert_offsets[1:runner.num_experts + 1] = tokens_per_expert.cumsum(0)
l1_scale_a = runner._assemble_scales_cudagraph_safe(
x_sf, expert_offsets[:runner.num_experts + 1],
runner._padded_x_sf_buf_l1, runner._per_expert_scale_bufs_l1
)
l1_gsa = torch.full((runner.num_experts,), l1_gs, dtype=torch.float32, device=device)
l1_out = run_nvfp4_grouped_gemm(
mat_a=x_fp4, mat_b=runner._l1_mat_b,
scale_a=l1_scale_a, scale_b=runner._l1_scale_b,
expert_offsets=expert_offsets[1:runner.num_experts + 1],
global_scale_a=l1_gsa, global_scale_b=runner._l1_gsb,
)
# L2: compute gs from actual L1 output
gate = l1_out[:, :runner.intermediate_size]
up = l1_out[:, runner.intermediate_size:]
activated = torch.nn.functional.silu(gate) * up
_, _, l2_gs = quantize_to_nvfp4(activated)
return l1_gs, l2_gs
def main():
expert_indices = [0, 1, 2]
num_experts = len(expert_indices)
hidden_size = 7168
intermediate_size = 3072
print("Loading weights...")
l1_fp4, l1_sf, l1_gs, l2_fp4, l2_sf, l2_gs = load_layer0_experts(expert_indices)
torch.manual_seed(42)
hidden_states = torch.randn(4, hidden_size, dtype=torch.bfloat16, device=DEVICE) * 2.0
topk_ids = torch.tensor([[0, 1]] * 4, dtype=torch.int32, device=DEVICE)
topk_weights = torch.tensor([[0.6, 0.4]] * 4, dtype=torch.float32, device=DEVICE)
# Pipeline reference
weights = {'l1_fp4': l1_fp4, 'l1_sf': l1_sf, 'l1_gs': l1_gs,
'l2_fp4': l2_fp4, 'l2_sf': l2_sf, 'l2_gs': l2_gs}
ref = run_nvfp4_moe(hidden_states.clone(), topk_ids.clone(), topk_weights.clone(), weights, expert_indices)
print(f"Pipeline: amax={ref.abs().max():.4f}, mean={ref.float().mean():.6f}")
# ── Test 1: Runner with warmup gs (no safety margin) ──
print("\n--- Test 1: Warmup gs, no safety margin ---")
runner = CuTeDSLMoERunner(num_experts, hidden_size, intermediate_size, device=DEVICE)
runner.prepare_weights_direct(
[w.clone() for w in l1_fp4], [w.clone() for w in l1_sf], list(l1_gs),
[w.clone() for w in l2_fp4], [w.clone() for w in l2_sf], list(l2_gs),
)
# Use the runner's built-in warmup method
runner.compute_activation_global_scales(hidden_states.clone(), topk_weights, topk_ids)
result = runner.run(hidden_states.clone(), topk_weights, topk_ids)
cos = torch.nn.functional.cosine_similarity(
result.flatten().unsqueeze(0).float(), ref.flatten().unsqueeze(0).float()
).item()
print(f" Cosine: {cos:.6f}, amax={result.abs().max():.4f}")
# ── Test 2: Runner with warmup gs + safety margins ──
for safety in [1.0, 1.1, 1.2, 1.5, 2.0]:
runner2 = CuTeDSLMoERunner(num_experts, hidden_size, intermediate_size, device=DEVICE)
runner2.prepare_weights_direct(
[w.clone() for w in l1_fp4], [w.clone() for w in l1_sf], list(l1_gs),
[w.clone() for w in l2_fp4], [w.clone() for w in l2_sf], list(l2_gs),
)
runner2._l1_activation_global_scale = l1_gs_val * safety
runner2._l2_activation_global_scale = l2_gs_val * safety
result2 = runner2.run(hidden_states.clone(), topk_weights, topk_ids)
cos2 = torch.nn.functional.cosine_similarity(
result2.flatten().unsqueeze(0).float(), ref.flatten().unsqueeze(0).float()
).item()
print(f" Safety {safety:.1f}x: cosine={cos2:.6f}, amax={result2.abs().max():.4f}")
# ── Test 3: Different input (verify warmup gs generalizes) ──
print("\n--- Test 3: Different input with same warmup gs ---")
torch.manual_seed(99)
hidden_states2 = torch.randn(4, hidden_size, dtype=torch.bfloat16, device=DEVICE) * 2.0
topk_ids2 = torch.tensor([[0, 1]] * 4, dtype=torch.int32, device=DEVICE)
topk_weights2 = torch.tensor([[0.6, 0.4]] * 4, dtype=torch.float32, device=DEVICE)
ref2 = run_nvfp4_moe(hidden_states2.clone(), topk_ids2.clone(), topk_weights2.clone(), weights, expert_indices)
result3 = runner.run(hidden_states2.clone(), topk_weights2, topk_ids2)
cos3 = torch.nn.functional.cosine_similarity(
result3.flatten().unsqueeze(0).float(), ref2.flatten().unsqueeze(0).float()
).item()
print(f" Different input: cosine={cos3:.6f}")
if __name__ == "__main__":
main()