Major rewrite of single_shot_inference.py: - Replace broken mHC (gentle normalization hack) with proper Sinkhorn-Knopp - Add RMSNorm before each sub-block (attention + FFN) - Add inverse RoPE on attention output (paper §2.3.3) - Fix KV cache: RoPE applied before caching, K=V in DSV4 MQA - Fix MoE: proper dense routing with e_bias, SwiGLU clamping - Proper weight mapping: fn→W_stacked, base→S_pre/S_res/S_post, scale→alphas - Add identity mHC fallback when weights missing - No emergency normalization, no bandaids
14 KiB
SINGLE-SHOT INFERENCE BASELINE — design and recommendation
Use case (verbatim from you). Deterministic prompt "the capital of France is", greedy decode, expect "Paris" back. If we get Paris, we've ruled out kernel correctness as a source of bugs when we later do the official vLLM integration. The script also doubles as the integration reference for any inference engine — vLLM, SGLang, custom, doesn't matter.
TL;DR — recommendation
Standalone Python script. Do not fork tiny-vllm.
Put it at scripts/single_shot.py in the kernel repo. Have it do the full
orchestration itself (embedding → N layers → final norm → head → argmax). Make
it the reference implementation of how to drive the kernel, not a wrapper
around someone else's engine. Length target: 300–500 lines, one file, no
class hierarchy.
Two reasons. The first one is doctrine; the second is architectural.
Reason 1 — tiny-vllm is built around Llama, not DSV4
tiny-vllm is a teaching repo for a Llama 3.2 1B inference engine. Its
README is explicit: "load a real LLM model from Safetensors (Llama 3.2 1B
Instruct), full LLM forward pass, KV cache, static batching, continuous
batching, PagedAttention." Its python/reference.py calls
AutoModelForCausalLM.from_pretrained("meta-llama/Llama-3.2-1B-Instruct")
and walks Llama's model.layers[0].input_layernorm, q_proj, etc.
Everything about its KV cache, attention shape, layer structure, and weight naming assumes Llama. DSV4 violates basically all of those assumptions:
- KV cache — DSV4 has a heterogeneous cache (paged FP8 + state cache for SWA tail + separate FP4 indexer pool). tiny-vllm's PagedAttention is uniform.
- Attention — DSV4 has hybrid CSA/HCA/SWA per-layer + lightning indexer top-k + grouped output projection + attention sink. tiny-vllm has GQA.
- Layer structure — DSV4 has mHC (Sinkhorn-projected residuals,
n_hc=4, not plainx + sublayer(x)) and MoE withsqrt(softplus)routing, hash routing on the first 3 layers, plus a shared expert. tiny-vllm is dense MLP, plain residual. - Weight naming — DSV4's HF checkpoint has its own naming convention
(
q_down,q_up,kv_down,indexer_q_up,indexer_head_weights,wo_a/wo_b, mHCA/B/Craw params, MoEgate/up/downper expert). tiny-vllm's loader expects Llama keys.
You'd be replacing 80%+ of tiny-vllm to get DSV4 to run in it. At that point it's not "using tiny-vllm" — it's "rewriting tiny-vllm into DSV4," and the "vllm" in the name becomes misleading. The fact that tiny-vllm exists doesn't make it the right starting point.
Reason 2 — a reference implementation should be minimal, not feature-rich
You said "use it like a reference for integration into any inference engine." That sentence is doing a lot of work and it's worth being literal about it.
A reference implementation has one job: show, in the simplest possible code, what the model does so that a non-trivial inference engine knows what it needs to wire up. Anything beyond that is engine-specific concerns that the integrator will replace anyway.
What tiny-vllm has that a reference does not need: static batching, continuous batching, scheduler, PagedAttention, multi-request serving, continuous decode loop. All of those are exactly the things that differ between vLLM and SGLang and your-future-engine. Bundling them into the reference makes the reference less useful, not more.
What a reference does need: model construction, weight loading, one forward
pass per token, simple flat KV growth, deterministic argmax. That's a single
script. It already exists as a known pattern — HF's model.generate(do_sample=False)
is the same shape, minus the kernel control.
What this means for your existing repo
The script is both the inference baseline and the work item E3
("dsv4/model/dsv4.py is a 2-line stub — build the actual model class") from
the Stage E roadmap. Collapse those into one artifact. The script imports
DSV4Layer from dsv4/model/layer.py, drives it explicitly, and serves as
the executable spec for what DSV4Model.forward should later become.
When you eventually integrate into vLLM, you extract the model orchestration
(layer loop, RoPE caches, KV growth) into a DSV4Model class and the script
shrinks to "construct DSV4Model, generate, print." The script itself
remains, forever, as the minimal reproducer for "if we get Paris, kernels are
fine."
Two-tier strategy
The point of the script is to rule out kernel issues. That's only useful
if the script can actually fail in a way that points at the kernel and not at
some unrelated unfinished thing. Today, several DSV4 components are still
incomplete (E1 cache gather, loader/hf_checkpoint.py is a 2-line stub,
model/dsv4.py is a 2-line stub). If we wait for everything, the script
ships in 2 weeks. If we ship in two tiers, we get a useful signal today.
Tier 1 — single_shot_baseline.py — BF16 everything, simple cache
Goal: produce "Paris" with the maximum possible signal-to-noise ratio. No quantization. No paging. No FP8. Just the math.
- Load weights from HF safetensors, dequantize FP4 → BF16 at load time. (The reference checkpoint stores FP4 MoE expert weights; cast them up. Linear weights are already BF16.)
- KV cache: one flat 4D tensor per layer, grown by concatenation. No paged pool, no state cache, no FP8 round-trip. Direct BF16 throughout.
- Compressor: produce dense compressed BF16 entries directly, store in the flat cache. No FP8 quant on the cache write path.
- Indexer: compute index scores in FP32 (this matches the post-LUT-fix
scalar path you already have, so we know it's correct and we know it's not
fused-FP4 yet). Take the top-k indices. The lookup is a plain
cache[..., topk]. - FMHA: call the production multi-tile kernel via
dsv4_attention(q, k_gathered, v_gathered, ...). The dense KV is already materialized. - Sink merge: handled by the FMHA kernel (D5c, already done).
- Inverse RoPE → wo_a → wo_b. mHC residual. FFN sub-block. Repeat for N layers.
What this proves. If Paris comes back here, the architecture is correct and the FMHA core kernel is correct. That is the highest-leverage validation you can do today, with zero dependence on E1.
What this does not prove. Anything about the FP8 paged cache, anything about the FP4 fused SwiGLU MoE epilogue, anything about NVFP4 weight quant on linears. Those are explicitly not in scope for tier 1 — separating them is the whole point.
Tier 2 — single_shot_production.py — full production path
Same script, but routes through the actual production cache + the NVFP4 quantized weights + fused FP4 SwiGLU MoE + (eventually) E7's FP4 tensor-core indexer. This is the script that proves the full integrated path works, and it depends on E1 (cache gather kernels) landing first.
Tier 2's job: produce the same output as tier 1, deterministically. If they diverge, the diff is exactly the difference between BF16 ref and the quantized production path. That diff is the signal you actually want when debugging vLLM integration.
Tier 1 script — concrete specification
This is the spec, not the code — but the spec is precise enough that an agent can implement it without guessing.
File structure
One file: scripts/single_shot_baseline.py. Top-to-bottom: imports, config
load, tokenizer, weight load, model orchestration helpers, the decode loop,
if __name__ == "__main__". No class definitions (use functions and
named tuples where state needs to flow). The whole point is readability as a
reference.
What it imports from dsv4/
dsv4.model.config.DSV4Config— already exists and matches the paper.dsv4.model.layer.DSV4Layer— already exists, takes(X, token_ids, cache: LayerCacheHandle) → X.dsv4.layers.attention.AttentionSubBlock(used internally byDSV4Layer).dsv4.layers.ffn.FFNSubBlock.dsv4.layers.mhc.MHC(already wired intoDSV4Layer).dsv4.kernels.attention.production.dsv4_attention— the production FMHA entry point.dsv4.kernels.compressor.csa_compress_and_store/dsv4.kernels.compressor.hca_compress_and_store.dsv4.kernels.indexer.compute_index_scores_topk.dsv4.ops.rope.apply_rope_bf16,dsv4.ops.rope.inverse_rope_bf16.
What it has to implement itself (because the codebase doesn't have it yet)
These three are missing or stubbed. The script ships them inline; they get
promoted into dsv4/ proper as part of E1–E3.
-
load_dsv4_weights(model, hf_path)— HF safetensors loader. Walkmodel.safetensors.index.json, map HF key names toDSV4Layerparameter names, copy in. For FP4 weights (MoE experts), dequantize at load. This is whatdsv4/loader/hf_checkpoint.pyshould eventually be — keep the function signature small so it can be promoted as-is. -
SimpleLayerCache— aLayerCacheHandle-shaped object that uses flat BF16 tensors instead of paged FP8. Implements the same surfacedsv4/kernels/attention/__init__.pycalls (gather_compressed_kv,gather_all_compressed_kv,gather_swa_kv,num_query_heads,head_dim,positions,request_slots,read_classical_view,read_swa_view,read_indexer_view,write_swa,flush_compression). Eager torch ops are fine here — this is the reference, not the fast path. Each method gets one comment line pointing at the production E1 kernel that replaces it. -
The model loop. Embedding → for layer in layers:
layer.forward(X, ids, cache)→ final norm → prediction head → argmax. Per-layer cache handle from the simple cache manager. This is the body of whatDSV4Model.forwardbecomes in E3.
Numerics requirements
- All math in BF16 except: softmax accumulator (FP32, already in the kernel), RMSNorm reduction (FP32, already in the kernel), Sinkhorn iterations in mHC (FP32, already in the layer), indexer scores (FP32, current scalar path).
- Determinism: torch seed fixed,
torch.use_deterministic_algorithms(True), no cuDNN benchmark, single CUDA stream. Decode is single-token so there's no batching nondeterminism to worry about. - Greedy:
next_tok = logits[-1].argmax(). No temperature, no top-p, no sampling.
Pass/fail gate
prompt = "The capital of France is"
expected = "Paris"
# Decode 8 tokens (one is enough to see "Paris" but a few more catch
# off-by-one tokenization issues).
generated_ids = decode_n(prompt, n_new=8)
generated_text = tokenizer.decode(generated_ids, skip_special_tokens=True)
assert "Paris" in generated_text, f"FAIL: got {generated_text!r}"
print(f"PASS: {prompt!r} -> {generated_text!r}")
Run it. If it prints PASS, the architecture and the FMHA core kernel are known good. If it prints FAIL with an error, the stack trace points at the broken layer. If it prints FAIL with wrong text, the diff against the tier-0 HF reference (below) localizes which sub-block diverges.
Recommended companion — tier_0_hf_reference.py
For debugging tier 1 when it produces wrong text instead of erroring, you
need a known-good reference to diff against. Write a 30-line script that
loads the HF model with AutoModelForCausalLM.from_pretrained(..., dtype=bfloat16),
generates greedy 8 tokens for the same prompt, and dumps:
embeds[0, t, :10]per tokenlayer[0].input_layernorm(embeds)[0, t, :10]per tokenlayer[0].attn.q_proj(normed)[0, t, :10]per token- ... continuing through the layer
This is exactly what tiny-vllm's python/reference.py does for Llama. We're
borrowing the pattern, not the code. When tier 1 diverges from this, you
have a per-tensor diff to localize the failure to a specific sub-block.
What to put in scripts/ and what's in scope
scripts/
├── single_shot_baseline.py # Tier 1: BF16 everything, simple cache
├── single_shot_production.py # Tier 2: full production path (depends on E1)
├── tier_0_hf_reference.py # HF eager reference for per-tensor diff
└── README.md # 1 page: how to run, what each script proves
Keep all three small. The reference is single-shot. It's not a server, it's not a benchmark, it's not a test harness. If those things end up wanted, they live elsewhere.
DOCTRINE NOTES — applies here too
-
DSL wall → raw CUDA C++, not Python. Doesn't apply to this script — the script is the Python orchestration layer. But if a sub-call fails inside a kernel and the agent's instinct is "let me reimplement that part in Python to make the script work," the answer is no. Fix the kernel.
-
Raw CUDA ≠ scalar math. Doesn't apply directly, but a related corollary: this script is allowed to use eager torch ops in
SimpleLayerCachebecause it is explicitly labeled as the reference, not the fast path. When E1 lands, the simple cache is replaced wholesale by the production cache. The reference's job is to be obviously correct, not fast. -
Print, don't guess. If the script fails:
- First step: run
tier_0_hf_reference.pyon the same prompt, dump the per-tensor expected values. - Second step: instrument the tier 1 script to print the same tensors at the same points.
- Third step: diff. The first divergent tensor is the broken sub-block.
- Do not start "fixing things to see if it helps" before doing the diff.
- First step: run
-
Integration over exploration. The script lives in
scripts/, not indsv4/. It is a consumer of the library, not part of it. When functions in the script are promoted intodsv4/(the loader, the model class, etc.), the script should get shorter, not longer. If the script is growing, that's a smell that something belongs in the library. -
Falsifiable gate. The gate is
assert "Paris" in generated_text. Not "looks reasonable," not "produces something." Paris or fail.