double check that weird line
This commit is contained in:
@@ -128,8 +128,7 @@ def nvfp4_mega_moe_l1(
|
||||
topk_ids, topk_weights,
|
||||
alpha=alpha,
|
||||
)
|
||||
if MEGA_MOE_DEBUG:
|
||||
print(f"[L1-GEMM-OUT] amax={output.abs().max().item():.4e} mean={output.float().mean().item():.4e} nonzero_frac={(output != 0).float().mean().item():.4f}")
|
||||
print(f"[L1-GEMM-OUT] amax={output.abs().max().item():.4e} mean={output.float().mean().item():.4e} nonzero_frac={(output != 0).float().mean().item():.4f}")
|
||||
return output # (num_tokens, 6144) bfloat16
|
||||
|
||||
|
||||
@@ -328,6 +327,12 @@ def nvfp4_mega_moe_full(
|
||||
experts_start_idx = symm_buffer.experts_start_idx
|
||||
topk_ids_local = topk_ids - experts_start_idx
|
||||
|
||||
# Routing diagnostic (ungated — needed to diagnose zero-GEMM on specific ranks)
|
||||
local_topk = (topk_ids >= experts_start_idx) & (topk_ids < experts_start_idx + num_experts_per_rank)
|
||||
tokens_routed_locally = local_topk.any(dim=-1).sum().item()
|
||||
print(f"[ROUTING] tokens_routed_local={tokens_routed_locally}/{topk_ids.shape[0]} "
|
||||
f"unique_local_experts={local_topk.long().sum().item()}")
|
||||
|
||||
if MEGA_MOE_DEBUG:
|
||||
print(f"[nvfp4_mega_moe_full] x_fp4={x_fp4.shape} x_sf={x_sf.shape} "
|
||||
f"topk_ids={topk_ids.shape} topk_ids range: {topk_ids.min().item()}-{topk_ids.max().item()} "
|
||||
|
||||
Reference in New Issue
Block a user