test: warmup gs computation with safety margin sweep
This commit is contained in:
200
tests/test_warmup_gs.py
Normal file
200
tests/test_warmup_gs.py
Normal file
@@ -0,0 +1,200 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Test C: Warmup-based gs computation — verify that exact warmup gs values
|
||||
produce good cosine when used with quantize_activation_nvfp4.
|
||||
|
||||
The warmup runs quantize_to_nvfp4 (dynamic gs) on representative input,
|
||||
captures the exact gs for both L1 and L2, then feeds those values to
|
||||
quantize_activation_nvfp4 (fixed gs, cudagraph-safe).
|
||||
|
||||
Usage (on B200):
|
||||
source /root/nvfp4-megamoe-kernel/tests/.venv/bin/activate
|
||||
python3 tests/test_warmup_gs.py
|
||||
"""
|
||||
import torch, sys, os, json
|
||||
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
|
||||
from cutedsl.bridge import (
|
||||
quantize_to_nvfp4, quantize_activation_nvfp4,
|
||||
make_b_k_major, assemble_scales_2d_side, assemble_scales_3d_side,
|
||||
run_nvfp4_grouped_gemm, compute_expert_offsets,
|
||||
)
|
||||
from cutedsl.moe_pipeline import run_nvfp4_moe
|
||||
from vllm.nvfp4_cutedsl import CuTeDSLMoERunner
|
||||
from safetensors import safe_open
|
||||
|
||||
MODEL_DIR = "/root/nvidia-meeting/DeepSeek-V4-Pro-NVFP4"
|
||||
DEVICE = "cuda"
|
||||
E2M1_LUT = torch.tensor([0.,0.5,1.,1.5,2.,3.,4.,6.,-0.,-.5,-1.,-1.5,-2.,-3.,-4.,-6.], dtype=torch.float32)
|
||||
|
||||
|
||||
def dequant(w, sf, gs):
|
||||
dev = w.device
|
||||
lo = E2M1_LUT.to(dev)[(w & 0xF).long()]
|
||||
up = E2M1_LUT.to(dev)[((w >> 4) & 0xF).long()]
|
||||
o = torch.empty(w.shape[0], w.shape[1]*2, dtype=torch.float32, device=dev)
|
||||
o[:, 0::2] = lo; o[:, 1::2] = up
|
||||
return (o * sf.float().repeat_interleave(16, dim=1)[:, :o.shape[1]] * gs).to(torch.bfloat16)
|
||||
|
||||
|
||||
def load_tensor(key):
|
||||
with open(os.path.join(MODEL_DIR, "model.safetensors.index.json")) as f:
|
||||
wm = json.load(f)["weight_map"]
|
||||
shard = os.path.join(MODEL_DIR, wm.get(key, ""))
|
||||
if not os.path.exists(shard): return None
|
||||
with safe_open(shard, framework="pt") as f:
|
||||
if key in f.keys(): return f.get_tensor(key).to(DEVICE)
|
||||
return None
|
||||
|
||||
|
||||
def load_layer0_experts(expert_indices):
|
||||
l1_fp4, l1_sf, l1_gs = [], [], []
|
||||
l2_fp4, l2_sf, l2_gs = [], [], []
|
||||
for e in expert_indices:
|
||||
gw = load_tensor(f"model.layers.0.mlp.experts.{e}.gate_proj.weight")
|
||||
uw = load_tensor(f"model.layers.0.mlp.experts.{e}.up_proj.weight")
|
||||
gsf = load_tensor(f"model.layers.0.mlp.experts.{e}.gate_proj.weight_scale")
|
||||
usf = load_tensor(f"model.layers.0.mlp.experts.{e}.up_proj.weight_scale")
|
||||
ggs = load_tensor(f"model.layers.0.mlp.experts.{e}.gate_proj.weight_scale_2").item()
|
||||
ugs = load_tensor(f"model.layers.0.mlp.experts.{e}.up_proj.weight_scale_2").item()
|
||||
fw = torch.cat([gw, uw], dim=0)
|
||||
fw4 = fw.view(torch.float4_e2m1fn_x2).permute(1, 0).contiguous()
|
||||
fs = torch.cat([gsf, usf], dim=0).permute(1, 0).contiguous()
|
||||
mgs = max(ggs, ugs)
|
||||
if ggs != ugs:
|
||||
f32 = fs.float()
|
||||
f32[:, :3072] *= (ggs / mgs)
|
||||
f32[:, 3072:] *= (ugs / mgs)
|
||||
fs = f32.to(torch.float8_e4m3fn)
|
||||
l1_fp4.append(fw4); l1_sf.append(fs); l1_gs.append(mgs)
|
||||
dw = load_tensor(f"model.layers.0.mlp.experts.{e}.down_proj.weight")
|
||||
dsf = load_tensor(f"model.layers.0.mlp.experts.{e}.down_proj.weight_scale")
|
||||
dgs = load_tensor(f"model.layers.0.mlp.experts.{e}.down_proj.weight_scale_2").item()
|
||||
l2_fp4.append(dw.view(torch.float4_e2m1fn_x2).permute(1, 0).contiguous())
|
||||
l2_sf.append(dsf.permute(1, 0).contiguous()); l2_gs.append(dgs)
|
||||
return l1_fp4, l1_sf, l1_gs, l2_fp4, l2_sf, l2_gs
|
||||
|
||||
|
||||
def warmup_compute_gs(runner, hidden_states, topk_weights, topk_ids):
|
||||
"""Run a full forward pass with quantize_to_nvfp4 (dynamic gs)
|
||||
to capture the exact gs values for L1 and L2."""
|
||||
device = hidden_states.device
|
||||
num_tokens = hidden_states.shape[0]
|
||||
top_k = topk_ids.shape[1]
|
||||
|
||||
# Build slot mapping (same as runner.run())
|
||||
flat_ids = topk_ids.reshape(-1)
|
||||
num_slots = num_tokens * top_k
|
||||
token_indices = runner._token_indices[:num_slots]
|
||||
sort_idx = flat_ids.argsort(stable=True)
|
||||
sorted_ids = flat_ids[sort_idx]
|
||||
sorted_token_ids = token_indices[sort_idx]
|
||||
slot_hidden = hidden_states[sorted_token_ids]
|
||||
|
||||
# L1: dynamic gs
|
||||
_, _, l1_gs = quantize_to_nvfp4(slot_hidden)
|
||||
|
||||
# Run L1 GEMM with dynamic gs to get L1 output
|
||||
x_fp4, x_sf = quantize_activation_nvfp4(slot_hidden, l1_gs)
|
||||
|
||||
expert_id_range = runner._expert_id_range
|
||||
tokens_per_expert = (sorted_ids.unsqueeze(1) == expert_id_range.unsqueeze(0)).sum(dim=0).int()
|
||||
expert_offsets = runner._expert_offsets_buf
|
||||
expert_offsets.zero_()
|
||||
expert_offsets[1:runner.num_experts + 1] = tokens_per_expert.cumsum(0)
|
||||
|
||||
l1_scale_a = runner._assemble_scales_cudagraph_safe(
|
||||
x_sf, expert_offsets[:runner.num_experts + 1],
|
||||
runner._padded_x_sf_buf_l1, runner._per_expert_scale_bufs_l1
|
||||
)
|
||||
l1_gsa = torch.full((runner.num_experts,), l1_gs, dtype=torch.float32, device=device)
|
||||
|
||||
l1_out = run_nvfp4_grouped_gemm(
|
||||
mat_a=x_fp4, mat_b=runner._l1_mat_b,
|
||||
scale_a=l1_scale_a, scale_b=runner._l1_scale_b,
|
||||
expert_offsets=expert_offsets[1:runner.num_experts + 1],
|
||||
global_scale_a=l1_gsa, global_scale_b=runner._l1_gsb,
|
||||
)
|
||||
|
||||
# L2: compute gs from actual L1 output
|
||||
gate = l1_out[:, :runner.intermediate_size]
|
||||
up = l1_out[:, runner.intermediate_size:]
|
||||
activated = torch.nn.functional.silu(gate) * up
|
||||
_, _, l2_gs = quantize_to_nvfp4(activated)
|
||||
|
||||
return l1_gs, l2_gs
|
||||
|
||||
|
||||
def main():
|
||||
expert_indices = [0, 1, 2]
|
||||
num_experts = len(expert_indices)
|
||||
hidden_size = 7168
|
||||
intermediate_size = 3072
|
||||
|
||||
print("Loading weights...")
|
||||
l1_fp4, l1_sf, l1_gs, l2_fp4, l2_sf, l2_gs = load_layer0_experts(expert_indices)
|
||||
|
||||
torch.manual_seed(42)
|
||||
hidden_states = torch.randn(4, hidden_size, dtype=torch.bfloat16, device=DEVICE) * 2.0
|
||||
topk_ids = torch.tensor([[0, 1]] * 4, dtype=torch.int32, device=DEVICE)
|
||||
topk_weights = torch.tensor([[0.6, 0.4]] * 4, dtype=torch.float32, device=DEVICE)
|
||||
|
||||
# Pipeline reference
|
||||
weights = {'l1_fp4': l1_fp4, 'l1_sf': l1_sf, 'l1_gs': l1_gs,
|
||||
'l2_fp4': l2_fp4, 'l2_sf': l2_sf, 'l2_gs': l2_gs}
|
||||
ref = run_nvfp4_moe(hidden_states.clone(), topk_ids.clone(), topk_weights.clone(), weights, expert_indices)
|
||||
print(f"Pipeline: amax={ref.abs().max():.4f}, mean={ref.float().mean():.6f}")
|
||||
|
||||
# ── Test 1: Runner with warmup gs (no safety margin) ──
|
||||
print("\n--- Test 1: Warmup gs, no safety margin ---")
|
||||
runner = CuTeDSLMoERunner(num_experts, hidden_size, intermediate_size, device=DEVICE)
|
||||
runner.prepare_weights_direct(
|
||||
[w.clone() for w in l1_fp4], [w.clone() for w in l1_sf], list(l1_gs),
|
||||
[w.clone() for w in l2_fp4], [w.clone() for w in l2_sf], list(l2_gs),
|
||||
)
|
||||
|
||||
l1_gs_val, l2_gs_val = warmup_compute_gs(runner, hidden_states, topk_weights, topk_ids)
|
||||
print(f" Warmup L1 gs: {l1_gs_val:.10f}")
|
||||
print(f" Warmup L2 gs: {l2_gs_val:.10f}")
|
||||
|
||||
runner._l1_activation_global_scale = l1_gs_val
|
||||
runner._l2_activation_global_scale = l2_gs_val
|
||||
result = runner.run(hidden_states.clone(), topk_weights, topk_ids)
|
||||
|
||||
cos = torch.nn.functional.cosine_similarity(
|
||||
result.flatten().unsqueeze(0).float(), ref.flatten().unsqueeze(0).float()
|
||||
).item()
|
||||
print(f" Cosine: {cos:.6f}, amax={result.abs().max():.4f}")
|
||||
|
||||
# ── Test 2: Runner with warmup gs + safety margins ──
|
||||
for safety in [1.0, 1.1, 1.2, 1.5, 2.0]:
|
||||
runner2 = CuTeDSLMoERunner(num_experts, hidden_size, intermediate_size, device=DEVICE)
|
||||
runner2.prepare_weights_direct(
|
||||
[w.clone() for w in l1_fp4], [w.clone() for w in l1_sf], list(l1_gs),
|
||||
[w.clone() for w in l2_fp4], [w.clone() for w in l2_sf], list(l2_gs),
|
||||
)
|
||||
runner2._l1_activation_global_scale = l1_gs_val * safety
|
||||
runner2._l2_activation_global_scale = l2_gs_val * safety
|
||||
result2 = runner2.run(hidden_states.clone(), topk_weights, topk_ids)
|
||||
cos2 = torch.nn.functional.cosine_similarity(
|
||||
result2.flatten().unsqueeze(0).float(), ref.flatten().unsqueeze(0).float()
|
||||
).item()
|
||||
print(f" Safety {safety:.1f}x: cosine={cos2:.6f}, amax={result2.abs().max():.4f}")
|
||||
|
||||
# ── Test 3: Different input (verify warmup gs generalizes) ──
|
||||
print("\n--- Test 3: Different input with same warmup gs ---")
|
||||
torch.manual_seed(99)
|
||||
hidden_states2 = torch.randn(4, hidden_size, dtype=torch.bfloat16, device=DEVICE) * 2.0
|
||||
topk_ids2 = torch.tensor([[0, 1]] * 4, dtype=torch.int32, device=DEVICE)
|
||||
topk_weights2 = torch.tensor([[0.6, 0.4]] * 4, dtype=torch.float32, device=DEVICE)
|
||||
|
||||
ref2 = run_nvfp4_moe(hidden_states2.clone(), topk_ids2.clone(), topk_weights2.clone(), weights, expert_indices)
|
||||
result3 = runner.run(hidden_states2.clone(), topk_weights2, topk_ids2)
|
||||
cos3 = torch.nn.functional.cosine_similarity(
|
||||
result3.flatten().unsqueeze(0).float(), ref2.flatten().unsqueeze(0).float()
|
||||
).item()
|
||||
print(f" Different input: cosine={cos3:.6f}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user