Add granular attention diagnostics: pre/post attn, embed, dequant stats
This commit is contained in:
@@ -1,55 +1,58 @@
|
||||
# Current Bug: vLLM produces NaN from layer 0
|
||||
|
||||
**Status:** ROOT CAUSE IDENTIFIED
|
||||
**Status:** Active debugging — BF16 dequant fix in progress
|
||||
**Date:** 2026-05-18
|
||||
|
||||
## Symptom
|
||||
- vLLM server starts, loads model, but every inference produces NaN logits
|
||||
- Diagnostic prints show **NaN from layer 0 onward** — no layer ever produces valid output
|
||||
- Empty content in chat completions, NaN in logprobs
|
||||
|
||||
## Root Cause
|
||||
## Root Cause (in progress)
|
||||
|
||||
**The attention NVFP4 linear layers produce NaN immediately.**
|
||||
|
||||
The attention projections (`q_a_proj`, `q_b_proj`, `kv_proj`, `o_a_proj`, `o_b_proj`) go through vLLM's `FlashInferCutlassNvFp4LinearKernel` which calls:
|
||||
```python
|
||||
x_fp4, x_blockscale = scaled_fp4_quant(x, layer.input_global_scale_inv, ...)
|
||||
```
|
||||
The attention projections go through vLLM's `FlashInferCutlassNvFp4LinearKernel` which uses checkpoint `input_scale` as the activation global scale for `scaled_fp4_quant()`. The checkpoint `input_scale` values are wrong for this use case, causing overflow → NaN.
|
||||
|
||||
`input_global_scale_inv` comes from the checkpoint `input_scale` field. For MoE, we override this with a warmup. For attention, there's **no warmup** — it uses the raw checkpoint value.
|
||||
### What we've tried
|
||||
|
||||
The `CompressedTensorsW4A4Fp4.process_weights_after_loading` sets:
|
||||
```python
|
||||
input_global_scale_inv = layer.input_scale.max().to(torch.float32) # = 0.00025141
|
||||
layer.alpha = input_global_scale * layer.weight_global_scale
|
||||
```
|
||||
1. ✅ **MoE kernel is NOT the problem** — `test_runner_vllm_style.py` with warmup gs gives cosine 0.988, no NaN
|
||||
2. ❌ **Dequant ALL attn projections to BF16** — crashed: `wo_a.weight_scale_inv` missing (fp8_einsum needs it)
|
||||
3. ❌ **Dequant all except wo_a (keep wo_a as FP8)** — still NaN from layer 0. `wq_a` and `wkv` don't exist as separate attrs — they're **fused as `fused_wqa_wkv`**
|
||||
4. ❌ **Changed to dequant `fused_wqa_wkv`** — still NaN from layer 0. Debug prints added to check if the attrs are actually found.
|
||||
|
||||
For q_a_proj: `input_scale = 0.00025141`, meaning `1/input_scale = 3977.6`. The activation quantization divides by 0.00025141 (multiplies by 3977.6). For typical activations with amax ~2-8, this produces values far beyond FP4 range (max 6.0), causing NaN via overflow.
|
||||
### Current theory
|
||||
|
||||
## Evidence
|
||||
The BF16 dequant code may not be finding `fused_wqa_wkv` on the attention module, so it silently skips the most important projection. Debug logging added in latest commit to verify.
|
||||
|
||||
1. **MoE kernel is fine** — `test_runner_vllm_style.py` with warmup gs gives cosine 0.988
|
||||
2. **NaN from layer 0** — diagnostic prints show ALL layers from 0 produce NaN
|
||||
3. **Attention weights dequantize fine** — `test_attn_weights.py` shows no NaN from dequantized BF16 matmul
|
||||
4. **The problem is in the NVFP4 activation quantization**, not the weights
|
||||
### Attention architecture (DeepSeek V4 MLA)
|
||||
|
||||
## Fix
|
||||
- `fused_wqa_wkv` — MergedColumnParallelLinear (q_a + kv fused)
|
||||
- `wq_b` — ColumnParallelLinear (second Q projection after RoPE)
|
||||
- `wo_a` — ColumnParallelLinear (FP8 via fp8_einsum, weight-only, NO input_scale)
|
||||
- `wo_b` — ColumnParallelLinear (final output projection)
|
||||
- `compressor` — already handled (reconstructed to BF16 from checkpoint)
|
||||
|
||||
The attention `input_scale` needs the same warmup-based override we did for MoE, OR the `input_scale` values need to be validated/corrected.
|
||||
### Why `wo_a` is safe as FP8
|
||||
|
||||
Options:
|
||||
1. **Add warmup for attention `input_global_scale_inv`** — same pattern as MoE: run a dummy forward, capture actual activation amax, compute correct gs
|
||||
2. **Dequantize attention weights to BF16** (like compressor weights) — avoids NVFP4 activation quantization entirely, at the cost of more memory
|
||||
3. **Fix the checkpoint input_scale** — if the values are wrong, re-calibrate
|
||||
`wo_a` uses `fp8_einsum` which does `output = fp8_act * fp8_weight * scale`. It's a **weight-only FP8** GEMM — no `input_scale` involved. The NaN comes from `scaled_fp4_quant(x, input_global_scale_inv)` in the other projections.
|
||||
|
||||
Option 2 is the quickest path — dequantize attention NVFP4 weights to BF16 at load time (the `_dequant_nvfp4_to_bf16` method already exists). This trades memory for correctness.
|
||||
## Key evidence
|
||||
|
||||
## Progress
|
||||
- `q_a_proj.input_scale = 0.00025141` → `1/input_scale = 3977.6` → quantizing activations with amax ~2-8 by 3977.6x = massive overflow
|
||||
- `q_b_proj.input_scale = 0.00006140` → `1/input_scale = 16287.1` → even worse
|
||||
- Embedding values: amax=1.27, std=0.09 — very small values that get multiplied by thousands during quantization
|
||||
|
||||
- [x] Removed NaN check (Dynamo incompatible)
|
||||
- [x] vLLM container starts and loads model
|
||||
- [x] Confirmed NaN logits from completions API
|
||||
- [x] MoE kernel: cosine 0.988 with warmup gs — NOT the problem
|
||||
- [x] NaN starts at layer 0 — attention is the source
|
||||
- [x] Root cause: attention NVFP4 `input_scale` from checkpoint produces NaN during activation quantization
|
||||
- [ ] **Next: Fix attention NVFP4 path — dequant to BF16 or add warmup**
|
||||
## Next steps
|
||||
|
||||
1. Check debug logs to see which projections were actually dequantized
|
||||
2. If `fused_wqa_wkv` wasn't found, fix the attribute path
|
||||
3. If it was found and dequantized, the NaN source is elsewhere (wo_b? wq_b? something else?)
|
||||
4. Consider: maybe the NaN is from the **KV cache FP8 quantization** or the **RoPE** implementation
|
||||
|
||||
## Docker/Build Notes
|
||||
|
||||
- Build: `screen -dmS build bash -c './build_and_run.sh 2>&1 | tee build.log'`
|
||||
- Currently using `--enforce-eager` + `CLAWMINE_DEBUG=1` for diagnostics
|
||||
- Don't hit the API with enforce-eager — JIT spikes crash the container
|
||||
- For real testing: use compilation-config `{"cudagraph_mode": "NONE", "custom_ops": ["all"]}` instead of enforce-eager
|
||||
|
||||
@@ -1197,7 +1197,11 @@ class DeepseekV4DecoderLayer(nn.Module):
|
||||
x, self.hc_attn_fn, self.hc_attn_scale, self.hc_attn_base
|
||||
)
|
||||
x = self.attn_norm(x)
|
||||
if os.environ.get('CLAWMINE_DEBUG', '0') == '1':
|
||||
_print_tensor("pre_attn", x, self._layer_idx if hasattr(self, '_layer_idx') else -1)
|
||||
x = self.attn(positions, x, None)
|
||||
if os.environ.get('CLAWMINE_DEBUG', '0') == '1':
|
||||
_print_tensor("post_attn", x, self._layer_idx if hasattr(self, '_layer_idx') else -1)
|
||||
x = self.hc_post(x, residual, post, comb)
|
||||
|
||||
residual = x
|
||||
@@ -1210,6 +1214,11 @@ class DeepseekV4DecoderLayer(nn.Module):
|
||||
return x
|
||||
|
||||
|
||||
def _print_tensor(label, t, layer_idx):
|
||||
with torch.no_grad():
|
||||
print(f"[CLAWMINE] L{layer_idx} {label}: amax={t.amax().item():.4f} NaN={torch.isnan(t).any().item()} shape={t.shape}")
|
||||
|
||||
|
||||
def _diag_hidden_stats(hidden_states: torch.Tensor, layer_idx: int):
|
||||
"""Print hidden state stats after each layer. Disabled unless
|
||||
CLAWMINE_DEBUG=1. os.environ is evaluated at trace time, so
|
||||
@@ -1323,9 +1332,12 @@ class DeepseekV4Model(nn.Module):
|
||||
) -> torch.Tensor | IntermediateTensors:
|
||||
hidden_states = self.embed_input_ids(input_ids)
|
||||
hidden_states = hidden_states.unsqueeze(-2).repeat(1, self.hc_mult, 1)
|
||||
if os.environ.get('CLAWMINE_DEBUG', '0') == '1':
|
||||
_print_tensor("embed", hidden_states, -1)
|
||||
if self.use_mega_moe:
|
||||
input_ids = input_ids.to(torch.int64)
|
||||
for layer_idx, layer in enumerate(islice(self.layers, self.start_layer, self.end_layer)):
|
||||
layer._layer_idx = layer_idx
|
||||
hidden_states = layer(
|
||||
hidden_states,
|
||||
positions,
|
||||
@@ -1741,6 +1753,8 @@ class DeepseekV4Model(nn.Module):
|
||||
if mod.weight.dtype in (torch.uint8, torch.int8):
|
||||
E2M1_LUT = torch.tensor([0, 0.5, 1, 1.5, 2, 3, 4, 6], dtype=torch.bfloat16)
|
||||
self._dequant_nvfp4_to_bf16(mod, E2M1_LUT)
|
||||
if layer_idx == 0:
|
||||
print(f"[CLAWMINE] Layer 0: {proj_name} AFTER dequant: dtype={mod.weight.dtype} amax={mod.weight.data.amax().item():.4f} NaN={torch.isnan(mod.weight.data).any().item()}")
|
||||
bf16_converted += 1
|
||||
|
||||
# FP8 conversion: wo_a (used by fp8_einsum, no input_scale)
|
||||
|
||||
Reference in New Issue
Block a user