Major rewrite of single_shot_inference.py: - Replace broken mHC (gentle normalization hack) with proper Sinkhorn-Knopp - Add RMSNorm before each sub-block (attention + FFN) - Add inverse RoPE on attention output (paper §2.3.3) - Fix KV cache: RoPE applied before caching, K=V in DSV4 MQA - Fix MoE: proper dense routing with e_bias, SwiGLU clamping - Proper weight mapping: fn→W_stacked, base→S_pre/S_res/S_post, scale→alphas - Add identity mHC fallback when weights missing - No emergency normalization, no bandaids
292 lines
14 KiB
Markdown
292 lines
14 KiB
Markdown
# SINGLE-SHOT INFERENCE BASELINE — design and recommendation
|
||
|
||
**Use case (verbatim from you).** Deterministic prompt "the capital of France is",
|
||
greedy decode, expect "Paris" back. If we get Paris, we've ruled out kernel
|
||
correctness as a source of bugs when we later do the official vLLM integration.
|
||
The script also doubles as **the integration reference** for any inference
|
||
engine — vLLM, SGLang, custom, doesn't matter.
|
||
|
||
---
|
||
|
||
## TL;DR — recommendation
|
||
|
||
**Standalone Python script. Do not fork tiny-vllm.**
|
||
|
||
Put it at `scripts/single_shot.py` in the kernel repo. Have it do the full
|
||
orchestration itself (embedding → N layers → final norm → head → argmax). Make
|
||
it the *reference implementation* of how to drive the kernel, not a wrapper
|
||
around someone else's engine. Length target: 300–500 lines, one file, no
|
||
class hierarchy.
|
||
|
||
Two reasons. The first one is doctrine; the second is architectural.
|
||
|
||
### Reason 1 — tiny-vllm is built around Llama, not DSV4
|
||
|
||
tiny-vllm is a teaching repo for a **Llama 3.2 1B** inference engine. Its
|
||
README is explicit: "load a real LLM model from Safetensors (Llama 3.2 1B
|
||
Instruct), full LLM forward pass, KV cache, static batching, continuous
|
||
batching, PagedAttention." Its `python/reference.py` calls
|
||
`AutoModelForCausalLM.from_pretrained("meta-llama/Llama-3.2-1B-Instruct")`
|
||
and walks Llama's `model.layers[0].input_layernorm`, `q_proj`, etc.
|
||
|
||
Everything about its KV cache, attention shape, layer structure, and weight
|
||
naming assumes Llama. DSV4 violates basically all of those assumptions:
|
||
|
||
- **KV cache** — DSV4 has a *heterogeneous* cache (paged FP8 + state cache for
|
||
SWA tail + separate FP4 indexer pool). tiny-vllm's PagedAttention is uniform.
|
||
- **Attention** — DSV4 has hybrid CSA/HCA/SWA per-layer + lightning indexer
|
||
top-k + grouped output projection + attention sink. tiny-vllm has GQA.
|
||
- **Layer structure** — DSV4 has mHC (Sinkhorn-projected residuals,
|
||
`n_hc=4`, not plain `x + sublayer(x)`) and MoE with `sqrt(softplus)` routing,
|
||
hash routing on the first 3 layers, plus a shared expert. tiny-vllm is dense
|
||
MLP, plain residual.
|
||
- **Weight naming** — DSV4's HF checkpoint has its own naming convention
|
||
(`q_down`, `q_up`, `kv_down`, `indexer_q_up`, `indexer_head_weights`,
|
||
`wo_a` / `wo_b`, mHC `A`/`B`/`C` raw params, MoE `gate`/`up`/`down` per expert).
|
||
tiny-vllm's loader expects Llama keys.
|
||
|
||
You'd be replacing 80%+ of tiny-vllm to get DSV4 to run in it. At that point
|
||
it's not "using tiny-vllm" — it's "rewriting tiny-vllm into DSV4," and the
|
||
"vllm" in the name becomes misleading. The fact that tiny-vllm *exists* doesn't
|
||
make it the right starting point.
|
||
|
||
### Reason 2 — a reference implementation should be minimal, not feature-rich
|
||
|
||
You said "use it like a reference for integration into any inference engine."
|
||
That sentence is doing a lot of work and it's worth being literal about it.
|
||
|
||
A reference implementation has one job: **show, in the simplest possible code,
|
||
what the model does so that a non-trivial inference engine knows what it needs
|
||
to wire up.** Anything beyond that is engine-specific concerns that the
|
||
integrator will replace anyway.
|
||
|
||
What tiny-vllm has that a reference does *not* need: static batching,
|
||
continuous batching, scheduler, PagedAttention, multi-request serving,
|
||
continuous decode loop. **All of those are exactly the things that differ
|
||
between vLLM and SGLang and your-future-engine.** Bundling them into the
|
||
reference makes the reference less useful, not more.
|
||
|
||
What a reference *does* need: model construction, weight loading, one forward
|
||
pass per token, simple flat KV growth, deterministic argmax. That's a single
|
||
script. It already exists as a known pattern — HF's `model.generate(do_sample=False)`
|
||
is the same shape, minus the kernel control.
|
||
|
||
### What this means for your existing repo
|
||
|
||
The script is *both* the inference baseline *and* the work item E3
|
||
("`dsv4/model/dsv4.py` is a 2-line stub — build the actual model class") from
|
||
the Stage E roadmap. Collapse those into one artifact. The script imports
|
||
`DSV4Layer` from `dsv4/model/layer.py`, drives it explicitly, and serves as
|
||
the executable spec for what `DSV4Model.forward` should later become.
|
||
|
||
When you eventually integrate into vLLM, you extract the model orchestration
|
||
(layer loop, RoPE caches, KV growth) into a `DSV4Model` class and the script
|
||
shrinks to "construct DSV4Model, generate, print." The script *itself*
|
||
remains, forever, as the minimal reproducer for "if we get Paris, kernels are
|
||
fine."
|
||
|
||
---
|
||
|
||
## Two-tier strategy
|
||
|
||
The point of the script is **to rule out kernel issues.** That's only useful
|
||
if the script can actually fail in a way that points at the kernel and not at
|
||
some unrelated unfinished thing. Today, several DSV4 components are still
|
||
incomplete (E1 cache gather, `loader/hf_checkpoint.py` is a 2-line stub,
|
||
`model/dsv4.py` is a 2-line stub). If we wait for everything, the script
|
||
ships in 2 weeks. If we ship in two tiers, we get a useful signal *today*.
|
||
|
||
### Tier 1 — `single_shot_baseline.py` — BF16 everything, simple cache
|
||
|
||
Goal: produce "Paris" with the maximum possible signal-to-noise ratio. **No
|
||
quantization. No paging. No FP8.** Just the math.
|
||
|
||
- Load weights from HF safetensors, dequantize FP4 → BF16 at load time. (The
|
||
reference checkpoint stores FP4 MoE expert weights; cast them up. Linear
|
||
weights are already BF16.)
|
||
- KV cache: **one flat 4D tensor per layer**, grown by concatenation. No paged
|
||
pool, no state cache, no FP8 round-trip. Direct BF16 throughout.
|
||
- Compressor: produce dense compressed BF16 entries directly, store in the
|
||
flat cache. No FP8 quant on the cache write path.
|
||
- Indexer: compute index scores in FP32 (this matches the *post-LUT-fix*
|
||
scalar path you already have, so we know it's correct and we know it's not
|
||
fused-FP4 yet). Take the top-k indices. The lookup is a plain `cache[..., topk]`.
|
||
- FMHA: call the production multi-tile kernel via
|
||
`dsv4_attention(q, k_gathered, v_gathered, ...)`. The dense KV is already
|
||
materialized.
|
||
- Sink merge: handled by the FMHA kernel (D5c, already done).
|
||
- Inverse RoPE → wo_a → wo_b. mHC residual. FFN sub-block. Repeat for N layers.
|
||
|
||
**What this proves.** If Paris comes back here, the *architecture* is correct
|
||
and the *FMHA core kernel* is correct. That is the highest-leverage validation
|
||
you can do today, with zero dependence on E1.
|
||
|
||
**What this does not prove.** Anything about the FP8 paged cache, anything
|
||
about the FP4 fused SwiGLU MoE epilogue, anything about NVFP4 weight quant on
|
||
linears. Those are explicitly *not* in scope for tier 1 — separating them is
|
||
the whole point.
|
||
|
||
### Tier 2 — `single_shot_production.py` — full production path
|
||
|
||
Same script, but routes through the actual production cache + the NVFP4
|
||
quantized weights + fused FP4 SwiGLU MoE + (eventually) E7's FP4 tensor-core
|
||
indexer. This is the script that proves the *full integrated path* works, and
|
||
it depends on E1 (cache gather kernels) landing first.
|
||
|
||
Tier 2's job: produce the *same* output as tier 1, deterministically. If they
|
||
diverge, the diff is exactly the difference between BF16 ref and the
|
||
quantized production path. That diff is the signal you actually want when
|
||
debugging vLLM integration.
|
||
|
||
---
|
||
|
||
## Tier 1 script — concrete specification
|
||
|
||
This is the spec, not the code — but the spec is precise enough that an
|
||
agent can implement it without guessing.
|
||
|
||
### File structure
|
||
|
||
One file: `scripts/single_shot_baseline.py`. Top-to-bottom: imports, config
|
||
load, tokenizer, weight load, model orchestration helpers, the decode loop,
|
||
`if __name__ == "__main__"`. **No class definitions** (use functions and
|
||
named tuples where state needs to flow). The whole point is readability as a
|
||
reference.
|
||
|
||
### What it imports from `dsv4/`
|
||
|
||
- `dsv4.model.config.DSV4Config` — already exists and matches the paper.
|
||
- `dsv4.model.layer.DSV4Layer` — already exists, takes
|
||
`(X, token_ids, cache: LayerCacheHandle) → X`.
|
||
- `dsv4.layers.attention.AttentionSubBlock` (used internally by `DSV4Layer`).
|
||
- `dsv4.layers.ffn.FFNSubBlock`.
|
||
- `dsv4.layers.mhc.MHC` (already wired into `DSV4Layer`).
|
||
- `dsv4.kernels.attention.production.dsv4_attention` — the production FMHA
|
||
entry point.
|
||
- `dsv4.kernels.compressor.csa_compress_and_store` /
|
||
`dsv4.kernels.compressor.hca_compress_and_store`.
|
||
- `dsv4.kernels.indexer.compute_index_scores_topk`.
|
||
- `dsv4.ops.rope.apply_rope_bf16`, `dsv4.ops.rope.inverse_rope_bf16`.
|
||
|
||
### What it has to implement itself (because the codebase doesn't have it yet)
|
||
|
||
These three are missing or stubbed. The script ships them inline; they get
|
||
promoted into `dsv4/` proper as part of E1–E3.
|
||
|
||
1. **`load_dsv4_weights(model, hf_path)`** — HF safetensors loader. Walk
|
||
`model.safetensors.index.json`, map HF key names to `DSV4Layer` parameter
|
||
names, copy in. For FP4 weights (MoE experts), dequantize at load. This is
|
||
what `dsv4/loader/hf_checkpoint.py` should eventually be — keep the
|
||
function signature small so it can be promoted as-is.
|
||
|
||
2. **`SimpleLayerCache`** — a `LayerCacheHandle`-shaped object that uses flat
|
||
BF16 tensors instead of paged FP8. Implements the same surface
|
||
`dsv4/kernels/attention/__init__.py` calls
|
||
(`gather_compressed_kv`, `gather_all_compressed_kv`, `gather_swa_kv`,
|
||
`num_query_heads`, `head_dim`, `positions`, `request_slots`,
|
||
`read_classical_view`, `read_swa_view`, `read_indexer_view`, `write_swa`,
|
||
`flush_compression`). Eager torch ops are fine here — this is the
|
||
**reference**, not the fast path. **Each method gets one comment line
|
||
pointing at the production E1 kernel that replaces it.**
|
||
|
||
3. **The model loop.** Embedding → for layer in layers: `layer.forward(X, ids, cache)` →
|
||
final norm → prediction head → argmax. Per-layer cache handle from the
|
||
simple cache manager. This is the body of what `DSV4Model.forward` becomes
|
||
in E3.
|
||
|
||
### Numerics requirements
|
||
|
||
- All math in BF16 except: softmax accumulator (FP32, already in the kernel),
|
||
RMSNorm reduction (FP32, already in the kernel), Sinkhorn iterations in mHC
|
||
(FP32, already in the layer), indexer scores (FP32, current scalar path).
|
||
- Determinism: torch seed fixed, `torch.use_deterministic_algorithms(True)`,
|
||
no cuDNN benchmark, single CUDA stream. Decode is single-token so there's
|
||
no batching nondeterminism to worry about.
|
||
- Greedy: `next_tok = logits[-1].argmax()`. No temperature, no top-p, no
|
||
sampling.
|
||
|
||
### Pass/fail gate
|
||
|
||
```
|
||
prompt = "The capital of France is"
|
||
expected = "Paris"
|
||
|
||
# Decode 8 tokens (one is enough to see "Paris" but a few more catch
|
||
# off-by-one tokenization issues).
|
||
generated_ids = decode_n(prompt, n_new=8)
|
||
generated_text = tokenizer.decode(generated_ids, skip_special_tokens=True)
|
||
|
||
assert "Paris" in generated_text, f"FAIL: got {generated_text!r}"
|
||
print(f"PASS: {prompt!r} -> {generated_text!r}")
|
||
```
|
||
|
||
Run it. If it prints PASS, the architecture and the FMHA core kernel are
|
||
known good. If it prints FAIL with an error, the stack trace points at the
|
||
broken layer. If it prints FAIL with wrong text, the diff against the tier-0
|
||
HF reference (below) localizes which sub-block diverges.
|
||
|
||
### Recommended companion — `tier_0_hf_reference.py`
|
||
|
||
For debugging tier 1 when it produces wrong text instead of erroring, you
|
||
need a known-good reference to diff against. Write a 30-line script that
|
||
loads the HF model with `AutoModelForCausalLM.from_pretrained(..., dtype=bfloat16)`,
|
||
generates greedy 8 tokens for the same prompt, and dumps:
|
||
|
||
- `embeds[0, t, :10]` per token
|
||
- `layer[0].input_layernorm(embeds)[0, t, :10]` per token
|
||
- `layer[0].attn.q_proj(normed)[0, t, :10]` per token
|
||
- ... continuing through the layer
|
||
|
||
This is exactly what tiny-vllm's `python/reference.py` does for Llama. We're
|
||
borrowing the pattern, not the code. When tier 1 diverges from this, you
|
||
have a per-tensor diff to localize the failure to a specific sub-block.
|
||
|
||
---
|
||
|
||
## What to put in `scripts/` and what's in scope
|
||
|
||
```
|
||
scripts/
|
||
├── single_shot_baseline.py # Tier 1: BF16 everything, simple cache
|
||
├── single_shot_production.py # Tier 2: full production path (depends on E1)
|
||
├── tier_0_hf_reference.py # HF eager reference for per-tensor diff
|
||
└── README.md # 1 page: how to run, what each script proves
|
||
```
|
||
|
||
Keep all three small. The reference is single-shot. It's not a server, it's
|
||
not a benchmark, it's not a test harness. If those things end up wanted,
|
||
they live elsewhere.
|
||
|
||
---
|
||
|
||
## DOCTRINE NOTES — applies here too
|
||
|
||
1. **DSL wall → raw CUDA C++, not Python.** Doesn't apply to this script —
|
||
the script *is* the Python orchestration layer. But if a sub-call fails
|
||
inside a kernel and the agent's instinct is "let me reimplement that part
|
||
in Python to make the script work," the answer is no. Fix the kernel.
|
||
|
||
2. **Raw CUDA ≠ scalar math.** Doesn't apply directly, but a related
|
||
corollary: this script is allowed to use eager torch ops in
|
||
`SimpleLayerCache` because it is explicitly labeled as the *reference,
|
||
not the fast path.* When E1 lands, the simple cache is replaced wholesale
|
||
by the production cache. The reference's job is to be obviously correct,
|
||
not fast.
|
||
|
||
3. **Print, don't guess.** If the script fails:
|
||
- First step: run `tier_0_hf_reference.py` on the same prompt, dump the
|
||
per-tensor expected values.
|
||
- Second step: instrument the tier 1 script to print the *same* tensors
|
||
at the same points.
|
||
- Third step: diff. The first divergent tensor is the broken sub-block.
|
||
- Do **not** start "fixing things to see if it helps" before doing the
|
||
diff.
|
||
|
||
4. **Integration over exploration.** The script lives in `scripts/`, not in
|
||
`dsv4/`. It is a *consumer* of the library, not part of it. When functions
|
||
in the script are promoted into `dsv4/` (the loader, the model class,
|
||
etc.), the script should *get shorter*, not longer. If the script is
|
||
growing, that's a smell that something belongs in the library.
|
||
|
||
5. **Falsifiable gate.** The gate is `assert "Paris" in generated_text`. Not
|
||
"looks reasonable," not "produces something." Paris or fail. |