Files
nvfp4-megamoe-kernel/archived_plans/CONSULTANT_RECCOMMENDATION_DOC.md
biondizzle d772885d7e single_shot_inference: proper mHC+RMSNorm+inverse RoPE pipeline
Major rewrite of single_shot_inference.py:
- Replace broken mHC (gentle normalization hack) with proper Sinkhorn-Knopp
- Add RMSNorm before each sub-block (attention + FFN)
- Add inverse RoPE on attention output (paper §2.3.3)
- Fix KV cache: RoPE applied before caching, K=V in DSV4 MQA
- Fix MoE: proper dense routing with e_bias, SwiGLU clamping
- Proper weight mapping: fn→W_stacked, base→S_pre/S_res/S_post, scale→alphas
- Add identity mHC fallback when weights missing
- No emergency normalization, no bandaids
2026-05-31 02:45:52 +00:00

292 lines
14 KiB
Markdown
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
# SINGLE-SHOT INFERENCE BASELINE — design and recommendation
**Use case (verbatim from you).** Deterministic prompt "the capital of France is",
greedy decode, expect "Paris" back. If we get Paris, we've ruled out kernel
correctness as a source of bugs when we later do the official vLLM integration.
The script also doubles as **the integration reference** for any inference
engine — vLLM, SGLang, custom, doesn't matter.
---
## TL;DR — recommendation
**Standalone Python script. Do not fork tiny-vllm.**
Put it at `scripts/single_shot.py` in the kernel repo. Have it do the full
orchestration itself (embedding → N layers → final norm → head → argmax). Make
it the *reference implementation* of how to drive the kernel, not a wrapper
around someone else's engine. Length target: 300500 lines, one file, no
class hierarchy.
Two reasons. The first one is doctrine; the second is architectural.
### Reason 1 — tiny-vllm is built around Llama, not DSV4
tiny-vllm is a teaching repo for a **Llama 3.2 1B** inference engine. Its
README is explicit: "load a real LLM model from Safetensors (Llama 3.2 1B
Instruct), full LLM forward pass, KV cache, static batching, continuous
batching, PagedAttention." Its `python/reference.py` calls
`AutoModelForCausalLM.from_pretrained("meta-llama/Llama-3.2-1B-Instruct")`
and walks Llama's `model.layers[0].input_layernorm`, `q_proj`, etc.
Everything about its KV cache, attention shape, layer structure, and weight
naming assumes Llama. DSV4 violates basically all of those assumptions:
- **KV cache** — DSV4 has a *heterogeneous* cache (paged FP8 + state cache for
SWA tail + separate FP4 indexer pool). tiny-vllm's PagedAttention is uniform.
- **Attention** — DSV4 has hybrid CSA/HCA/SWA per-layer + lightning indexer
top-k + grouped output projection + attention sink. tiny-vllm has GQA.
- **Layer structure** — DSV4 has mHC (Sinkhorn-projected residuals,
`n_hc=4`, not plain `x + sublayer(x)`) and MoE with `sqrt(softplus)` routing,
hash routing on the first 3 layers, plus a shared expert. tiny-vllm is dense
MLP, plain residual.
- **Weight naming** — DSV4's HF checkpoint has its own naming convention
(`q_down`, `q_up`, `kv_down`, `indexer_q_up`, `indexer_head_weights`,
`wo_a` / `wo_b`, mHC `A`/`B`/`C` raw params, MoE `gate`/`up`/`down` per expert).
tiny-vllm's loader expects Llama keys.
You'd be replacing 80%+ of tiny-vllm to get DSV4 to run in it. At that point
it's not "using tiny-vllm" — it's "rewriting tiny-vllm into DSV4," and the
"vllm" in the name becomes misleading. The fact that tiny-vllm *exists* doesn't
make it the right starting point.
### Reason 2 — a reference implementation should be minimal, not feature-rich
You said "use it like a reference for integration into any inference engine."
That sentence is doing a lot of work and it's worth being literal about it.
A reference implementation has one job: **show, in the simplest possible code,
what the model does so that a non-trivial inference engine knows what it needs
to wire up.** Anything beyond that is engine-specific concerns that the
integrator will replace anyway.
What tiny-vllm has that a reference does *not* need: static batching,
continuous batching, scheduler, PagedAttention, multi-request serving,
continuous decode loop. **All of those are exactly the things that differ
between vLLM and SGLang and your-future-engine.** Bundling them into the
reference makes the reference less useful, not more.
What a reference *does* need: model construction, weight loading, one forward
pass per token, simple flat KV growth, deterministic argmax. That's a single
script. It already exists as a known pattern — HF's `model.generate(do_sample=False)`
is the same shape, minus the kernel control.
### What this means for your existing repo
The script is *both* the inference baseline *and* the work item E3
("`dsv4/model/dsv4.py` is a 2-line stub — build the actual model class") from
the Stage E roadmap. Collapse those into one artifact. The script imports
`DSV4Layer` from `dsv4/model/layer.py`, drives it explicitly, and serves as
the executable spec for what `DSV4Model.forward` should later become.
When you eventually integrate into vLLM, you extract the model orchestration
(layer loop, RoPE caches, KV growth) into a `DSV4Model` class and the script
shrinks to "construct DSV4Model, generate, print." The script *itself*
remains, forever, as the minimal reproducer for "if we get Paris, kernels are
fine."
---
## Two-tier strategy
The point of the script is **to rule out kernel issues.** That's only useful
if the script can actually fail in a way that points at the kernel and not at
some unrelated unfinished thing. Today, several DSV4 components are still
incomplete (E1 cache gather, `loader/hf_checkpoint.py` is a 2-line stub,
`model/dsv4.py` is a 2-line stub). If we wait for everything, the script
ships in 2 weeks. If we ship in two tiers, we get a useful signal *today*.
### Tier 1 — `single_shot_baseline.py` — BF16 everything, simple cache
Goal: produce "Paris" with the maximum possible signal-to-noise ratio. **No
quantization. No paging. No FP8.** Just the math.
- Load weights from HF safetensors, dequantize FP4 → BF16 at load time. (The
reference checkpoint stores FP4 MoE expert weights; cast them up. Linear
weights are already BF16.)
- KV cache: **one flat 4D tensor per layer**, grown by concatenation. No paged
pool, no state cache, no FP8 round-trip. Direct BF16 throughout.
- Compressor: produce dense compressed BF16 entries directly, store in the
flat cache. No FP8 quant on the cache write path.
- Indexer: compute index scores in FP32 (this matches the *post-LUT-fix*
scalar path you already have, so we know it's correct and we know it's not
fused-FP4 yet). Take the top-k indices. The lookup is a plain `cache[..., topk]`.
- FMHA: call the production multi-tile kernel via
`dsv4_attention(q, k_gathered, v_gathered, ...)`. The dense KV is already
materialized.
- Sink merge: handled by the FMHA kernel (D5c, already done).
- Inverse RoPE → wo_a → wo_b. mHC residual. FFN sub-block. Repeat for N layers.
**What this proves.** If Paris comes back here, the *architecture* is correct
and the *FMHA core kernel* is correct. That is the highest-leverage validation
you can do today, with zero dependence on E1.
**What this does not prove.** Anything about the FP8 paged cache, anything
about the FP4 fused SwiGLU MoE epilogue, anything about NVFP4 weight quant on
linears. Those are explicitly *not* in scope for tier 1 — separating them is
the whole point.
### Tier 2 — `single_shot_production.py` — full production path
Same script, but routes through the actual production cache + the NVFP4
quantized weights + fused FP4 SwiGLU MoE + (eventually) E7's FP4 tensor-core
indexer. This is the script that proves the *full integrated path* works, and
it depends on E1 (cache gather kernels) landing first.
Tier 2's job: produce the *same* output as tier 1, deterministically. If they
diverge, the diff is exactly the difference between BF16 ref and the
quantized production path. That diff is the signal you actually want when
debugging vLLM integration.
---
## Tier 1 script — concrete specification
This is the spec, not the code — but the spec is precise enough that an
agent can implement it without guessing.
### File structure
One file: `scripts/single_shot_baseline.py`. Top-to-bottom: imports, config
load, tokenizer, weight load, model orchestration helpers, the decode loop,
`if __name__ == "__main__"`. **No class definitions** (use functions and
named tuples where state needs to flow). The whole point is readability as a
reference.
### What it imports from `dsv4/`
- `dsv4.model.config.DSV4Config` — already exists and matches the paper.
- `dsv4.model.layer.DSV4Layer` — already exists, takes
`(X, token_ids, cache: LayerCacheHandle) → X`.
- `dsv4.layers.attention.AttentionSubBlock` (used internally by `DSV4Layer`).
- `dsv4.layers.ffn.FFNSubBlock`.
- `dsv4.layers.mhc.MHC` (already wired into `DSV4Layer`).
- `dsv4.kernels.attention.production.dsv4_attention` — the production FMHA
entry point.
- `dsv4.kernels.compressor.csa_compress_and_store` /
`dsv4.kernels.compressor.hca_compress_and_store`.
- `dsv4.kernels.indexer.compute_index_scores_topk`.
- `dsv4.ops.rope.apply_rope_bf16`, `dsv4.ops.rope.inverse_rope_bf16`.
### What it has to implement itself (because the codebase doesn't have it yet)
These three are missing or stubbed. The script ships them inline; they get
promoted into `dsv4/` proper as part of E1E3.
1. **`load_dsv4_weights(model, hf_path)`** — HF safetensors loader. Walk
`model.safetensors.index.json`, map HF key names to `DSV4Layer` parameter
names, copy in. For FP4 weights (MoE experts), dequantize at load. This is
what `dsv4/loader/hf_checkpoint.py` should eventually be — keep the
function signature small so it can be promoted as-is.
2. **`SimpleLayerCache`** — a `LayerCacheHandle`-shaped object that uses flat
BF16 tensors instead of paged FP8. Implements the same surface
`dsv4/kernels/attention/__init__.py` calls
(`gather_compressed_kv`, `gather_all_compressed_kv`, `gather_swa_kv`,
`num_query_heads`, `head_dim`, `positions`, `request_slots`,
`read_classical_view`, `read_swa_view`, `read_indexer_view`, `write_swa`,
`flush_compression`). Eager torch ops are fine here — this is the
**reference**, not the fast path. **Each method gets one comment line
pointing at the production E1 kernel that replaces it.**
3. **The model loop.** Embedding → for layer in layers: `layer.forward(X, ids, cache)`
final norm → prediction head → argmax. Per-layer cache handle from the
simple cache manager. This is the body of what `DSV4Model.forward` becomes
in E3.
### Numerics requirements
- All math in BF16 except: softmax accumulator (FP32, already in the kernel),
RMSNorm reduction (FP32, already in the kernel), Sinkhorn iterations in mHC
(FP32, already in the layer), indexer scores (FP32, current scalar path).
- Determinism: torch seed fixed, `torch.use_deterministic_algorithms(True)`,
no cuDNN benchmark, single CUDA stream. Decode is single-token so there's
no batching nondeterminism to worry about.
- Greedy: `next_tok = logits[-1].argmax()`. No temperature, no top-p, no
sampling.
### Pass/fail gate
```
prompt = "The capital of France is"
expected = "Paris"
# Decode 8 tokens (one is enough to see "Paris" but a few more catch
# off-by-one tokenization issues).
generated_ids = decode_n(prompt, n_new=8)
generated_text = tokenizer.decode(generated_ids, skip_special_tokens=True)
assert "Paris" in generated_text, f"FAIL: got {generated_text!r}"
print(f"PASS: {prompt!r} -> {generated_text!r}")
```
Run it. If it prints PASS, the architecture and the FMHA core kernel are
known good. If it prints FAIL with an error, the stack trace points at the
broken layer. If it prints FAIL with wrong text, the diff against the tier-0
HF reference (below) localizes which sub-block diverges.
### Recommended companion — `tier_0_hf_reference.py`
For debugging tier 1 when it produces wrong text instead of erroring, you
need a known-good reference to diff against. Write a 30-line script that
loads the HF model with `AutoModelForCausalLM.from_pretrained(..., dtype=bfloat16)`,
generates greedy 8 tokens for the same prompt, and dumps:
- `embeds[0, t, :10]` per token
- `layer[0].input_layernorm(embeds)[0, t, :10]` per token
- `layer[0].attn.q_proj(normed)[0, t, :10]` per token
- ... continuing through the layer
This is exactly what tiny-vllm's `python/reference.py` does for Llama. We're
borrowing the pattern, not the code. When tier 1 diverges from this, you
have a per-tensor diff to localize the failure to a specific sub-block.
---
## What to put in `scripts/` and what's in scope
```
scripts/
├── single_shot_baseline.py # Tier 1: BF16 everything, simple cache
├── single_shot_production.py # Tier 2: full production path (depends on E1)
├── tier_0_hf_reference.py # HF eager reference for per-tensor diff
└── README.md # 1 page: how to run, what each script proves
```
Keep all three small. The reference is single-shot. It's not a server, it's
not a benchmark, it's not a test harness. If those things end up wanted,
they live elsewhere.
---
## DOCTRINE NOTES — applies here too
1. **DSL wall → raw CUDA C++, not Python.** Doesn't apply to this script —
the script *is* the Python orchestration layer. But if a sub-call fails
inside a kernel and the agent's instinct is "let me reimplement that part
in Python to make the script work," the answer is no. Fix the kernel.
2. **Raw CUDA ≠ scalar math.** Doesn't apply directly, but a related
corollary: this script is allowed to use eager torch ops in
`SimpleLayerCache` because it is explicitly labeled as the *reference,
not the fast path.* When E1 lands, the simple cache is replaced wholesale
by the production cache. The reference's job is to be obviously correct,
not fast.
3. **Print, don't guess.** If the script fails:
- First step: run `tier_0_hf_reference.py` on the same prompt, dump the
per-tensor expected values.
- Second step: instrument the tier 1 script to print the *same* tensors
at the same points.
- Third step: diff. The first divergent tensor is the broken sub-block.
- Do **not** start "fixing things to see if it helps" before doing the
diff.
4. **Integration over exploration.** The script lives in `scripts/`, not in
`dsv4/`. It is a *consumer* of the library, not part of it. When functions
in the script are promoted into `dsv4/` (the loader, the model class,
etc.), the script should *get shorter*, not longer. If the script is
growing, that's a smell that something belongs in the library.
5. **Falsifiable gate.** The gate is `assert "Paris" in generated_text`. Not
"looks reasonable," not "produces something." Paris or fail.