Fix router gate W_gate shape: must be (H, E) not (E, H)
dense_router_dispatch expects W_gate as (hidden, experts) and does W_gate.T internally. dequant_nvfp4 returns (out, in) = (E, H), so we need to transpose.
This commit is contained in:
@@ -1316,8 +1316,8 @@ def main():
|
||||
gate_bf16 = dequant_nvfp4(gate_w.to(dev), gate_ws.to(dev),
|
||||
gate_ws2.to(dev) if gate_ws2 is not None else None,
|
||||
gate_isc.to(dev) if gate_isc is not None else None)
|
||||
# W_gate shape: (E, H) for F.linear(x, W_gate)
|
||||
router.W_gate = gate_bf16
|
||||
# W_gate shape: (H, E) for F.linear(x, W_gate.T)
|
||||
router.W_gate = gate_bf16.T.contiguous()
|
||||
router.gate_lin = None # force BF16 dispatch path
|
||||
router.load_weights(e_bias=eb.to(dev, torch.float32))
|
||||
if li < 5: print(f" L{li}: BF16 router gate (dequantized from NVFP4)", flush=True)
|
||||
@@ -1326,7 +1326,7 @@ def main():
|
||||
gw = all_w.get(f"{pfx}.gate.weight")
|
||||
if gw is not None:
|
||||
g_bf16 = gw if gw.shape == (E, H) else gw.T.contiguous()
|
||||
router.W_gate = g_bf16.bfloat16().to(dev)
|
||||
router.W_gate = g_bf16.T.contiguous()
|
||||
router.gate_lin = None
|
||||
router.load_weights(e_bias=eb.to(dev, torch.float32))
|
||||
if li < 5: print(f" L{li}: BF16 router gate (checkpoint BF16)", flush=True)
|
||||
|
||||
Reference in New Issue
Block a user