diff --git a/tests/unit/test_fused_router.py b/tests/unit/test_fused_router.py index 135dd67c..3250cef4 100644 --- a/tests/unit/test_fused_router.py +++ b/tests/unit/test_fused_router.py @@ -39,7 +39,7 @@ def test_fused_router(): from dsv4.layers.linear import Nvfp4Linear gate_lin = Nvfp4Linear(in_features=K, out_features=N, device=device) gate_lin.fp4 = [w_fp4.contiguous()] - gate_lin.scale_b = [w_sf.contiguous()] + gate_lin.sf = [w_sf.contiguous()] gate_lin.gsb = [ws2_val] gate_lin._activation_global_scale = None # set at runtime gate_lin._ensure_stacked = lambda *a, **kw: None