more debug
This commit is contained in:
@@ -114,6 +114,9 @@ def nvfp4_mega_moe_l1(
|
||||
print(f"[nvfp4_moe_l1] tokens={num_tokens} K={K} N={N} "
|
||||
f"experts={num_experts_per_rank} native=1")
|
||||
|
||||
# DEBUG: verify weight shapes after transpose
|
||||
print(f"[L1-WT] l1_w shape={l1_weights.shape} l1_sf shape={l1_scales.shape} w_sf dtype={l1_scales.dtype}")
|
||||
|
||||
# Unpack uint32 packed UE4M3 scales to float8_e4m3fn
|
||||
x_sf_fp8 = unpack_ue4m3_u32(x_sf) if x_sf.dtype == torch.uint32 else x_sf
|
||||
w_sf_fp8 = unpack_ue4m3_u32(l1_scales) if l1_scales.dtype == torch.uint32 else l1_scales
|
||||
@@ -124,6 +127,7 @@ def nvfp4_mega_moe_l1(
|
||||
topk_ids, topk_weights,
|
||||
alpha=alpha,
|
||||
)
|
||||
print(f"[L1-GEMM-OUT] amax={output.abs().max().item():.4e} mean={output.float().mean().item():.4e} nonzero_frac={(output != 0).float().mean().item():.4f}")
|
||||
return output # (num_tokens, 6144) bfloat16
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user