Fix NaN check: wrap in @torch.compiler.disable to prevent Dynamo graph break

The inline os.environ gate doesn't work — Dynamo still sees the
data-dependent branching (torch.isnan().any()) and crashes with
'Unsupported: Data-dependent branching'. Extracting into a
@torch.compiler.disable decorated function makes Dynamo skip it.
This commit is contained in:
2026-05-18 11:33:29 +00:00
parent 8758bc93ca
commit 65763a200c

View File

@@ -1210,6 +1210,24 @@ class DeepseekV4DecoderLayer(nn.Module):
return x
@torch.compiler.disable
def _clawmine_nan_check(hidden_states: torch.Tensor, layer_idx: int):
"""NaN/Inf detection — only active when CLAWMINE_NAN_CHECK=1.
Decorated with @torch.compiler.disable so Dynamo never traces
the data-dependent branching (if tensor.any())."""
if os.environ.get('CLAWMINE_NAN_CHECK', '0') != '1':
return
with torch.no_grad():
if torch.isnan(hidden_states).any():
nan_pct = torch.isnan(hidden_states).float().mean().item() * 100
print(f"[CLAWMINE] NaN after layer {layer_idx}! {nan_pct:.2f}% NaN, amax={hidden_states.amax().item():.4f}")
elif torch.isinf(hidden_states).any():
inf_pct = torch.isinf(hidden_states).float().mean().item() * 100
print(f"[CLAWMINE] Inf after layer {layer_idx}! {inf_pct:.2f}% Inf, amax={hidden_states.amax().item():.4f}")
elif layer_idx % 10 == 0:
print(f"[CLAWMINE] Layer {layer_idx}: amax={hidden_states.amax().item():.4f} mean={hidden_states.mean().item():.6f}")
@support_torch_compile
class DeepseekV4Model(nn.Module):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
@@ -1316,21 +1334,9 @@ class DeepseekV4Model(nn.Module):
positions,
input_ids,
)
# NaN detection — only during prefill. Disabled via env var during cudagraph.
# os.environ is evaluated at trace time by Dynamo, so the entire
# NaN check block is skipped during compilation.
if os.environ.get('CLAWMINE_NAN_CHECK', '0') == '1':
with torch.no_grad():
if torch.isnan(hidden_states).any():
nan_pct = torch.isnan(hidden_states).float().mean().item() * 100
print(f"[CLAWMINE] NaN after layer {layer_idx}! {nan_pct:.2f}% NaN, amax={hidden_states.amax().item():.4f}")
break
if torch.isinf(hidden_states).any():
inf_pct = torch.isinf(hidden_states).float().mean().item() * 100
print(f"[CLAWMINE] Inf after layer {layer_idx}! {inf_pct:.2f}% Inf, amax={hidden_states.amax().item():.4f}")
break
if layer_idx % 10 == 0:
print(f"[CLAWMINE] Layer {layer_idx}: amax={hidden_states.amax().item():.4f} mean={hidden_states.mean().item():.6f}")
# NaN detection — guarded by env var, wrapped in torch.compiler.disable
# so Dynamo never traces the data-dependent branching.
_clawmine_nan_check(hidden_states, layer_idx)
# Stash pre-hc_head residual for the MTP draft (captured copy_).
num_tokens = hidden_states.shape[0]