Fix router gate BF16 quantize path for production FMHA test

This commit is contained in:
2026-06-03 02:47:47 +00:00
parent 07168357cc
commit 1f757151ef

View File

@@ -167,6 +167,23 @@ def main():
gate_lin.finalize_weights()
router.load_nvfp4_gate(gate_lin)
router.load_weights(e_bias=eb.to(dev, torch.float32))
else:
# BF16 gate weight — quantize to NVFP4
gw = all_w.get(f"{mlp_pfx}.gate.weight")
if gw is not None:
g_bf16 = gw if gw.shape == (E, H) else gw.T.contiguous()
g_bf16 = g_bf16.bfloat16().to(dev)
g_fp4, g_sf, g_gs = quantize_to_nvfp4(g_bf16)
gate_lin = Nvfp4Linear(in_features=H, out_features=E, device=dev)
gate_lin.fp4 = [g_fp4]
gate_lin.sf = [g_sf]
gate_lin.gs = [g_gs]
gate_lin.ws2 = [torch.tensor([g_gs], device=dev, dtype=torch.float32)]
gate_lin._activation_global_scale = 1.0 / (6.0 * 448.0)
gate_lin._use_runtime_gsa = True
gate_lin.finalize_weights()
router.load_nvfp4_gate(gate_lin)
router.load_weights(e_bias=eb.to(dev, torch.float32))
router.finalize_weights(); routers[li] = router
moe = Nvfp4MoE(num_experts=cfg["n_routed_experts"], hidden_size=H,