probe: indexer and compressor weight shapes from checkpoint

This commit is contained in:
2026-06-02 04:36:35 +00:00
parent 9254cb0b0d
commit 938e9079ce
2 changed files with 509 additions and 0 deletions

View File

@@ -0,0 +1,434 @@
# ARCHITECTURE & MEMORY AUDIT — 1M-context viability
**Method.** Verified against `single_shot_inference.py` v16 and the DSv4
paper §2.3.1§2.3.4 (CSA/HCA) and §3.5.1 (heterogeneous KV cache). Every
finding has a line number. Per doctrine.
**Framing.** The Paris demo runs ≤ 50 tokens. The model is built for 1M.
That's a **20,000×** gap. Several things in the current single_shot are
fine at 50 tokens, will OOM hard at 410K tokens, and don't even resemble
the paper's KV design at 1M. Below: drift first, then memory, in order of
how badly each one blocks the 1M-context goal.
---
# PART 1 — ARCHITECTURE DRIFT FROM PAPER
## D1 — `comp_idx_buf` shape is wrong (CRITICAL — silent corruption or crash)
`single_shot_inference.py:419`:
```python
self.comp_idx_buf = torch.zeros(max_comp, head_dim, dtype=torch.bfloat16, device=device)
^^^^^^^^
512 WRONG
```
But indexer keys are `n_ih × ihd` wide. From `:1030`:
```python
n_ih = cfg.get("index_n_heads", 64)
ihd = cfg.get("index_head_dim", 128)
```
So indexer keys have width `64 × 128 = 8192`, not 512. **The indexer KV
buffer is 16× too narrow.** What happens in practice depends on whether
the assignment broadcasts, raises, or silently truncates — and from the
fact that Paris-back works, the indexer probably isn't being used at all
yet (CSA layers may be producing 0 compressed blocks at 50 tokens — see
D2). At any context where CSA actually compresses, this either crashes
or stores garbage in the top-k selection input.
**Fix.** Read the actual indexer key width from the indexer's compressor
output (`indexer.compressor.kv_dim = 2 * ihd = 256` for the indexer-side
CSA, since the indexer's compressor takes `(4, ihd, H)`). Then check what
the indexer's compressor actually produces — print its output shape on
first call instead of guessing — and size `comp_idx_buf` to match.
Verification step: instrument `Compressor.forward` to print
`compressed.shape` on first call from both the **main** compressor (kv_dim
= 512 or 1024) and the **indexer's** compressor (kv_dim = 256). Code to
the observed values. Do not infer them from variable names.
## D2 — The Compressor is built twice per CSA layer, with different config
`:394` (inside `Indexer.load`):
```python
self.compressor = Compressor(4, self.ihd, 7168, dev)
```
`:1034` (in `main()`):
```python
if ratio > 0: compressors[li] = Compressor(ratio, hd, H, dev)
```
For a CSA layer:
- The **layer's** compressor has `(ratio=4, hd=512, H=7168)`, output dim
`kv_dim = 2*hd = 1024`.
- The **indexer's** compressor has `(ratio=4, ihd=128, H=7168)`, output
dim `kv_dim = 2*ihd = 256`.
Both are constructed, both load weights independently, both have their own
NVFP4 GEMM Nvfp4Linear instances. This matches the paper (§2.3.1: the
indexer has its own compressed key path that's narrower than the main
compressed KV path) — but **it is being done as two completely
independent code paths**, with the indexer's compressor's existence
implicit and easy to miss. That's why D1 was missed.
It also means the main compressor's `forward` runs twice per layer at
prefill: once for the main KV path, once for the indexer's keys. The
hidden states being projected are *identical*; the GEMM weights are
different. Two separate launches.
**Architecturally correct, but the code shape hides it.** Two consequences:
1. The "compressor" abstraction has a different meaning depending on
which compressor instance you're looking at. Rename for clarity:
`Compressor` (main) and `IndexerKeyCompressor` (smaller, for indexer).
2. The Indexer should expose its compressed key width as a property
(`indexer.compressor.kv_dim // 2`) so D1 can compute the buffer width
from data instead of assuming `head_dim`.
## D3 — KV gather still uses `torch.cat`, undoing P3's pre-allocation
`:569571`:
```python
if ratio == 4 and topk_idx is not None:
tk = topk_idx[0].clamp(0, kv_cache.n_comp - 1)
all_kv = torch.cat([kv_cache.comp_kv[tk], swa_kv], dim=0)
elif ratio > 4: all_kv = torch.cat([kv_cache.comp_kv, swa_kv], dim=0)
```
P3 preallocated `comp_kv_buf`, but the gather **immediately allocates a
fresh `(top_k + ws, hd) = (1024 + 128, 512)` BF16 tensor per layer call**
just to pass to FMHA. For 61 layers × per-token decode:
- 61 × (1152 × 512 × 2 bytes) ≈ **72 MB allocated and freed per token**.
That's small in absolute terms, but it's allocator churn on the hot path,
and the entire point of P3 was to remove this pattern. The `comp_kv[tk]`
gather already allocates (it's an advanced-indexing copy); the `cat` then
allocates again to merge with SWA. Two allocs per layer that don't need
to exist.
**Fix.** Preallocate one more buffer at cache init:
```python
self.all_kv_buf = torch.zeros(top_k + window_size, head_dim, ...)
```
Write the gathered top-k into `all_kv_buf[:top_k]`, the SWA into
`all_kv_buf[top_k:top_k+swa_len]`, pass `all_kv_buf[:top_k+swa_len]` to
FMHA. The gather becomes `torch.index_select(comp_kv_buf, 0, tk,
out=all_kv_buf[:top_k])` — zero allocs.
## D4 — The HCA path concatenates the ENTIRE compressed history, every layer, every token
Same line, the `elif` branch:
```python
elif ratio > 4: all_kv = torch.cat([kv_cache.comp_kv, swa_kv], dim=0)
```
HCA layers don't run the indexer. They attend over the **full compressed
history**. At 1M context with HCA ratio=128, that's `1M / 128 = 7813`
compressed entries. Per-token-per-layer FMHA input grows linearly with
prefill length. That's correct math (paper §2.3.2: HCA is dense over
compressed entries), but the *implementation* is allocating a new
contiguous tensor for it every single layer call.
At 1M context, that's `(7813 + 128) × 512 × 2 bytes ≈ 7.7 MB` per HCA
layer call. Across ~30 HCA layers per token ≈ **230 MB of alloc-and-free
per token, on the decode hot path.**
**Fix.** Same as D3: preallocate `all_kv_buf` sized for the worst case
(HCA full history + SWA). Use `out=` parameters on the gather. Or skip
the concat entirely — the FMHA kernel can take two K/V tensors and the
mask can encode the boundary. Today it can't, but it should; this is a
real kernel-side ask (see "M5" below).
## D5 — Indexer top-k attends *only* compressed entries; SWA is appended separately. Confirm against paper
Paper §2.3.1 figure: CSA's FMHA input is
`Concatenate(selected_compressed_KV, sliding_window_KV)`. Yes, that's
what the code does at `:571`. **Architecturally correct.**
One subtle thing worth checking: the paper says the sliding window provides
*local* fine-grained context the compressor can't (since compressed entries
each represent m tokens). The current SWA window is **128** (from config).
At 1M context that's still 128 local tokens visible per query — correct.
But the window slide in `KVCache.append_swa` evicts when full
(`self.swa_head = (self.swa_head + T) % self.ws`), which is correct.
✅ no drift.
## D6 — Attention sink is wired, but per-CSA-layer correctness needs checking
`_run_production_fmha:489`:
```python
sinks = w.get(f"{pfx}.sinks")
if sinks is not None: sink_bias = sinks.to(device=dev).float().reshape(n_h)
attn_out = dsv4_attention(q=q, k=k, v=v, scale=scale, n_comp=0, sink_bias=sink_bias)
```
`n_comp=0` is passed regardless. Paper §2.3.3 ("Attention Sink"): the sink
is a per-head additive logit to the softmax denominator. The kernel
signature in v15 said `n_comp` is "reserved for future kernel integration"
for the **D5c sink merge** (different softmax over compressed vs SWA).
v16 still passes `n_comp=0`, which means we're using global sink, not
per-segment sink merge.
The paper isn't explicit about whether sink should be per-segment, but if
the production FMHA was designed for D5c merge and it's being bypassed,
that's an unfinished integration, not necessarily a bug.
**Action:** confirm against the kernel's actual handling. Print the
sink_bias usage in `dsv4_attention` for one layer. If sink merge is
needed for CSA correctness at long context, that's a real wiring gap.
## D7 — mHC residual growth (|X|→500700 at L60) was flagged but not understood
Your perf audit notes (line 88): "Residual |X| grows to 500700 at L60
— mHC bounds it but residual is high."
Paper §2.2 designed mHC specifically to **bound** the residual via the
doubly-stochastic B matrix (spectral norm ≤ 1). The growth from |X|=1 at
L0 to |X|=700 at L60 suggests B isn't actually norm-bounded at runtime,
or A·C are amplifying.
This is the **same** issue I flagged in the v14 docs and it's still open.
Not a perf bug, but it's an architecture-fidelity bug, and it's the
**single most likely cause of decode degradation past step 10** (the
"...the" repetition loop noted in your audit). Compounded across decode
steps, a slightly-not-bounded residual becomes a numerically saturated
residual, which makes the final logits less informative.
**Action.** Print Sinkhorn-Knopp B's row/col sums per layer for one
forward pass. They should be `1.0 ± 1e-6` if Sinkhorn converged. If
they're e.g. `1.021.05`, t_max=20 isn't converging at this scale; bump
it or check the dynamic-parameter generation. The single_shot does
`sinkhorn_iters=20` (`:937`) which matches the paper, so the issue is
likely upstream: either A or C is producing values outside [0, 1] or
[0, 2], or the dynamic parameter generation has FP32 noise that breaks
doubly-stochastic.
---
# PART 2 — MEMORY AT 1M CONTEXT
This is the part that should be terrifying. The single_shot was sized for
50 tokens; the model targets 1M. Below is what each KV-cache structure
costs at the four interesting scales, with all numbers worked out, not
estimated.
## Per-layer KV cache sizes — read off the code
Layer setup:
- **CSA layer** (compressor ratio=4): main compressed at `hd=512` BF16,
indexer keys at `n_ih * ihd = 64 * 128 = 8192` BF16, SWA at `ws=128 × hd`
- **HCA layer** (compressor ratio=128): main compressed at `hd=512` BF16,
no indexer, SWA at `ws=128 × hd`
- **SWA-only layer** (first 2 layers of Flash, or HCA per-layer-2 of Pro):
SWA only
Layer counts per V4-Pro: 61 total, alternating CSA/HCA after layer 1. So
roughly **30 CSA + 30 HCA + 1 SWA-only** (paper §4.2.1).
### Per-layer KV growth (per token of context):
| Component | Per token | Bytes / token | × 1M tokens |
|---|---|---|---|
| **CSA main compressed** (1 entry / 4 tokens, hd=512 BF16) | 0.25 × 1024 B | 256 B | **256 MB** |
| **CSA indexer keys** (1 entry / 4 tokens, 8192 BF16) | 0.25 × 16384 B | 4096 B | **4.1 GB** |
| **HCA compressed** (1 entry / 128 tokens, hd=512 BF16) | 0.0078 × 1024 B | 8 B | **8 MB** |
| **SWA** (per layer, fixed 128 × hd × 2) | const | — | 128 KB |
### Total KV cache @ 1M context, all layers, BF16:
| Layer type | Count | Per-layer @ 1M | Total |
|---|---|---|---|
| CSA: main + indexer | 30 | 256 MB + 4.1 GB | **131 GB** |
| HCA: main | 30 | 8 MB | 240 MB |
| SWA | 61 | 128 KB | 8 MB |
| **GRAND TOTAL @ 1M, BF16** | | | **~131 GB** |
**The KV cache alone is 131 GB.** That's before model weights (which are
already EP-sharded across 8 GPUs). On 8 × B200 with 192 GB each = 1.5 TB
total, sharding KV across the 8 GPUs gives ~16 GB per GPU — fits, but
it's **15% of HBM dedicated to one request's KV.** And the dominant cost
is the **indexer keys** at 4.1 GB per layer.
**Critical observation: the indexer keys are 16× larger than the main
compressed KV per token, and they alone are 86% of total KV.** This is
because `n_ih * ihd = 8192` is much wider than `hd = 512` for the main KV.
The paper does specify this — the indexer is a wide multi-query mechanism
— but if storage is the constraint, the indexer key path is where to
attack.
## M1 — `comp_idx_buf` allocation as written: `(65536, 512)` per layer
`:419`:
```python
self.comp_idx_buf = torch.zeros(max_comp, head_dim, dtype=torch.bfloat16, ...)
```
If this *were* the correct width `(max_comp=65536, 8192)`, that's
`65536 × 8192 × 2 = 1 GB per layer × 30 CSA layers = 30 GB pre-allocated
at startup`, sized for **only 262K tokens of context**, not 1M.
The current shape `(65536, 512)` allocates 64 MB per layer × 30 = 1.9 GB
— too small by 16× as noted in D1, so either crashes at first compressed
write or silently truncates. Both bad. After fixing D1, you're staring
down 30 GB of pre-allocated KV cache that supports only a quarter of the
target context. And it's 30 GB × 8 GPUs = 240 GB if everything is
replicated, or 30 GB total if cache is sharded.
**This is the load-bearing memory bug for 1M context.** Two viable fixes:
1. **Quantize the indexer keys to FP4** (paper §5.2.1: "QK activations
are cached, loaded, and multiplied entirely in FP4"). The indexer keys
are *designed* to be FP4. That's 16× smaller: 4.1 GB/layer → 256 MB at
1M. Total cache becomes ~10 GB instead of 131 GB. **This is the
correct fix per paper.**
2. **Page the indexer KV.** Only the top-k indices' worth (≤ 1024) need to
be in attention. A paged cache (per the paper's §3.5.1 "heterogeneous
KV cache" with on-disk overflow) only keeps recent + selected pages in
HBM.
(1) is required regardless of (2). The current BF16 indexer cache is the
single biggest blocker between Paris-demo-works and 1M-context-works.
## M2 — `max_comp = 65536` hardcoded
`:411`:
```python
def __init__(self, head_dim, window_size=128, max_comp=65536, device='cuda:0'):
```
For CSA (ratio=4), 65536 compressed entries = `262144` tokens of context.
That's the ceiling. At 262K tokens you get `IndexError` on
`comp_kv_buf[self.n_comp:end] = ckv`. There is no graceful behavior, no
on-disk overflow, no error message — it just dies.
For 1M context, `max_comp` needs to be `1M / 4 = 262144` for CSA layers
and `1M / 128 = 7813` for HCA layers. Currently both layer types share
the same 65536 default.
**Fix.** Size the buffer per-layer using compress ratio:
```python
max_comp_csa = ceil(target_context / 4) # 262144 for 1M
max_comp_hca = ceil(target_context / 128) # 7813 for 1M
```
And, critically, make `target_context` a CLI flag with a sensible default
(say 8K) so the script can run small while staying honest about the
ceiling. Hardcoded 65536 with no docstring on what it means is a footgun.
## M3 — Allocator churn from gather (D3, D4) compounds at 1M
Repeated for emphasis with numbers: the per-token `torch.cat` in the KV
gather allocates and frees memory proportional to context length. At 1M
context with HCA at all layers, that's **~230 MB of alloc/free per
decoded token**. PyTorch caching allocator handles this but it's still
fragmentation pressure across thousands of decoded tokens. After hours
of decoding, the allocator's cached blocks bloat.
Already covered in D3/D4. Restating because at 1M scale it's no longer
"small overhead" — it's GB/min of churn.
## M4 — `get_swa` does `.clone()` every call
`:457460`:
```python
def get_swa(self):
if self.swa_len == 0: return torch.zeros(0, self.hd, ...), torch.zeros(0, ...)
if self.swa_len < self.ws: return self.swa[:self.swa_len].clone(), self.swa_pos[:self.swa_len].clone()
idx = torch.arange(self.swa_head, self.swa_head + self.ws) % self.ws
return self.swa[idx].clone(), self.swa_pos[idx].clone()
```
Three clones in the second return path (and the third one allocates an
arange too). At `ws=128 hd=512`, the SWA tensor is 128 KB — small — but
**this runs every layer every token**. 61 × decoded tokens × 128 KB ≈ a
few MB/token in allocator pressure. Same fix pattern as D3: return views
into the ring buffer, let the FMHA gather kernel consume them with strides.
## M5 — KV gather memory could be eliminated entirely with a smarter kernel
D3 and D4 both pre-allocate a fused `all_kv` buffer because the FMHA
takes one K/V tensor. The deeper fix is to have the FMHA take **two K/V
inputs** (compressed + SWA) and handle them with masking inside. Then no
fused buffer ever has to exist — the kernel reads compressed entries
directly from `comp_kv_buf` (with a gather indices vector for the top-k
selection) and SWA entries directly from `swa_buf`.
This is the **right** long-term design (and matches how the paper §3.5.1
envisions the heterogeneous cache). It's a kernel ask, not a script fix,
but it's worth flagging now: the gather → cat → FMHA pattern is *the*
memory inefficiency at long context, and it can be designed out at the
kernel boundary.
---
# PART 3 — PRIORITY ORDER
These are sequenced by **what blocks 1M context viability**, not by
implementation cost.
| # | Item | Required for 1M? | Effort |
|---|---|---|---|
| **A1** | **D1: Fix `comp_idx_buf` width to actual indexer key width** | **Critical — silent corruption today** | XS |
| **A2** | **M1: Quantize indexer KV to FP4** (paper §5.2.1) | **Critical — saves 121 GB at 1M** | M-L |
| **A3** | **M2: Make `max_comp` per-layer-type + a CLI flag** | **Critical — current ceiling is 262K** | XS |
| **A4** | D3/D4/M3: Preallocate `all_kv_buf`, eliminate `torch.cat` | High — perf and stability over hours of decode | S |
| **A5** | D7: Investigate mHC residual growth (Sinkhorn convergence print) | High — likely root of decode degradation | S |
| **A6** | M4: Return SWA views, no clone | Medium — small per-call, large in aggregate | XS |
| **A7** | D2: Rename `Compressor` → split into `MainCompressor` and `IndexerKeyCompressor` | Medium — clarifies the duplicate-build pattern | XS |
| **A8** | D6: Verify sink merge semantics with kernel author | Medium — possible silent numerical drift | S |
| **A9** | M5: Two-buffer FMHA kernel (eliminate gather buffer entirely) | Long-term — production design | L |
**The "should I run on 1M context tomorrow" answer is no, regardless of
anything else, until A1/A2/A3 are done.** Without A1 you get garbage
top-k. Without A2 you OOM at ~250K tokens even with 8×B200. Without A3
you crash at 262K. Together those three are the gating set.
**The "is it still architecturally DSv4?" answer is yes — mostly.** The
hot path is faithful to the paper: CSA does overlapped-2m compression
with softmax weights, HCA does heavy non-overlapped compression, indexer
does ReLU(QK)·w_h reduction, attention concatenates selected-compressed
+ SWA, sinks are applied. The drifts above (D1, D2, D6, D7) are
implementation flaws or unfinished wiring, not architectural deviations.
---
# DOCTRINE — applies to every priority above
1. **DSL wall → raw CUDA C++, not Python.** Most of the fixes above are
pre-allocation and shape correctness, not new kernels. The two
exceptions (A2 FP4 indexer KV, A9 two-buffer FMHA) are kernel work and
must follow doctrine: tcgen05/UMMA/TMA, not scalar.
2. **Raw CUDA ≠ scalar math.** When A2 lands the FP4 indexer cache, the
dequant on the read side must use `__constant__` LUT (per the original
indexer LUT fix from issue #1), not branch arithmetic.
3. **Print, don't guess.** A1's fix is the canonical example: do not
assume the indexer key width is `head_dim` or `n_ih * ihd` — print
`indexer.compressor.forward(...)[0].shape` on first call and code to
that. The current bug exists *because* someone wrote `head_dim`
thinking it was right.
4. **Integration over exploration.** No `KVCache_v2`. Edit `KVCache`.
The 4 fixes (A1/A3/A4/A6) are surgical edits to one class.
5. **Falsifiable gates.** Numbers to hit:
- A1: `comp_idx_buf.shape[1] == indexer.compressor.kv_dim // 2` (or
whatever the print reveals). Test: run with VERBOSE=2 at 100 tokens
of context; no shape mismatch, no garbage in top-k selection.
- A2: KV cache footprint at 1M context (measured via
`torch.cuda.memory_allocated()` after prefill) drops from ~131 GB to
≤ 15 GB. Recall@1024 vs FP32 indexer oracle ≥ 99.7% per paper.
- A3: A `--max-context N` flag works; running with `N=1048576` does
not OOM during prefill (it might be slow — that's a separate fight).
- A4: `torch.cuda.memory_reserved()` measured every 10 decode steps is
flat (±50 MB) across 1000 steps.

75
probe_indexer_shapes.py Normal file
View File

@@ -0,0 +1,75 @@
#!/usr/bin/env python3
"""Probe indexer and compressor weight shapes from the checkpoint.
This tells us the ACTUAL dimensions, not what we assume.
Run via: fire_b200_test probe_indexer_shapes.py
"""
import json, sys
from pathlib import Path
from safetensors.torch import load_file
CHECKPOINT = "/root/nvidia-meeting/DeepSeek-V4-Pro-NVFP4"
def main():
cdir = Path(CHECKPOINT)
with open(cdir / "config.json") as f:
cfg = json.load(f)
n_layers = cfg["num_hidden_layers"]
n_ih = cfg.get("index_n_heads", 64)
ihd = cfg.get("index_head_dim", 128)
hd = cfg["head_dim"]
cr = cfg.get("compress_ratios", [128] * n_layers)
print(f"Config: n_ih={n_ih}, ihd={ihd}, hd={hd}")
print(f"n_ih * ihd = {n_ih * ihd}")
print(f"2 * ihd = {2 * ihd}")
print(f"2 * hd = {2 * hd}")
print(f"Compress ratios: first5={cr[:5]}")
print()
# Load weight map to find indexer weights
idx_file = cdir / "model.safetensors.index.json"
if idx_file.exists():
with open(idx_file) as f:
wmap = json.load(f).get("weight_map", {})
# Find indexer/compressor weights for layer 2 (first CSA layer)
for li in [0, 1, 2, 3]:
pfx = f"model.layers.{li}.self_attn"
print(f"\n=== Layer {li} (ratio={cr[li] if li < len(cr) else '?'}) ===")
for k in sorted(wmap.keys()):
if k.startswith(pfx) and ('compressor' in k or 'indexer' in k or 'q_b_proj' in k or 'kv_proj' in k or 'gate_proj' in k):
shard = cdir / wmap[k]
print(f" {k} -> shard {wmap[k]}")
else:
print("No index file, loading all weights...")
# Actually load some weights and print shapes
# Just load the first shard to get shapes
print("\n=== Loading weight shapes ===")
all_w = {}
if idx_file.exists():
shards = set(wmap.values())
for sn in sorted(shards):
sf = cdir / sn
if sf.exists():
w = load_file(str(sf))
# Only print relevant keys
for k, v in w.items():
if ('compressor' in k or 'indexer' in k) and 'layers.2' in k:
print(f" {k}: shape={list(v.shape)} dtype={v.dtype}")
del w
# Also check q_b_proj for layer 2
print("\n=== Layer 2 attention projection shapes ===")
for sn in sorted(shards):
sf = cdir / sn
if sf.exists():
w = load_file(str(sf))
for k, v in w.items():
if 'layers.2.self_attn' in k and ('q_b' in k or 'kv_proj' in k or 'gate_proj' in k):
print(f" {k}: shape={list(v.shape)} dtype={v.dtype}")
del w
if __name__ == "__main__":
main()