HOTFIX: remove NaN checks from run() — torch.isnan().any() does CPU-GPU sync, breaks cudagraph
This commit is contained in:
@@ -466,11 +466,6 @@ class CuTeDSLMoERunner:
|
||||
global_scale_a=l1_gsa, global_scale_b=self._l1_gsb,
|
||||
)
|
||||
|
||||
# DEBUG: Check for NaN after L1 GEMM
|
||||
if torch.isnan(l1_out).any():
|
||||
nan_count = torch.isnan(l1_out).sum().item()
|
||||
print(f"[CLAWMINE] NaN in L1 output! {nan_count}/{l1_out.numel()} values are NaN")
|
||||
|
||||
# Extract real token outputs from padded GEMM output
|
||||
l1_out_real = l1_out[padded_dst]
|
||||
|
||||
@@ -507,11 +502,6 @@ class CuTeDSLMoERunner:
|
||||
global_scale_a=l2_gsa, global_scale_b=self._l2_gsb,
|
||||
)
|
||||
|
||||
# DEBUG: Check for NaN after L2 GEMM
|
||||
if torch.isnan(l2_out).any():
|
||||
nan_count = torch.isnan(l2_out).sum().item()
|
||||
print(f"[CLAWMINE] NaN in L2 output! {nan_count}/{l2_out.numel()} values are NaN")
|
||||
|
||||
l2_out_real = l2_out[padded_dst]
|
||||
|
||||
# === Scatter -> final output ===
|
||||
|
||||
Reference in New Issue
Block a user