diag: add weight_scale uint8 histogram to verify E8M0 vs E4M3 format
This commit is contained in:
@@ -1950,6 +1950,18 @@ class DeepseekV4Model(nn.Module):
|
||||
# scale_fmt=ue8m0: weight_scale bytes are E8M0 format (power-of-2 only).
|
||||
# A simple .to(float32) misinterprets them as E4M3. Must reinterpret
|
||||
# the raw uint8 bits as IEEE 754 exponent field.
|
||||
# Diagnostic: histogram of raw weight_scale uint8 bytes
|
||||
# E8M0 values cluster narrowly (120-130 = 2^-7 to 2^3)
|
||||
# E4M3 values spread across 0-127 with mantissa noise
|
||||
ws_u8 = mod.weight_scale.data.view(torch.uint8)
|
||||
ws_u8_flat = ws_u8.flatten()[:256].cpu().numpy()
|
||||
import numpy as _np
|
||||
_hist, _edges = _np.histogram(ws_u8_flat, bins=16, range=(0, 255))
|
||||
_prefix = getattr(mod, 'prefix', 'unknown')
|
||||
print(f"[SCALE-FMT] {_prefix}: uint8 histogram (first 256 vals): "
|
||||
f"{list(zip(_edges.astype(int), _hist))}")
|
||||
print(f"[SCALE-FMT] {_prefix}: uint8 min={ws_u8.min().item()} max={ws_u8.max().item()} "
|
||||
f"mean={ws_u8.float().mean().item():.1f} std={ws_u8.float().std().item():.1f}")
|
||||
block_scale = self._ue8m0_to_float32(mod.weight_scale.data)
|
||||
if block_scale.dim() == 2 and w_bf16.dim() == 2:
|
||||
block_size = w_bf16.shape[1] // block_scale.shape[1]
|
||||
|
||||
Reference in New Issue
Block a user