Files
nvfp4-megamoe-kernel/ARCHITECTURE_AND_MEMORY_AUDIT.md
biondizzle 617da29a5b fix: assert topk_idx is not None in CSA layers — no silent fallback to SWA-only
The indexer silently returning None caused CSA layers to attend over only the
SWA window (128 tokens), not the compressed sparse KV. This went undetected
because the model still produced plausible output at short context. The assert
makes any future indexer regression immediately visible.
2026-06-02 06:14:23 +00:00

20 KiB
Raw Blame History

ARCHITECTURE & MEMORY AUDIT — Post-probe rewrite

Supersedes: the prior ARCHITECTURE_AND_MEMORY.md (M1 was wrong by 64× in the bad direction). Incorporates the indexer probe results from archived_plans/INDEXER_PROBE_RESULTS_20260602.md.

Method. Every claim verified against single_shot_inference.py v16 + the probe results. Per doctrine.


TL;DR — the picture is much better than the prior audit suggested

The architecture is faithful to the paper. The 1M-context memory story is fine on 8×B200. There is no looming OOM crisis.

That said, the probe surfaced a finding bigger than memory: the lightning indexer has never actually run in any production decode to date. Paris-back is real, but it ran via dense attention over the full compressed KV history in CSA layers — the sparse-selection path was silently bypassed because the indexer's internal compressor never loaded its weights. The system has been correct because the fallback was algebraically correct, not because the designed CSA path was working.

This is good news. It means:

  1. Fixing the indexer is the next correctness milestone. It unlocks the actual sparse path, which is what makes 1M context tractable at runtime (not memory-wise — speed-wise, since dense over 250K compressed entries per CSA layer per token is the actual perf wall, not KV storage).
  2. Memory at 1M is dominated by the main compressed KV cache (~10 GB total across all CSA+HCA+SWA layers), which is small enough that the prior audit's "131 GB" panic was wrong. No FP4 quantization of the indexer cache is needed for memory reasons. (It is still wanted for throughput per paper §5.2.1, but that's a different fight.)
  3. Three small bugs are blocking the indexer from running correctly. Two are surface (weight-path + buffer-width); one is deeper (the scoring einsum's algebra is wrong, treating MQA-on-indexer as full multi-head). All three are easy fixes once seen.

PART 1 — WHAT THE PROBE REVEALED

The probe confirmed hypothesis A from the prerequisite doc and surfaced two collateral findings. The combined picture:

F1 — Indexer keys are c_I = 128-wide, MQA-on-indexer (paper-aligned)

comp_indexer_kv.shape == (n_comp, 128). One vector per compressed block, shared across all n_ih = 64 indexer query heads. This is the standard multi-query-attention shape, but applied to the indexer scoring path.

Per-block cost: 128 × 2 bytes = 256 B per compressed block per CSA layer. At 1M context (CSA ratio=4 → 250K compressed blocks):

  • Per CSA layer: 250K × 256 B = 64 MB
  • × 30 CSA layers = ~1.9 GB total for indexer KV at 1M context

That's small. ~6× smaller than the main compressed KV cache. The prior audit's M1 ("indexer KV is 125 GB at 1M, OOM at 250K tokens") was backwards — the indexer cache is the smallest of the three KV streams.

F2 — The indexer compressor never loaded weights (the real bug)

Indexer.load:392:

if f"{pfx}.compressor.kv_proj.weight" in w:
    self.compressor = Compressor(4, self.ihd, 7168, dev)

The checkpoint stores the indexer's compressor weights at *.indexer.kv_proj.weight, not *.indexer.compressor.kv_proj.weight. So this if was always False, self.compressor stayed None, and Indexer.forward always returned None at the early-return guard (line 397: if ... comp_indexer_kv is None or comp_indexer_kv.shape[0] == 0: return None).

What this means for every Paris-back run to date:

  • CSA layers received topk_idx = None from the indexer.
  • The gather path at forward_attention:569571 checks if ratio == 4 and topk_idx is not None: → False, so it falls through to elif ratio > 4: all_kv = torch.cat([kv_cache.comp_kv, swa_kv], ...). Wait — that branch is for ratio > 4 (HCA), not ratio == 4 (CSA). Need to check what CSA actually did with topk_idx=None.

The agent should verify which fallback path CSA actually took, and confirm whether the existing test runs were:

  • (a) attending over just SWA (correct only at short context, since SWA window is 128 — would explain why Paris works but degrades past step 10),
  • (b) attending over the full compressed history as if it were HCA (correct but slow at scale), or
  • (c) producing no attention output at all and being saved by a downstream operation.

This is a 10-line print insertion at forward_attention, not an investigation campaign. Add it to the indexer-fix work below, do not spin up a separate probe.

F3 — The scoring einsum has the wrong algebra (MQA vs per-head keys)

The current code at Indexer.forward:404:

k_idx = comp_indexer_kv.reshape(n_comp, self.n_ih, self.ihd)
scores = torch.einsum('tnd,cnd->tnc', q_idx.float(), k_idx.float())

The reshape requires comp_indexer_kv to have n_ih × ihd = 8192 elements per block. The probe shows it actually has ihd = 128 elements. So the reshape raises today.

The temptation is to "fix" this by widening comp_idx_buf to 8192. That would let the reshape succeed and produce numerically plausible scores. It would be wrong. The paper's scoring formula (§2.3.1, eq. 16) is:

I[t,s] = Σ_h w^I_{t,h} · ReLU(q^I_{t,h} · K^IComp_s)

K^IComp_s has no head subscript. It's one key vector per block, shared across all n_ih indexer query heads. The score is computed by dotting each of the 64 query heads against the same key, applying ReLU, then weighting and summing across heads. That's MQA — the same trick used for the main attention path in DSv4 (§2.3.1 "Shared Key-Value MQA").

The correct einsum:

# q_idx: (T, n_ih, ihd) = (T, 64, 128)
# k_idx: (n_comp, ihd) = (n_comp, 128)        <-- no head dim
# w_h:   (T, n_ih) = (T, 64)
scores = torch.einsum('tnd,cd->tnc', q_idx.float(), k_idx.float())  # 'cd', not 'cnd'
scores = F.relu(scores)
total = (scores * w_h.unsqueeze(-1).float()).sum(1)  # (T, n_comp)
tk = min(self.top_k, n_comp)
_, idx = total.topk(tk, -1)
return idx

The k_idx.reshape(n_comp, self.n_ih, self.ihd) line goes away entirely — no reshape needed when keys are MQA-shared.

Why this matters beyond "the reshape stops crashing": without this correction, an agent fixing F2 (load the indexer compressor) and "fixing" F3 by widening the buffer would produce silently wrong top-k selections. Same shape as the original indexer LUT bug — code runs, produces plausible numbers, but the ranking of compressed blocks is corrupted because the math doesn't match the model. Recall@k drops from paper's 99.7% to something much lower, and we'd be back to debugging "model gets dumber at long context" by ripping apart the FMHA kernel that isn't broken.

F4 — The buffer width is wrong but smaller than the prior audit claimed

KVCache:419:

self.comp_idx_buf = torch.zeros(max_comp, head_dim, dtype=torch.bfloat16, ...)
                                          ^^^^^^^^
                                          512  should be 128

head_dim = 512 (main attention head dim). Indexer keys want c_I = 128. The buffer is 4× too wide, not 16× as the prior audit assumed. Storage waste at 1M context (CSA only): 30 layers × 250K × (512 - 128) × 2 bytes = 5.7 GB wasted. Real, fixable, not catastrophic.

The fix needs a value to use, and that value should come from the indexer instance, not hard-coded:

# In __init__:
self.comp_idx_buf = torch.zeros(
    max_comp,
    indexer_key_dim,   # passed from caller, = indexer.ihd = 128
    dtype=torch.bfloat16, device=device,
)

The construction site at single_shot_inference.py (where KVCache is created per layer) needs to pass indexer.ihd for CSA layers and skip the buffer for HCA layers (which have no indexer).


PART 2 — MEMORY AT 1M CONTEXT, REVISED

The numbers below replace the prior audit's. They are conservative and worst-case.

Per-layer KV cache sizes — read off the (corrected) code

Component Per token (compressed) Bytes / token × 1M tokens
CSA main compressed (1 entry / 4 tokens, hd=512 BF16) 0.25 × 1024 B 256 B 256 MB
CSA indexer keys (1 entry / 4 tokens, c_I=128 BF16) 0.25 × 256 B 64 B 64 MB
HCA compressed (1 entry / 128 tokens, hd=512 BF16) 0.0078 × 1024 B 8 B 8 MB
SWA per layer (ring buffer, 128 × hd × 2) constant 128 KB

Total KV cache @ 1M context, all layers, BF16:

Layer type Count Per-layer @ 1M Total
CSA: main + indexer 30 256 MB + 64 MB 9.6 GB
HCA: main 30 8 MB 240 MB
SWA 61 128 KB 8 MB
GRAND TOTAL @ 1M, BF16 ~9.9 GB

~10 GB of KV state for a 1M-token context. On 8×B200 (192 GB each, 1.5 TB total) that's negligible — about 0.7% of total HBM, or ~1.25 GB per GPU if sharded EP-style alongside the experts. The system has plenty of memory headroom for the design target.

For comparison, DeepSeek-V3.2's KV cache at 1M context is ~92 GB (per V4 paper Figure 1). V4 at ~10 GB is a 9× reduction — which is exactly the "~10% of V3.2's KV cache" claim from the paper. The implementation hits the design memory budget; the prior audit was wrong about how it gets there.

What this changes about priorities

  • "Quantize indexer KV to FP4 to save 121 GB" is gone. It was based on a wrong width. The indexer cache is 2 GB at 1M; FP4 would shrink it to 500 MB. Nice; not urgent.
  • "max_comp = 65536 is the ceiling at 262K tokens" is still real. That hardcoded buffer size hasn't changed. At 1M context CSA needs max_comp_csa = 262144. Still a config fix, just not paired with a quantization fight.
  • "Allocator churn from torch.cat in the gather" is still real and still gets worse with context length. Pre-allocation still matters at long context for perf and stability over hours of decoding. Just not urgent for "does it fit in memory."

PART 3 — PRIORITY ORDER (REVISED)

Sequenced by what unblocks correctness first, then performance, then memory. The big shift from the prior audit: the indexer fix is the gating correctness work; memory is no longer the crisis it was framed as.

Tier 1 — Make the indexer actually work (correctness)

These are all small edits but they have to land together. The agent should treat this as one atomic landing, not three independent fixes, because individually each one either does nothing or makes things worse.

A1 — Fix the indexer compressor weight path

Indexer.load:392. Change the check and the load prefix to match the checkpoint:

# Was:
if f"{pfx}.compressor.kv_proj.weight" in w:
    self.compressor = Compressor(4, self.ihd, 7168, dev)
    self.compressor.load(w, f"{pfx}.compressor", dev)
# Should be (read the actual key from the checkpoint, not assumed):
if f"{pfx}.kv_proj.weight" in w:
    self.compressor = Compressor(4, self.ihd, 7168, dev)
    self.compressor.load(w, f"{pfx}", dev)

The agent's probe already identified this — verify the fix is in v17 by running a checkpoint-loaded forward and confirming self.compressor is not None for at least one CSA layer.

A2 — Fix comp_idx_buf width to c_I = 128

KVCache:419. Plumb indexer_key_dim through KVCache.__init__ (or better: derive it from a probe of the indexer's compressor on first call). Default for non-CSA layers: skip the buffer.

A3 — Fix the scoring einsum to MQA-on-indexer

Indexer.forward:404. Drop the head-axis reshape and use 'tnd,cd->tnc' as shown in F3 above. This is the deeper correctness fix and the easiest one to get wrong if A1+A2 land first and an agent "fixes" the reshape by widening the buffer.

Gate for Tier 1:

  1. Indexer.forward returns a non-None idx tensor for every CSA layer on a prompt of ≥ 4 tokens. Verify with a print on layer 0.
  2. forward_attention at CSA layers takes the if ratio == 4 and topk_idx is not None branch, not the fallback.
  3. Paris-back still works. Output is identical-or-better than v16's Paris-back (since v16 was running the dense fallback, which is a correctness superset of CSA — it attends over more keys, not fewer).
  4. Recall test: compare the top-k indices from the indexer against an FP32 oracle (just compute the scoring in FP32 outside the kernel and topk on that). Recall ≥ 99% at top_k=512 with n_comp ≥ 1024.

Tier 2 — Verify what the fallback was actually doing (cleanup)

B1 — Find and document the v16 CSA fallback path

forward_attention:569571: when topk_idx was always None, what actually happened in CSA layers? The branches as read:

if ratio == 4 and topk_idx is not None:        # never taken
    all_kv = torch.cat([kv_cache.comp_kv[tk], swa_kv], dim=0)
elif ratio > 4:                                 # only HCA layers
    all_kv = torch.cat([kv_cache.comp_kv, swa_kv], dim=0)

For CSA with topk_idx=None and ratio == 4, neither branch fires. What all_kv is at that point depends on what came before. The agent should run a 5-line probe in v16 (or look at the bisected behavior) to confirm whether v16 CSA layers:

  • attended over just SWA (would explain decode degradation past step 10),
  • attended over the full compressed history (would explain decode working but being slower than necessary),
  • crashed at this point and something downstream rescued the run (most likely if Paris-back still happened).

This is informational — it doesn't gate Tier 1 — but it answers "what exactly did 'Paris-back' validate?" and it tells you whether decode quality should jump (if v16 was on SWA-only) or stay flat (if v16 was on full compressed) when Tier 1 lands.

B2 — Once Tier 1 lands, add explicit error on topk_idx is None in CSA

The fact that the CSA fallback was silent for this long is the meta-bug. After Tier 1, the CSA path should require topk_idx is not None:

if ratio == 4:
    assert topk_idx is not None, f"CSA layer {li} got no top-k from indexer — indexer is broken"
    tk = topk_idx[0].clamp(0, kv_cache.n_comp - 1)
    all_kv = torch.cat([kv_cache.comp_kv[tk], swa_kv], dim=0)
elif ratio > 4:
    all_kv = torch.cat([kv_cache.comp_kv, swa_kv], dim=0)

This is a tripwire for future regressions of the same shape.

Tier 3 — Memory & allocator hygiene (still real, just not urgent)

C1 — max_comp per-layer-type + CLI flag

KVCache.__init__:411. Make max_comp a function of context length and compress ratio:

def __init__(self, head_dim, indexer_key_dim, compress_ratio,
             window_size=128, target_context=8192, device='cuda:0'):
    self.max_comp = (target_context + compress_ratio - 1) // compress_ratio
    ...

And expose target_context as a CLI arg (--max-context). Default small (8192) so the script stays runnable.

C2 — Pre-allocate all_kv_buf, eliminate torch.cat in gather

Same fix as D3/D4 in the prior audit — still valid:

# Once at init:
self.all_kv_buf = torch.zeros(max_top_k + window_size, head_dim, ...)

Gather writes into views of this buffer with out= arguments. FMHA consumes the prefix. Zero allocs on hot path.

C3 — KVCache.get_swa returns views, not clones

KVCache:457460. Drop the .clone() calls. Return slices.

C4 — Optional: Quantize indexer KV to FP4 (paper §5.2.1)

For throughput, not memory. Defer until E7 (Stage F indexer FP4 tensor-core scoring) lands — at that point the FP4 storage and FP4 MMA path are paired, which is the right shape. Don't quantize the cache without also upgrading the scoring kernel — that would be storage savings paid for with a dequant kernel that doesn't exist yet.

Tier 4 — Architecture fidelity nice-to-haves

D1 — Split Compressor class into MainCompressor and IndexerKeyCompressor

single_shot_inference.py:272. Same class is instantiated with totally different config in two places. Splitting documents the difference and prevents the "I assumed it was the same thing" bug class (which is how the buffer width bug happened in the first place).

D2 — Verify sink merge semantics (D6 from prior audit, unchanged)

_run_production_fmha:489 passes n_comp=0 always. The kernel may expect n_comp = len(compressed_kv) for the D5c sink merge. Print the kernel's actual handling, confirm or fix.

D3 — Understand mHC residual growth (D7 from prior audit, unchanged)

|X| → 500-700 at L60 still indicates Sinkhorn B isn't doubly-stochastic at runtime. Print B row/col sums, expect 1.0 ± 1e-6. This may also partly explain the decode degradation past step 10 (compounding non-bounded residuals → saturated logits → low-information argmax). Tier 1 fixing the indexer may improve decode behavior enough that this stops mattering — but worth still checking once the indexer is correct.


REVISED PRIORITY TABLE

# Item What it unblocks Effort Blocks 1M?
A1 Fix indexer compressor weight path Indexer runs at all XS Yes — correctness
A2 comp_idx_buf width = 128 (not 512) Indexer can store keys XS Yes — correctness
A3 Scoring einsum 'tnd,cd->tnc' Top-k is correct XS Yes — correctness
B1 Document the v16 CSA fallback Knowing what Paris validated XS No
B2 Assert topk_idx is not None in CSA Future regression tripwire XS No
C1 Per-layer max_comp + --max-context Long context doesn't crash at 262K XS Yes — but trivial
C2 Pre-alloc all_kv_buf, kill cat Stable decode over hours S No, but real perf
C3 get_swa returns views Small but everywhere XS No
C4 FP4 indexer cache (paired with E7) Throughput, paper compliance M-L No
D1 Split Compressor classes for clarity Prevents the same-class-confusion bug XS No
D2 Sink merge semantics check Subtle numerics S No
D3 mHC Sinkhorn convergence check Decode degradation S No

Land A1+A2+A3 together as one atomic correctness fix. That is the critical path. Everything else is sequential and not gating.


DOCTRINE — applies to every priority

  1. DSL wall → raw CUDA C++, not Python. Doesn't apply much in this round — most fixes are 3-line edits to Python orchestration. The exception is C4 (FP4 indexer cache) which is a kernel fight and must follow doctrine: tcgen05/UMMA/TMA on the read side, __constant__ LUT for any dequant, paired with the E7 scoring kernel.

  2. Raw CUDA ≠ scalar math. Same — when C4 lands, the indexer's tcgen05.mma FP4 path replaces the scoring einsum. The current FP32 einsum (post-fix) is a correctness oracle, not a perf target.

  3. Print, don't guess. This entire round exists because of a probe that printed instead of assuming. The pattern works. Use it again for:

    • B1: probe what the v16 CSA fallback actually returned.
    • C2: print all_kv shape and dtype to verify the pre-allocated buffer is being sliced correctly.
    • D3: print Sinkhorn row/col sums per layer. Stop running new code until the probes have written their output to a .md next to this one.
  4. Integration over exploration. No Indexer_v2, no KVCache_v2. Edit the existing classes. Tier 1 is ~10 line-edits total across 3 functions.

  5. Falsifiable gates. Already listed per priority above. The meta-gate for the whole audit: after Tier 1, the indexer's top-k recall vs an FP32 oracle is ≥ 99% on a prompt with n_comp ≥ 1024. Until that number is measured and recorded, "the indexer works" is an assertion, not a fact.

  6. Don't optimize for a problem you don't have. The prior audit's biggest mistake was framing memory as a 1M-context crisis based on a wrong width. The real picture is: V4 hit its KV cache memory targets, the implementation is faithful, the actual blocker is a handful of small bugs in the sparse-selection path. Fix those first and re-measure before adding new infrastructure.