Fix dequant gsb: input_scale * ws2, not 1.0 * ws2

The NVFP4 dequantize formula is w = lut[w_packed] * scale * ws2,
and in the GEMM the global_scale_b = input_scale * ws2. Was incorrectly
using gsb = 1.0 * ws2 (missing input_scale). This would produce
wrongly-scaled BF16 weights from dequantize_nvfp4.
This commit is contained in:
2026-06-03 14:26:59 +00:00
parent 2dd16d5789
commit 470e65fb19

View File

@@ -322,15 +322,17 @@ class Compressor:
gate_w, gate_ws, gate_ws2, gate_isc = get_nvfp4_weight(w, pfx, 'gate_proj')
if kv_w is not None:
ws2_v = kv_ws2.float().item() if kv_ws2 is not None else 1.0
gsb = 1.0 * ws2_v
isc_v = kv_isc.float().item() if kv_isc is not None else 1.0/(6.0*448.0)
gsb = isc_v * ws2_v # global_scale_b = input_scale * weight_scale_2
gsa = torch.tensor([gsb] * kv_w.shape[0], device=dev, dtype=torch.float32)
kv_bf16 = dequantize_nvfp4(kv_w.to(dev), kv_ws.to(dev), gsa) # (out, in)
self._kv_bf16 = kv_bf16.to(dev).contiguous()
if gate_w is not None:
ws2_v = gate_ws2.float().item() if gate_ws2 is not None else 1.0
gsb = 1.0 * ws2_v
isc_v = gate_isc.float().item() if gate_isc is not None else 1.0/(6.0*448.0)
gsb = isc_v * ws2_v
gsa = torch.tensor([gsb] * gate_w.shape[0], device=dev, dtype=torch.float32)
gate_bf16 = dequantize_nvfp4(gate_w.to(dev), gate_ws.to(dev), gsa) # (out, in)
gate_bf16 = dequantize_nvfp4(gate_w.to(dev), gate_ws.to(dev), gsa)
self._gate_bf16 = gate_bf16.to(dev).contiguous()
self.ape = w.get(f"{pfx}.position_bias")
self.kv_norm_w = w.get(f"{pfx}.kv_norm.weight")
@@ -420,7 +422,8 @@ class Indexer:
if wp_w is not None:
from dsv4.ops.quantize import dequantize_nvfp4
ws2_v = wp_ws2.float().item() if wp_ws2 is not None else 1.0
gsb = 1.0 * ws2_v
isc_v = wp_isc.float().item() if wp_isc is not None else 1.0/(6.0*448.0)
gsb = isc_v * ws2_v # global_scale_b = input_scale * weight_scale_2
gsa = torch.tensor([gsb] * wp_w.shape[0], device=dev, dtype=torch.float32)
wp_bf16 = dequantize_nvfp4(wp_w.to(dev), wp_ws.to(dev), gsa)
self._wp_bf16 = wp_bf16.to(dev).contiguous()
@@ -1323,7 +1326,7 @@ def main():
# Checkpoint has NVFP4 gate weight — dequantize to BF16
from dsv4.ops.quantize import dequantize_nvfp4
ws2_v = gate_ws2.float().item() if gate_ws2 is not None else 1.0
gsb = 1.0 * ws2_v # global_scale_b = gs * ws2
gsb = isc_v * ws2_v # global_scale_b = input_scale * weight_scale_2
gsa = torch.tensor([gsb] * gate_w.shape[0], device=dev, dtype=torch.float32)
gate_bf16 = dequantize_nvfp4(gate_w.to(dev), gate_ws.to(dev), gsa) # (E_packed*2, H)
router.W_gate = gate_bf16.T.contiguous().to(dev) # (H, E) for F.linear(x, W_gate.T)