HOTFIX: remove NaN checks from run() — torch.isnan().any() does CPU-GPU sync, breaks cudagraph

This commit is contained in:
2026-05-17 22:28:32 +00:00
parent 8717e0e411
commit 87582fc9f7

View File

@@ -466,11 +466,6 @@ class CuTeDSLMoERunner:
global_scale_a=l1_gsa, global_scale_b=self._l1_gsb,
)
# DEBUG: Check for NaN after L1 GEMM
if torch.isnan(l1_out).any():
nan_count = torch.isnan(l1_out).sum().item()
print(f"[CLAWMINE] NaN in L1 output! {nan_count}/{l1_out.numel()} values are NaN")
# Extract real token outputs from padded GEMM output
l1_out_real = l1_out[padded_dst]
@@ -507,11 +502,6 @@ class CuTeDSLMoERunner:
global_scale_a=l2_gsa, global_scale_b=self._l2_gsb,
)
# DEBUG: Check for NaN after L2 GEMM
if torch.isnan(l2_out).any():
nan_count = torch.isnan(l2_out).sum().item()
print(f"[CLAWMINE] NaN in L2 output! {nan_count}/{l2_out.numel()} values are NaN")
l2_out_real = l2_out[padded_dst]
# === Scatter -> final output ===