diag: add weight_scale uint8 histogram to verify E8M0 vs E4M3 format

This commit is contained in:
2026-05-11 19:55:41 +00:00
parent 50a945bde4
commit 0fd2d4f078

View File

@@ -1950,6 +1950,18 @@ class DeepseekV4Model(nn.Module):
# scale_fmt=ue8m0: weight_scale bytes are E8M0 format (power-of-2 only).
# A simple .to(float32) misinterprets them as E4M3. Must reinterpret
# the raw uint8 bits as IEEE 754 exponent field.
# Diagnostic: histogram of raw weight_scale uint8 bytes
# E8M0 values cluster narrowly (120-130 = 2^-7 to 2^3)
# E4M3 values spread across 0-127 with mantissa noise
ws_u8 = mod.weight_scale.data.view(torch.uint8)
ws_u8_flat = ws_u8.flatten()[:256].cpu().numpy()
import numpy as _np
_hist, _edges = _np.histogram(ws_u8_flat, bins=16, range=(0, 255))
_prefix = getattr(mod, 'prefix', 'unknown')
print(f"[SCALE-FMT] {_prefix}: uint8 histogram (first 256 vals): "
f"{list(zip(_edges.astype(int), _hist))}")
print(f"[SCALE-FMT] {_prefix}: uint8 min={ws_u8.min().item()} max={ws_u8.max().item()} "
f"mean={ws_u8.float().mean().item():.1f} std={ws_u8.float().std().item():.1f}")
block_scale = self._ue8m0_to_float32(mod.weight_scale.data)
if block_scale.dim() == 2 and w_bf16.dim() == 2:
block_size = w_bf16.shape[1] // block_scale.shape[1]