Fix dequant gsa: use ws2 only, NOT input_scale * ws2
For weight dequantization, gsa should be weight_scale_2 only. input_scale is the activation global scale — it belongs on the GEMM's activation side, not the weight side. Using input_scale * ws2 gave gsa = 6e-8 (essentially zero), making dequantized weights ~0. The GEMM formula is y = (x * scale_a * gsa) @ (w * scale_b * gsb) where gsb = input_scale * ws2. But dequantize_nvfp4 is just the weight half: w_bf16 = lut[w] * block_scale * ws2.
This commit is contained in:
@@ -322,16 +322,14 @@ class Compressor:
|
||||
gate_w, gate_ws, gate_ws2, gate_isc = get_nvfp4_weight(w, pfx, 'gate_proj')
|
||||
if kv_w is not None:
|
||||
ws2_v = kv_ws2.float().item() if kv_ws2 is not None else 1.0
|
||||
isc_v = kv_isc.float().item() if kv_isc is not None else 1.0/(6.0*448.0)
|
||||
gsb = isc_v * ws2_v # global_scale_b = input_scale * weight_scale_2
|
||||
gsa = torch.tensor([gsb] * kv_w.shape[0], device=dev, dtype=torch.float32)
|
||||
# For weight dequantization: gsa = ws2 (NOT input_scale * ws2)
|
||||
# input_scale is the activation global scale, only used in GEMM's gsb computation
|
||||
gsa = torch.tensor([ws2_v] * kv_w.shape[0], device=dev, dtype=torch.float32)
|
||||
kv_bf16 = dequantize_nvfp4(kv_w.to(dev), kv_ws.to(dev), gsa) # (out, in)
|
||||
self._kv_bf16 = kv_bf16.to(dev).contiguous()
|
||||
if gate_w is not None:
|
||||
ws2_v = gate_ws2.float().item() if gate_ws2 is not None else 1.0
|
||||
isc_v = gate_isc.float().item() if gate_isc is not None else 1.0/(6.0*448.0)
|
||||
gsb = isc_v * ws2_v
|
||||
gsa = torch.tensor([gsb] * gate_w.shape[0], device=dev, dtype=torch.float32)
|
||||
gsa = torch.tensor([ws2_v] * gate_w.shape[0], device=dev, dtype=torch.float32)
|
||||
gate_bf16 = dequantize_nvfp4(gate_w.to(dev), gate_ws.to(dev), gsa)
|
||||
self._gate_bf16 = gate_bf16.to(dev).contiguous()
|
||||
self.ape = w.get(f"{pfx}.position_bias")
|
||||
@@ -1326,8 +1324,8 @@ def main():
|
||||
# Checkpoint has NVFP4 gate weight — dequantize to BF16
|
||||
from dsv4.ops.quantize import dequantize_nvfp4
|
||||
ws2_v = gate_ws2.float().item() if gate_ws2 is not None else 1.0
|
||||
gsb = isc_v * ws2_v # global_scale_b = input_scale * weight_scale_2
|
||||
gsa = torch.tensor([gsb] * gate_w.shape[0], device=dev, dtype=torch.float32)
|
||||
# For weight dequantization: gsa = ws2 (NOT input_scale * ws2)
|
||||
gsa = torch.tensor([ws2_v] * gate_w.shape[0], device=dev, dtype=torch.float32)
|
||||
gate_bf16 = dequantize_nvfp4(gate_w.to(dev), gate_ws.to(dev), gsa) # (E_packed*2, H)
|
||||
router.W_gate = gate_bf16.T.contiguous().to(dev) # (H, E) for F.linear(x, W_gate.T)
|
||||
else:
|
||||
|
||||
Reference in New Issue
Block a user