debug: add try/except with shape logging to _run_mega_moe

This commit is contained in:
2026-05-16 04:02:01 +00:00
parent b04bff7e8b
commit 37aa0cbeab

View File

@@ -558,11 +558,28 @@ class DeepseekV4MegaMoEExperts(nn.Module):
# Build expert indices list for this rank
expert_indices = list(range(self.num_local_experts))
result = self._cutedsl_runner.run(
hidden_states, topk_weights, topk_ids,
expert_indices=expert_indices,
)
y.copy_(result)
try:
result = self._cutedsl_runner.run(
hidden_states, topk_weights, topk_ids,
expert_indices=expert_indices,
)
y.copy_(result)
except Exception as exc:
import traceback
traceback.print_exc()
# Debug: print shapes
runner = self._cutedsl_runner
print(f"[NVFP4 DEBUG] num_local_experts={self.num_local_experts} "
f"hidden={self.hidden_size} intermediate={self.intermediate_size}")
if runner.l1_fp4:
print(f"[NVFP4 DEBUG] l1_fp4[0] shape={runner.l1_fp4[0].shape} "
f"l1_sf[0] shape={runner.l1_sf[0].shape} l1_gs[0]={runner.l1_gs[0]}")
if runner.l2_fp4:
print(f"[NVFP4 DEBUG] l2_fp4[0] shape={runner.l2_fp4[0].shape} "
f"l2_sf[0] shape={runner.l2_sf[0].shape} l2_gs[0]={runner.l2_gs[0]}")
print(f"[NVFP4 DEBUG] hidden_states shape={hidden_states.shape} "
f"topk_ids shape={topk_ids.shape}")
raise
if os.environ.get('NVFP4_DEBUG_SYNC', '') == '1':
torch.cuda.synchronize()