more debug

This commit is contained in:
2026-05-15 04:53:26 +00:00
parent c7f6a1dc4d
commit 912e4622d7

View File

@@ -114,6 +114,9 @@ def nvfp4_mega_moe_l1(
print(f"[nvfp4_moe_l1] tokens={num_tokens} K={K} N={N} "
f"experts={num_experts_per_rank} native=1")
# DEBUG: verify weight shapes after transpose
print(f"[L1-WT] l1_w shape={l1_weights.shape} l1_sf shape={l1_scales.shape} w_sf dtype={l1_scales.dtype}")
# Unpack uint32 packed UE4M3 scales to float8_e4m3fn
x_sf_fp8 = unpack_ue4m3_u32(x_sf) if x_sf.dtype == torch.uint32 else x_sf
w_sf_fp8 = unpack_ue4m3_u32(l1_scales) if l1_scales.dtype == torch.uint32 else l1_scales
@@ -124,6 +127,7 @@ def nvfp4_mega_moe_l1(
topk_ids, topk_weights,
alpha=alpha,
)
print(f"[L1-GEMM-OUT] amax={output.abs().max().item():.4e} mean={output.float().mean().item():.4e} nonzero_frac={(output != 0).float().mean().item():.4f}")
return output # (num_tokens, 6144) bfloat16