From 65763a200cf728d4e9cd9cf10c70181ef5302f9d Mon Sep 17 00:00:00 2001 From: biondizzle Date: Mon, 18 May 2026 11:33:29 +0000 Subject: [PATCH] Fix NaN check: wrap in @torch.compiler.disable to prevent Dynamo graph break MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The inline os.environ gate doesn't work — Dynamo still sees the data-dependent branching (torch.isnan().any()) and crashes with 'Unsupported: Data-dependent branching'. Extracting into a @torch.compiler.disable decorated function makes Dynamo skip it. --- vllm/patches/deepseek_v4.py | 36 +++++++++++++++++++++--------------- 1 file changed, 21 insertions(+), 15 deletions(-) diff --git a/vllm/patches/deepseek_v4.py b/vllm/patches/deepseek_v4.py index 60a040d2..0f60f9fa 100644 --- a/vllm/patches/deepseek_v4.py +++ b/vllm/patches/deepseek_v4.py @@ -1210,6 +1210,24 @@ class DeepseekV4DecoderLayer(nn.Module): return x +@torch.compiler.disable +def _clawmine_nan_check(hidden_states: torch.Tensor, layer_idx: int): + """NaN/Inf detection — only active when CLAWMINE_NAN_CHECK=1. + Decorated with @torch.compiler.disable so Dynamo never traces + the data-dependent branching (if tensor.any()).""" + if os.environ.get('CLAWMINE_NAN_CHECK', '0') != '1': + return + with torch.no_grad(): + if torch.isnan(hidden_states).any(): + nan_pct = torch.isnan(hidden_states).float().mean().item() * 100 + print(f"[CLAWMINE] NaN after layer {layer_idx}! {nan_pct:.2f}% NaN, amax={hidden_states.amax().item():.4f}") + elif torch.isinf(hidden_states).any(): + inf_pct = torch.isinf(hidden_states).float().mean().item() * 100 + print(f"[CLAWMINE] Inf after layer {layer_idx}! {inf_pct:.2f}% Inf, amax={hidden_states.amax().item():.4f}") + elif layer_idx % 10 == 0: + print(f"[CLAWMINE] Layer {layer_idx}: amax={hidden_states.amax().item():.4f} mean={hidden_states.mean().item():.6f}") + + @support_torch_compile class DeepseekV4Model(nn.Module): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): @@ -1316,21 +1334,9 @@ class DeepseekV4Model(nn.Module): positions, input_ids, ) - # NaN detection — only during prefill. Disabled via env var during cudagraph. - # os.environ is evaluated at trace time by Dynamo, so the entire - # NaN check block is skipped during compilation. - if os.environ.get('CLAWMINE_NAN_CHECK', '0') == '1': - with torch.no_grad(): - if torch.isnan(hidden_states).any(): - nan_pct = torch.isnan(hidden_states).float().mean().item() * 100 - print(f"[CLAWMINE] NaN after layer {layer_idx}! {nan_pct:.2f}% NaN, amax={hidden_states.amax().item():.4f}") - break - if torch.isinf(hidden_states).any(): - inf_pct = torch.isinf(hidden_states).float().mean().item() * 100 - print(f"[CLAWMINE] Inf after layer {layer_idx}! {inf_pct:.2f}% Inf, amax={hidden_states.amax().item():.4f}") - break - if layer_idx % 10 == 0: - print(f"[CLAWMINE] Layer {layer_idx}: amax={hidden_states.amax().item():.4f} mean={hidden_states.mean().item():.6f}") + # NaN detection — guarded by env var, wrapped in torch.compiler.disable + # so Dynamo never traces the data-dependent branching. + _clawmine_nan_check(hidden_states, layer_idx) # Stash pre-hc_head residual for the MTP draft (captured copy_). num_tokens = hidden_states.shape[0]