debug: try/catch around mega_moe kernel with data diagnostics on crash
This commit is contained in:
@@ -364,15 +364,33 @@ def fp8_nvfp4_mega_moe(y: torch.Tensor,
|
||||
("sym_l2_acts", sym_buffer.l2_acts), ("sym_l2_acts_sf", sym_buffer.l2_acts_sf)]:
|
||||
print(f"[debug] {name}: dtype={t.dtype} shape={tuple(t.shape)} contig={t.is_contiguous()}", flush=True)
|
||||
|
||||
_C.fp8_nvfp4_mega_moe(
|
||||
y,
|
||||
(l1_w, l1_w_sf), (l2_w, l2_w_sf),
|
||||
cumulative_local_expert_recv_stats,
|
||||
sym_buffer.buffer,
|
||||
sym_buffer.handle.buffer_ptrs, sym_buffer.group.rank(),
|
||||
sym_buffer.num_max_tokens_per_rank,
|
||||
sym_buffer.num_experts, sym_buffer.num_topk,
|
||||
recipe,
|
||||
activation, activation_clamp,
|
||||
fast_math
|
||||
)
|
||||
try:
|
||||
_C.fp8_nvfp4_mega_moe(
|
||||
y,
|
||||
(l1_w, l1_w_sf), (l2_w, l2_w_sf),
|
||||
cumulative_local_expert_recv_stats,
|
||||
sym_buffer.buffer,
|
||||
sym_buffer.handle.buffer_ptrs, sym_buffer.group.rank(),
|
||||
sym_buffer.num_max_tokens_per_rank,
|
||||
sym_buffer.num_experts, sym_buffer.num_topk,
|
||||
recipe,
|
||||
activation, activation_clamp,
|
||||
fast_math
|
||||
)
|
||||
torch.cuda.synchronize()
|
||||
except RuntimeError as e:
|
||||
print(f"[CRASH] fp8_nvfp4_mega_moe FAILED: {e}", flush=True)
|
||||
print(f"[CRASH] num_tokens={y.shape[0]} hidden={y.shape[1]}", flush=True)
|
||||
print(f"[CRASH] l1_w_sf shape={l1_w_sf.shape} strides={l1_w_sf.stride()}", flush=True)
|
||||
print(f"[CRASH] l2_w_sf shape={l2_w_sf.shape} strides={l2_w_sf.stride()}", flush=True)
|
||||
# Check activation data for anomalies
|
||||
x = sym_buffer.x[:y.shape[0]]
|
||||
print(f"[CRASH] x stats: min={x.min().item()} max={x.max().item()} nonzero={torch.count_nonzero(x).item()}", flush=True)
|
||||
x_sf = sym_buffer.x_sf[:y.shape[0]]
|
||||
print(f"[CRASH] x_sf stats: min={x_sf.min().item()} max={x_sf.max().item()} zeros={(x_sf==0).sum().item()}", flush=True)
|
||||
# Check weight data for anomalies
|
||||
print(f"[CRASH] l1_w stats: min={l1_w.min().item()} max={l1_w.max().item()}", flush=True)
|
||||
print(f"[CRASH] l1_w_sf stats: min={l1_w_sf.min().item()} max={l1_w_sf.max().item()}", flush=True)
|
||||
print(f"[CRASH] l2_w stats: min={l2_w.min().item()} max={l2_w.max().item()}", flush=True)
|
||||
print(f"[CRASH] l2_w_sf stats: min={l2_w_sf.min().item()} max={l2_w_sf.max().item()}", flush=True)
|
||||
raise
|
||||
|
||||
Reference in New Issue
Block a user