double check that weird line

This commit is contained in:
2026-05-15 06:40:38 +00:00
parent beacc31569
commit 6b4b59c6a4

View File

@@ -128,8 +128,7 @@ def nvfp4_mega_moe_l1(
topk_ids, topk_weights,
alpha=alpha,
)
if MEGA_MOE_DEBUG:
print(f"[L1-GEMM-OUT] amax={output.abs().max().item():.4e} mean={output.float().mean().item():.4e} nonzero_frac={(output != 0).float().mean().item():.4f}")
print(f"[L1-GEMM-OUT] amax={output.abs().max().item():.4e} mean={output.float().mean().item():.4e} nonzero_frac={(output != 0).float().mean().item():.4f}")
return output # (num_tokens, 6144) bfloat16
@@ -328,6 +327,12 @@ def nvfp4_mega_moe_full(
experts_start_idx = symm_buffer.experts_start_idx
topk_ids_local = topk_ids - experts_start_idx
# Routing diagnostic (ungated — needed to diagnose zero-GEMM on specific ranks)
local_topk = (topk_ids >= experts_start_idx) & (topk_ids < experts_start_idx + num_experts_per_rank)
tokens_routed_locally = local_topk.any(dim=-1).sum().item()
print(f"[ROUTING] tokens_routed_local={tokens_routed_locally}/{topk_ids.shape[0]} "
f"unique_local_experts={local_topk.long().sum().item()}")
if MEGA_MOE_DEBUG:
print(f"[nvfp4_mega_moe_full] x_fp4={x_fp4.shape} x_sf={x_sf.shape} "
f"topk_ids={topk_ids.shape} topk_ids range: {topk_ids.min().item()}-{topk_ids.max().item()} "