diff --git a/tests/unit/test_production_fmha_layer.py b/tests/unit/test_production_fmha_layer.py index a7c5744a..7bcd3cf0 100644 --- a/tests/unit/test_production_fmha_layer.py +++ b/tests/unit/test_production_fmha_layer.py @@ -167,6 +167,23 @@ def main(): gate_lin.finalize_weights() router.load_nvfp4_gate(gate_lin) router.load_weights(e_bias=eb.to(dev, torch.float32)) + else: + # BF16 gate weight — quantize to NVFP4 + gw = all_w.get(f"{mlp_pfx}.gate.weight") + if gw is not None: + g_bf16 = gw if gw.shape == (E, H) else gw.T.contiguous() + g_bf16 = g_bf16.bfloat16().to(dev) + g_fp4, g_sf, g_gs = quantize_to_nvfp4(g_bf16) + gate_lin = Nvfp4Linear(in_features=H, out_features=E, device=dev) + gate_lin.fp4 = [g_fp4] + gate_lin.sf = [g_sf] + gate_lin.gs = [g_gs] + gate_lin.ws2 = [torch.tensor([g_gs], device=dev, dtype=torch.float32)] + gate_lin._activation_global_scale = 1.0 / (6.0 * 448.0) + gate_lin._use_runtime_gsa = True + gate_lin.finalize_weights() + router.load_nvfp4_gate(gate_lin) + router.load_weights(e_bias=eb.to(dev, torch.float32)) router.finalize_weights(); routers[li] = router moe = Nvfp4MoE(num_experts=cfg["n_routed_experts"], hidden_size=H,