Fix NaN check: wrap in @torch.compiler.disable to prevent Dynamo graph break
The inline os.environ gate doesn't work — Dynamo still sees the data-dependent branching (torch.isnan().any()) and crashes with 'Unsupported: Data-dependent branching'. Extracting into a @torch.compiler.disable decorated function makes Dynamo skip it.
This commit is contained in:
@@ -1210,6 +1210,24 @@ class DeepseekV4DecoderLayer(nn.Module):
|
||||
return x
|
||||
|
||||
|
||||
@torch.compiler.disable
|
||||
def _clawmine_nan_check(hidden_states: torch.Tensor, layer_idx: int):
|
||||
"""NaN/Inf detection — only active when CLAWMINE_NAN_CHECK=1.
|
||||
Decorated with @torch.compiler.disable so Dynamo never traces
|
||||
the data-dependent branching (if tensor.any())."""
|
||||
if os.environ.get('CLAWMINE_NAN_CHECK', '0') != '1':
|
||||
return
|
||||
with torch.no_grad():
|
||||
if torch.isnan(hidden_states).any():
|
||||
nan_pct = torch.isnan(hidden_states).float().mean().item() * 100
|
||||
print(f"[CLAWMINE] NaN after layer {layer_idx}! {nan_pct:.2f}% NaN, amax={hidden_states.amax().item():.4f}")
|
||||
elif torch.isinf(hidden_states).any():
|
||||
inf_pct = torch.isinf(hidden_states).float().mean().item() * 100
|
||||
print(f"[CLAWMINE] Inf after layer {layer_idx}! {inf_pct:.2f}% Inf, amax={hidden_states.amax().item():.4f}")
|
||||
elif layer_idx % 10 == 0:
|
||||
print(f"[CLAWMINE] Layer {layer_idx}: amax={hidden_states.amax().item():.4f} mean={hidden_states.mean().item():.6f}")
|
||||
|
||||
|
||||
@support_torch_compile
|
||||
class DeepseekV4Model(nn.Module):
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
@@ -1316,21 +1334,9 @@ class DeepseekV4Model(nn.Module):
|
||||
positions,
|
||||
input_ids,
|
||||
)
|
||||
# NaN detection — only during prefill. Disabled via env var during cudagraph.
|
||||
# os.environ is evaluated at trace time by Dynamo, so the entire
|
||||
# NaN check block is skipped during compilation.
|
||||
if os.environ.get('CLAWMINE_NAN_CHECK', '0') == '1':
|
||||
with torch.no_grad():
|
||||
if torch.isnan(hidden_states).any():
|
||||
nan_pct = torch.isnan(hidden_states).float().mean().item() * 100
|
||||
print(f"[CLAWMINE] NaN after layer {layer_idx}! {nan_pct:.2f}% NaN, amax={hidden_states.amax().item():.4f}")
|
||||
break
|
||||
if torch.isinf(hidden_states).any():
|
||||
inf_pct = torch.isinf(hidden_states).float().mean().item() * 100
|
||||
print(f"[CLAWMINE] Inf after layer {layer_idx}! {inf_pct:.2f}% Inf, amax={hidden_states.amax().item():.4f}")
|
||||
break
|
||||
if layer_idx % 10 == 0:
|
||||
print(f"[CLAWMINE] Layer {layer_idx}: amax={hidden_states.amax().item():.4f} mean={hidden_states.mean().item():.6f}")
|
||||
# NaN detection — guarded by env var, wrapped in torch.compiler.disable
|
||||
# so Dynamo never traces the data-dependent branching.
|
||||
_clawmine_nan_check(hidden_states, layer_idx)
|
||||
|
||||
# Stash pre-hc_head residual for the MTP draft (captured copy_).
|
||||
num_tokens = hidden_states.shape[0]
|
||||
|
||||
Reference in New Issue
Block a user