diff --git a/ARCHITECTURE_AND_MEMORY_AUDIT.md b/ARCHITECTURE_AND_MEMORY_AUDIT.md new file mode 100644 index 00000000..213a3928 --- /dev/null +++ b/ARCHITECTURE_AND_MEMORY_AUDIT.md @@ -0,0 +1,434 @@ +# ARCHITECTURE & MEMORY AUDIT — 1M-context viability + +**Method.** Verified against `single_shot_inference.py` v16 and the DSv4 +paper §2.3.1–§2.3.4 (CSA/HCA) and §3.5.1 (heterogeneous KV cache). Every +finding has a line number. Per doctrine. + +**Framing.** The Paris demo runs ≤ 50 tokens. The model is built for 1M. +That's a **20,000×** gap. Several things in the current single_shot are +fine at 50 tokens, will OOM hard at 4–10K tokens, and don't even resemble +the paper's KV design at 1M. Below: drift first, then memory, in order of +how badly each one blocks the 1M-context goal. + +--- + +# PART 1 — ARCHITECTURE DRIFT FROM PAPER + +## D1 — `comp_idx_buf` shape is wrong (CRITICAL — silent corruption or crash) + +`single_shot_inference.py:419`: + +```python +self.comp_idx_buf = torch.zeros(max_comp, head_dim, dtype=torch.bfloat16, device=device) + ^^^^^^^^ + 512 — WRONG +``` + +But indexer keys are `n_ih × ihd` wide. From `:1030`: + +```python +n_ih = cfg.get("index_n_heads", 64) +ihd = cfg.get("index_head_dim", 128) +``` + +So indexer keys have width `64 × 128 = 8192`, not 512. **The indexer KV +buffer is 16× too narrow.** What happens in practice depends on whether +the assignment broadcasts, raises, or silently truncates — and from the +fact that Paris-back works, the indexer probably isn't being used at all +yet (CSA layers may be producing 0 compressed blocks at 50 tokens — see +D2). At any context where CSA actually compresses, this either crashes +or stores garbage in the top-k selection input. + +**Fix.** Read the actual indexer key width from the indexer's compressor +output (`indexer.compressor.kv_dim = 2 * ihd = 256` for the indexer-side +CSA, since the indexer's compressor takes `(4, ihd, H)`). Then check what +the indexer's compressor actually produces — print its output shape on +first call instead of guessing — and size `comp_idx_buf` to match. + +Verification step: instrument `Compressor.forward` to print +`compressed.shape` on first call from both the **main** compressor (kv_dim += 512 or 1024) and the **indexer's** compressor (kv_dim = 256). Code to +the observed values. Do not infer them from variable names. + +## D2 — The Compressor is built twice per CSA layer, with different config + +`:394` (inside `Indexer.load`): +```python +self.compressor = Compressor(4, self.ihd, 7168, dev) +``` + +`:1034` (in `main()`): +```python +if ratio > 0: compressors[li] = Compressor(ratio, hd, H, dev) +``` + +For a CSA layer: +- The **layer's** compressor has `(ratio=4, hd=512, H=7168)`, output dim + `kv_dim = 2*hd = 1024`. +- The **indexer's** compressor has `(ratio=4, ihd=128, H=7168)`, output + dim `kv_dim = 2*ihd = 256`. + +Both are constructed, both load weights independently, both have their own +NVFP4 GEMM Nvfp4Linear instances. This matches the paper (§2.3.1: the +indexer has its own compressed key path that's narrower than the main +compressed KV path) — but **it is being done as two completely +independent code paths**, with the indexer's compressor's existence +implicit and easy to miss. That's why D1 was missed. + +It also means the main compressor's `forward` runs twice per layer at +prefill: once for the main KV path, once for the indexer's keys. The +hidden states being projected are *identical*; the GEMM weights are +different. Two separate launches. + +**Architecturally correct, but the code shape hides it.** Two consequences: + +1. The "compressor" abstraction has a different meaning depending on + which compressor instance you're looking at. Rename for clarity: + `Compressor` (main) and `IndexerKeyCompressor` (smaller, for indexer). +2. The Indexer should expose its compressed key width as a property + (`indexer.compressor.kv_dim // 2`) so D1 can compute the buffer width + from data instead of assuming `head_dim`. + +## D3 — KV gather still uses `torch.cat`, undoing P3's pre-allocation + +`:569–571`: + +```python +if ratio == 4 and topk_idx is not None: + tk = topk_idx[0].clamp(0, kv_cache.n_comp - 1) + all_kv = torch.cat([kv_cache.comp_kv[tk], swa_kv], dim=0) +elif ratio > 4: all_kv = torch.cat([kv_cache.comp_kv, swa_kv], dim=0) +``` + +P3 preallocated `comp_kv_buf`, but the gather **immediately allocates a +fresh `(top_k + ws, hd) = (1024 + 128, 512)` BF16 tensor per layer call** +just to pass to FMHA. For 61 layers × per-token decode: + +- 61 × (1152 × 512 × 2 bytes) ≈ **72 MB allocated and freed per token**. + +That's small in absolute terms, but it's allocator churn on the hot path, +and the entire point of P3 was to remove this pattern. The `comp_kv[tk]` +gather already allocates (it's an advanced-indexing copy); the `cat` then +allocates again to merge with SWA. Two allocs per layer that don't need +to exist. + +**Fix.** Preallocate one more buffer at cache init: + +```python +self.all_kv_buf = torch.zeros(top_k + window_size, head_dim, ...) +``` + +Write the gathered top-k into `all_kv_buf[:top_k]`, the SWA into +`all_kv_buf[top_k:top_k+swa_len]`, pass `all_kv_buf[:top_k+swa_len]` to +FMHA. The gather becomes `torch.index_select(comp_kv_buf, 0, tk, +out=all_kv_buf[:top_k])` — zero allocs. + +## D4 — The HCA path concatenates the ENTIRE compressed history, every layer, every token + +Same line, the `elif` branch: + +```python +elif ratio > 4: all_kv = torch.cat([kv_cache.comp_kv, swa_kv], dim=0) +``` + +HCA layers don't run the indexer. They attend over the **full compressed +history**. At 1M context with HCA ratio=128, that's `1M / 128 = 7813` +compressed entries. Per-token-per-layer FMHA input grows linearly with +prefill length. That's correct math (paper §2.3.2: HCA is dense over +compressed entries), but the *implementation* is allocating a new +contiguous tensor for it every single layer call. + +At 1M context, that's `(7813 + 128) × 512 × 2 bytes ≈ 7.7 MB` per HCA +layer call. Across ~30 HCA layers per token ≈ **230 MB of alloc-and-free +per token, on the decode hot path.** + +**Fix.** Same as D3: preallocate `all_kv_buf` sized for the worst case +(HCA full history + SWA). Use `out=` parameters on the gather. Or skip +the concat entirely — the FMHA kernel can take two K/V tensors and the +mask can encode the boundary. Today it can't, but it should; this is a +real kernel-side ask (see "M5" below). + +## D5 — Indexer top-k attends *only* compressed entries; SWA is appended separately. Confirm against paper + +Paper §2.3.1 figure: CSA's FMHA input is +`Concatenate(selected_compressed_KV, sliding_window_KV)`. Yes, that's +what the code does at `:571`. **Architecturally correct.** + +One subtle thing worth checking: the paper says the sliding window provides +*local* fine-grained context the compressor can't (since compressed entries +each represent m tokens). The current SWA window is **128** (from config). +At 1M context that's still 128 local tokens visible per query — correct. +But the window slide in `KVCache.append_swa` evicts when full +(`self.swa_head = (self.swa_head + T) % self.ws`), which is correct. +✅ no drift. + +## D6 — Attention sink is wired, but per-CSA-layer correctness needs checking + +`_run_production_fmha:489`: + +```python +sinks = w.get(f"{pfx}.sinks") +if sinks is not None: sink_bias = sinks.to(device=dev).float().reshape(n_h) +attn_out = dsv4_attention(q=q, k=k, v=v, scale=scale, n_comp=0, sink_bias=sink_bias) +``` + +`n_comp=0` is passed regardless. Paper §2.3.3 ("Attention Sink"): the sink +is a per-head additive logit to the softmax denominator. The kernel +signature in v15 said `n_comp` is "reserved for future kernel integration" +for the **D5c sink merge** (different softmax over compressed vs SWA). +v16 still passes `n_comp=0`, which means we're using global sink, not +per-segment sink merge. + +The paper isn't explicit about whether sink should be per-segment, but if +the production FMHA was designed for D5c merge and it's being bypassed, +that's an unfinished integration, not necessarily a bug. + +**Action:** confirm against the kernel's actual handling. Print the +sink_bias usage in `dsv4_attention` for one layer. If sink merge is +needed for CSA correctness at long context, that's a real wiring gap. + +## D7 — mHC residual growth (|X|→500–700 at L60) was flagged but not understood + +Your perf audit notes (line 88): "Residual |X| grows to 500–700 at L60 +— mHC bounds it but residual is high." + +Paper §2.2 designed mHC specifically to **bound** the residual via the +doubly-stochastic B matrix (spectral norm ≤ 1). The growth from |X|=1 at +L0 to |X|=700 at L60 suggests B isn't actually norm-bounded at runtime, +or A·C are amplifying. + +This is the **same** issue I flagged in the v14 docs and it's still open. +Not a perf bug, but it's an architecture-fidelity bug, and it's the +**single most likely cause of decode degradation past step 10** (the +"...the" repetition loop noted in your audit). Compounded across decode +steps, a slightly-not-bounded residual becomes a numerically saturated +residual, which makes the final logits less informative. + +**Action.** Print Sinkhorn-Knopp B's row/col sums per layer for one +forward pass. They should be `1.0 ± 1e-6` if Sinkhorn converged. If +they're e.g. `1.02–1.05`, t_max=20 isn't converging at this scale; bump +it or check the dynamic-parameter generation. The single_shot does +`sinkhorn_iters=20` (`:937`) which matches the paper, so the issue is +likely upstream: either A or C is producing values outside [0, 1] or +[0, 2], or the dynamic parameter generation has FP32 noise that breaks +doubly-stochastic. + +--- + +# PART 2 — MEMORY AT 1M CONTEXT + +This is the part that should be terrifying. The single_shot was sized for +50 tokens; the model targets 1M. Below is what each KV-cache structure +costs at the four interesting scales, with all numbers worked out, not +estimated. + +## Per-layer KV cache sizes — read off the code + +Layer setup: +- **CSA layer** (compressor ratio=4): main compressed at `hd=512` BF16, + indexer keys at `n_ih * ihd = 64 * 128 = 8192` BF16, SWA at `ws=128 × hd` +- **HCA layer** (compressor ratio=128): main compressed at `hd=512` BF16, + no indexer, SWA at `ws=128 × hd` +- **SWA-only layer** (first 2 layers of Flash, or HCA per-layer-2 of Pro): + SWA only + +Layer counts per V4-Pro: 61 total, alternating CSA/HCA after layer 1. So +roughly **30 CSA + 30 HCA + 1 SWA-only** (paper §4.2.1). + +### Per-layer KV growth (per token of context): + +| Component | Per token | Bytes / token | × 1M tokens | +|---|---|---|---| +| **CSA main compressed** (1 entry / 4 tokens, hd=512 BF16) | 0.25 × 1024 B | 256 B | **256 MB** | +| **CSA indexer keys** (1 entry / 4 tokens, 8192 BF16) | 0.25 × 16384 B | 4096 B | **4.1 GB** | +| **HCA compressed** (1 entry / 128 tokens, hd=512 BF16) | 0.0078 × 1024 B | 8 B | **8 MB** | +| **SWA** (per layer, fixed 128 × hd × 2) | const | — | 128 KB | + +### Total KV cache @ 1M context, all layers, BF16: + +| Layer type | Count | Per-layer @ 1M | Total | +|---|---|---|---| +| CSA: main + indexer | 30 | 256 MB + 4.1 GB | **131 GB** | +| HCA: main | 30 | 8 MB | 240 MB | +| SWA | 61 | 128 KB | 8 MB | +| **GRAND TOTAL @ 1M, BF16** | | | **~131 GB** | + +**The KV cache alone is 131 GB.** That's before model weights (which are +already EP-sharded across 8 GPUs). On 8 × B200 with 192 GB each = 1.5 TB +total, sharding KV across the 8 GPUs gives ~16 GB per GPU — fits, but +it's **15% of HBM dedicated to one request's KV.** And the dominant cost +is the **indexer keys** at 4.1 GB per layer. + +**Critical observation: the indexer keys are 16× larger than the main +compressed KV per token, and they alone are 86% of total KV.** This is +because `n_ih * ihd = 8192` is much wider than `hd = 512` for the main KV. +The paper does specify this — the indexer is a wide multi-query mechanism +— but if storage is the constraint, the indexer key path is where to +attack. + +## M1 — `comp_idx_buf` allocation as written: `(65536, 512)` per layer + +`:419`: +```python +self.comp_idx_buf = torch.zeros(max_comp, head_dim, dtype=torch.bfloat16, ...) +``` + +If this *were* the correct width `(max_comp=65536, 8192)`, that's +`65536 × 8192 × 2 = 1 GB per layer × 30 CSA layers = 30 GB pre-allocated +at startup`, sized for **only 262K tokens of context**, not 1M. + +The current shape `(65536, 512)` allocates 64 MB per layer × 30 = 1.9 GB +— too small by 16× as noted in D1, so either crashes at first compressed +write or silently truncates. Both bad. After fixing D1, you're staring +down 30 GB of pre-allocated KV cache that supports only a quarter of the +target context. And it's 30 GB × 8 GPUs = 240 GB if everything is +replicated, or 30 GB total if cache is sharded. + +**This is the load-bearing memory bug for 1M context.** Two viable fixes: + +1. **Quantize the indexer keys to FP4** (paper §5.2.1: "QK activations + are cached, loaded, and multiplied entirely in FP4"). The indexer keys + are *designed* to be FP4. That's 16× smaller: 4.1 GB/layer → 256 MB at + 1M. Total cache becomes ~10 GB instead of 131 GB. **This is the + correct fix per paper.** +2. **Page the indexer KV.** Only the top-k indices' worth (≤ 1024) need to + be in attention. A paged cache (per the paper's §3.5.1 "heterogeneous + KV cache" with on-disk overflow) only keeps recent + selected pages in + HBM. + +(1) is required regardless of (2). The current BF16 indexer cache is the +single biggest blocker between Paris-demo-works and 1M-context-works. + +## M2 — `max_comp = 65536` hardcoded + +`:411`: +```python +def __init__(self, head_dim, window_size=128, max_comp=65536, device='cuda:0'): +``` + +For CSA (ratio=4), 65536 compressed entries = `262144` tokens of context. +That's the ceiling. At 262K tokens you get `IndexError` on +`comp_kv_buf[self.n_comp:end] = ckv`. There is no graceful behavior, no +on-disk overflow, no error message — it just dies. + +For 1M context, `max_comp` needs to be `1M / 4 = 262144` for CSA layers +and `1M / 128 = 7813` for HCA layers. Currently both layer types share +the same 65536 default. + +**Fix.** Size the buffer per-layer using compress ratio: +```python +max_comp_csa = ceil(target_context / 4) # 262144 for 1M +max_comp_hca = ceil(target_context / 128) # 7813 for 1M +``` +And, critically, make `target_context` a CLI flag with a sensible default +(say 8K) so the script can run small while staying honest about the +ceiling. Hardcoded 65536 with no docstring on what it means is a footgun. + +## M3 — Allocator churn from gather (D3, D4) compounds at 1M + +Repeated for emphasis with numbers: the per-token `torch.cat` in the KV +gather allocates and frees memory proportional to context length. At 1M +context with HCA at all layers, that's **~230 MB of alloc/free per +decoded token**. PyTorch caching allocator handles this but it's still +fragmentation pressure across thousands of decoded tokens. After hours +of decoding, the allocator's cached blocks bloat. + +Already covered in D3/D4. Restating because at 1M scale it's no longer +"small overhead" — it's GB/min of churn. + +## M4 — `get_swa` does `.clone()` every call + +`:457–460`: +```python +def get_swa(self): + if self.swa_len == 0: return torch.zeros(0, self.hd, ...), torch.zeros(0, ...) + if self.swa_len < self.ws: return self.swa[:self.swa_len].clone(), self.swa_pos[:self.swa_len].clone() + idx = torch.arange(self.swa_head, self.swa_head + self.ws) % self.ws + return self.swa[idx].clone(), self.swa_pos[idx].clone() +``` + +Three clones in the second return path (and the third one allocates an +arange too). At `ws=128 hd=512`, the SWA tensor is 128 KB — small — but +**this runs every layer every token**. 61 × decoded tokens × 128 KB ≈ a +few MB/token in allocator pressure. Same fix pattern as D3: return views +into the ring buffer, let the FMHA gather kernel consume them with strides. + +## M5 — KV gather memory could be eliminated entirely with a smarter kernel + +D3 and D4 both pre-allocate a fused `all_kv` buffer because the FMHA +takes one K/V tensor. The deeper fix is to have the FMHA take **two K/V +inputs** (compressed + SWA) and handle them with masking inside. Then no +fused buffer ever has to exist — the kernel reads compressed entries +directly from `comp_kv_buf` (with a gather indices vector for the top-k +selection) and SWA entries directly from `swa_buf`. + +This is the **right** long-term design (and matches how the paper §3.5.1 +envisions the heterogeneous cache). It's a kernel ask, not a script fix, +but it's worth flagging now: the gather → cat → FMHA pattern is *the* +memory inefficiency at long context, and it can be designed out at the +kernel boundary. + +--- + +# PART 3 — PRIORITY ORDER + +These are sequenced by **what blocks 1M context viability**, not by +implementation cost. + +| # | Item | Required for 1M? | Effort | +|---|---|---|---| +| **A1** | **D1: Fix `comp_idx_buf` width to actual indexer key width** | **Critical — silent corruption today** | XS | +| **A2** | **M1: Quantize indexer KV to FP4** (paper §5.2.1) | **Critical — saves 121 GB at 1M** | M-L | +| **A3** | **M2: Make `max_comp` per-layer-type + a CLI flag** | **Critical — current ceiling is 262K** | XS | +| **A4** | D3/D4/M3: Preallocate `all_kv_buf`, eliminate `torch.cat` | High — perf and stability over hours of decode | S | +| **A5** | D7: Investigate mHC residual growth (Sinkhorn convergence print) | High — likely root of decode degradation | S | +| **A6** | M4: Return SWA views, no clone | Medium — small per-call, large in aggregate | XS | +| **A7** | D2: Rename `Compressor` → split into `MainCompressor` and `IndexerKeyCompressor` | Medium — clarifies the duplicate-build pattern | XS | +| **A8** | D6: Verify sink merge semantics with kernel author | Medium — possible silent numerical drift | S | +| **A9** | M5: Two-buffer FMHA kernel (eliminate gather buffer entirely) | Long-term — production design | L | + +**The "should I run on 1M context tomorrow" answer is no, regardless of +anything else, until A1/A2/A3 are done.** Without A1 you get garbage +top-k. Without A2 you OOM at ~250K tokens even with 8×B200. Without A3 +you crash at 262K. Together those three are the gating set. + +**The "is it still architecturally DSv4?" answer is yes — mostly.** The +hot path is faithful to the paper: CSA does overlapped-2m compression +with softmax weights, HCA does heavy non-overlapped compression, indexer +does ReLU(QK)·w_h reduction, attention concatenates selected-compressed ++ SWA, sinks are applied. The drifts above (D1, D2, D6, D7) are +implementation flaws or unfinished wiring, not architectural deviations. + +--- + +# DOCTRINE — applies to every priority above + +1. **DSL wall → raw CUDA C++, not Python.** Most of the fixes above are + pre-allocation and shape correctness, not new kernels. The two + exceptions (A2 FP4 indexer KV, A9 two-buffer FMHA) are kernel work and + must follow doctrine: tcgen05/UMMA/TMA, not scalar. + +2. **Raw CUDA ≠ scalar math.** When A2 lands the FP4 indexer cache, the + dequant on the read side must use `__constant__` LUT (per the original + indexer LUT fix from issue #1), not branch arithmetic. + +3. **Print, don't guess.** A1's fix is the canonical example: do not + assume the indexer key width is `head_dim` or `n_ih * ihd` — print + `indexer.compressor.forward(...)[0].shape` on first call and code to + that. The current bug exists *because* someone wrote `head_dim` + thinking it was right. + +4. **Integration over exploration.** No `KVCache_v2`. Edit `KVCache`. + The 4 fixes (A1/A3/A4/A6) are surgical edits to one class. + +5. **Falsifiable gates.** Numbers to hit: + - A1: `comp_idx_buf.shape[1] == indexer.compressor.kv_dim // 2` (or + whatever the print reveals). Test: run with VERBOSE=2 at 100 tokens + of context; no shape mismatch, no garbage in top-k selection. + - A2: KV cache footprint at 1M context (measured via + `torch.cuda.memory_allocated()` after prefill) drops from ~131 GB to + ≤ 15 GB. Recall@1024 vs FP32 indexer oracle ≥ 99.7% per paper. + - A3: A `--max-context N` flag works; running with `N=1048576` does + not OOM during prefill (it might be slow — that's a separate fight). + - A4: `torch.cuda.memory_reserved()` measured every 10 decode steps is + flat (±50 MB) across 1000 steps. \ No newline at end of file diff --git a/probe_indexer_shapes.py b/probe_indexer_shapes.py new file mode 100644 index 00000000..6bca0299 --- /dev/null +++ b/probe_indexer_shapes.py @@ -0,0 +1,75 @@ +#!/usr/bin/env python3 +"""Probe indexer and compressor weight shapes from the checkpoint. +This tells us the ACTUAL dimensions, not what we assume. +Run via: fire_b200_test probe_indexer_shapes.py +""" +import json, sys +from pathlib import Path +from safetensors.torch import load_file + +CHECKPOINT = "/root/nvidia-meeting/DeepSeek-V4-Pro-NVFP4" + +def main(): + cdir = Path(CHECKPOINT) + with open(cdir / "config.json") as f: + cfg = json.load(f) + + n_layers = cfg["num_hidden_layers"] + n_ih = cfg.get("index_n_heads", 64) + ihd = cfg.get("index_head_dim", 128) + hd = cfg["head_dim"] + cr = cfg.get("compress_ratios", [128] * n_layers) + + print(f"Config: n_ih={n_ih}, ihd={ihd}, hd={hd}") + print(f"n_ih * ihd = {n_ih * ihd}") + print(f"2 * ihd = {2 * ihd}") + print(f"2 * hd = {2 * hd}") + print(f"Compress ratios: first5={cr[:5]}") + print() + + # Load weight map to find indexer weights + idx_file = cdir / "model.safetensors.index.json" + if idx_file.exists(): + with open(idx_file) as f: + wmap = json.load(f).get("weight_map", {}) + + # Find indexer/compressor weights for layer 2 (first CSA layer) + for li in [0, 1, 2, 3]: + pfx = f"model.layers.{li}.self_attn" + print(f"\n=== Layer {li} (ratio={cr[li] if li < len(cr) else '?'}) ===") + for k in sorted(wmap.keys()): + if k.startswith(pfx) and ('compressor' in k or 'indexer' in k or 'q_b_proj' in k or 'kv_proj' in k or 'gate_proj' in k): + shard = cdir / wmap[k] + print(f" {k} -> shard {wmap[k]}") + else: + print("No index file, loading all weights...") + + # Actually load some weights and print shapes + # Just load the first shard to get shapes + print("\n=== Loading weight shapes ===") + all_w = {} + if idx_file.exists(): + shards = set(wmap.values()) + for sn in sorted(shards): + sf = cdir / sn + if sf.exists(): + w = load_file(str(sf)) + # Only print relevant keys + for k, v in w.items(): + if ('compressor' in k or 'indexer' in k) and 'layers.2' in k: + print(f" {k}: shape={list(v.shape)} dtype={v.dtype}") + del w + + # Also check q_b_proj for layer 2 + print("\n=== Layer 2 attention projection shapes ===") + for sn in sorted(shards): + sf = cdir / sn + if sf.exists(): + w = load_file(str(sf)) + for k, v in w.items(): + if 'layers.2.self_attn' in k and ('q_b' in k or 'kv_proj' in k or 'gate_proj' in k): + print(f" {k}: shape={list(v.shape)} dtype={v.dtype}") + del w + +if __name__ == "__main__": + main()