The indexer silently returning None caused CSA layers to attend over only the SWA window (128 tokens), not the compressed sparse KV. This went undetected because the model still produced plausible output at short context. The assert makes any future indexer regression immediately visible.
20 KiB
ARCHITECTURE & MEMORY AUDIT — Post-probe rewrite
Supersedes: the prior ARCHITECTURE_AND_MEMORY.md (M1 was wrong by 64×
in the bad direction). Incorporates the indexer probe results from
archived_plans/INDEXER_PROBE_RESULTS_20260602.md.
Method. Every claim verified against single_shot_inference.py v16 + the
probe results. Per doctrine.
TL;DR — the picture is much better than the prior audit suggested
The architecture is faithful to the paper. The 1M-context memory story is fine on 8×B200. There is no looming OOM crisis.
That said, the probe surfaced a finding bigger than memory: the lightning indexer has never actually run in any production decode to date. Paris-back is real, but it ran via dense attention over the full compressed KV history in CSA layers — the sparse-selection path was silently bypassed because the indexer's internal compressor never loaded its weights. The system has been correct because the fallback was algebraically correct, not because the designed CSA path was working.
This is good news. It means:
- Fixing the indexer is the next correctness milestone. It unlocks the actual sparse path, which is what makes 1M context tractable at runtime (not memory-wise — speed-wise, since dense over 250K compressed entries per CSA layer per token is the actual perf wall, not KV storage).
- Memory at 1M is dominated by the main compressed KV cache (~10 GB total across all CSA+HCA+SWA layers), which is small enough that the prior audit's "131 GB" panic was wrong. No FP4 quantization of the indexer cache is needed for memory reasons. (It is still wanted for throughput per paper §5.2.1, but that's a different fight.)
- Three small bugs are blocking the indexer from running correctly. Two are surface (weight-path + buffer-width); one is deeper (the scoring einsum's algebra is wrong, treating MQA-on-indexer as full multi-head). All three are easy fixes once seen.
PART 1 — WHAT THE PROBE REVEALED
The probe confirmed hypothesis A from the prerequisite doc and surfaced two collateral findings. The combined picture:
F1 — Indexer keys are c_I = 128-wide, MQA-on-indexer (paper-aligned)
comp_indexer_kv.shape == (n_comp, 128). One vector per compressed block,
shared across all n_ih = 64 indexer query heads. This is the standard
multi-query-attention shape, but applied to the indexer scoring path.
Per-block cost: 128 × 2 bytes = 256 B per compressed block per CSA layer. At 1M context (CSA ratio=4 → 250K compressed blocks):
- Per CSA layer: 250K × 256 B = 64 MB
- × 30 CSA layers = ~1.9 GB total for indexer KV at 1M context
That's small. ~6× smaller than the main compressed KV cache. The prior audit's M1 ("indexer KV is 125 GB at 1M, OOM at 250K tokens") was backwards — the indexer cache is the smallest of the three KV streams.
F2 — The indexer compressor never loaded weights (the real bug)
Indexer.load:392:
if f"{pfx}.compressor.kv_proj.weight" in w:
self.compressor = Compressor(4, self.ihd, 7168, dev)
The checkpoint stores the indexer's compressor weights at
*.indexer.kv_proj.weight, not *.indexer.compressor.kv_proj.weight.
So this if was always False, self.compressor stayed None, and
Indexer.forward always returned None at the early-return guard (line
397: if ... comp_indexer_kv is None or comp_indexer_kv.shape[0] == 0: return None).
What this means for every Paris-back run to date:
- CSA layers received
topk_idx = Nonefrom the indexer. - The gather path at
forward_attention:569–571checksif ratio == 4 and topk_idx is not None:→ False, so it falls through toelif ratio > 4: all_kv = torch.cat([kv_cache.comp_kv, swa_kv], ...). Wait — that branch is forratio > 4(HCA), notratio == 4(CSA). Need to check what CSA actually did with topk_idx=None.
The agent should verify which fallback path CSA actually took, and confirm whether the existing test runs were:
- (a) attending over just SWA (correct only at short context, since SWA window is 128 — would explain why Paris works but degrades past step 10),
- (b) attending over the full compressed history as if it were HCA (correct but slow at scale), or
- (c) producing no attention output at all and being saved by a downstream operation.
This is a 10-line print insertion at forward_attention, not an
investigation campaign. Add it to the indexer-fix work below, do not
spin up a separate probe.
F3 — The scoring einsum has the wrong algebra (MQA vs per-head keys)
The current code at Indexer.forward:404:
k_idx = comp_indexer_kv.reshape(n_comp, self.n_ih, self.ihd)
scores = torch.einsum('tnd,cnd->tnc', q_idx.float(), k_idx.float())
The reshape requires comp_indexer_kv to have n_ih × ihd = 8192 elements
per block. The probe shows it actually has ihd = 128 elements. So the
reshape raises today.
The temptation is to "fix" this by widening comp_idx_buf to 8192.
That would let the reshape succeed and produce numerically plausible
scores. It would be wrong. The paper's scoring formula (§2.3.1, eq.
16) is:
I[t,s] = Σ_h w^I_{t,h} · ReLU(q^I_{t,h} · K^IComp_s)
K^IComp_s has no head subscript. It's one key vector per block, shared
across all n_ih indexer query heads. The score is computed by dotting
each of the 64 query heads against the same key, applying ReLU, then
weighting and summing across heads. That's MQA — the same trick used for
the main attention path in DSv4 (§2.3.1 "Shared Key-Value MQA").
The correct einsum:
# q_idx: (T, n_ih, ihd) = (T, 64, 128)
# k_idx: (n_comp, ihd) = (n_comp, 128) <-- no head dim
# w_h: (T, n_ih) = (T, 64)
scores = torch.einsum('tnd,cd->tnc', q_idx.float(), k_idx.float()) # 'cd', not 'cnd'
scores = F.relu(scores)
total = (scores * w_h.unsqueeze(-1).float()).sum(1) # (T, n_comp)
tk = min(self.top_k, n_comp)
_, idx = total.topk(tk, -1)
return idx
The k_idx.reshape(n_comp, self.n_ih, self.ihd) line goes away entirely —
no reshape needed when keys are MQA-shared.
Why this matters beyond "the reshape stops crashing": without this correction, an agent fixing F2 (load the indexer compressor) and "fixing" F3 by widening the buffer would produce silently wrong top-k selections. Same shape as the original indexer LUT bug — code runs, produces plausible numbers, but the ranking of compressed blocks is corrupted because the math doesn't match the model. Recall@k drops from paper's 99.7% to something much lower, and we'd be back to debugging "model gets dumber at long context" by ripping apart the FMHA kernel that isn't broken.
F4 — The buffer width is wrong but smaller than the prior audit claimed
KVCache:419:
self.comp_idx_buf = torch.zeros(max_comp, head_dim, dtype=torch.bfloat16, ...)
^^^^^^^^
512 — should be 128
head_dim = 512 (main attention head dim). Indexer keys want c_I = 128.
The buffer is 4× too wide, not 16× as the prior audit assumed. Storage
waste at 1M context (CSA only): 30 layers × 250K × (512 - 128) × 2 bytes
= 5.7 GB wasted. Real, fixable, not catastrophic.
The fix needs a value to use, and that value should come from the indexer instance, not hard-coded:
# In __init__:
self.comp_idx_buf = torch.zeros(
max_comp,
indexer_key_dim, # passed from caller, = indexer.ihd = 128
dtype=torch.bfloat16, device=device,
)
The construction site at single_shot_inference.py (where KVCache is
created per layer) needs to pass indexer.ihd for CSA layers and skip
the buffer for HCA layers (which have no indexer).
PART 2 — MEMORY AT 1M CONTEXT, REVISED
The numbers below replace the prior audit's. They are conservative and worst-case.
Per-layer KV cache sizes — read off the (corrected) code
| Component | Per token (compressed) | Bytes / token | × 1M tokens |
|---|---|---|---|
| CSA main compressed (1 entry / 4 tokens, hd=512 BF16) | 0.25 × 1024 B | 256 B | 256 MB |
| CSA indexer keys (1 entry / 4 tokens, c_I=128 BF16) | 0.25 × 256 B | 64 B | 64 MB |
| HCA compressed (1 entry / 128 tokens, hd=512 BF16) | 0.0078 × 1024 B | 8 B | 8 MB |
| SWA per layer (ring buffer, 128 × hd × 2) | constant | — | 128 KB |
Total KV cache @ 1M context, all layers, BF16:
| Layer type | Count | Per-layer @ 1M | Total |
|---|---|---|---|
| CSA: main + indexer | 30 | 256 MB + 64 MB | 9.6 GB |
| HCA: main | 30 | 8 MB | 240 MB |
| SWA | 61 | 128 KB | 8 MB |
| GRAND TOTAL @ 1M, BF16 | ~9.9 GB |
~10 GB of KV state for a 1M-token context. On 8×B200 (192 GB each, 1.5 TB total) that's negligible — about 0.7% of total HBM, or ~1.25 GB per GPU if sharded EP-style alongside the experts. The system has plenty of memory headroom for the design target.
For comparison, DeepSeek-V3.2's KV cache at 1M context is ~92 GB (per V4 paper Figure 1). V4 at ~10 GB is a 9× reduction — which is exactly the "~10% of V3.2's KV cache" claim from the paper. The implementation hits the design memory budget; the prior audit was wrong about how it gets there.
What this changes about priorities
- "Quantize indexer KV to FP4 to save 121 GB" is gone. It was based on a wrong width. The indexer cache is 2 GB at 1M; FP4 would shrink it to 500 MB. Nice; not urgent.
- "max_comp = 65536 is the ceiling at 262K tokens" is still real. That
hardcoded buffer size hasn't changed. At 1M context CSA needs
max_comp_csa = 262144. Still a config fix, just not paired with a quantization fight. - "Allocator churn from
torch.catin the gather" is still real and still gets worse with context length. Pre-allocation still matters at long context for perf and stability over hours of decoding. Just not urgent for "does it fit in memory."
PART 3 — PRIORITY ORDER (REVISED)
Sequenced by what unblocks correctness first, then performance, then memory. The big shift from the prior audit: the indexer fix is the gating correctness work; memory is no longer the crisis it was framed as.
Tier 1 — Make the indexer actually work (correctness)
These are all small edits but they have to land together. The agent should treat this as one atomic landing, not three independent fixes, because individually each one either does nothing or makes things worse.
A1 — Fix the indexer compressor weight path
Indexer.load:392. Change the check and the load prefix to match the
checkpoint:
# Was:
if f"{pfx}.compressor.kv_proj.weight" in w:
self.compressor = Compressor(4, self.ihd, 7168, dev)
self.compressor.load(w, f"{pfx}.compressor", dev)
# Should be (read the actual key from the checkpoint, not assumed):
if f"{pfx}.kv_proj.weight" in w:
self.compressor = Compressor(4, self.ihd, 7168, dev)
self.compressor.load(w, f"{pfx}", dev)
The agent's probe already identified this — verify the fix is in v17 by
running a checkpoint-loaded forward and confirming self.compressor is not None for at least one CSA layer.
A2 — Fix comp_idx_buf width to c_I = 128
KVCache:419. Plumb indexer_key_dim through KVCache.__init__ (or
better: derive it from a probe of the indexer's compressor on first
call). Default for non-CSA layers: skip the buffer.
A3 — Fix the scoring einsum to MQA-on-indexer
Indexer.forward:404. Drop the head-axis reshape and use 'tnd,cd->tnc'
as shown in F3 above. This is the deeper correctness fix and the easiest
one to get wrong if A1+A2 land first and an agent "fixes" the reshape by
widening the buffer.
Gate for Tier 1:
Indexer.forwardreturns a non-Noneidxtensor for every CSA layer on a prompt of ≥ 4 tokens. Verify with a print on layer 0.forward_attentionat CSA layers takes theif ratio == 4 and topk_idx is not Nonebranch, not the fallback.- Paris-back still works. Output is identical-or-better than v16's Paris-back (since v16 was running the dense fallback, which is a correctness superset of CSA — it attends over more keys, not fewer).
- Recall test: compare the top-k indices from the indexer against an FP32 oracle (just compute the scoring in FP32 outside the kernel and topk on that). Recall ≥ 99% at top_k=512 with n_comp ≥ 1024.
Tier 2 — Verify what the fallback was actually doing (cleanup)
B1 — Find and document the v16 CSA fallback path
forward_attention:569–571: when topk_idx was always None, what
actually happened in CSA layers? The branches as read:
if ratio == 4 and topk_idx is not None: # never taken
all_kv = torch.cat([kv_cache.comp_kv[tk], swa_kv], dim=0)
elif ratio > 4: # only HCA layers
all_kv = torch.cat([kv_cache.comp_kv, swa_kv], dim=0)
For CSA with topk_idx=None and ratio == 4, neither branch fires.
What all_kv is at that point depends on what came before. The agent
should run a 5-line probe in v16 (or look at the bisected behavior) to
confirm whether v16 CSA layers:
- attended over just SWA (would explain decode degradation past step 10),
- attended over the full compressed history (would explain decode working but being slower than necessary),
- crashed at this point and something downstream rescued the run (most likely if Paris-back still happened).
This is informational — it doesn't gate Tier 1 — but it answers "what exactly did 'Paris-back' validate?" and it tells you whether decode quality should jump (if v16 was on SWA-only) or stay flat (if v16 was on full compressed) when Tier 1 lands.
B2 — Once Tier 1 lands, add explicit error on topk_idx is None in CSA
The fact that the CSA fallback was silent for this long is the meta-bug.
After Tier 1, the CSA path should require topk_idx is not None:
if ratio == 4:
assert topk_idx is not None, f"CSA layer {li} got no top-k from indexer — indexer is broken"
tk = topk_idx[0].clamp(0, kv_cache.n_comp - 1)
all_kv = torch.cat([kv_cache.comp_kv[tk], swa_kv], dim=0)
elif ratio > 4:
all_kv = torch.cat([kv_cache.comp_kv, swa_kv], dim=0)
This is a tripwire for future regressions of the same shape.
Tier 3 — Memory & allocator hygiene (still real, just not urgent)
C1 — max_comp per-layer-type + CLI flag
KVCache.__init__:411. Make max_comp a function of context length and
compress ratio:
def __init__(self, head_dim, indexer_key_dim, compress_ratio,
window_size=128, target_context=8192, device='cuda:0'):
self.max_comp = (target_context + compress_ratio - 1) // compress_ratio
...
And expose target_context as a CLI arg (--max-context). Default
small (8192) so the script stays runnable.
C2 — Pre-allocate all_kv_buf, eliminate torch.cat in gather
Same fix as D3/D4 in the prior audit — still valid:
# Once at init:
self.all_kv_buf = torch.zeros(max_top_k + window_size, head_dim, ...)
Gather writes into views of this buffer with out= arguments. FMHA
consumes the prefix. Zero allocs on hot path.
C3 — KVCache.get_swa returns views, not clones
KVCache:457–460. Drop the .clone() calls. Return slices.
C4 — Optional: Quantize indexer KV to FP4 (paper §5.2.1)
For throughput, not memory. Defer until E7 (Stage F indexer FP4 tensor-core scoring) lands — at that point the FP4 storage and FP4 MMA path are paired, which is the right shape. Don't quantize the cache without also upgrading the scoring kernel — that would be storage savings paid for with a dequant kernel that doesn't exist yet.
Tier 4 — Architecture fidelity nice-to-haves
D1 — Split Compressor class into MainCompressor and IndexerKeyCompressor
single_shot_inference.py:272. Same class is instantiated with totally
different config in two places. Splitting documents the difference and
prevents the "I assumed it was the same thing" bug class (which is how
the buffer width bug happened in the first place).
D2 — Verify sink merge semantics (D6 from prior audit, unchanged)
_run_production_fmha:489 passes n_comp=0 always. The kernel may
expect n_comp = len(compressed_kv) for the D5c sink merge. Print the
kernel's actual handling, confirm or fix.
D3 — Understand mHC residual growth (D7 from prior audit, unchanged)
|X| → 500-700 at L60 still indicates Sinkhorn B isn't doubly-stochastic at runtime. Print B row/col sums, expect 1.0 ± 1e-6. This may also partly explain the decode degradation past step 10 (compounding non-bounded residuals → saturated logits → low-information argmax). Tier 1 fixing the indexer may improve decode behavior enough that this stops mattering — but worth still checking once the indexer is correct.
REVISED PRIORITY TABLE
| # | Item | What it unblocks | Effort | Blocks 1M? |
|---|---|---|---|---|
| A1 | Fix indexer compressor weight path | Indexer runs at all | XS | Yes — correctness |
| A2 | comp_idx_buf width = 128 (not 512) |
Indexer can store keys | XS | Yes — correctness |
| A3 | Scoring einsum 'tnd,cd->tnc' |
Top-k is correct | XS | Yes — correctness |
| B1 | Document the v16 CSA fallback | Knowing what Paris validated | XS | No |
| B2 | Assert topk_idx is not None in CSA |
Future regression tripwire | XS | No |
| C1 | Per-layer max_comp + --max-context |
Long context doesn't crash at 262K | XS | Yes — but trivial |
| C2 | Pre-alloc all_kv_buf, kill cat |
Stable decode over hours | S | No, but real perf |
| C3 | get_swa returns views |
Small but everywhere | XS | No |
| C4 | FP4 indexer cache (paired with E7) | Throughput, paper compliance | M-L | No |
| D1 | Split Compressor classes for clarity | Prevents the same-class-confusion bug | XS | No |
| D2 | Sink merge semantics check | Subtle numerics | S | No |
| D3 | mHC Sinkhorn convergence check | Decode degradation | S | No |
Land A1+A2+A3 together as one atomic correctness fix. That is the critical path. Everything else is sequential and not gating.
DOCTRINE — applies to every priority
-
DSL wall → raw CUDA C++, not Python. Doesn't apply much in this round — most fixes are 3-line edits to Python orchestration. The exception is C4 (FP4 indexer cache) which is a kernel fight and must follow doctrine: tcgen05/UMMA/TMA on the read side,
__constant__LUT for any dequant, paired with the E7 scoring kernel. -
Raw CUDA ≠ scalar math. Same — when C4 lands, the indexer's
tcgen05.mmaFP4 path replaces the scoring einsum. The current FP32 einsum (post-fix) is a correctness oracle, not a perf target. -
Print, don't guess. This entire round exists because of a probe that printed instead of assuming. The pattern works. Use it again for:
- B1: probe what the v16 CSA fallback actually returned.
- C2: print
all_kvshape and dtype to verify the pre-allocated buffer is being sliced correctly. - D3: print Sinkhorn row/col sums per layer.
Stop running new code until the probes have written their output to
a
.mdnext to this one.
-
Integration over exploration. No
Indexer_v2, noKVCache_v2. Edit the existing classes. Tier 1 is ~10 line-edits total across 3 functions. -
Falsifiable gates. Already listed per priority above. The meta-gate for the whole audit: after Tier 1, the indexer's top-k recall vs an FP32 oracle is ≥ 99% on a prompt with n_comp ≥ 1024. Until that number is measured and recorded, "the indexer works" is an assertion, not a fact.
-
Don't optimize for a problem you don't have. The prior audit's biggest mistake was framing memory as a 1M-context crisis based on a wrong width. The real picture is: V4 hit its KV cache memory targets, the implementation is faithful, the actual blocker is a handful of small bugs in the sparse-selection path. Fix those first and re-measure before adding new infrastructure.