fix: use correct Nvfp4Linear field names (fp4, scale_b, gsb)
This commit is contained in:
@@ -38,9 +38,9 @@ def test_fused_router():
|
||||
# The Nvfp4Linear expects stacked/contiguous weight tensors
|
||||
from dsv4.layers.linear import Nvfp4Linear
|
||||
gate_lin = Nvfp4Linear(in_features=K, out_features=N, device=device)
|
||||
gate_lin._weight_fp4 = [w_fp4.contiguous()]
|
||||
gate_lin._scale_b = [w_sf.contiguous()]
|
||||
gate_lin._gsb = [ws2_val]
|
||||
gate_lin.fp4 = [w_fp4.contiguous()]
|
||||
gate_lin.scale_b = [w_sf.contiguous()]
|
||||
gate_lin.gsb = [ws2_val]
|
||||
gate_lin._activation_global_scale = None # set at runtime
|
||||
gate_lin._ensure_stacked = lambda *a, **kw: None
|
||||
gate_lin._ensure_initialized = lambda *a, **kw: None
|
||||
|
||||
Reference in New Issue
Block a user