The fold block_sf (float8) * global_sf (float32) -> float8 loses ~25% precision. Product of ~56-448 block_sf * ~4.65e-05 global_sf lands in float8 low-precision zone where step size is 25%. This makes model output garbage despite finite values. Fix: keep block scales as original float8, return global scales separately as float32 per-expert vectors. Apply global scale as per-expert GEMM alpha in cutlass_grouped_nvfp4_gemm (already iterates per-expert). For L1 with separate gate/up global scales, use gate_gs as alpha and apply up_correction ratio to the up half post-GEMM. weight_transform.py: no more _fold_global_scale, returns (w, sf, global_sf) nvfp4_mega_moe.py: per-expert alpha = activation_gs * weight_gs kernel.py: per_expert_alpha parameter in grouped GEMM deepseek_v4.py: updated type hints and comments
40 lines
1.5 KiB
Python
40 lines
1.5 KiB
Python
"""Find ALL weight_scale_2 keys in the checkpoint for layer 0 experts."""
|
|
from safetensors import safe_open
|
|
import glob
|
|
import os
|
|
|
|
MODEL_PATH = "/model"
|
|
ckpt_files = sorted(glob.glob(os.path.join(MODEL_PATH, "*.safetensors")))
|
|
|
|
# Collect ALL keys that mention layer 0 experts and scale
|
|
scale_keys = []
|
|
for f in ckpt_files:
|
|
with safe_open(f, framework="pt") as st:
|
|
for key in st.keys():
|
|
if "layers.0" in key and "experts.0" in key and "scale" in key.lower():
|
|
val = st.get_tensor(key)
|
|
scale_keys.append((key, list(val.shape), str(val.dtype), val.float().min().item(), val.float().max().item()))
|
|
|
|
scale_keys.sort()
|
|
for k, s, d, mn, mx in scale_keys:
|
|
print(f" {k} shape={s} dtype={d} range=[{mn:.4e}, {mx:.4e}]")
|
|
|
|
print(f"\nTotal: {len(scale_keys)} scale keys for layer 0 expert 0")
|
|
|
|
# Also find gate_proj and up_proj weight_scale_2 keys
|
|
print("\n--- All weight_scale_2 keys with gate/up/down for layer 0 ---")
|
|
ws2_keys = []
|
|
for f in ckpt_files:
|
|
with safe_open(f, framework="pt") as st:
|
|
for key in st.keys():
|
|
if "layers.0" in key and "weight_scale_2" in key:
|
|
val = st.get_tensor(key)
|
|
ws2_keys.append((key, list(val.shape), str(val.dtype), val.float().min().item(), val.float().max().item()))
|
|
|
|
ws2_keys.sort()
|
|
for k, s, d, mn, mx in ws2_keys[:10]:
|
|
print(f" {k} shape={s} dtype={d} range=[{mn:.4e}, {mx:.4e}]")
|
|
if len(ws2_keys) > 10:
|
|
print(f" ... and {len(ws2_keys)-10} more")
|
|
print(f"Total: {len(ws2_keys)} weight_scale_2 keys for layer 0")
|