Files
nvfp4-megamoe-kernel/diag_keys.py
biondizzle fd59222fc0 fix: stop folding global scale into float8 block scales
The fold block_sf (float8) * global_sf (float32) -> float8 loses ~25% precision.
Product of ~56-448 block_sf * ~4.65e-05 global_sf lands in float8 low-precision
zone where step size is 25%. This makes model output garbage despite finite values.

Fix: keep block scales as original float8, return global scales separately as
float32 per-expert vectors. Apply global scale as per-expert GEMM alpha in
cutlass_grouped_nvfp4_gemm (already iterates per-expert). For L1 with separate
gate/up global scales, use gate_gs as alpha and apply up_correction ratio to
the up half post-GEMM.

weight_transform.py: no more _fold_global_scale, returns (w, sf, global_sf)
nvfp4_mega_moe.py: per-expert alpha = activation_gs * weight_gs
kernel.py: per_expert_alpha parameter in grouped GEMM
deepseek_v4.py: updated type hints and comments
2026-05-15 12:42:53 +00:00

40 lines
1.5 KiB
Python

"""Find ALL weight_scale_2 keys in the checkpoint for layer 0 experts."""
from safetensors import safe_open
import glob
import os
MODEL_PATH = "/model"
ckpt_files = sorted(glob.glob(os.path.join(MODEL_PATH, "*.safetensors")))
# Collect ALL keys that mention layer 0 experts and scale
scale_keys = []
for f in ckpt_files:
with safe_open(f, framework="pt") as st:
for key in st.keys():
if "layers.0" in key and "experts.0" in key and "scale" in key.lower():
val = st.get_tensor(key)
scale_keys.append((key, list(val.shape), str(val.dtype), val.float().min().item(), val.float().max().item()))
scale_keys.sort()
for k, s, d, mn, mx in scale_keys:
print(f" {k} shape={s} dtype={d} range=[{mn:.4e}, {mx:.4e}]")
print(f"\nTotal: {len(scale_keys)} scale keys for layer 0 expert 0")
# Also find gate_proj and up_proj weight_scale_2 keys
print("\n--- All weight_scale_2 keys with gate/up/down for layer 0 ---")
ws2_keys = []
for f in ckpt_files:
with safe_open(f, framework="pt") as st:
for key in st.keys():
if "layers.0" in key and "weight_scale_2" in key:
val = st.get_tensor(key)
ws2_keys.append((key, list(val.shape), str(val.dtype), val.float().min().item(), val.float().max().item()))
ws2_keys.sort()
for k, s, d, mn, mx in ws2_keys[:10]:
print(f" {k} shape={s} dtype={d} range=[{mn:.4e}, {mx:.4e}]")
if len(ws2_keys) > 10:
print(f" ... and {len(ws2_keys)-10} more")
print(f"Total: {len(ws2_keys)} weight_scale_2 keys for layer 0")