Add granular attention diagnostics: pre/post attn, embed, dequant stats

This commit is contained in:
2026-05-18 14:24:14 +00:00
parent e0e0528778
commit 5c1dda10f6
2 changed files with 50 additions and 33 deletions

View File

@@ -1,55 +1,58 @@
# Current Bug: vLLM produces NaN from layer 0
**Status:** ROOT CAUSE IDENTIFIED
**Status:** Active debugging — BF16 dequant fix in progress
**Date:** 2026-05-18
## Symptom
- vLLM server starts, loads model, but every inference produces NaN logits
- Diagnostic prints show **NaN from layer 0 onward** — no layer ever produces valid output
- Empty content in chat completions, NaN in logprobs
## Root Cause
## Root Cause (in progress)
**The attention NVFP4 linear layers produce NaN immediately.**
The attention projections (`q_a_proj`, `q_b_proj`, `kv_proj`, `o_a_proj`, `o_b_proj`) go through vLLM's `FlashInferCutlassNvFp4LinearKernel` which calls:
```python
x_fp4, x_blockscale = scaled_fp4_quant(x, layer.input_global_scale_inv, ...)
```
The attention projections go through vLLM's `FlashInferCutlassNvFp4LinearKernel` which uses checkpoint `input_scale` as the activation global scale for `scaled_fp4_quant()`. The checkpoint `input_scale` values are wrong for this use case, causing overflow → NaN.
`input_global_scale_inv` comes from the checkpoint `input_scale` field. For MoE, we override this with a warmup. For attention, there's **no warmup** — it uses the raw checkpoint value.
### What we've tried
The `CompressedTensorsW4A4Fp4.process_weights_after_loading` sets:
```python
input_global_scale_inv = layer.input_scale.max().to(torch.float32) # = 0.00025141
layer.alpha = input_global_scale * layer.weight_global_scale
```
1.**MoE kernel is NOT the problem**`test_runner_vllm_style.py` with warmup gs gives cosine 0.988, no NaN
2.**Dequant ALL attn projections to BF16** — crashed: `wo_a.weight_scale_inv` missing (fp8_einsum needs it)
3.**Dequant all except wo_a (keep wo_a as FP8)** — still NaN from layer 0. `wq_a` and `wkv` don't exist as separate attrs — they're **fused as `fused_wqa_wkv`**
4.**Changed to dequant `fused_wqa_wkv`** — still NaN from layer 0. Debug prints added to check if the attrs are actually found.
For q_a_proj: `input_scale = 0.00025141`, meaning `1/input_scale = 3977.6`. The activation quantization divides by 0.00025141 (multiplies by 3977.6). For typical activations with amax ~2-8, this produces values far beyond FP4 range (max 6.0), causing NaN via overflow.
### Current theory
## Evidence
The BF16 dequant code may not be finding `fused_wqa_wkv` on the attention module, so it silently skips the most important projection. Debug logging added in latest commit to verify.
1. **MoE kernel is fine**`test_runner_vllm_style.py` with warmup gs gives cosine 0.988
2. **NaN from layer 0** — diagnostic prints show ALL layers from 0 produce NaN
3. **Attention weights dequantize fine**`test_attn_weights.py` shows no NaN from dequantized BF16 matmul
4. **The problem is in the NVFP4 activation quantization**, not the weights
### Attention architecture (DeepSeek V4 MLA)
## Fix
- `fused_wqa_wkv` — MergedColumnParallelLinear (q_a + kv fused)
- `wq_b` — ColumnParallelLinear (second Q projection after RoPE)
- `wo_a` — ColumnParallelLinear (FP8 via fp8_einsum, weight-only, NO input_scale)
- `wo_b` — ColumnParallelLinear (final output projection)
- `compressor` — already handled (reconstructed to BF16 from checkpoint)
The attention `input_scale` needs the same warmup-based override we did for MoE, OR the `input_scale` values need to be validated/corrected.
### Why `wo_a` is safe as FP8
Options:
1. **Add warmup for attention `input_global_scale_inv`** — same pattern as MoE: run a dummy forward, capture actual activation amax, compute correct gs
2. **Dequantize attention weights to BF16** (like compressor weights) — avoids NVFP4 activation quantization entirely, at the cost of more memory
3. **Fix the checkpoint input_scale** — if the values are wrong, re-calibrate
`wo_a` uses `fp8_einsum` which does `output = fp8_act * fp8_weight * scale`. It's a **weight-only FP8** GEMM — no `input_scale` involved. The NaN comes from `scaled_fp4_quant(x, input_global_scale_inv)` in the other projections.
Option 2 is the quickest path — dequantize attention NVFP4 weights to BF16 at load time (the `_dequant_nvfp4_to_bf16` method already exists). This trades memory for correctness.
## Key evidence
## Progress
- `q_a_proj.input_scale = 0.00025141``1/input_scale = 3977.6` → quantizing activations with amax ~2-8 by 3977.6x = massive overflow
- `q_b_proj.input_scale = 0.00006140``1/input_scale = 16287.1` → even worse
- Embedding values: amax=1.27, std=0.09 — very small values that get multiplied by thousands during quantization
- [x] Removed NaN check (Dynamo incompatible)
- [x] vLLM container starts and loads model
- [x] Confirmed NaN logits from completions API
- [x] MoE kernel: cosine 0.988 with warmup gs — NOT the problem
- [x] NaN starts at layer 0 — attention is the source
- [x] Root cause: attention NVFP4 `input_scale` from checkpoint produces NaN during activation quantization
- [ ] **Next: Fix attention NVFP4 path — dequant to BF16 or add warmup**
## Next steps
1. Check debug logs to see which projections were actually dequantized
2. If `fused_wqa_wkv` wasn't found, fix the attribute path
3. If it was found and dequantized, the NaN source is elsewhere (wo_b? wq_b? something else?)
4. Consider: maybe the NaN is from the **KV cache FP8 quantization** or the **RoPE** implementation
## Docker/Build Notes
- Build: `screen -dmS build bash -c './build_and_run.sh 2>&1 | tee build.log'`
- Currently using `--enforce-eager` + `CLAWMINE_DEBUG=1` for diagnostics
- Don't hit the API with enforce-eager — JIT spikes crash the container
- For real testing: use compilation-config `{"cudagraph_mode": "NONE", "custom_ops": ["all"]}` instead of enforce-eager

View File

@@ -1197,7 +1197,11 @@ class DeepseekV4DecoderLayer(nn.Module):
x, self.hc_attn_fn, self.hc_attn_scale, self.hc_attn_base
)
x = self.attn_norm(x)
if os.environ.get('CLAWMINE_DEBUG', '0') == '1':
_print_tensor("pre_attn", x, self._layer_idx if hasattr(self, '_layer_idx') else -1)
x = self.attn(positions, x, None)
if os.environ.get('CLAWMINE_DEBUG', '0') == '1':
_print_tensor("post_attn", x, self._layer_idx if hasattr(self, '_layer_idx') else -1)
x = self.hc_post(x, residual, post, comb)
residual = x
@@ -1210,6 +1214,11 @@ class DeepseekV4DecoderLayer(nn.Module):
return x
def _print_tensor(label, t, layer_idx):
with torch.no_grad():
print(f"[CLAWMINE] L{layer_idx} {label}: amax={t.amax().item():.4f} NaN={torch.isnan(t).any().item()} shape={t.shape}")
def _diag_hidden_stats(hidden_states: torch.Tensor, layer_idx: int):
"""Print hidden state stats after each layer. Disabled unless
CLAWMINE_DEBUG=1. os.environ is evaluated at trace time, so
@@ -1323,9 +1332,12 @@ class DeepseekV4Model(nn.Module):
) -> torch.Tensor | IntermediateTensors:
hidden_states = self.embed_input_ids(input_ids)
hidden_states = hidden_states.unsqueeze(-2).repeat(1, self.hc_mult, 1)
if os.environ.get('CLAWMINE_DEBUG', '0') == '1':
_print_tensor("embed", hidden_states, -1)
if self.use_mega_moe:
input_ids = input_ids.to(torch.int64)
for layer_idx, layer in enumerate(islice(self.layers, self.start_layer, self.end_layer)):
layer._layer_idx = layer_idx
hidden_states = layer(
hidden_states,
positions,
@@ -1741,6 +1753,8 @@ class DeepseekV4Model(nn.Module):
if mod.weight.dtype in (torch.uint8, torch.int8):
E2M1_LUT = torch.tensor([0, 0.5, 1, 1.5, 2, 3, 4, 6], dtype=torch.bfloat16)
self._dequant_nvfp4_to_bf16(mod, E2M1_LUT)
if layer_idx == 0:
print(f"[CLAWMINE] Layer 0: {proj_name} AFTER dequant: dtype={mod.weight.dtype} amax={mod.weight.data.amax().item():.4f} NaN={torch.isnan(mod.weight.data).any().item()}")
bf16_converted += 1
# FP8 conversion: wo_a (used by fp8_einsum, no input_scale)