Compare commits
60 Commits
v-b1-b2-do
...
v-official
| Author | SHA1 | Date | |
|---|---|---|---|
| 7901470e63 | |||
| ca7c309463 | |||
| 8cfc1cae58 | |||
| a86d6d90a5 | |||
| 284fc9ca86 | |||
| 6a3374da18 | |||
| 5003e756e2 | |||
| 572bdd2840 | |||
| 3c06fd5591 | |||
| 89f6e64057 | |||
| 29d6986dd4 | |||
| 60b9bbd470 | |||
| 1e77dfcaa0 | |||
| 2a42686e8e | |||
| 11c2d5fe53 | |||
| c77b83fffc | |||
| c5a131c358 | |||
| 019a3a34b7 | |||
| 5e09be08af | |||
| 60309ef124 | |||
| 0bf276f8c9 | |||
| d463ac8512 | |||
| 7450ebc67a | |||
| 9dbfac9dfa | |||
| a682c6adf4 | |||
| f2c1b3afd5 | |||
| 86e59c16c5 | |||
| 262f844e2e | |||
| 6459fbca9a | |||
| 91dfac34d8 | |||
| d99503732d | |||
| 801bfc9a83 | |||
| b385ecc05e | |||
| d518fcb82a | |||
| 9574a9dc2e | |||
| 9a9b347b2b | |||
| f5fa20c581 | |||
| 693975ec92 | |||
| e1d96c509d | |||
| 1ebe7f0dde | |||
| d8306be3f2 | |||
| 4126909dfb | |||
| 8c54cfa748 | |||
| 04cf8ca848 | |||
| 75288bd12f | |||
| 5417f65b08 | |||
| dd1cbe1faa | |||
| 09384a637a | |||
| d3dc8cf901 | |||
| 223c22488f | |||
| 2bf5e74e61 | |||
| eb69c3bfb9 | |||
| 99b6de316b | |||
| 9034f67b0f | |||
| a4ef6c3454 | |||
| 1f757151ef | |||
| 07168357cc | |||
| 27d8d80a40 | |||
| 26a817c2f2 | |||
| ba67e055f7 |
@@ -1,100 +0,0 @@
|
||||
# WE ARE BACKLOGGING THIS ISSUE AND WILL REVIST IT AFTER WE FINISH THE OTHER ITEMS IN THE FINAL STRETCH
|
||||
|
||||
**Context:** post-cleanup `single_shot_inference.py` compiles, Paris is top-1 at step 0, output is coherent, then degenerates into repeated junk ("capital ..."). Defaults in effect: `temperature=0.6`, `repetition_penalty=1.1`, `--warmup-gsa` **off**, fused rmsnorm+quant **on**.
|
||||
|
||||
**Read this first — what the symptom rules in/out.** The model is *sampling* (temp 0.6) with a penalty (1.1) and still loops a near-constant token. That is not "greedy with no penalty." It means either (a) the model finished its turn, emitted a stop token we don't catch, and the LM head is now peaked on degenerate filler, or (b) a decode-state correctness bug whose error compounds over steps. (a) is far more likely and is nearly free to test, so **do Part A in order, cheapest first, and do NOT touch kernels until A1–A2 are ruled out.** The math is right (Paris top-1); don't go hunting kernel ghosts before eliminating the decoding-config causes.
|
||||
|
||||
Code is the source of truth. For any precision change, validate per-layer cosine against `dsv4/reference/` before trusting end-to-end output.
|
||||
|
||||
---
|
||||
|
||||
# PART A — Decode repetition (correctness). Do in order.
|
||||
|
||||
## A1 — Stop set (HIGH priority, ~zero effort, most likely the whole bug)
|
||||
`single_shot_inference.py:1571` stops only on `next_id == tokenizer.eos_token_id`. DSV4 is a reasoning/chat model; an assistant turn ends with a special token (`<|end_of_sentence|>`, and the turn structure also uses USER=128803 / ASSISTANT=128804 / `</think>`=128822). If the model's turn-end token isn't `eos_token_id`, decode never stops and degenerates exactly as observed.
|
||||
|
||||
**Diagnose (no code change):**
|
||||
1. Print `tokenizer.eos_token_id`, `tokenizer.eos_token`, and `tokenizer.special_tokens_map` once at startup.
|
||||
2. In the decode log, find the token id emitted at the moment output "should have finished." Decode it: `tokenizer.decode([id])`. If it's a special/end token not in the stop set, that's the bug.
|
||||
|
||||
**Fix:** build an explicit stop set and break on membership, e.g.:
|
||||
```python
|
||||
STOP_IDS = {tokenizer.eos_token_id}
|
||||
for t in ("<|end_of_sentence|>",): # add the real turn-end token name(s) for this checkpoint
|
||||
tid = tokenizer.convert_tokens_to_ids(t)
|
||||
if tid is not None and tid >= 0: STOP_IDS.add(tid)
|
||||
STOP_IDS.add(USER_TOKEN) # model trying to open a new user turn = it's done
|
||||
# ... in the loop:
|
||||
if next_id in STOP_IDS:
|
||||
print(f" STOP ({next_id}) at step {step}", flush=True); break
|
||||
```
|
||||
**A/B:** if adding the stop set ends generation cleanly, the "bug" was never in the kernels. Stop here.
|
||||
|
||||
## A2 — Sampler / penalty sanity (MEDIUM, cheap, diagnostic)
|
||||
A 1.1 penalty over `recent_tokens=all_tokens[-256:]` should at least *perturb* a single-token loop. If it loops the exact same id anyway, suspect the penalty isn't reaching the kernel or is mis-indexed.
|
||||
- **Test:** rerun with `--repetition-penalty 1.5`. If the loop is *unchanged*, the penalty path in `dsv4/model/sampler.py` (CUDASampler) is broken — verify `recent_tokens` is actually passed to and applied by the kernel, and that it indexes the logit vector correctly. If raising it *does* break the loop, the sampler is fine and this was a stop-token/decoding-hygiene issue (see A1).
|
||||
- Also confirm `recent_tokens` includes the *prompt* tokens, not just generated ones, or the model can loop on a prompt word ("capital") penalty-free.
|
||||
|
||||
## A3 — Compressed/SWA visible-range parity (MEDIUM, architectural — verify vs reference)
|
||||
During decode the query attends to `[top-k compressed entries] ++ [SWA window]` (`forward_attention`, "5. Gather KV"). Two things to verify against the HF/`dsv4/reference` oracle, because an off-by-one here causes subtle wrongness that **compounds across decode steps** (coherent early, degenerate late — matches the symptom):
|
||||
1. **Which compressed blocks are visible to a decode query.** Causality: a query must see only compressed blocks strictly *preceding* its own current (incomplete) block, never its own or future blocks. Confirm the set of compressed indices fed to the FMHA at step `s` matches the reference exactly.
|
||||
2. **SWA / compressed overlap.** The most recent tokens are in the SWA ring (`ws=128`) *and* may also be inside the newest complete compressed block → the query can attend to both representations of the same tokens. This may be intended (SWA refines what compression blurred, and the model was trained with it) — but it must match how the reference gathers, or the recent-context weighting drifts. Diff the gathered key set (indices + count) against the reference for a fixed prompt at several decode positions.
|
||||
|
||||
> Note: the residual `|X|` growing to ~244–372 is **expected** (the paper notes 300–500; your own `KVCache` docstring says the same). It is not by itself the bug. See B5 only if A1–A3 don't resolve it.
|
||||
|
||||
## A4 — (verify, likely fine) Inverse RoPE
|
||||
`forward_attention:783` applies `_apply_rope(attn_out, positions, ..., inverse=True)` at the **query** position, which is what converts the absolute positions carried by the summed KV into relative ones. This looked correct. Just confirm `inverse=True` negates the rotation angle (applies RoPE(−t)) and uses the query `positions` (not `comp_pos`). Only revisit if A1–A3 are clean and degeneration persists.
|
||||
|
||||
---
|
||||
|
||||
|
||||
When we were doing this we tried ALL OF THE ABOVE!!!!! and this is where we left off:
|
||||
|
||||
```
|
||||
Good, the greedy test (temp=0, 64 tokens) is prefilling. Let me wait for it to hit decode.Greedy with temp=0 and rep_penalty=1.0. The model is looping on "capital of France is" — just re-echoing the prompt. It never generates "Paris" or any real answer. This is NOT a sampler issue. The model's logits are fundamentally wrong.
|
||||
|
||||
The sampler and penalty are cleared. The degeneration is in the model/pipeline itself. Let me look at the step 0 output more carefully.Key finding: " Paris" (111) has logit 10.62, but the TOP-1 token at step 0 is "," with logit 17.4. That's a 7-point gap. The model should have " Paris" much higher — in a correctly working model, "Paris" after "The capital of France is" should be 25+ logit, not 10.
|
||||
|
||||
This is a pipeline correctness issue. The model is not producing the right logits. The step-0 degeneration confirms it's not about repetition or sampling — the logits themselves are wrong from the very first decode step.
|
||||
|
||||
Given the user's instruction to move to Part B if the sampler/prompt diagnostics don't resolve it, let me focus on the most impactful pipeline fix. The residual growing to 472-732 (|X|) is suspicious. Let me look at the P4 fused rmsnorm — specifically the `run_from_quantized` path vs the regular `run` path, since memory notes the P4 integration caused output quality regression.
|
||||
|
||||
The previous test with `--no-fused-rmsnorm` ALSO produced degenerate output, so P4 isn't the sole cause. But let me check the KEY difference: in the `run_from_quantized` path, `gsa` is passed per-row from the QuantizedActivation, while in the regular path, gsa is a scalar. The MEMORY.md notes this exact issue:
|
||||
|
||||
> **ROOT CAUSE HYPOTHESIS**: run_from_quantized passes per-row gsa to GEMM, but CuTeDSL NVFP4 GEMM expects scalar global_scale_aSo `run_from_quantized` passes per-row gsa (shape (M,)) as `global_scale_a`. Let me check what the GEMM expects.The `global_scale_a` is passed as a tensor to the CuTe kernel. For the single-expert case (Nvfp4Linear), `expert_offsets` has shape (1,) and `global_scale_a` would be the gsa. Let me check how the kernel uses it.There it is. The kernel loads `global_scale_a` using `global_scale_a.iterator + expert_idx`. This is a **per-expert** scalar — it indexes by expert_idx, not by row. For a single-expert case (Nvfp4Linear with expert_offsets of shape (1,)), expert_idx is always 0, so it reads `global_scale_a[0]` — a single scalar.
|
||||
|
||||
But `run_from_quantized` passes per-row gsa as `global_scale_a` with shape `(M,)`. The kernel only reads `global_scale_a[0]` (the first row's gsa), ignoring the rest. This means for decode (M=1), it works fine (one row, one gsa). But the issue is that the fused rmsnorm kernel computes gsa per row, and the GEMM only uses row 0's gsa.
|
||||
|
||||
Wait, for decode, M=1 always. So per-row vs scalar doesn't matter for the decode path. The issue would be in prefill where M>1. Let me re-examine...
|
||||
|
||||
Actually, the prefill is done one token at a time (the code explicitly says "one token at a time (decode-style; TODO: batched prefill)"). So M=1 throughout. The per-row gsa should be fine since there's only one row.
|
||||
|
||||
Let me look deeper at the actual numerical pipeline. The real question is: why are the logits wrong? Let me add diagnostic code to compare our pipeline output with a known reference at a specific layer.Let me run with --verbose 2 to get per-layer diagnostics.Still building MoE layers. Need to wait longer.**CRITICAL FINDING**: L1 FMHA cosine is 0.679. That's terrible. L2 FMHA cosine is 0.956 — also bad. The q_a projection is fine (0.999998), but the FMHA output is wrong. This is the root cause of the degeneration.
|
||||
|
||||
Let me get the L0 data too.L0 FMHA cos=0.955, L1 FMHA cos=0.679, L2 FMHA cos=0.956. L1 is particularly bad. This is the attention pipeline producing incorrect results.
|
||||
|
||||
The reference comparison is done against PyTorch SDPA on the same gathered KV. So the FMHA kernel itself is wrong for these attention patterns. Let me understand why L1 is 0.679 while L0 and L2 are ~0.95.
|
||||
|
||||
L0 and L2 are HCA layers (ratio=128), L1 is also HCA (the first two layers are HCA for Pro). Wait, let me check the layer schedule.So L0=HCA(128), L1=HCA(128), L2=CSA(4), L3=HCA(128), L4=CSA(4). L1 (HCA) has terrible FMHA cosine. The issue is in the FMHA when dealing with the compressed KV patterns.
|
||||
|
||||
Let me check what seq_len the FMHA sees at each layer. The verbose=2 diagnostic should show this in the forward_attention function.The cos=0.679 at L1 is devastating. Even though the magnitudes are similar (2.78 vs 2.94), the direction is wrong. This compounds through 61 layers.
|
||||
|
||||
Let me add more diagnostic output to the FMHA — specifically, print the shapes and seq_len being passed.Wait — `n_comp=0` is hardcoded! The production FMHA call always sets `n_comp=0`. But the attention pattern for HCA/CSA involves compressed entries + SWA window. The `n_comp` parameter is crucial for the sink bias and possibly for the FMHA dispatch logic.
|
||||
|
||||
Let me check what `n_comp` does in the production attention.The `n_comp` parameter is described as "reserved for future kernel integration" and it's always 0. The sink_bias is also "reserved for future kernel integration". So the production FMHA is running WITHOUT the D5c sink bias that the paper describes as essential for DSV4 attention.
|
||||
|
||||
But more critically, the cos=0.679 at L1 suggests the FMHA itself is producing wrong results even without sink bias. Let me look at the reference comparison code — it compares the production FMHA output against PyTorch SDPA on the SAME gathered KV. So the KV is the same, but the kernel produces a different result.
|
||||
|
||||
Wait — the reference comparison is in the forward_attention function at `VERBOSE >= 2`. Let me look at that code more carefully.The reference uses `v_exp = k_exp.clone()` — this means V = K. In DSV4, the KV cache stores combined K/V where K and V are the same tensor (MQA). That's correct for the reference comparison.
|
||||
|
||||
But the production FMHA also uses V = K:
|
||||
```python
|
||||
v = k # Same tensor, no copy
|
||||
```
|
||||
|
||||
So the reference and production should agree. The cos=0.679 means the kernel is numerically wrong for this specific attention pattern.
|
||||
|
||||
Let me check what seq_len the FMHA is seeing. At L1 during prefill of the first token, the KV cache would have very few entries. Let me add seq_len printing.
|
||||
```
|
||||
|
||||
SO SINCE WE HAD TO TOUCH FMHA ANYWAY IN PART B. WE DID THAT FIRST AND TRIED TO GET THAT CORRECT BEFORE WE REVISTED THIS ISSUE!!!
|
||||
113
README.md
113
README.md
@@ -2,7 +2,8 @@
|
||||
|
||||
Production-grade Blackwell SM100 inference kernel for **DeepSeek-V4-Pro NVFP4**, written in CuTeDSL with a CUDA fallback path. Target hardware: NVIDIA B200 (180 GiB HBM3e).
|
||||
|
||||
For what's done, what's blocked, and what's next, see **ROADMAP.md**. This file is the durable reference — architecture, design choices, package layout, workflow, and hard-won lessons. If you're touching the kernel, read the "Lessons learned" section every time.
|
||||
|
||||
This file is the durable reference — architecture, design choices, package layout, workflow, and hard-won lessons. If you're touching the kernel, read the "Lessons learned" section every time.
|
||||
|
||||
---
|
||||
|
||||
@@ -88,50 +89,6 @@ One pass, one kernel. No two-loop epilogue, no LSE arithmetic in the merge. This
|
||||
|
||||
---
|
||||
|
||||
## Our kernel design choices
|
||||
|
||||
### Attention kernel (FmhaKernel)
|
||||
|
||||
**6-warp specialization.** Warps 0–3 handle softmax + correction + epilogue. Warp 4 is the MMA warp (QK + PV). Warp 5 is the TMA warp (Q/K/V loads, output store via pipeline).
|
||||
|
||||
**P staging — two paths.**
|
||||
- **TMEM-P** (hd ≤ 64): P stored to TMEM via register bridge (FP32 backing + BF16 view). PV reads P from TMEM. Used at the small head dims where QK C-fragment and PV A-fragment TMEM layouts agree.
|
||||
- **SMEM-P** (hd > 64): P written to SMEM via coordinate-indexed store using `tTMEM_LOADcS` to map register indices to `(m, k)` then into `sP`'s subtile layout. PV reads P from SMEM with `OperandSource.SMEM`. Required because the QK ↔ PV TMEM layout disagreement at hd > 64 corrupts the round-trip.
|
||||
|
||||
**Un-normalized O + LSE output.** The kernel emits raw `sum(P · V)` and `lse = ln(row_sum) + row_max · ln(2)`. External code (or the next kernel pass) divides. This composes — D5 merge, multi-tile rescale, and the inverse-RoPE → wo_a fuse all rely on it.
|
||||
|
||||
**Per-head launch for multi-head.** Python loop dispatches the single-CTA kernel once per head. Multi-CTA grid using `flat_divide` + `tma_partition` is the next refactor (see ROADMAP); the path is unblocked once the correction-epilog rewrite lands.
|
||||
|
||||
**Head-packed M dimension for decode.** Q reshaped to `(n_h * T, hd, 1)`, all heads' rows packed into the 128-row M tile. Per-row softmax. At Pro decode (T=1, n_h=128) the M tile fits exactly.
|
||||
|
||||
**K-dim sub-tiling at hd > 256.** When `head_dim > 256` (MMA instruction K-dim limit), Q and K split into `n_k_sub_tiles = head_dim / 256` chunks along head_dim. QK accumulates in TMEM across sub-tiles (additive in logit space). The PV path uses `pv_n_tile = 128` for hd > 256 to keep sV+sC within the 232 KB SMEM budget.
|
||||
|
||||
**Sink bias as logit modification.** D3 (SWA length mask), D4 (causal mask on SWA), and D5c (attention sink) all live in the same post-QK, pre-softmax in-register code. They read `tTMEM_LOADcS` to get `(m, k)` coordinates and modify `tTMEM_LOADrS` before the row-max reduction. The sink bias is added in the raw-logit domain as `attn_sink / scale_softmax`, then the existing `* scale_log2` multiply converts to log2 space.
|
||||
|
||||
### MoE kernel (FusedSwiGLUScaledGroupedGemmKernel)
|
||||
|
||||
**7-warp specialization.** Warps 0–3 epilogue (TMEM → registers → SMEM → GMEM with global scale, SwiGLU, clamp). Warp 4 MMA (`tcgen05.mma.block_scale` with SFA/SFB in TMEM). Warp 5 TMA load (A, B, SFA, SFB). Warp 6 scheduler (`MoEStaticPersistentTileScheduler`).
|
||||
|
||||
**One-way TMEM → registers → SMEM → GMEM epilogue.** Uses `epilogue_tmem_copy_and_partition` + `epilogue_smem_copy_and_partition` (CUTLASS helpers, paired atoms). The SwiGLU + clamping math runs in registers between the t2r and r2s copies. No TMEM round-trip. This is the same pattern FMHA needs to adopt to fix the D1.5 blocker — see ROADMAP.
|
||||
|
||||
**Subtile-level gate/up pairing.** With granularity-8 interleaved L1 weights and `epi_tile_n=8`, even subtiles are gate and odd subtiles are up. `silu_gate_buf` register tensor carries the SiLU result across the subtile-pair boundary.
|
||||
|
||||
**`use_2cta_instrs` conditional** on `tokens_sum ≥ 256` and even `cluster_m`. Decode (small M) stays 1-CTA; prefill/batched gets 2-CTA UMMA with multicast B (1.7–1.9× throughput).
|
||||
|
||||
### Heterogeneous KV cache
|
||||
|
||||
- **State cache** per request: fixed-size block holding `(n_win SWA KV)` and `(uncompressed tail tokens awaiting compression)`. One block per request, lifetime managed by request scheduling.
|
||||
- **Classical paged cache** per request: variable blocks holding `(k1 CSA compressed entries, k2 HCA compressed entries)` per layer. `k1 = lcm(m, m') / m = 32`, `k2 = lcm(m, m') / m' = 1`. Block covers 128 original tokens.
|
||||
- Different layers can produce different KV cache sizes (CSA vs HCA vs SWA-only). The state cache + classical-pool split keeps PagedAttention-style alignment intact for the compressed pool.
|
||||
|
||||
### NVFP4 throughout
|
||||
|
||||
- **Weights**: NVFP4 (FP8 E4M3 scales, 16-element microblocks). Verified: `sf_dtype`, TMA element type, MMA kind (`mxf4nvf4`) all correct.
|
||||
- **Activations**: BF16 today, FP4 after NVFP4-1.x epilogue fusion lands (see ROADMAP).
|
||||
- **KV cache**: BF16 today; the FP8 (RoPE in BF16, NoPE in FP8) split per paper §2.3.4 is on the roadmap as NVFP4-2.
|
||||
- **Indexer keys**: stored FP4 in the cache today, but scored with a scalar CUDA-core kernel. Tensor-core FP4 scoring (paper §5.2.1) is a Stage F priority.
|
||||
|
||||
---
|
||||
|
||||
## Package structure
|
||||
|
||||
@@ -201,30 +158,35 @@ Both harnesses follow the same discipline:
|
||||
4. **Run in screen** — survives SSH drops, has a timeout
|
||||
5. **One test at a time** — no parallel launches, ever
|
||||
|
||||
### Python test (one command)
|
||||
### Python test
|
||||
|
||||
```bash
|
||||
# From local machine — auto-pushes, runs, polls, dumps log
|
||||
# DEFAULT timeout: 600s (10 min). Override with all 4 args:
|
||||
~/.openclaw/workspace/fire_b200_test <test_file> [screen_name] [log_file] [timeout_sec]
|
||||
|
||||
# Examples:
|
||||
~/.openclaw/workspace/fire_b200_test tests/unit/test_fmha_v3_stage_c.py
|
||||
~/.openclaw/workspace/fire_b200_test tests/unit/test_degeneration_2_mhc_falsify.py kernel-test /tmp/kernel-test.log 1800
|
||||
```
|
||||
|
||||
### CUDA test (one command)
|
||||
### CUDA test
|
||||
|
||||
```bash
|
||||
# From local machine — compiles with nvcc, runs, polls, dumps log
|
||||
# Default timeout: 60s. Pass a second arg for custom timeout.
|
||||
~/.openclaw/workspace/fire_b200_cuda_test tests/unit/test_fmha_sm100_standalone.cu
|
||||
~/.openclaw/workspace/fire_b200_cuda_test tests/unit/test_tmem_minimal.cu 30
|
||||
~/.openclaw/workspace/fire_b200_cuda_test tests/unit/test_tmem_minimal.cu 30 # custom timeout
|
||||
```
|
||||
|
||||
### Check on a running CUDA test
|
||||
### Check on a running test
|
||||
|
||||
```bash
|
||||
# Show current log + screen status
|
||||
# Check CUDA test log + screen status
|
||||
~/.openclaw/workspace/check_b200_cuda
|
||||
~/.openclaw/workspace/check_b200_cuda kill # kill a hung test
|
||||
|
||||
# Kill a hung test + show the log
|
||||
~/.openclaw/workspace/check_b200_cuda kill
|
||||
# Check Python test — SSH to B200 and tail the log:
|
||||
ssh root@<B200> tail -f /tmp/kernel-test.log
|
||||
```
|
||||
|
||||
### Manual B200 cycle (emergency only)
|
||||
@@ -236,7 +198,44 @@ bash tests/run_test.sh tests/unit/test_<...>.py
|
||||
bash tests/check_log.sh
|
||||
```
|
||||
|
||||
`run_test.sh` kills any prior `kernel-test` screen (with SIGKILL on stuck GPU procs), deletes the old log, starts a fresh `screen -dmS kernel-test`, and logs to `/tmp/kernel-test.log`.
|
||||
### ⚠️ Test harness gotchas (READ THIS — cost real time)
|
||||
|
||||
1. **The timeout is the 4th argument, not the 2nd.**
|
||||
- WRONG: `fire_b200_test test.py 1800` ← this makes `1800` the SCREEN NAME
|
||||
- RIGHT: `fire_b200_test test.py kernel-test /tmp/kernel-test.log 1800`
|
||||
- When you pass just a number as the 2nd arg, the screen gets a numeric name
|
||||
and the harness can't kill the old `kernel-test` screen on the next run.
|
||||
- **Always pass all 4 args** when you need a custom timeout.
|
||||
|
||||
2. **After a timeout, the harness kills the screen but NOT the GPU process.**
|
||||
- The `timeout` command inside screen kills the shell, but CUDA processes survive.
|
||||
- Before re-running, check: `ssh root@<B200> nvidia-smi --query-compute-apps=pid --format=csv,noheader`
|
||||
- Kill stale processes: `kill -9 <pid>` for each GPU process listed
|
||||
- Or: `for pid in $(nvidia-smi --query-compute-apps=pid --format=csv,noheader); do kill -9 $pid; done`
|
||||
|
||||
3. **After an OOM or crash, stale GPU processes WILL be left behind.**
|
||||
- Always check `nvidia-smi` before running a new test after a failure.
|
||||
- The harness kills `python.*test_` and `python.*inference` procs, but if the
|
||||
process name doesn't match the pattern, it survives.
|
||||
|
||||
4. **Single-shot tests MUST use the harness too.**
|
||||
- `single_shot_inference.py` is NOT a unit test, but it MUST be run via the harness.
|
||||
- WRONG: ssh to B200 and run `python single_shot_inference.py` directly
|
||||
- RIGHT: `fire_b200_test single_shot_inference.py kernel-test /tmp/kernel-test.log 1800 -- --max-tokens 512`
|
||||
- Extra args after `--` are passed to the Python script.
|
||||
- If the harness can't handle your use case, FIX THE HARNESS, don't bypass it.
|
||||
|
||||
5. **Weight loading + CuTeDSL compilation takes 5-10 minutes.**
|
||||
- First FMHA call triggers JIT compile of CuTeDSL kernels.
|
||||
- This is EXPECTED. Do NOT kill the process because it "seems stuck".
|
||||
- Use 1800s (30 min) timeout for full-model tests.
|
||||
|
||||
6. **The screen name must match between runs.**
|
||||
- The harness kills the old screen by name. If you used a different name last time,
|
||||
the old screen survives and holds GPU memory.
|
||||
- Always use `kernel-test` for Python tests and `cuda-test` for CUDA tests.
|
||||
- If you accidentally used a numeric screen name, clean up manually:
|
||||
`ssh root@<B200> screen -S <wrong_name> -X quit`
|
||||
|
||||
### Environment
|
||||
|
||||
@@ -262,7 +261,7 @@ These are surface-level traps. Get them wrong and the kernel silently produces g
|
||||
|
||||
4. **`cute.arch.fmax` is impure** for the vectorizer. Use it inside plain `range()`, never inside `vectorize=True`.
|
||||
|
||||
5. **Hand-constructed TMEM atoms corrupt data on round-trip.** Independently-built `Ld32x32bOp` + `St32x32bOp` atoms have addressing that doesn't match — even a NO-OP round-trip drops cos to ~0.97. Use paired atoms from `epilogue_tmem_copy_and_partition` / `epilogue_smem_copy_and_partition` for one-way trips. This is the D1.5 blocker in ROADMAP.
|
||||
5. **Hand-constructed TMEM atoms corrupt data on round-trip.** Independently-built `Ld32x32bOp` + `St32x32bOp` atoms have addressing that doesn't match — even a NO-OP round-trip drops cos to ~0.97. Use paired atoms from `epilogue_tmem_copy_and_partition` / `epilogue_smem_copy_and_partition` for one-way trips.
|
||||
|
||||
6. **CuTeDSL `if` blocks are separate MLIR regions.** Variables defined inside one `if` are not visible in another, even when the condition is a compile-time constant. Define all variables unconditionally before any branching.
|
||||
|
||||
@@ -303,13 +302,13 @@ These cost real days to learn. They are listed in priority of how easy they are
|
||||
- **FMHA P store uses QK C-fragment composition, not PV A-fragment.** Two aliases of the same TMEM region. Mixing them up gives valid-looking garbage.
|
||||
- **Register bridge for P: FP32 backing (store partition) + BF16 view (QK-load layout).** Do not skip the dual view.
|
||||
- **TMEM round-trip mismatch with `epilogue_tma_store`**: `epilogue_tma_store` reads O from TMEM using `get_tmem_load_op`'s layout. Hand-built atoms read with a different layout. Round-tripping through hand-built atoms transcodes the data, leaving 3% error.
|
||||
- **The correction-epilog pattern is the fix.** TMEM → registers (via paired t2r atom) → modify in registers → SMEM (via paired r2s atom) → GMEM (via TMA). One-way trip, no round-trip, no transcoding. The MoE kernel uses this and gets perfect results. See ROADMAP.
|
||||
- **The correction-epilog pattern is the fix.** TMEM → registers (via paired t2r atom) → modify in registers → SMEM (via paired r2s atom) → GMEM (via TMA). One-way trip, no round-trip, no transcoding. The MoE kernel uses this and gets perfect results.
|
||||
|
||||
### CuTeDSL & MLIR
|
||||
|
||||
- **CuTeDSL `if` blocks create separate MLIR regions.** Variables defined in `if not use_smem_p:` and read in another `if not use_smem_p:` inside a `for` inside an `if warp_idx < mma_warp_id:` are not visible. Define unconditionally before any branching.
|
||||
- **CuTeDSL compiles both branches of Python `if`.** Wrap mode-specific dead code in `const_expr(condition)` to eliminate it. Critical for O rescale (`n_kv_tiles > 1`), LSE compute (`not normalize`), SMEM-P path.
|
||||
- **CuTeDSL MLIR backend cannot handle complex pipeline loops at hd=512.** Both unrolled (Python `range`) and runtime (`cutlass.range unroll=1`) loops trigger exponential-or-worse optimizer time. Tracer is fast (~0.8s); MLIR optimizer chews for 3+ hours. Workaround options in ROADMAP.
|
||||
- **CuTeDSL MLIR backend cannot handle complex pipeline loops at hd=512.** Both unrolled (Python `range`) and runtime (`cutlass.range unroll=1`) loops trigger exponential-or-worse optimizer time. Tracer is fast (~0.8s); MLIR optimizer chews for 3+ hours.
|
||||
- **Don't mix Python loops and pipeline ops.** Python `for` unrolls at trace time — N copies of pipeline acquire/release + TMA + GEMM blow up the IR. Prefer `cutlass.range(unroll=1)` for pipeline loops.
|
||||
|
||||
### Math & merging
|
||||
|
||||
288
archived_plans/CORRECTNESS_BACKLOG.md
Normal file
288
archived_plans/CORRECTNESS_BACKLOG.md
Normal file
@@ -0,0 +1,288 @@
|
||||
# CORRECTNESS BACKLOG — Production Pipeline Verification Results
|
||||
|
||||
Everything in this file has been TESTED at production values on the B200.
|
||||
If you think something is broken, check here first — it might already be verified correct.
|
||||
Last updated: 2026-06-03 07:30 UTC
|
||||
|
||||
---
|
||||
|
||||
## 1. FMHA (Flash Multi-Head Attention)
|
||||
|
||||
### Prefill FMHA — VERIFIED CORRECT
|
||||
- **Test**: `tests/unit/test_production_fmha_layer.py`
|
||||
- **Method**: Run 5 prefill tokens, compare production FMHA output vs PyTorch SDPA on the SAME KV, per layer
|
||||
- **Result**: cos >= 0.999993 for all 5 tested layers
|
||||
- **Production values**: HD=512, H=128, MQA (1 KV head), scale from config
|
||||
- **Status**: ✅ CORRECT — not a source of decode degeneration
|
||||
|
||||
### Decode FMHA — VERIFIED CORRECT
|
||||
- **Test**: `tests/unit/test_decode_fmha_layer.py`
|
||||
- **Method**: Run prefill to populate KV cache, then compare production FMHA vs PyTorch SDPA during the FIRST decode step
|
||||
- **Result**: cos >= 0.999976 for all 5 tested layers
|
||||
- **Production values**: HD=512, H=128, mixed FP8/BF16 KV (B1 path), MQA
|
||||
- **Key insight**: The FMHA kernel is correct during BOTH prefill and decode. The mixed FP8/BF16 KV path (noPE in FP8, RoPE in BF16) works correctly.
|
||||
- **Status**: ✅ CORRECT — not a source of decode degeneration
|
||||
|
||||
### B1 Mixed FP8 Decode Kernel — VERIFIED CORRECT
|
||||
- **Test**: `tests/unit/test_b1_mixed_fp8_fmha.py`
|
||||
- **7 test categories, ALL PASS** at production values (HD=512, H=128, N=128..2048)
|
||||
- Includes: quantize_q_fp8_split, gather_mixed, FMHA cosine, attention sinks, GQA, weight loading, batch sizes
|
||||
- **Bug fixed**: V matrix canonical layout swap (canon_idx args were swapped) — commit 4fe7f9d
|
||||
- **Status**: ✅ CORRECT
|
||||
|
||||
### B1 Prefill Kernel (T>1) — VERIFIED CORRECT
|
||||
- **Bug fixed**: T-dimension strides were wrong for T>1
|
||||
- q_nope_t_stride, q_scale_t_stride, q_rope_t_stride added to params + C API + Python
|
||||
- For T=1: wrong stride is invisible. For T>1: reads from wrong head's data
|
||||
- Commit 5417f65
|
||||
- **Result**: ALL 16 T>1 test configs pass (cos >= 0.999887)
|
||||
- **Status**: ✅ CORRECT
|
||||
|
||||
---
|
||||
|
||||
## 2. Compressor (CSA/HCA)
|
||||
|
||||
### Compressor kv_norm — VERIFIED CORRECT
|
||||
- **kv_norm_weight loaded for ALL 61 layers** — values range 0.21-4.16 (most are 0.3-2.0)
|
||||
- The `apply_kv_norm_kernel` in `compressor_reduce.cu` IS being called after compression
|
||||
- kv_norm applies unweighted RMSNorm + learned weight: `output = input * inv_rms * norm_weight[c]`
|
||||
- After kv_norm, compressed KV should have magnitude ~0.3-2.0 (matches norm_weight range)
|
||||
- **Status**: ✅ CORRECT — kv_norm IS being applied, weights ARE loaded
|
||||
|
||||
### Compressor Output — VERIFIED at production scale
|
||||
- CSA (ratio=4): compresses every 4 tokens, produces 1 compressed entry per block
|
||||
- HCA (ratio=128): compresses every 128 tokens — with only 10 prefill tokens, produces 0 entries
|
||||
- After 10 prefill tokens: CSA layers have n_comp=2, HCA layers have n_comp=0
|
||||
- **Status**: ✅ WORKING — produces reasonable compressed entries
|
||||
|
||||
### Compressor CUDA kernels — VERIFIED
|
||||
- `compressor_reduce.cu`: CSA and HCA reduce kernels with token-level softmax + weighted sum + kv_norm
|
||||
- `csa_compress_reduce_kernel`: applies position bias, softmax over m=4 tokens, weighted sum, then kv_norm
|
||||
- `hca_compress_reduce_kernel`: same for m'=128 tokens (mean reduction for HCA)
|
||||
- Both call `apply_kv_norm_kernel` if `kv_norm_weight.numel() > 0`
|
||||
- **Status**: ✅ CORRECT
|
||||
|
||||
---
|
||||
|
||||
## 3. KV Cache & Gathering
|
||||
|
||||
### Mixed FP8/BF16 KV Format — VERIFIED
|
||||
- noPE dims (448): stored as FP8 E4M3 + per-row float32 scale
|
||||
- RoPE dims (64): stored as BF16
|
||||
- `gather_mixed_selective()`: CSA top-k gather of compressed + SWA tail
|
||||
- `gather_mixed_all()`: HCA dense gather of all compressed + SWA tail
|
||||
- `gather_mixed_swa_only()`: for layers with ratio<=1 or no compression yet
|
||||
- `copy_comp_rows_kernel` in `fp8_attention_io.cu`: actual CUDA gather
|
||||
- **Status**: ✅ WORKING — correct dtypes, correct shapes
|
||||
|
||||
### Causality — VERIFIED NO VIOLATIONS
|
||||
- **Test**: `test_part_a_decode_diagnostics.py` checks `future_leak` for all 61 layers
|
||||
- At decode step: no compressed position >= decode position
|
||||
- CSA top-k indices are clamped to [0, n_comp-1]
|
||||
- **Result**: `future_leak=no` for ALL 61 layers during decode
|
||||
- **Status**: ✅ CORRECT — no causality violations
|
||||
|
||||
### KV Cache State After 10 Prefill Tokens
|
||||
- HCA layers (ratio=128): n_comp=0, swa_len=10, total_KV=10
|
||||
- CSA layers (ratio=4): n_comp=2, swa_len=10, total_KV=12
|
||||
- CSA attends to: 2 compressed + 11 SWA = 13 entries during decode (11 SWA = 10 from prefill + 1 from decode)
|
||||
- HCA attends to: 0 compressed + 11 SWA = 11 entries during decode
|
||||
- **Status**: ✅ CORRECT — expected behavior with 10 prefill tokens
|
||||
|
||||
---
|
||||
|
||||
## 4. mHC (Manifold-Constrained Hyper-Connections)
|
||||
|
||||
### mHC Sinkhorn — VERIFIED
|
||||
- B_l is produced by Sinkhorn-Knopp with t_max=20 iterations
|
||||
- B_l col sums = 1.0000 (perfectly doubly stochastic)
|
||||
- B_l row sums range [0.93, 1.08] — not perfectly doubly stochastic but close
|
||||
- This matches the PyTorch reference: eps after softmax shifts rows slightly
|
||||
- The Sinkhorn IS working correctly — the growth is inherent to mHC, not a kernel bug
|
||||
- **Status**: ✅ CORRECT — but causes residual growth (see below)
|
||||
|
||||
### mHC Residual Growth — CONFIRMED as Root Cause of Decode Degeneration
|
||||
- **|X| grows from 0.21 to 860 across 61 layers during decode**
|
||||
- Growth pattern (decode step, 10 prefill tokens):
|
||||
- L0-L20: |X| stays 0.2-2.5 (bounded)
|
||||
- L21-L45: |X| grows 2.5-35 (gradual increase, C_l values growing)
|
||||
- L46-L55: |X| grows 35-73 (accelerating)
|
||||
- L56-L60: |X| grows 73-860 (exponential)
|
||||
- Key layers where growth spikes:
|
||||
- L56 (CSA): 73 → 177 (C_l max=1.92)
|
||||
- L58 (CSA): 151 → 209 (C_l max=1.60)
|
||||
- L59 (HCA): 209 → 330 (C_l max=1.88)
|
||||
- L60 (CSA): 330 → 860 (C_l max=1.73, |F_attn|=314, |F_ffn|=460)
|
||||
- **This is ARCHITECTURAL, not a kernel bug**: B_l preserves X (col sums=1.0), C_l adds F_out. Over 61 layers, |X| compounds.
|
||||
- The paper says 300-500 is expected. We see 860 with only 10 prefill tokens.
|
||||
- **The degenerate output ("capitalizing" loops) is caused by this residual growth compressing the logit range** — the model cannot distinguish between tokens when |X| is large.
|
||||
- **Status**: ❌ NOT A BUG — architectural property. Need model-level fix (residual clipping, C_l scaling, etc.)
|
||||
|
||||
### mHC Dynamic Parameters — VERIFIED
|
||||
- A_l (pre-block mixing): values mostly near 1.0 (sigmoid saturated at 0 or 1)
|
||||
- C_l (post-block scaling): values grow from 0.02 at L0 to 1.9 at L60
|
||||
- This growth in C_l is what amplifies F_out and drives |X| growth
|
||||
- B_l (post-block mixing): Sinkhorn working correctly (col sums=1.0)
|
||||
|
||||
---
|
||||
|
||||
## 5. Router
|
||||
|
||||
### Hash Router (L0-L2) — VERIFIED
|
||||
- Mode: "hash" — deterministic per-token-ID LUT lookup
|
||||
- Uses `tid2eid` weight (shape [129280, 6], int64 → cast to int32)
|
||||
- `hash_router_dispatch` CUDA kernel loads and runs correctly
|
||||
- **Status**: ✅ CORRECT
|
||||
|
||||
### Dense Router (L3+) — VERIFIED
|
||||
- Mode: "dense" — sqrt(softplus(X @ W_gate)) + e_bias, top-k selection
|
||||
- NVFP4 gate GEMM with runtime-quantized activation global scale
|
||||
- For layers where gate.weight is BF16 (no weight_scale in checkpoint): quantized to NVFP4 at runtime
|
||||
- `dense_router_dispatch` CUDA kernel with fused NVFP4 GEMM + activation_topk
|
||||
- **Status**: ✅ WORKING
|
||||
|
||||
---
|
||||
|
||||
## 6. MoE (Mixture of Experts)
|
||||
|
||||
### Nvfp4MoE (Routed Experts) — VERIFIED
|
||||
- 384 routed experts, top-6 selection
|
||||
- SwiGLU activation with swiglu_limit=10.0
|
||||
- Fused SwiGLU NVFP4 GEMM kernel (7-warp specialization)
|
||||
- `_use_runtime_gsa = True` — activation global scale computed at runtime
|
||||
- |F_ffn| ranges 0.5-460 during decode (scales with |X|, expected)
|
||||
- **Status**: ✅ WORKING
|
||||
|
||||
### Nvfp4SharedExpert — VERIFIED
|
||||
- Shared expert with SwiGLU activation
|
||||
- Fused SwiGLU NVFP4 GEMM kernel
|
||||
- `_use_runtime_gsa = True`
|
||||
- **Status**: ✅ WORKING
|
||||
|
||||
---
|
||||
|
||||
## 7. NVFP4 Quantization
|
||||
|
||||
### Runtime Activation Global Scale (gsa) — VERIFIED
|
||||
- `gsa = max(|x|) / (6.0 * 448.0)` — prevents E4M3 block scale overflow
|
||||
- Applied to: Nvfp4Linear, Nvfp4GroupedLinear, Nvfp4MoE, Nvfp4SharedExpert, Router gate
|
||||
- Flag: `_use_runtime_gsa = True` on each module
|
||||
- Previous bug: checkpoint's `input_scale` caused E4M3 overflow (gsa=0.000251, x_norm=7956 → 32% magnitude loss per projection)
|
||||
- Fix: compute gsa from actual activation at runtime — commit 2b1fca6
|
||||
- **Status**: ✅ CORRECT
|
||||
|
||||
### NVFP4 Weight Global Scale (gsb) — VERIFIED
|
||||
- `gsb = weight_scale_2` (NOT input_scale * ws2)
|
||||
- Previous bug: used input_scale as gsb base, causing 4000x magnitude reduction
|
||||
- Fix: gsb=weight_scale_2 for production GEMM
|
||||
- **Status**: ✅ CORRECT
|
||||
|
||||
### FP8 KV Quantization — VERIFIED
|
||||
- noPE dims: FP8 E4M3 with per-row float32 scale
|
||||
- `quantize_fp8_e4m3_from_fp32()`: quantizes FP32 → FP8 with per-row amax
|
||||
- FP8 E4M3 max = 448, FP4 max = 6
|
||||
- **Status**: ✅ WORKING
|
||||
|
||||
---
|
||||
|
||||
## 8. RoPE
|
||||
|
||||
### FP32 RoPE Cache — VERIFIED
|
||||
- BF16 cos/sin cache destroys cos²+sin²=1 (can be 0.996)
|
||||
- ~3% per-layer error accumulates to garbage over 61 layers
|
||||
- Fix: FP32 cache, BF16 round-trip error ~1.5% (expected BF16 quantization noise)
|
||||
- **Status**: ✅ CORRECT
|
||||
|
||||
### Inverse RoPE — VERIFIED
|
||||
- Applied after FMHA output to remove positional encoding
|
||||
- Same FP32 cache as forward RoPE
|
||||
- **Status**: ✅ WORKING
|
||||
|
||||
---
|
||||
|
||||
## 9. Indexer (CSA)
|
||||
|
||||
### B2 FP8 Indexer — VERIFIED
|
||||
- **Test**: `tests/unit/test_b2_indexer_fp8.py` — 5 test categories, ALL PASS
|
||||
- 100% overlap with FP32 reference at n_comp ≤ 1024
|
||||
- ~88% overlap at n_comp = 8192 (expected FP8 quantization noise)
|
||||
- **Bugs fixed**:
|
||||
1. `tcgen05.ld.16x256b.x1` hangs on SM100 — replaced with `tcgen05.ld.32x32b.x8`
|
||||
2. TMEM_COLS=128 too small for 128×128 MMA output — fixed to TMEM_COLS=512
|
||||
3. TMEM offset for rows 32-63: NO offset needed (different warps see different row slices from same address)
|
||||
4. Cross-warp accumulation race condition: per-warp score partitions, merged after __syncthreads()
|
||||
- **Status**: ✅ CORRECT
|
||||
|
||||
---
|
||||
|
||||
## 10. Production Pipeline — FULL 61-LAYER TEST
|
||||
|
||||
### Numerical Stability — VERIFIED STABLE
|
||||
- **Test**: `tests/unit/test_part_a_decode_diagnostics.py` with `TEST_LAYERS=61`
|
||||
- 61 layers, 10 prefill tokens, 1 decode step, 8 GPUs
|
||||
- No NaN, No Inf, No causality violations
|
||||
- |X| bounded at 0.2-860 (see mHC section for growth details)
|
||||
- Compressor, FMHA, MoE, Router all working correctly together
|
||||
- **Status**: ✅ STABLE — no numerical instability
|
||||
|
||||
### Per-Token |X| Growth During Prefill (10 tokens, 61 layers)
|
||||
- Token 0: 0.45 → 6,240 (warmup spike — first token always large)
|
||||
- Token 1: 0.18 → 255 (stabilizes but still grows at L55+)
|
||||
- Token 2: 0.16 → 320 (same pattern)
|
||||
- Token 9: 0.24 → 476 (representative prefill token)
|
||||
- The growth accelerates at L38 (CSA): |X| jumps from 16 → 724 at token 0
|
||||
|
||||
### Decode Step |X| Growth (61 layers)
|
||||
- L0: |X|=0.21, |F_attn|=10, |F_ffn|=3.3, C_l=[0.0, 0.02]
|
||||
- L10: |X|=2.17, |F_attn|=10, |F_ffn|=0.9, C_l=[0.0, 0.07]
|
||||
- L20: |X|=2.41, |F_attn|=14, |F_ffn|=1.0, C_l=[0.0, 0.09]
|
||||
- L30: |X|=22.5, |F_attn|=17, |F_ffn|=1.3, C_l=[0.0, 0.51]
|
||||
- L40: |X|=41.5, |F_attn|=7, |F_ffn|=2.0, C_l=[0.0, 0.94]
|
||||
- L50: |X|=56.3, |F_attn|=9, |F_ffn|=2.1, C_l=[0.2, 1.33]
|
||||
- L55: |X|=73.0, |F_attn|=16, |F_ffn|=3.8, C_l=[0.0, 1.70]
|
||||
- L60: |X|=860, |F_attn|=314, |F_ffn|=460, C_l=[0.1, 1.73]
|
||||
|
||||
### kv_norm_weight Values (all 61 layers, verified loaded)
|
||||
- L0-L20: 0.21-1.65 (growing gradually)
|
||||
- L21-L40: 0.45-2.16 (continued growth)
|
||||
- L41-L60: 0.47-4.16 (L54 has outlier at 4.16)
|
||||
- All loaded correctly, all shapes (512,), all on correct GPU
|
||||
|
||||
---
|
||||
|
||||
## 11. Test Infrastructure Notes
|
||||
|
||||
### TEST_LAYERS must be set via ENV VAR, not CLI arg
|
||||
- `single_shot_inference.py` has its own `argparse` that intercepts CLI args
|
||||
- Passing `TEST_LAYERS=10` as a CLI arg to the test causes it to be parsed by single_shot's argparse instead
|
||||
- This causes `--max-tokens` to be set incorrectly, leading to pipeline blowup
|
||||
- **Correct usage**: `export TEST_LAYERS=10` (env var, read via `os.environ.get`)
|
||||
- Previous "blowup" reports (|X|=3.27e+16) were ALL caused by this test bug
|
||||
|
||||
### Test Harness Usage
|
||||
- Python tests: `~/.openclaw/workspace/fire_b200_test tests/unit/test_foo.py`
|
||||
- CUDA tests: `~/.openclaw/workspace/fire_b200_cuda_test tests/unit/test_bar.cu`
|
||||
- NEVER run code directly on B200 — always use the harness
|
||||
- NEVER edit code on B200 — edit locally → commit → push → pull on B200 → test
|
||||
|
||||
---
|
||||
|
||||
## 12. Ruled-Out Root Causes for Decode Degeneration
|
||||
|
||||
These have been TESTED and VERIFIED to NOT be the cause:
|
||||
|
||||
1. ❌ FMHA kernel bug — cos=0.999993 (prefill), 0.999976 (decode)
|
||||
2. ❌ Compressor kv_norm missing — loaded and applied for all 61 layers
|
||||
3. ❌ Causality violation — no future_leak in any layer
|
||||
4. ❌ FP8 KV quantization error — reasonable scales and values
|
||||
5. ❌ Router bug — hash and dense routers both working
|
||||
6. ❌ MoE bug — experts produce correct output, |F_ffn| scales as expected
|
||||
7. ❌ NVFP4 quantization overflow — runtime gsa prevents E4M3 overflow
|
||||
8. ❌ RoPE error — FP32 cache, correct round-trip
|
||||
9. ❌ Numerical instability — no NaN, no Inf across 61 layers
|
||||
|
||||
### Confirmed Root Cause: mHC Residual Growth
|
||||
- |X| grows to 860 at L60 during decode
|
||||
- This compresses the logit range → model cannot distinguish tokens → degenerate output
|
||||
- The growth is ARCHITECTURAL: B_l preserves X, C_l adds F_out, compounds over 61 layers
|
||||
- Not a kernel bug — requires model-level intervention to fix
|
||||
107
archived_plans/DEGENERATION_TESTS.md
Normal file
107
archived_plans/DEGENERATION_TESTS.md
Normal file
@@ -0,0 +1,107 @@
|
||||
# DSV4 Decode Degeneration — Two Decisive Tests (run BEFORE any kernel/model change)
|
||||
|
||||
**Symptom:** coherent-ish then degenerate decode; loops on a content token ("capital"/"capitalizing"); at times wrong top-1 from step 0.
|
||||
|
||||
## ⛔ HARD STOP — do not do any of these until both tests below are run and reported
|
||||
|
||||
- **Do NOT modify any kernel.**
|
||||
- **Do NOT modify the mHC math.**
|
||||
- **Do NOT add residual clipping, `C_l` scaling, or any "tame the residual" change.**
|
||||
|
||||
The `CORRECTNESS_BACKLOG.md` verdict — *"mHC residual growth (|X|→860) is the confirmed root cause"* — is **unproven**, and the proposed remedies are surgery on a *trained* model to mask a symptom. If the real cause is the prompt (likely) or a missing final norm, those changes corrupt the model and hide the actual bug.
|
||||
|
||||
## Why the backlog does NOT rule this out
|
||||
|
||||
Every verification in `CORRECTNESS_BACKLOG.md` is a **same-input cosine**: production kernel vs PyTorch reference, both fed the **identical hand-rolled prompt**. That proves the kernels match *each other*. It is **structurally blind** to a chat-template/prompt bug — feed both sides the same malformed prompt and every layer agrees at cos 0.9999 *while both produce garbage*. So "we ruled out everything" means "everything a same-input cosine can see." The prompt is outside that set. The backlog is **silent** on the two hypotheses below, not a refutation of them.
|
||||
|
||||
---
|
||||
|
||||
## TEST 1 — Chat-template token-ID diff (most likely the actual bug; run first)
|
||||
|
||||
**Hypothesis:** the hand-rolled prompt is out-of-distribution for this reasoning model → degenerate / looping output. The current construction in `single_shot_inference.py` is roughly:
|
||||
|
||||
```python
|
||||
input_ids = [bos, USER_TOKEN] # USER_TOKEN = 128803
|
||||
input_ids += tokenizer.encode('\n\n' + PROMPT, add_special_tokens=False)
|
||||
input_ids.append(ASSISTANT_TOKEN) # ASSISTANT_TOKEN = 128804
|
||||
```
|
||||
|
||||
This almost certainly does **not** match what the model was trained on (a reasoning model expects specific assistant-turn + `<think>` priming; THINK_START=128821, THINK_END=128822 exist for a reason).
|
||||
|
||||
**Procedure**
|
||||
|
||||
1. Print what we actually build:
|
||||
```python
|
||||
print("hand_rolled ids:", input_ids)
|
||||
print("hand_rolled str:", tokenizer.decode(input_ids))
|
||||
```
|
||||
2. Print the canonical template the tokenizer itself produces:
|
||||
```python
|
||||
ref_ids = tokenizer.apply_chat_template(
|
||||
[{"role": "user", "content": PROMPT}],
|
||||
add_generation_prompt=True, tokenize=True,
|
||||
# This is a reasoner. Check whether the template takes a thinking kwarg
|
||||
# (e.g. enable_thinking=True / thinking=...). Try with and without.
|
||||
)
|
||||
print("template ids:", ref_ids)
|
||||
print("template str:", tokenizer.apply_chat_template(
|
||||
[{"role":"user","content":PROMPT}], add_generation_prompt=True, tokenize=False))
|
||||
```
|
||||
3. Also dump the raw source so we can read the special-token layout directly:
|
||||
```python
|
||||
print(tokenizer.chat_template) # or read tokenizer_config.json / chat_template.jinja
|
||||
```
|
||||
4. Diff `input_ids` vs `ref_ids`. Look specifically at: BOS handling, the user/assistant delimiter tokens, newline placement, and **the `<think>` priming after the assistant token**.
|
||||
|
||||
**Decision**
|
||||
|
||||
- **They differ (expected):** replace the hand-rolled construction with `apply_chat_template` output, then run a short greedy generation (`--temperature 0`, modest `--max-tokens`). If Paris returns as top-1 and the loop is gone → **this was the bug. Done.** Do not touch mHC.
|
||||
- **Identical but still degenerate:** the tokenizer template is faithful yet the model still loops → compare `chat_template.jinja` against the reference inference impl (`deepseek-ai/DeepSeek-V4-Pro/tree/main/inference`), and confirm the thinking-enabled variant is what's being applied. Then proceed to Test 2.
|
||||
|
||||
> Note: the NVIDIA sglang run used `--reasoning-parser deepseek-v4` and `SGLANG_DEFAULT_THINKING=1`. The real format is not a bare `USER … ASSISTANT` sandwich — there is a thinking setup the hand-rolled path omits.
|
||||
|
||||
---
|
||||
|
||||
## TEST 2 — Falsify the mHC "root cause" (run before ANY mHC/residual change)
|
||||
|
||||
**Claim under test (from the backlog):** *"|X|=860 compresses the logit range so the model can't distinguish tokens."*
|
||||
|
||||
**Why it's suspect:** there is a final RMSNorm before the LM head, and RMSNorm is **scale-invariant** — it divides the magnitude out. So |X|=860 and |X|=8 should produce the *same* logits (modulo the learned norm weight). Also, the residual grows just as much during **prefill** (backlog's own numbers: |X| up to 476, ~6240 on token 0) yet prefill/first-token is correct — magnitude common to both phases cannot be what breaks *only* decode.
|
||||
|
||||
**Procedure**
|
||||
|
||||
1. **Confirm the final norm exists and is applied.** Trace the path from the last layer's residual `X` → final RMSNorm → `lm_head_lin(x_out)`. Print whether a final norm runs before the LM head.
|
||||
- **If it is MISSING or not applied → STOP. That is the real bug.** The fix is to apply the final norm, *not* to clip the residual.
|
||||
2. **Falsification.** At the last decode layer, capture the residual at |X|≈860. Compute logits two ways through the *same* final-norm + LM-head path:
|
||||
```python
|
||||
logits_A = lm_head(final_norm(X)) # X as-is, |X|≈860
|
||||
logits_B = lm_head(final_norm(X / 100.0)) # scaled down
|
||||
cos = F.cosine_similarity(logits_A.flatten().float(), logits_B.flatten().float(), dim=0)
|
||||
print("argmax_A", logits_A.argmax().item(), "argmax_B", logits_B.argmax().item(), "cos", cos.item())
|
||||
```
|
||||
|
||||
**Decision**
|
||||
|
||||
- **argmax_A == argmax_B and cos ≈ 1.0 (expected):** mHC growth is **exonerated**. |X| magnitude is not the cause. Stop chasing mHC; the answer is in Test 1.
|
||||
- **They differ materially:** something downstream of the residual is magnitude-sensitive → the final norm is missing/broken/misapplied. **Fix the norm.** Still do not clip the residual.
|
||||
|
||||
---
|
||||
|
||||
## Test ordering
|
||||
|
||||
1. **Test 1 first** — it's the most likely fix and is trivial. If it resolves the loop, you're done and mHC was never the problem.
|
||||
2. **Test 2 before touching mHC** — even if Test 1 isn't a full fix, prove (or correctly redirect) the mHC verdict before any model-level change. The only "fix" Test 2 can license is *applying a missing final norm*, never residual clipping.
|
||||
|
||||
## Harness / workflow (from CORRECTNESS_BACKLOG §11)
|
||||
|
||||
- Run via the harness: `~/.openclaw/workspace/fire_b200_test tests/unit/<test>.py`. Never run or edit directly on the B200.
|
||||
- Edit locally → commit → push → pull on B200 → test.
|
||||
- Set `TEST_LAYERS` as an **env var** (`export TEST_LAYERS=10`), never as a CLI arg — single_shot's argparse will eat it and corrupt `--max-tokens` (this caused the bogus |X|=3.27e16 "blowups").
|
||||
- Both tests above are quick: Test 1 needs no GPU (tokenizer only); Test 2 needs one decode pass with `TEST_LAYERS=61`.
|
||||
|
||||
## Report back (paste these)
|
||||
|
||||
- **Test 1:** `hand_rolled ids`, `template ids`, the diff, and the greedy top-1 token after switching to `apply_chat_template`.
|
||||
- **Test 2:** whether a final norm is applied before the LM head; `argmax_A`, `argmax_B`, `cos`.
|
||||
|
||||
Until both are reported, the mHC verdict stays **unproven** and no kernel/model change is authorized.
|
||||
@@ -41,9 +41,12 @@ Storage-native DSV4 attention: noPE KV stays FP8_E4M3, RoPE KV stays BF16, no gl
|
||||
1. **V matrix canonical layout swap** (commit 4fe7f9d): `canon_idx_bf16_16x16(kk, dd)` was wrong — should be `canon_idx_bf16_16x16(dd, kk)`. The SMEM group structure was transposed vs the working TMA-loaded V in the multitile kernel. This caused cos=0.158 vs BF16 reference. After fix: cos=0.999972 at N=128.
|
||||
|
||||
### Known Limitations
|
||||
- **Decode only (T==1)**. Prefill runs one token at a time through the decode kernel. A batched prefill kernel (T>1) is needed for production prefill performance.
|
||||
- **Prefill batch size**: T=1..128 supported. For T>128, caller must split. T_BATCH=32 sub-batches used internally.
|
||||
- Specialized for DSV4 HD=512/NOPE=448/ROPE=64.
|
||||
|
||||
### Bug Fix (2026-06-03)
|
||||
1. **CRITICAL: T-dimension strides were wrong for T>1** — the kernel used `q_nope_head_stride` (stride(1) = T*NOPE) for the T dimension, but the correct stride is `stride(2) = NOPE`. For T=1 this is invisible (qr=0 always), but for T>1 it reads garbage from adjacent heads' data. Fix: added explicit T-dimension strides (`q_nope_t_stride`, `q_scale_t_stride`, `q_rope_t_stride`) to params struct, C API, and Python wrapper. All 16 T>1 test configs now pass (cos >= 0.999887).
|
||||
|
||||
## B2 — FP8 tensor-core indexer scoring: ✅ DONE
|
||||
|
||||
**Implementation**: `dsv4/kernels/cuda/indexer_fp8_score_topk.cu`
|
||||
@@ -88,4 +91,6 @@ Native Blackwell FP8 GEMM via tcgen05 for CSA Lightning Indexer scoring. No PyTo
|
||||
|
||||
# PART D — Dangling TODOS
|
||||
|
||||
- Batched Prefill. Did we ever do this???
|
||||
- Batched Prefill: ✅ DONE (T=1..128, mixed FP8/BF16 kernel, chunked for T>128)
|
||||
- Prefill wired into single_shot_inference.py: ✅ DONE (chunked batched prefill replaces T=1 token-by-token)
|
||||
- T>128 support: ✅ DONE (splits into multiple launches of ≤128 tokens each)
|
||||
43
archived_plans/OLD_README_STUFF.md
Normal file
43
archived_plans/OLD_README_STUFF.md
Normal file
@@ -0,0 +1,43 @@
|
||||
|
||||
## Our kernel design choices
|
||||
|
||||
### Attention kernel (FmhaKernel)
|
||||
|
||||
**6-warp specialization.** Warps 0–3 handle softmax + correction + epilogue. Warp 4 is the MMA warp (QK + PV). Warp 5 is the TMA warp (Q/K/V loads, output store via pipeline).
|
||||
|
||||
**P staging — two paths.**
|
||||
- **TMEM-P** (hd ≤ 64): P stored to TMEM via register bridge (FP32 backing + BF16 view). PV reads P from TMEM. Used at the small head dims where QK C-fragment and PV A-fragment TMEM layouts agree.
|
||||
- **SMEM-P** (hd > 64): P written to SMEM via coordinate-indexed store using `tTMEM_LOADcS` to map register indices to `(m, k)` then into `sP`'s subtile layout. PV reads P from SMEM with `OperandSource.SMEM`. Required because the QK ↔ PV TMEM layout disagreement at hd > 64 corrupts the round-trip.
|
||||
|
||||
**Un-normalized O + LSE output.** The kernel emits raw `sum(P · V)` and `lse = ln(row_sum) + row_max · ln(2)`. External code (or the next kernel pass) divides. This composes — D5 merge, multi-tile rescale, and the inverse-RoPE → wo_a fuse all rely on it.
|
||||
|
||||
**Per-head launch for multi-head.** Python loop dispatches the single-CTA kernel once per head. Multi-CTA grid using `flat_divide` + `tma_partition` is the next refactor; the path is unblocked once the correction-epilog rewrite lands.
|
||||
|
||||
**Head-packed M dimension for decode.** Q reshaped to `(n_h * T, hd, 1)`, all heads' rows packed into the 128-row M tile. Per-row softmax. At Pro decode (T=1, n_h=128) the M tile fits exactly.
|
||||
|
||||
**K-dim sub-tiling at hd > 256.** When `head_dim > 256` (MMA instruction K-dim limit), Q and K split into `n_k_sub_tiles = head_dim / 256` chunks along head_dim. QK accumulates in TMEM across sub-tiles (additive in logit space). The PV path uses `pv_n_tile = 128` for hd > 256 to keep sV+sC within the 232 KB SMEM budget.
|
||||
|
||||
**Sink bias as logit modification.** D3 (SWA length mask), D4 (causal mask on SWA), and D5c (attention sink) all live in the same post-QK, pre-softmax in-register code. They read `tTMEM_LOADcS` to get `(m, k)` coordinates and modify `tTMEM_LOADrS` before the row-max reduction. The sink bias is added in the raw-logit domain as `attn_sink / scale_softmax`, then the existing `* scale_log2` multiply converts to log2 space.
|
||||
|
||||
### MoE kernel (FusedSwiGLUScaledGroupedGemmKernel)
|
||||
|
||||
**7-warp specialization.** Warps 0–3 epilogue (TMEM → registers → SMEM → GMEM with global scale, SwiGLU, clamp). Warp 4 MMA (`tcgen05.mma.block_scale` with SFA/SFB in TMEM). Warp 5 TMA load (A, B, SFA, SFB). Warp 6 scheduler (`MoEStaticPersistentTileScheduler`).
|
||||
|
||||
**One-way TMEM → registers → SMEM → GMEM epilogue.** Uses `epilogue_tmem_copy_and_partition` + `epilogue_smem_copy_and_partition` (CUTLASS helpers, paired atoms). The SwiGLU + clamping math runs in registers between the t2r and r2s copies. No TMEM round-trip. This is the same pattern FMHA needs to adopt to fix the D1.5 blocker.
|
||||
|
||||
**Subtile-level gate/up pairing.** With granularity-8 interleaved L1 weights and `epi_tile_n=8`, even subtiles are gate and odd subtiles are up. `silu_gate_buf` register tensor carries the SiLU result across the subtile-pair boundary.
|
||||
|
||||
**`use_2cta_instrs` conditional** on `tokens_sum ≥ 256` and even `cluster_m`. Decode (small M) stays 1-CTA; prefill/batched gets 2-CTA UMMA with multicast B (1.7–1.9× throughput).
|
||||
|
||||
### Heterogeneous KV cache
|
||||
|
||||
- **State cache** per request: fixed-size block holding `(n_win SWA KV)` and `(uncompressed tail tokens awaiting compression)`. One block per request, lifetime managed by request scheduling.
|
||||
- **Classical paged cache** per request: variable blocks holding `(k1 CSA compressed entries, k2 HCA compressed entries)` per layer. `k1 = lcm(m, m') / m = 32`, `k2 = lcm(m, m') / m' = 1`. Block covers 128 original tokens.
|
||||
- Different layers can produce different KV cache sizes (CSA vs HCA vs SWA-only). The state cache + classical-pool split keeps PagedAttention-style alignment intact for the compressed pool.
|
||||
|
||||
### NVFP4 throughout
|
||||
|
||||
- **Weights**: NVFP4 (FP8 E4M3 scales, 16-element microblocks). Verified: `sf_dtype`, TMA element type, MMA kind (`mxf4nvf4`) all correct.
|
||||
- **Activations**: BF16 today, FP4 after NVFP4-1.x epilogue fusion lands.
|
||||
- **KV cache**: BF16 today; the FP8 (RoPE in BF16, NoPE in FP8) split per paper §2.3.4 is on the roadmap as NVFP4-2.
|
||||
- **Indexer keys**: stored FP4 in the cache today, but scored with a scalar CUDA-core kernel. Tensor-core FP4 scoring (paper §5.2.1) is a Stage F priority.
|
||||
488
dsv4/kernels/attention/fmha_mixed_fp8_prefill.cuh
Normal file
488
dsv4/kernels/attention/fmha_mixed_fp8_prefill.cuh
Normal file
@@ -0,0 +1,488 @@
|
||||
/**
|
||||
* DSV4 B1 — mixed FP8/BF16 prefill FMHA for DeepSeek-V4 attention KV.
|
||||
*
|
||||
* Extension of the decode kernel (fmha_mixed_fp8_decode.cuh) to support T > 1.
|
||||
* Same storage-native DSV4 layout as decode:
|
||||
* Q noPE: FP8_E4M3 + per-row FP32 scale, Q RoPE: BF16
|
||||
* KV noPE: FP8_E4M3 + per-row FP32 scale, KV RoPE: BF16
|
||||
*
|
||||
* Architecture:
|
||||
* - noPE QK: f8f6f4 E4M3 x E4M3 -> FP32 (same MMA as decode)
|
||||
* - RoPE QK: f16 BF16 x BF16 -> FP32 (same MMA as decode)
|
||||
* - Multi-row softmax: T independent per-row softmax in SMEM (online algorithm)
|
||||
* - PV: per query row (one PV MMA per row; correctness first, batched PV is TODO)
|
||||
* - Sink bias: denominator-only logit per head
|
||||
* - Output: normalized (BF16)
|
||||
*
|
||||
* SMEM budget: process in T_BATCH sub-batches to fit in 232KB.
|
||||
* T_BATCH=32: sOacc=64KB, sLogits=16KB, sP=16KB, rest=40KB → ~136KB ✓
|
||||
* T_BATCH=64: sOacc=128KB, sLogits=32KB, sP=32KB, rest=40KB → ~232KB (tight)
|
||||
*
|
||||
* Supports T=1..128. For T>128, caller must split into multiple launches.
|
||||
*/
|
||||
#pragma once
|
||||
|
||||
#include <cuda_runtime.h>
|
||||
#include <cuda_fp8.h>
|
||||
#include <cuda_fp8.hpp>
|
||||
#include <cstdint>
|
||||
#include <cmath>
|
||||
#include "fmha_common.cuh"
|
||||
#include "fmha_umma_desc.cuh"
|
||||
|
||||
namespace dsv4::kernels::attention {
|
||||
|
||||
struct FmhaMixedFp8PrefillParams {
|
||||
const uint8_t* __restrict__ q_nope_fp8; // (B,H,T,NOPE)
|
||||
const float* __restrict__ q_nope_scale; // (B,H,T)
|
||||
const bf16_t* __restrict__ q_rope_bf16; // (B,H,T,ROPE)
|
||||
|
||||
const uint8_t* __restrict__ k_nope_fp8; // (N,NOPE), MQA shared
|
||||
const float* __restrict__ k_nope_scale; // (N,)
|
||||
const bf16_t* __restrict__ k_rope_bf16; // (N,ROPE)
|
||||
|
||||
bf16_t* __restrict__ o; // (B,H,T,HD)
|
||||
float* __restrict__ lse; // (B,H,T), optional
|
||||
const float* __restrict__ sink_bias; // (B,H), optional
|
||||
|
||||
int B, H, T, N, HD, NOPE, ROPE;
|
||||
int q_nope_t_stride, q_nope_head_stride, q_nope_batch_stride;
|
||||
int q_scale_t_stride, q_scale_head_stride, q_scale_batch_stride;
|
||||
int q_rope_t_stride, q_rope_head_stride, q_rope_batch_stride;
|
||||
int o_head_stride, o_batch_stride, o_t_stride;
|
||||
int lse_head_stride, lse_batch_stride, lse_t_stride;
|
||||
float scale;
|
||||
};
|
||||
|
||||
// ---- Reuse helpers from decode kernel ----
|
||||
|
||||
__device__ __forceinline__ float _prefill_fp8_to_f32(uint8_t byte) {
|
||||
__nv_fp8_e4m3 v; *reinterpret_cast<uint8_t*>(&v) = byte;
|
||||
return static_cast<float>(v);
|
||||
}
|
||||
|
||||
__device__ __forceinline__ int _pfill_cidx_f8(int r, int c) {
|
||||
int cm = r >> 3, ck = c >> 4, lr = r & 7, lc = c & 15;
|
||||
return ck * 16 * 128 + cm * 128 + lr * 16 + lc;
|
||||
}
|
||||
|
||||
__device__ __forceinline__ int _pfill_cidx_bf16_128(int r, int c) {
|
||||
int cm = r >> 3, ck = c >> 3, lr = r & 7, lc = c & 7;
|
||||
return ck * 16 * 64 + cm * 64 + lr * 8 + lc;
|
||||
}
|
||||
|
||||
__device__ __forceinline__ int _pfill_cidx_bf16_16(int r, int c) {
|
||||
int cm = r >> 3, ck = c >> 3, lr = r & 7, lc = c & 7;
|
||||
return ck * 2 * 64 + cm * 64 + lr * 8 + lc;
|
||||
}
|
||||
|
||||
/**
|
||||
* Read T_ACT rows of QK TMEM result into sLogits (T_ACT × SK_TILE).
|
||||
*
|
||||
* tcgen05.ld.32x32b.x8 reads 32 rows × 8 columns per call.
|
||||
* Warp 0 → rows 0-31, Warp 1 → rows 32-63 (from SAME TMEM address).
|
||||
* Rows 64-127 require TMEM base offset +256.
|
||||
*
|
||||
* Only warps 0 and 1 participate.
|
||||
*/
|
||||
template<int SK_TILE=128>
|
||||
__device__ void prefill_read_qk_rows(uint32_t tb, float* sLogits,
|
||||
int T_ACT, int kv_len) {
|
||||
const int wid = threadIdx.x >> 5;
|
||||
const int lane = threadIdx.x & 31;
|
||||
if (wid >= 2) return;
|
||||
|
||||
// 2 super-groups: rows 0-63 (tb+0), rows 64-127 (tb+256)
|
||||
for (int sg = 0; sg < 2; sg++) {
|
||||
int row_base = sg * 64;
|
||||
if (row_base >= T_ACT) break;
|
||||
|
||||
uint32_t sg_off = sg * 256;
|
||||
int warp_row = row_base + (wid == 0 ? 0 : 32);
|
||||
if (warp_row >= T_ACT) continue;
|
||||
|
||||
for (int n = 0; n < SK_TILE / 8; n++) {
|
||||
float tmp[8];
|
||||
asm volatile("tcgen05.ld.sync.aligned.32x32b.x8.b32 {%0,%1,%2,%3,%4,%5,%6,%7},[%8];"
|
||||
: "=f"(tmp[0]),"=f"(tmp[1]),"=f"(tmp[2]),"=f"(tmp[3]),
|
||||
"=f"(tmp[4]),"=f"(tmp[5]),"=f"(tmp[6]),"=f"(tmp[7])
|
||||
: "r"(tb + sg_off + n * 8));
|
||||
asm volatile("tcgen05.wait::ld.sync.aligned;" ::: "memory");
|
||||
|
||||
int row = warp_row + lane;
|
||||
if (row < T_ACT) {
|
||||
#pragma unroll
|
||||
for (int c = 0; c < 8; c++) {
|
||||
int col = n * 8 + c;
|
||||
sLogits[row * SK_TILE + col] = (col < kv_len) ? tmp[c] : -INFINITY;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Read a single row (query row qr) from PV TMEM result.
|
||||
* The PV MMA result has 128 rows, but only row qr has valid data.
|
||||
* Using tcgen05.ld.32x32b.x8, lane (qr % 32) holds row qr's data.
|
||||
* For qr >= 64, offset TMEM base by 256.
|
||||
*
|
||||
* Writes 16 values (one n_sub PV output) to sOacc[qr*HD + d_base + 0..15].
|
||||
*/
|
||||
/**
|
||||
* Read a single row (query row qr) from ALL PV TMEM results.
|
||||
* Uses the SAME approach as the decode kernel PV read, but extracts
|
||||
* from the lane corresponding to row qr instead of always lane 0.
|
||||
*
|
||||
* For qr < 32: warp 0, lane qr
|
||||
* For qr 32-63: warp 1, lane (qr-32) -- same TMEM address, different rows
|
||||
* For qr 64-95: same but TMEM offset +256
|
||||
* For qr 96-127: same but TMEM offset +256
|
||||
*
|
||||
* This mirrors the proven decode kernel read pattern exactly.
|
||||
*/
|
||||
template<int HD=512, int N_SUB=32>
|
||||
__device__ void prefill_read_pv_all_subs(uint32_t tb, int qr,
|
||||
float* sOacc, float rescale) {
|
||||
const int lane = threadIdx.x & 31;
|
||||
const int wid = threadIdx.x >> 5;
|
||||
|
||||
int local_lane = qr % 32;
|
||||
int target_wid = (qr < 32) ? 0 : 1;
|
||||
uint32_t rg_off = (qr >= 64) ? 256 : 0;
|
||||
|
||||
for (int n = 0; n < HD / 8; n++) {
|
||||
float tmp[8];
|
||||
if (wid == target_wid) {
|
||||
asm volatile("tcgen05.ld.sync.aligned.32x32b.x8.b32 {%0,%1,%2,%3,%4,%5,%6,%7},[%8];"
|
||||
: "=f"(tmp[0]),"=f"(tmp[1]),"=f"(tmp[2]),"=f"(tmp[3]),
|
||||
"=f"(tmp[4]),"=f"(tmp[5]),"=f"(tmp[6]),"=f"(tmp[7])
|
||||
: "r"(tb + rg_off + n * 8));
|
||||
asm volatile("tcgen05.wait::ld.sync.aligned;" ::: "memory");
|
||||
}
|
||||
|
||||
if (wid == target_wid && lane == local_lane) {
|
||||
#pragma unroll
|
||||
for (int c = 0; c < 8; c++) {
|
||||
int d = n * 8 + c;
|
||||
sOacc[qr * HD + d] += tmp[c] * rescale;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Prefill kernel: T query rows, processing in T_BATCH sub-batches.
|
||||
*
|
||||
* T_BATCH controls the SMEM usage. T_BATCH=32 uses ~136KB. T_BATCH=64 uses ~232KB.
|
||||
* For each sub-batch of T_BATCH rows, we iterate over all KV tiles, computing
|
||||
* QK → softmax → PV for those rows.
|
||||
*/
|
||||
template<int HD=512, int NOPE=448, int ROPE=64, int SK_TILE=128, int T_BATCH=32>
|
||||
__global__ void __launch_bounds__(192)
|
||||
fmha_mixed_fp8_prefill_kernel(FmhaMixedFp8PrefillParams p) {
|
||||
static_assert(HD == 512 && NOPE == 448 && ROPE == 64,
|
||||
"B1 prefill kernel specialized for DSV4 HD=512/NOPE=448/ROPE=64");
|
||||
|
||||
constexpr int MMA_K_F8 = 32;
|
||||
constexpr int MMA_K_F16 = 16;
|
||||
constexpr int NKT_NOPE = NOPE / MMA_K_F8;
|
||||
constexpr int NKT_ROPE = ROPE / MMA_K_F16;
|
||||
constexpr int NKT_PV = SK_TILE / MMA_K_F16;
|
||||
constexpr int N_SUB = HD / 16;
|
||||
constexpr int TILE_F8 = 128 * MMA_K_F8;
|
||||
constexpr int TILE_F16 = 128 * MMA_K_F16;
|
||||
constexpr int V_SUB_SZ = 16 * MMA_K_F16;
|
||||
constexpr int TMEM_COLS = 512;
|
||||
|
||||
const int head_idx = blockIdx.y;
|
||||
const int batch_idx = blockIdx.z;
|
||||
const int tid = threadIdx.x;
|
||||
const int wid = tid >> 5;
|
||||
const int lane = tid & 31;
|
||||
const bool is_mma_warp = (wid == 4);
|
||||
const int n_kv_tiles = (p.N + SK_TILE - 1) / SK_TILE;
|
||||
|
||||
const uint8_t* q8 = p.q_nope_fp8 + batch_idx * p.q_nope_batch_stride + head_idx * p.q_nope_head_stride;
|
||||
const float* q8_scale = p.q_nope_scale + batch_idx * p.q_scale_batch_stride + head_idx * p.q_scale_head_stride;
|
||||
const bf16_t* qrope = p.q_rope_bf16 + batch_idx * p.q_rope_batch_stride + head_idx * p.q_rope_head_stride;
|
||||
|
||||
// SMEM layout — sized for T_BATCH rows
|
||||
extern __shared__ __align__(128) char sbuf[];
|
||||
size_t off = 0;
|
||||
uint32_t* sTmemBase = (uint32_t*)(sbuf + off); off += 4;
|
||||
off = (off + 127) & ~(size_t)127;
|
||||
uint8_t* sQ8 = (uint8_t*)(sbuf + off); off += TILE_F8;
|
||||
off = (off + 127) & ~(size_t)127;
|
||||
uint8_t* sK8 = (uint8_t*)(sbuf + off); off += TILE_F8;
|
||||
off = (off + 127) & ~(size_t)127;
|
||||
bf16_t* sQ16 = (bf16_t*)(sbuf + off); off += TILE_F16 * sizeof(bf16_t);
|
||||
off = (off + 127) & ~(size_t)127;
|
||||
bf16_t* sK16 = (bf16_t*)(sbuf + off); off += TILE_F16 * sizeof(bf16_t);
|
||||
off = (off + 127) & ~(size_t)127;
|
||||
bf16_t* sPk = (bf16_t*)(sbuf + off); off += TILE_F16 * sizeof(bf16_t);
|
||||
off = (off + 127) & ~(size_t)127;
|
||||
bf16_t* sV = (bf16_t*)(sbuf + off); off += V_SUB_SZ * sizeof(bf16_t);
|
||||
off = (off + 127) & ~(size_t)127;
|
||||
// Per-sub-batch SMEM
|
||||
float* sLogits = (float*)(sbuf + off); off += T_BATCH * SK_TILE * sizeof(float);
|
||||
float* sP = (float*)(sbuf + off); off += T_BATCH * SK_TILE * sizeof(float);
|
||||
float* sOacc = (float*)(sbuf + off); off += T_BATCH * HD * sizeof(float);
|
||||
float* sRunningMax = (float*)(sbuf + off); off += T_BATCH * sizeof(float);
|
||||
float* sRunningSum = (float*)(sbuf + off); off += T_BATCH * sizeof(float);
|
||||
bf16_t* sOepi = (bf16_t*)(sbuf + off); off += T_BATCH * HD * sizeof(bf16_t);
|
||||
|
||||
// TMEM alloc
|
||||
if (is_mma_warp) tmem_alloc((uint32_t)__cvta_generic_to_shared(sTmemBase), TMEM_COLS);
|
||||
asm volatile("fence.proxy.async.shared::cta;" ::: "memory");
|
||||
__syncthreads();
|
||||
uint32_t tb = *sTmemBase;
|
||||
|
||||
const uint32_t idesc_f8_qk = make_idesc_f8_e4m3(128, 128);
|
||||
const uint32_t idesc_f16_qk = make_idesc(128, 128);
|
||||
const uint32_t idesc_pv = make_idesc(128, 16);
|
||||
|
||||
// ================================================================
|
||||
// Outer loop: process T_BATCH rows at a time
|
||||
// ================================================================
|
||||
for (int t_start = 0; t_start < p.T; t_start += T_BATCH) {
|
||||
int T_ACT = min(T_BATCH, p.T - t_start);
|
||||
|
||||
// Initialize accumulators for this sub-batch
|
||||
for (int i = tid; i < T_ACT * HD; i += blockDim.x) sOacc[i] = 0.0f;
|
||||
for (int t = tid; t < T_ACT; t += blockDim.x) {
|
||||
sRunningMax[t] = -INFINITY;
|
||||
sRunningSum[t] = 0.0f;
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// ============================================================
|
||||
// KV-tile loop (shared across all sub-batch rows)
|
||||
// ============================================================
|
||||
for (int kv_tile = 0; kv_tile < n_kv_tiles; kv_tile++) {
|
||||
const int kv_start = kv_tile * SK_TILE;
|
||||
const int kv_len = min(SK_TILE, p.N - kv_start);
|
||||
|
||||
// --------------------------------------------------------
|
||||
// QK noPE: FP8 tensor cores
|
||||
// Write T_ACT rows of Q (not just row 0)
|
||||
// --------------------------------------------------------
|
||||
for (int kt = 0; kt < NKT_NOPE; kt++) {
|
||||
for (int i = tid; i < TILE_F8; i += blockDim.x) { sQ8[i] = 0; sK8[i] = 0; }
|
||||
__syncthreads();
|
||||
// T_ACT rows of Q
|
||||
for (int r = tid; r < T_ACT; r += blockDim.x) {
|
||||
int qr = t_start + r;
|
||||
for (int c = 0; c < MMA_K_F8; c++) {
|
||||
int d = kt * MMA_K_F8 + c;
|
||||
sQ8[_pfill_cidx_f8(r, c)] = q8[qr * p.q_nope_t_stride + d];
|
||||
}
|
||||
}
|
||||
// K: same as decode
|
||||
for (int i = tid; i < kv_len * MMA_K_F8; i += blockDim.x) {
|
||||
int r = i / MMA_K_F8, c = i % MMA_K_F8;
|
||||
int d = kt * MMA_K_F8 + c;
|
||||
sK8[_pfill_cidx_f8(r, c)] = p.k_nope_fp8[(int64_t)(kv_start + r) * NOPE + d];
|
||||
}
|
||||
__syncthreads();
|
||||
if (is_mma_warp && lane == 0) {
|
||||
uint64_t dq = make_umma_desc_kmajor_none((uint32_t)__cvta_generic_to_shared(sQ8), 128);
|
||||
uint64_t dk = make_umma_desc_kmajor_none((uint32_t)__cvta_generic_to_shared(sK8), 128);
|
||||
umma_ss_f8f6f4(tb, dq, dk, idesc_f8_qk, kt > 0);
|
||||
asm volatile("tcgen05.fence::after_thread_sync;" ::: "memory");
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
asm volatile("fence.sc.gpu;" ::: "memory");
|
||||
__syncthreads();
|
||||
|
||||
// Read all T_ACT rows of QK noPE result
|
||||
prefill_read_qk_rows<SK_TILE>(tb, sLogits, T_ACT, kv_len);
|
||||
__syncthreads();
|
||||
|
||||
// Apply Q and K scales
|
||||
for (int r = tid; r < T_ACT; r += blockDim.x) {
|
||||
int qr = t_start + r;
|
||||
float q_s = q8_scale[qr * p.q_scale_t_stride];
|
||||
for (int c = 0; c < kv_len; c++) {
|
||||
float ks = p.k_nope_scale[kv_start + c];
|
||||
sLogits[r * SK_TILE + c] *= q_s * ks;
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// --------------------------------------------------------
|
||||
// QK RoPE: BF16 tensor cores
|
||||
// --------------------------------------------------------
|
||||
for (int kt = 0; kt < NKT_ROPE; kt++) {
|
||||
for (int i = tid; i < TILE_F16; i += blockDim.x) { sQ16[i] = 0; sK16[i] = 0; }
|
||||
__syncthreads();
|
||||
for (int r = tid; r < T_ACT; r += blockDim.x) {
|
||||
int qr = t_start + r;
|
||||
for (int c = 0; c < MMA_K_F16; c++) {
|
||||
int d = kt * MMA_K_F16 + c;
|
||||
sQ16[_pfill_cidx_bf16_128(r, c)] = qrope[qr * p.q_rope_t_stride + d];
|
||||
}
|
||||
}
|
||||
for (int i = tid; i < kv_len * MMA_K_F16; i += blockDim.x) {
|
||||
int r = i / MMA_K_F16, c = i % MMA_K_F16;
|
||||
int d = kt * MMA_K_F16 + c;
|
||||
sK16[_pfill_cidx_bf16_128(r, c)] = p.k_rope_bf16[(int64_t)(kv_start + r) * ROPE + d];
|
||||
}
|
||||
__syncthreads();
|
||||
if (is_mma_warp && lane == 0) {
|
||||
uint64_t dq = make_umma_desc_kmajor_none((uint32_t)__cvta_generic_to_shared(sQ16), 128);
|
||||
uint64_t dk = make_umma_desc_kmajor_none((uint32_t)__cvta_generic_to_shared(sK16), 128);
|
||||
umma_ss_f16(tb, dq, dk, idesc_f16_qk, kt > 0);
|
||||
asm volatile("tcgen05.fence::after_thread_sync;" ::: "memory");
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
asm volatile("fence.sc.gpu;" ::: "memory");
|
||||
__syncthreads();
|
||||
|
||||
// Add RoPE logits to noPE logits (reuse sP as temp buffer)
|
||||
prefill_read_qk_rows<SK_TILE>(tb, sP, T_ACT, kv_len);
|
||||
__syncthreads();
|
||||
for (int i = tid; i < T_ACT * kv_len; i += blockDim.x) {
|
||||
sLogits[i] += sP[i];
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// --------------------------------------------------------
|
||||
// Per-row softmax (online algorithm)
|
||||
// Each thread handles a few rows
|
||||
// --------------------------------------------------------
|
||||
for (int r = tid; r < T_ACT; r += blockDim.x) {
|
||||
float tile_max = -INFINITY;
|
||||
for (int c = 0; c < kv_len; c++)
|
||||
tile_max = fmaxf(tile_max, sLogits[r * SK_TILE + c] * p.scale);
|
||||
|
||||
float tile_sum = 0.0f;
|
||||
for (int c = 0; c < kv_len; c++) {
|
||||
float pv = expf(sLogits[r * SK_TILE + c] * p.scale - tile_max);
|
||||
sP[r * SK_TILE + c] = pv;
|
||||
tile_sum += pv;
|
||||
}
|
||||
for (int c = kv_len; c < SK_TILE; c++) sP[r * SK_TILE + c] = 0.0f;
|
||||
|
||||
float old_max = sRunningMax[r];
|
||||
float new_max = fmaxf(old_max, tile_max);
|
||||
float rescale_old = (old_max > -INFINITY) ? expf(old_max - new_max) : 0.0f;
|
||||
for (int d = 0; d < HD; d++) sOacc[r * HD + d] *= rescale_old;
|
||||
float rescale_new = expf(tile_max - new_max);
|
||||
sRunningSum[r] = sRunningSum[r] * rescale_old + tile_sum * rescale_new;
|
||||
sRunningMax[r] = new_max;
|
||||
|
||||
// Store rescale_new for PV (reuse sLogits first column)
|
||||
sLogits[r * SK_TILE] = rescale_new;
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// --------------------------------------------------------
|
||||
// PV: per query row (one PV MMA per row)
|
||||
// TODO: batch all T_ACT rows into one PV MMA for performance
|
||||
// --------------------------------------------------------
|
||||
for (int qr = 0; qr < T_ACT; qr++) {
|
||||
float p_rescale = sLogits[qr * SK_TILE];
|
||||
|
||||
for (int n_sub = 0; n_sub < N_SUB; n_sub++) {
|
||||
int d_base = n_sub * 16;
|
||||
for (int pv_kt = 0; pv_kt < NKT_PV; pv_kt++) {
|
||||
const int col_start = pv_kt * MMA_K_F16;
|
||||
for (int i = tid; i < TILE_F16; i += blockDim.x) sPk[i] = 0;
|
||||
for (int i = tid; i < V_SUB_SZ; i += blockDim.x) sV[i] = 0;
|
||||
__syncthreads();
|
||||
|
||||
// P matrix: only row qr is active
|
||||
for (int c = tid; c < MMA_K_F16; c += blockDim.x) {
|
||||
int gc = col_start + c;
|
||||
sPk[_pfill_cidx_bf16_128(qr, c)] = f32_to_bf16(sP[qr * SK_TILE + gc]);
|
||||
}
|
||||
|
||||
// V matrix (same as decode)
|
||||
for (int i = tid; i < 16 * MMA_K_F16; i += blockDim.x) {
|
||||
int dd = i / MMA_K_F16, kk = i % MMA_K_F16;
|
||||
int row = col_start + kk;
|
||||
int g_row = kv_start + row;
|
||||
int d = d_base + dd;
|
||||
bf16_t vbits = 0;
|
||||
if (row < kv_len) {
|
||||
if (d < NOPE) {
|
||||
uint8_t b = p.k_nope_fp8[(int64_t)g_row * NOPE + d];
|
||||
float v = _prefill_fp8_to_f32(b) * p.k_nope_scale[g_row];
|
||||
vbits = f32_to_bf16(v);
|
||||
} else {
|
||||
vbits = p.k_rope_bf16[(int64_t)g_row * ROPE + (d - NOPE)];
|
||||
}
|
||||
}
|
||||
sV[_pfill_cidx_bf16_16(dd, kk)] = vbits;
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
bool first = (pv_kt == 0); // Fresh for each query row's PV
|
||||
if (is_mma_warp && lane == 0) {
|
||||
uint64_t dp = make_umma_desc_kmajor_none((uint32_t)__cvta_generic_to_shared(sPk), 128);
|
||||
uint64_t dv = make_umma_desc_kmajor_none((uint32_t)__cvta_generic_to_shared(sV), 16);
|
||||
umma_ss_f16(tb + n_sub * 16, dp, dv, idesc_pv, !first);
|
||||
asm volatile("tcgen05.fence::after_thread_sync;" ::: "memory");
|
||||
}
|
||||
__syncthreads();
|
||||
} // pv_kt
|
||||
} // n_sub
|
||||
|
||||
// Read PV result for row qr from TMEM
|
||||
asm volatile("fence.sc.gpu;" ::: "memory");
|
||||
__syncthreads();
|
||||
prefill_read_pv_all_subs<HD, N_SUB>(tb, qr, sOacc, p_rescale);
|
||||
__syncthreads();
|
||||
} // qr
|
||||
} // kv_tile
|
||||
|
||||
// --------------------------------------------------------
|
||||
// Attention sink
|
||||
// --------------------------------------------------------
|
||||
if (p.sink_bias != nullptr) {
|
||||
float sb = p.sink_bias[batch_idx * p.H + head_idx];
|
||||
for (int r = tid; r < T_ACT; r += blockDim.x) {
|
||||
float old_max = sRunningMax[r];
|
||||
float new_max = fmaxf(old_max, sb);
|
||||
float rescale_old = (old_max > -INFINITY) ? expf(old_max - new_max) : 0.0f;
|
||||
for (int d = 0; d < HD; d++) sOacc[r * HD + d] *= rescale_old;
|
||||
sRunningSum[r] = sRunningSum[r] * rescale_old + expf(sb - new_max);
|
||||
sRunningMax[r] = new_max;
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
// --------------------------------------------------------
|
||||
// Normalize and write output
|
||||
// --------------------------------------------------------
|
||||
bf16_t* out = p.o + batch_idx * p.o_batch_stride + head_idx * p.o_head_stride;
|
||||
float* lse = p.lse ? p.lse + batch_idx * p.lse_batch_stride + head_idx * p.lse_head_stride : nullptr;
|
||||
|
||||
for (int r = tid; r < T_ACT; r += blockDim.x) {
|
||||
float inv_sum = 1.0f / sRunningSum[r];
|
||||
int qr = t_start + r;
|
||||
for (int d = 0; d < HD; d++) {
|
||||
bf16_t val = f32_to_bf16(sOacc[r * HD + d] * inv_sum);
|
||||
sOepi[r * HD + d] = val;
|
||||
}
|
||||
if (lse) lse[qr * p.lse_t_stride] = logf(sRunningSum[r]) + sRunningMax[r];
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// Write to GMEM
|
||||
for (int r = 0; r < T_ACT; r++) {
|
||||
int qr = t_start + r;
|
||||
bf16_t* out_row = out + qr * p.o_t_stride;
|
||||
for (int d = tid; d < HD; d += blockDim.x) out_row[d] = sOepi[r * HD + d];
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
} // t_start sub-batch loop
|
||||
|
||||
if (is_mma_warp) tmem_dealloc(tb, TMEM_COLS);
|
||||
}
|
||||
|
||||
} // namespace dsv4::kernels::attention
|
||||
95
dsv4/kernels/attention/fmha_mixed_fp8_prefill_capi.cu
Normal file
95
dsv4/kernels/attention/fmha_mixed_fp8_prefill_capi.cu
Normal file
@@ -0,0 +1,95 @@
|
||||
#include <cuda.h>
|
||||
#include <cuda_runtime.h>
|
||||
#include <cstdint>
|
||||
#include "fmha_common.cuh"
|
||||
#include "fmha_umma_desc.cuh"
|
||||
#include "fmha_mixed_fp8_prefill.cuh"
|
||||
|
||||
using namespace dsv4::kernels::attention;
|
||||
|
||||
extern "C" {
|
||||
|
||||
int fmha_mixed_fp8_prefill_launch(
|
||||
const void* q_nope_fp8,
|
||||
const float* q_nope_scale,
|
||||
const void* q_rope_bf16,
|
||||
const void* k_nope_fp8,
|
||||
const float* k_nope_scale,
|
||||
const void* k_rope_bf16,
|
||||
void* o_ptr,
|
||||
void* lse_ptr,
|
||||
const float* sink_bias_ptr,
|
||||
int B, int H, int T, int N, int HD, int NOPE, int ROPE,
|
||||
int q_nope_t_stride, int q_nope_head_stride, int q_nope_batch_stride,
|
||||
int q_scale_t_stride, int q_scale_head_stride, int q_scale_batch_stride,
|
||||
int q_rope_t_stride, int q_rope_head_stride, int q_rope_batch_stride,
|
||||
int o_head_stride, int o_batch_stride, int o_t_stride,
|
||||
int lse_head_stride, int lse_batch_stride, int lse_t_stride,
|
||||
float scale
|
||||
) {
|
||||
if (HD != 512 || NOPE != 448 || ROPE != 64) return -2;
|
||||
if (T < 1 || T > 128) return -3;
|
||||
|
||||
FmhaMixedFp8PrefillParams p;
|
||||
p.q_nope_fp8 = (const uint8_t*)q_nope_fp8;
|
||||
p.q_nope_scale = q_nope_scale;
|
||||
p.q_rope_bf16 = (const bf16_t*)q_rope_bf16;
|
||||
p.k_nope_fp8 = (const uint8_t*)k_nope_fp8;
|
||||
p.k_nope_scale = k_nope_scale;
|
||||
p.k_rope_bf16 = (const bf16_t*)k_rope_bf16;
|
||||
p.o = (bf16_t*)o_ptr;
|
||||
p.lse = (float*)lse_ptr;
|
||||
p.sink_bias = sink_bias_ptr;
|
||||
p.B = B; p.H = H; p.T = T; p.N = N;
|
||||
p.HD = HD; p.NOPE = NOPE; p.ROPE = ROPE;
|
||||
p.q_nope_t_stride = q_nope_t_stride;
|
||||
p.q_nope_head_stride = q_nope_head_stride;
|
||||
p.q_nope_batch_stride = q_nope_batch_stride;
|
||||
p.q_scale_t_stride = q_scale_t_stride;
|
||||
p.q_scale_head_stride = q_scale_head_stride;
|
||||
p.q_scale_batch_stride = q_scale_batch_stride;
|
||||
p.q_rope_t_stride = q_rope_t_stride;
|
||||
p.q_rope_head_stride = q_rope_head_stride;
|
||||
p.q_rope_batch_stride = q_rope_batch_stride;
|
||||
p.o_head_stride = o_head_stride;
|
||||
p.o_batch_stride = o_batch_stride;
|
||||
p.o_t_stride = o_t_stride;
|
||||
p.lse_head_stride = lse_head_stride;
|
||||
p.lse_batch_stride = lse_batch_stride;
|
||||
p.lse_t_stride = lse_t_stride;
|
||||
p.scale = scale;
|
||||
|
||||
// SMEM size for T_BATCH=32
|
||||
constexpr int T_BATCH = 32;
|
||||
constexpr int SK_TILE = 128;
|
||||
constexpr int TILE_F8 = 128 * 32;
|
||||
constexpr int TILE_F16 = 128 * 16;
|
||||
constexpr int V_SUB_SZ = 16 * 16;
|
||||
int smem = 0;
|
||||
smem += 4; smem = (smem + 127) & ~127;
|
||||
smem += TILE_F8; smem = (smem + 127) & ~127; // sQ8
|
||||
smem += TILE_F8; smem = (smem + 127) & ~127; // sK8
|
||||
smem += TILE_F16 * 2; smem = (smem + 127) & ~127; // sQ16
|
||||
smem += TILE_F16 * 2; smem = (smem + 127) & ~127; // sK16
|
||||
smem += TILE_F16 * 2; smem = (smem + 127) & ~127; // sPk
|
||||
smem += V_SUB_SZ * 2; smem = (smem + 127) & ~127; // sV
|
||||
smem += T_BATCH * SK_TILE * 4; // sLogits
|
||||
smem += T_BATCH * SK_TILE * 4; // sP
|
||||
smem += T_BATCH * 512 * 4; // sOacc
|
||||
smem += T_BATCH * 4; // sRunningMax
|
||||
smem += T_BATCH * 4; // sRunningSum
|
||||
smem += T_BATCH * 512 * 2; // sOepi
|
||||
smem = (smem + 127) & ~127;
|
||||
|
||||
cudaFuncSetAttribute(
|
||||
fmha_mixed_fp8_prefill_kernel<512,448,64,128,32>,
|
||||
cudaFuncAttributeMaxDynamicSharedMemorySize, smem);
|
||||
dim3 grid(1, H, B);
|
||||
dim3 block(192);
|
||||
fmha_mixed_fp8_prefill_kernel<512,448,64,128,32>
|
||||
<<<grid, block, smem>>>(p);
|
||||
cudaError_t err = cudaGetLastError();
|
||||
return err == cudaSuccess ? 0 : (int)err;
|
||||
}
|
||||
|
||||
} // extern C
|
||||
149
dsv4/kernels/attention/fmha_mixed_fp8_prefill_op.py
Normal file
149
dsv4/kernels/attention/fmha_mixed_fp8_prefill_op.py
Normal file
@@ -0,0 +1,149 @@
|
||||
"""DSV4 B1 mixed FP8/BF16 prefill FMHA loader.
|
||||
|
||||
Supports T > 1 for batched prefill. Same storage-native format as the
|
||||
decode kernel: FP8_E4M3 for noPE KV, BF16 for RoPE KV.
|
||||
"""
|
||||
import ctypes
|
||||
import logging
|
||||
import os
|
||||
import subprocess
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
KERNEL_DIR = os.path.dirname(os.path.abspath(__file__))
|
||||
REPO_ROOT = os.path.normpath(os.path.join(KERNEL_DIR, "..", ".."))
|
||||
SOURCE = os.path.join(KERNEL_DIR, "fmha_mixed_fp8_prefill_capi.cu")
|
||||
BUILD_DIR = os.path.join(REPO_ROOT, "build", "fmha_mixed_fp8_prefill")
|
||||
SO_NAME = "libfmha_mixed_fp8_prefill.so"
|
||||
|
||||
_lib = None
|
||||
_lib_lock = False
|
||||
|
||||
|
||||
def _find_nvcc():
|
||||
import shutil
|
||||
for c in ["/usr/local/cuda-13.2/bin/nvcc", "/usr/local/cuda/bin/nvcc"]:
|
||||
if os.path.isfile(c):
|
||||
return c
|
||||
nvcc = shutil.which("nvcc")
|
||||
if nvcc:
|
||||
return nvcc
|
||||
raise RuntimeError("nvcc not found")
|
||||
|
||||
|
||||
def _ensure_built():
|
||||
global _lib, _lib_lock
|
||||
if _lib is not None:
|
||||
return _lib
|
||||
if _lib_lock:
|
||||
raise RuntimeError("Recursive mixed-FP8 prefill FMHA build")
|
||||
_lib_lock = True
|
||||
try:
|
||||
so_path = os.path.join(BUILD_DIR, SO_NAME)
|
||||
deps = [
|
||||
SOURCE,
|
||||
os.path.join(KERNEL_DIR, "fmha_common.cuh"),
|
||||
os.path.join(KERNEL_DIR, "fmha_umma_desc.cuh"),
|
||||
os.path.join(KERNEL_DIR, "fmha_mixed_fp8_prefill.cuh"),
|
||||
]
|
||||
src_mtime = max(os.path.getmtime(p) for p in deps if os.path.exists(p))
|
||||
need_build = not os.path.isfile(so_path) or src_mtime > os.path.getmtime(so_path)
|
||||
if not need_build:
|
||||
_lib = ctypes.CDLL(so_path)
|
||||
return _lib
|
||||
|
||||
os.makedirs(BUILD_DIR, exist_ok=True)
|
||||
nvcc = _find_nvcc()
|
||||
cmd = [
|
||||
nvcc, "-std=c++20", "-shared", "-Xcompiler", "-fPIC",
|
||||
"-gencode=arch=compute_100a,code=sm_100a",
|
||||
"-gencode=arch=compute_100a,code=compute_100a",
|
||||
f"-I{KERNEL_DIR}", f"-I{REPO_ROOT}",
|
||||
"-O3", "--use_fast_math", "--expt-relaxed-constexpr",
|
||||
SOURCE, "-o", so_path, "-lcudart", "-lcuda",
|
||||
]
|
||||
logger.info("Building libfmha_mixed_fp8_prefill.so (sm_100a)...")
|
||||
res = subprocess.run(cmd, capture_output=True, text=True)
|
||||
if res.returncode != 0:
|
||||
raise RuntimeError(f"mixed FP8 prefill FMHA nvcc failed:\n{res.stderr}")
|
||||
_lib = ctypes.CDLL(so_path)
|
||||
return _lib
|
||||
finally:
|
||||
_lib_lock = False
|
||||
|
||||
|
||||
def _quantize_q_split(q: torch.Tensor, rope_dim: int):
|
||||
from dsv4.kernels.cuda.loader import get_cuda_module
|
||||
mod = get_cuda_module("fp8_attention_io", ["fp8_attention_io.cu"],
|
||||
extra_cuda_cflags=[
|
||||
"-gencode=arch=compute_100a,code=sm_100a",
|
||||
"-O3", "--use_fast_math", "--expt-relaxed-constexpr",
|
||||
])
|
||||
return mod.quantize_q_fp8_split(q, rope_dim)
|
||||
|
||||
|
||||
def fmha_mixed_fp8_prefill_raw(
|
||||
q: torch.Tensor, # (B,H,T,HD) BF16
|
||||
k_nope_fp8: torch.Tensor, # (N,NOPE) uint8/float8_e4m3fn
|
||||
k_nope_scale: torch.Tensor, # (N,) FP32
|
||||
k_rope_bf16: torch.Tensor, # (N,ROPE) BF16
|
||||
scale: float,
|
||||
attn_sink: Optional[torch.Tensor] = None,
|
||||
rope_dim: int = 64,
|
||||
):
|
||||
"""Mixed FP8/BF16 prefill FMHA. Supports T = 1..128."""
|
||||
if q.dim() != 4:
|
||||
raise RuntimeError("q must be (B,H,T,HD)")
|
||||
B, H, T, HD = q.shape
|
||||
if T < 1 or T > 128:
|
||||
raise RuntimeError(f"mixed FP8 prefill FMHA supports 1 ≤ T ≤ 128, got T={T}")
|
||||
NOPE = HD - rope_dim
|
||||
if HD != 512 or NOPE != 448 or rope_dim != 64:
|
||||
raise RuntimeError(f"First pass supports HD=512/NOPE=448/ROPE=64, got {HD}/{NOPE}/{rope_dim}")
|
||||
|
||||
q = q.contiguous()
|
||||
k_nope_fp8 = k_nope_fp8.contiguous()
|
||||
k_nope_scale = k_nope_scale.contiguous()
|
||||
k_rope_bf16 = k_rope_bf16.contiguous()
|
||||
q_nope_fp8, q_nope_scale, q_rope = _quantize_q_split(q, rope_dim)
|
||||
|
||||
N = k_nope_fp8.shape[0]
|
||||
o = torch.empty((B, H, T, HD), dtype=torch.bfloat16, device=q.device)
|
||||
lse = torch.empty((B, H, T), dtype=torch.float32, device=q.device)
|
||||
|
||||
sink_ptr = ctypes.c_void_p(0)
|
||||
sb = None
|
||||
if attn_sink is not None:
|
||||
sb = attn_sink.float().contiguous()
|
||||
if sb.dim() == 1:
|
||||
sb = sb.unsqueeze(0).expand(B, -1).contiguous()
|
||||
if tuple(sb.shape) != (B, H):
|
||||
raise RuntimeError(f"sink bias shape {tuple(sb.shape)} != {(B,H)}")
|
||||
sink_ptr = ctypes.c_void_p(sb.data_ptr())
|
||||
|
||||
lib = _ensure_built()
|
||||
ret = lib.fmha_mixed_fp8_prefill_launch(
|
||||
ctypes.c_void_p(q_nope_fp8.data_ptr()),
|
||||
ctypes.c_void_p(q_nope_scale.data_ptr()),
|
||||
ctypes.c_void_p(q_rope.data_ptr()),
|
||||
ctypes.c_void_p(k_nope_fp8.data_ptr()),
|
||||
ctypes.c_void_p(k_nope_scale.data_ptr()),
|
||||
ctypes.c_void_p(k_rope_bf16.data_ptr()),
|
||||
ctypes.c_void_p(o.data_ptr()),
|
||||
ctypes.c_void_p(lse.data_ptr()),
|
||||
sink_ptr,
|
||||
ctypes.c_int(B), ctypes.c_int(H), ctypes.c_int(T), ctypes.c_int(N),
|
||||
ctypes.c_int(HD), ctypes.c_int(NOPE), ctypes.c_int(rope_dim),
|
||||
ctypes.c_int(q_nope_fp8.stride(2)), ctypes.c_int(q_nope_fp8.stride(1)), ctypes.c_int(q_nope_fp8.stride(0)),
|
||||
ctypes.c_int(q_nope_scale.stride(2)), ctypes.c_int(q_nope_scale.stride(1)), ctypes.c_int(q_nope_scale.stride(0)),
|
||||
ctypes.c_int(q_rope.stride(2)), ctypes.c_int(q_rope.stride(1)), ctypes.c_int(q_rope.stride(0)),
|
||||
ctypes.c_int(o.stride(1)), ctypes.c_int(o.stride(0)), ctypes.c_int(o.stride(2)),
|
||||
ctypes.c_int(lse.stride(1)), ctypes.c_int(lse.stride(0)), ctypes.c_int(lse.stride(2)),
|
||||
ctypes.c_float(scale),
|
||||
)
|
||||
if ret != 0:
|
||||
raise RuntimeError(f"mixed FP8 prefill FMHA launch failed: return code {ret}")
|
||||
return o, lse
|
||||
@@ -233,3 +233,40 @@ def dsv4_attention_mixed_fp8_decode(
|
||||
scale, attn_sink=sink_bias, rope_dim=rope_dim,
|
||||
)
|
||||
return o4 if has_batch else o4.squeeze(0)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# B1: mixed FP8/BF16 DeepSeek-V4 PREFILL attention (T > 1)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def dsv4_attention_mixed_fp8_prefill(
|
||||
q: torch.Tensor, # (n_q_heads,T,HD) or (B,n_q_heads,T,HD) BF16
|
||||
k_nope_fp8: torch.Tensor, # (N,NOPE) uint8/float8_e4m3fn
|
||||
k_nope_scale: torch.Tensor, # (N,) FP32
|
||||
k_rope_bf16: torch.Tensor, # (N,ROPE) BF16
|
||||
scale: Optional[float] = None,
|
||||
sink_bias: Optional[torch.Tensor] = None,
|
||||
rope_dim: int = 64,
|
||||
) -> torch.Tensor:
|
||||
"""B1 production path: storage-native FP8/BF16 KV prefill FMHA.
|
||||
|
||||
Supports T = 1..128. For T > 128, caller must split into multiple launches.
|
||||
Uses the same mixed FP8/BF16 KV format as the decode path.
|
||||
"""
|
||||
from dsv4.kernels.attention.fmha_mixed_fp8_prefill_op import fmha_mixed_fp8_prefill_raw
|
||||
|
||||
has_batch = q.dim() == 4
|
||||
if q.dim() == 3:
|
||||
q4 = q.unsqueeze(0).contiguous() # (1, H, T, HD)
|
||||
elif q.dim() == 4:
|
||||
q4 = q.contiguous()
|
||||
else:
|
||||
raise RuntimeError("q must be (H,T,HD) or (B,H,T,HD)")
|
||||
|
||||
hd = q4.shape[-1]
|
||||
scale = scale or (1.0 / math.sqrt(hd))
|
||||
o4, _lse = fmha_mixed_fp8_prefill_raw(
|
||||
q4, k_nope_fp8, k_nope_scale, k_rope_bf16,
|
||||
scale, attn_sink=sink_bias, rope_dim=rope_dim,
|
||||
)
|
||||
return o4 if has_batch else o4.squeeze(0)
|
||||
|
||||
@@ -337,6 +337,10 @@ class Nvfp4SharedExpert:
|
||||
|
||||
def _run_l2(self, intermediate: torch.Tensor) -> torch.Tensor:
|
||||
"""L2 GEMM: intermediate × down_weight → BF16."""
|
||||
# The intermediate from fused SwiGLU deinterleave is a column slice
|
||||
# (non-contiguous). quantize_nvfp4_gpu_fused requires contiguous input.
|
||||
if not intermediate.is_contiguous():
|
||||
intermediate = intermediate.contiguous()
|
||||
num_tokens = intermediate.shape[0]
|
||||
padded_rows = cutedsl_ceil_div(num_tokens, 128) * 128
|
||||
|
||||
|
||||
@@ -315,6 +315,9 @@ def quantize_nvfp4_gpu_fused(x_bf16, divisor=6.0 * 448.0):
|
||||
x_sf: (M, N//16) float8_e4m3fn
|
||||
gsa: (M,) float32 GPU tensor — per-row global scale for GEMM
|
||||
"""
|
||||
# CUDA kernels require contiguous input — column slices from deinterleave are non-contiguous
|
||||
if not x_bf16.is_contiguous():
|
||||
x_bf16 = x_bf16.contiguous()
|
||||
from dsv4.kernels.cuda.loader import get_cuda_module
|
||||
amax_mod = get_cuda_module("amax_gsa", ["amax_gsa.cu"])
|
||||
gsa_gpu = amax_mod.compute_amax_gsa(x_bf16, divisor) # scalar GPU tensor
|
||||
|
||||
1
encoding/__init__.py
Normal file
1
encoding/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
# encoding package
|
||||
757
encoding/deepseek_v4_encoding.py
Normal file
757
encoding/deepseek_v4_encoding.py
Normal file
@@ -0,0 +1,757 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
# ruff: noqa
|
||||
# fmt: off
|
||||
|
||||
"""
|
||||
DeepSeek-V4 Encoding
|
||||
|
||||
A self-contained implementation for encoding/decoding DeepSeek-V4 chat messages
|
||||
with tool calling, thinking mode, and quick instruction task support.
|
||||
"""
|
||||
|
||||
from typing import Any, Dict, List, Union, Optional, Tuple
|
||||
import copy
|
||||
import json
|
||||
|
||||
import regex as re
|
||||
|
||||
# ============================================================
|
||||
# Special Tokens
|
||||
# ============================================================
|
||||
|
||||
bos_token: str = "<|begin▁of▁sentence|>"
|
||||
eos_token: str = "<|end▁of▁sentence|>"
|
||||
thinking_start_token: str = "<think>"
|
||||
thinking_end_token: str = "</think>"
|
||||
dsml_token: str = "|DSML|"
|
||||
|
||||
USER_SP_TOKEN = "<|User|>"
|
||||
ASSISTANT_SP_TOKEN = "<|Assistant|>"
|
||||
LATEST_REMINDER_SP_TOKEN = "<|latest_reminder|>"
|
||||
|
||||
# Task special tokens for internal classification tasks
|
||||
DS_TASK_SP_TOKENS = {
|
||||
"action": "<|action|>",
|
||||
"query": "<|query|>",
|
||||
"authority": "<|authority|>",
|
||||
"domain": "<|domain|>",
|
||||
"title": "<|title|>",
|
||||
"read_url": "<|read_url|>",
|
||||
}
|
||||
VALID_TASKS = set(DS_TASK_SP_TOKENS.keys())
|
||||
|
||||
# ============================================================
|
||||
# Templates
|
||||
# ============================================================
|
||||
|
||||
system_msg_template: str = "{content}"
|
||||
user_msg_template: str = "{content}"
|
||||
latest_reminder_msg_template: str = "{content}"
|
||||
assistant_msg_template: str = "{reasoning}{content}{tool_calls}" + eos_token
|
||||
assistant_msg_wo_eos_template: str = "{reasoning}{content}{tool_calls}"
|
||||
thinking_template: str = "{reasoning}"
|
||||
|
||||
response_format_template: str = (
|
||||
"## Response Format:\n\nYou MUST strictly adhere to the following schema to reply:\n{schema}"
|
||||
)
|
||||
tool_call_template: str = (
|
||||
"<{dsml_token}invoke name=\"{name}\">\n{arguments}\n</{dsml_token}invoke>"
|
||||
)
|
||||
tool_calls_template = (
|
||||
"<{dsml_token}{tc_block_name}>\n{tool_calls}\n</{dsml_token}{tc_block_name}>"
|
||||
)
|
||||
tool_calls_block_name: str = "tool_calls"
|
||||
|
||||
tool_output_template: str = (
|
||||
"<tool_result>{content}</tool_result>"
|
||||
)
|
||||
|
||||
REASONING_EFFORT_MAX = (
|
||||
"Reasoning Effort: Absolute maximum with no shortcuts permitted.\n"
|
||||
"You MUST be very thorough in your thinking and comprehensively decompose the problem to resolve the root cause, rigorously stress-testing your logic against all potential paths, edge cases, and adversarial scenarios.\n"
|
||||
"Explicitly write out your entire deliberation process, documenting every intermediate step, considered alternative, and rejected hypothesis to ensure absolutely no assumption is left unchecked.\n\n"
|
||||
)
|
||||
|
||||
TOOLS_TEMPLATE = """## Tools
|
||||
|
||||
You have access to a set of tools to help answer the user's question. You can invoke tools by writing a "<{dsml_token}tool_calls>" block like the following:
|
||||
|
||||
<{dsml_token}tool_calls>
|
||||
<{dsml_token}invoke name="$TOOL_NAME">
|
||||
<{dsml_token}parameter name="$PARAMETER_NAME" string="true|false">$PARAMETER_VALUE</{dsml_token}parameter>
|
||||
...
|
||||
</{dsml_token}invoke>
|
||||
<{dsml_token}invoke name="$TOOL_NAME2">
|
||||
...
|
||||
</{dsml_token}invoke>
|
||||
</{dsml_token}tool_calls>
|
||||
|
||||
String parameters should be specified as is and set `string="true"`. For all other types (numbers, booleans, arrays, objects), pass the value in JSON format and set `string="false"`.
|
||||
|
||||
If thinking_mode is enabled (triggered by {thinking_start_token}), you MUST output your complete reasoning inside {thinking_start_token}...{thinking_end_token} BEFORE any tool calls or final response.
|
||||
|
||||
Otherwise, output directly after {thinking_end_token} with tool calls or final response.
|
||||
|
||||
### Available Tool Schemas
|
||||
|
||||
{tool_schemas}
|
||||
|
||||
You MUST strictly follow the above defined tool name and parameter schemas to invoke tool calls.
|
||||
"""
|
||||
|
||||
# ============================================================
|
||||
# Utility Functions
|
||||
# ============================================================
|
||||
|
||||
def to_json(value: Any) -> str:
|
||||
"""Serialize a value to JSON string."""
|
||||
try:
|
||||
return json.dumps(value, ensure_ascii=False)
|
||||
except Exception:
|
||||
return json.dumps(value, ensure_ascii=True)
|
||||
|
||||
|
||||
def tools_from_openai_format(tools):
|
||||
"""Extract function definitions from OpenAI-format tool list."""
|
||||
return [tool["function"] for tool in tools]
|
||||
|
||||
|
||||
def tool_calls_from_openai_format(tool_calls):
|
||||
"""Convert OpenAI-format tool calls to internal format."""
|
||||
return [
|
||||
{
|
||||
"name": tool_call["function"]["name"],
|
||||
"arguments": tool_call["function"]["arguments"],
|
||||
}
|
||||
for tool_call in tool_calls
|
||||
]
|
||||
|
||||
|
||||
def tool_calls_to_openai_format(tool_calls):
|
||||
"""Convert internal tool calls to OpenAI format."""
|
||||
return [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": tool_call["name"],
|
||||
"arguments": tool_call["arguments"],
|
||||
}
|
||||
}
|
||||
for tool_call in tool_calls
|
||||
]
|
||||
|
||||
|
||||
def encode_arguments_to_dsml(tool_call: Dict[str, Any]) -> str:
|
||||
"""
|
||||
Encode tool call arguments into DSML parameter format.
|
||||
|
||||
Args:
|
||||
tool_call: Dict with "name" and "arguments" keys.
|
||||
|
||||
Returns:
|
||||
DSML-formatted parameter string.
|
||||
"""
|
||||
p_dsml_template = '<{dsml_token}parameter name="{key}" string="{is_str}">{value}</{dsml_token}parameter>'
|
||||
P_dsml_strs = []
|
||||
|
||||
if isinstance(tool_call["arguments"], str):
|
||||
arguments = json.loads(tool_call["arguments"])
|
||||
else:
|
||||
arguments = tool_call["arguments"]
|
||||
|
||||
for k, v in arguments.items():
|
||||
p_dsml_str = p_dsml_template.format(
|
||||
dsml_token=dsml_token,
|
||||
key=k,
|
||||
is_str="true" if isinstance(v, str) else "false",
|
||||
value=v if isinstance(v, str) else to_json(v),
|
||||
)
|
||||
P_dsml_strs.append(p_dsml_str)
|
||||
|
||||
return "\n".join(P_dsml_strs)
|
||||
|
||||
|
||||
def decode_dsml_to_arguments(tool_name: str, tool_args: Dict[str, Tuple[str, str]]) -> Dict[str, str]:
|
||||
"""
|
||||
Decode DSML parameters back to a tool call dict.
|
||||
|
||||
Args:
|
||||
tool_name: Name of the tool.
|
||||
tool_args: Dict mapping param_name -> (value, is_string_flag).
|
||||
|
||||
Returns:
|
||||
Dict with "name" and "arguments" (JSON string) keys.
|
||||
"""
|
||||
def _decode_value(key: str, value: str, string: str):
|
||||
if string == "true":
|
||||
value = to_json(value)
|
||||
return f"{to_json(key)}: {value}"
|
||||
|
||||
tool_args_json = "{" + ", ".join([_decode_value(k, v, string=is_str) for k, (v, is_str) in tool_args.items()]) + "}"
|
||||
return dict(name=tool_name, arguments=tool_args_json)
|
||||
|
||||
|
||||
def render_tools(tools: List[Dict[str, Union[str, Dict[str, Any]]]]) -> str:
|
||||
"""
|
||||
Render tool schemas into the system prompt format.
|
||||
|
||||
Args:
|
||||
tools: List of tool schema dicts (each with name, description, parameters).
|
||||
|
||||
Returns:
|
||||
Formatted tools section string.
|
||||
"""
|
||||
tools_json = [to_json(t) for t in tools]
|
||||
|
||||
return TOOLS_TEMPLATE.format(
|
||||
tool_schemas="\n".join(tools_json),
|
||||
dsml_token=dsml_token,
|
||||
thinking_start_token=thinking_start_token,
|
||||
thinking_end_token=thinking_end_token,
|
||||
)
|
||||
|
||||
|
||||
def find_last_user_index(messages: List[Dict[str, Any]]) -> int:
|
||||
"""Find the index of the last user/developer message."""
|
||||
last_user_index = -1
|
||||
for idx in range(len(messages) - 1, -1, -1):
|
||||
if messages[idx].get("role") in ["user", "developer"]:
|
||||
last_user_index = idx
|
||||
break
|
||||
return last_user_index
|
||||
|
||||
|
||||
# ============================================================
|
||||
# Message Rendering
|
||||
# ============================================================
|
||||
|
||||
def render_message(index: int, messages: List[Dict[str, Any]], thinking_mode: str, drop_thinking: bool = True, reasoning_effort: Optional[str] = None) -> str:
|
||||
"""
|
||||
Render a single message at the given index into its encoded string form.
|
||||
|
||||
This is the core function that converts each message in the conversation
|
||||
into the DeepSeek-V4 format.
|
||||
|
||||
Args:
|
||||
index: Index of the message to render.
|
||||
messages: Full list of messages in the conversation.
|
||||
thinking_mode: Either "chat" or "thinking".
|
||||
drop_thinking: Whether to drop reasoning content from earlier turns.
|
||||
reasoning_effort: Optional reasoning effort level ("max", "high", or None).
|
||||
|
||||
Returns:
|
||||
Encoded string for this message.
|
||||
"""
|
||||
assert 0 <= index < len(messages)
|
||||
assert thinking_mode in ["chat", "thinking"], f"Invalid thinking_mode `{thinking_mode}`"
|
||||
|
||||
prompt = ""
|
||||
msg = messages[index]
|
||||
last_user_idx = find_last_user_index(messages)
|
||||
|
||||
role = msg.get("role")
|
||||
content = msg.get("content")
|
||||
tools = msg.get("tools")
|
||||
response_format = msg.get("response_format")
|
||||
tool_calls = msg.get("tool_calls")
|
||||
reasoning = msg.get("reasoning")
|
||||
wo_eos = msg.get("wo_eos", False)
|
||||
|
||||
if tools:
|
||||
tools = tools_from_openai_format(tools)
|
||||
if tool_calls:
|
||||
tool_calls = tool_calls_from_openai_format(tool_calls)
|
||||
|
||||
# Reasoning effort prefix (only at index 0 in thinking mode with max effort)
|
||||
assert reasoning_effort in ['max', None, 'high'], f"Invalid reasoning effort: {reasoning_effort}"
|
||||
if index == 0 and thinking_mode == "thinking" and reasoning_effort == 'max':
|
||||
prompt += REASONING_EFFORT_MAX
|
||||
|
||||
if role == "system":
|
||||
prompt += system_msg_template.format(content=content or "")
|
||||
if tools:
|
||||
prompt += "\n\n" + render_tools(tools)
|
||||
if response_format:
|
||||
prompt += "\n\n" + response_format_template.format(schema=to_json(response_format))
|
||||
|
||||
elif role == "developer":
|
||||
assert content, f"Invalid message for role `{role}`: {msg}"
|
||||
|
||||
content_developer = USER_SP_TOKEN
|
||||
content_developer += content
|
||||
|
||||
if tools:
|
||||
content_developer += "\n\n" + render_tools(tools)
|
||||
if response_format:
|
||||
content_developer += "\n\n" + response_format_template.format(schema=to_json(response_format))
|
||||
|
||||
prompt += user_msg_template.format(content=content_developer)
|
||||
|
||||
elif role == "user":
|
||||
prompt += USER_SP_TOKEN
|
||||
|
||||
# Handle content blocks (tool results mixed with text)
|
||||
content_blocks = msg.get("content_blocks")
|
||||
if content_blocks:
|
||||
parts = []
|
||||
for block in content_blocks:
|
||||
block_type = block.get("type")
|
||||
if block_type == "text":
|
||||
parts.append(block.get("text", ""))
|
||||
elif block_type == "tool_result":
|
||||
tool_content = block.get("content", "")
|
||||
if isinstance(tool_content, list):
|
||||
text_parts = []
|
||||
for b in tool_content:
|
||||
if b.get("type") == "text":
|
||||
text_parts.append(b.get("text", ""))
|
||||
else:
|
||||
text_parts.append(f"[Unsupported {b.get('type')}]")
|
||||
tool_content = "\n\n".join(text_parts)
|
||||
parts.append(tool_output_template.format(content=tool_content))
|
||||
else:
|
||||
parts.append(f"[Unsupported {block_type}]")
|
||||
prompt += "\n\n".join(parts)
|
||||
else:
|
||||
prompt += content or ""
|
||||
|
||||
elif role == "latest_reminder":
|
||||
prompt += LATEST_REMINDER_SP_TOKEN + latest_reminder_msg_template.format(content=content)
|
||||
|
||||
elif role == "tool":
|
||||
raise NotImplementedError("deepseek_v4 merges tool messages into user; please preprocess with merge_tool_messages()")
|
||||
|
||||
elif role == "assistant":
|
||||
thinking_part = ""
|
||||
tc_content = ""
|
||||
|
||||
if tool_calls:
|
||||
tc_list = [
|
||||
tool_call_template.format(
|
||||
dsml_token=dsml_token,
|
||||
name=tc.get("name"),
|
||||
arguments=encode_arguments_to_dsml(tc)
|
||||
)
|
||||
for tc in tool_calls
|
||||
]
|
||||
tc_content += '\n\n' + tool_calls_template.format(
|
||||
dsml_token=dsml_token,
|
||||
tool_calls="\n".join(tc_list),
|
||||
tc_block_name=tool_calls_block_name,
|
||||
)
|
||||
|
||||
summary_content = content or ""
|
||||
reasoning = reasoning or ""
|
||||
|
||||
# Check if previous message has a task - if so, this is a task output (no thinking)
|
||||
prev_has_task = index - 1 >= 0 and messages[index - 1].get("task") is not None
|
||||
|
||||
if thinking_mode == "thinking" and not prev_has_task:
|
||||
if not drop_thinking or index > last_user_idx:
|
||||
thinking_part = thinking_template.format(reasoning=reasoning) + thinking_end_token
|
||||
else:
|
||||
thinking_part = ""
|
||||
|
||||
if wo_eos:
|
||||
prompt += assistant_msg_wo_eos_template.format(
|
||||
reasoning=thinking_part,
|
||||
content=summary_content,
|
||||
tool_calls=tc_content,
|
||||
)
|
||||
else:
|
||||
prompt += assistant_msg_template.format(
|
||||
reasoning=thinking_part,
|
||||
content=summary_content,
|
||||
tool_calls=tc_content,
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError(f"Unknown role: {role}")
|
||||
|
||||
# Append transition tokens based on what follows
|
||||
if index + 1 < len(messages) and messages[index + 1].get("role") not in ["assistant", "latest_reminder"]:
|
||||
return prompt
|
||||
|
||||
task = messages[index].get("task")
|
||||
if task is not None:
|
||||
# Task special token for internal classification tasks
|
||||
assert task in VALID_TASKS, f"Invalid task: '{task}'. Valid tasks are: {list(VALID_TASKS)}"
|
||||
task_sp_token = DS_TASK_SP_TOKENS[task]
|
||||
|
||||
if task != "action":
|
||||
# Non-action tasks: append task sp token directly after the message
|
||||
prompt += task_sp_token
|
||||
else:
|
||||
# Action task: append Assistant + thinking token + action sp token
|
||||
prompt += ASSISTANT_SP_TOKEN
|
||||
prompt += thinking_end_token if thinking_mode != "thinking" else thinking_start_token
|
||||
prompt += task_sp_token
|
||||
|
||||
elif messages[index].get("role") in ["user", "developer"]:
|
||||
# Normal generation: append Assistant + thinking token
|
||||
prompt += ASSISTANT_SP_TOKEN
|
||||
if not drop_thinking and thinking_mode == "thinking":
|
||||
prompt += thinking_start_token
|
||||
elif drop_thinking and thinking_mode == "thinking" and index >= last_user_idx:
|
||||
prompt += thinking_start_token
|
||||
else:
|
||||
prompt += thinking_end_token
|
||||
|
||||
return prompt
|
||||
|
||||
|
||||
# ============================================================
|
||||
# Preprocessing
|
||||
# ============================================================
|
||||
|
||||
def merge_tool_messages(messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Merge tool messages into the preceding user message using content_blocks format.
|
||||
|
||||
DeepSeek-V4 does not have a standalone "tool" role; instead, tool results
|
||||
are encoded as <tool_result> blocks within user messages.
|
||||
|
||||
This function converts a standard OpenAI-format conversation (with separate
|
||||
"tool" role messages) into V4 format where tool results are merged into
|
||||
user messages.
|
||||
|
||||
Args:
|
||||
messages: List of message dicts in OpenAI format.
|
||||
|
||||
Returns:
|
||||
Processed message list with tool messages merged into user messages.
|
||||
"""
|
||||
merged: List[Dict[str, Any]] = []
|
||||
|
||||
for msg in messages:
|
||||
msg = copy.deepcopy(msg)
|
||||
role = msg.get("role")
|
||||
|
||||
if role == "tool":
|
||||
# Convert tool message to a user message with tool_result block
|
||||
tool_block = {
|
||||
"type": "tool_result",
|
||||
"tool_use_id": msg.get("tool_call_id", ""),
|
||||
"content": msg.get("content", ""),
|
||||
}
|
||||
# Merge into previous message if it's already a user (merged tool)
|
||||
if merged and merged[-1].get("role") == "user" and "content_blocks" in merged[-1]:
|
||||
merged[-1]["content_blocks"].append(tool_block)
|
||||
else:
|
||||
merged.append({
|
||||
"role": "user",
|
||||
"content_blocks": [tool_block],
|
||||
})
|
||||
elif role == "user":
|
||||
text_block = {"type": "text", "text": msg.get("content", "")}
|
||||
if merged and merged[-1].get("role") == "user" and "content_blocks" in merged[-1] and merged[-1].get("task") is None:
|
||||
merged[-1]["content_blocks"].append(text_block)
|
||||
else:
|
||||
new_msg = {
|
||||
"role": "user",
|
||||
"content": msg.get("content", ""),
|
||||
"content_blocks": [text_block],
|
||||
}
|
||||
# Preserve extra fields (task, wo_eos, mask, etc.)
|
||||
for key in ("task", "wo_eos", "mask"):
|
||||
if key in msg:
|
||||
new_msg[key] = msg[key]
|
||||
merged.append(new_msg)
|
||||
else:
|
||||
merged.append(msg)
|
||||
|
||||
return merged
|
||||
|
||||
|
||||
def sort_tool_results_by_call_order(messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Sort tool_result blocks within user messages by the order of tool_calls
|
||||
in the preceding assistant message.
|
||||
|
||||
Args:
|
||||
messages: Preprocessed message list (after merge_tool_messages).
|
||||
|
||||
Returns:
|
||||
Message list with sorted tool result blocks.
|
||||
"""
|
||||
last_tool_call_order: Dict[str, int] = {}
|
||||
|
||||
for msg in messages:
|
||||
role = msg.get("role")
|
||||
if role == "assistant" and msg.get("tool_calls"):
|
||||
last_tool_call_order = {}
|
||||
for idx, tc in enumerate(msg["tool_calls"]):
|
||||
tc_id = tc.get("id") or tc.get("function", {}).get("id", "")
|
||||
if tc_id:
|
||||
last_tool_call_order[tc_id] = idx
|
||||
|
||||
elif role == "user" and msg.get("content_blocks"):
|
||||
tool_blocks = [b for b in msg["content_blocks"] if b.get("type") == "tool_result"]
|
||||
if len(tool_blocks) > 1 and last_tool_call_order:
|
||||
sorted_blocks = sorted(
|
||||
tool_blocks,
|
||||
key=lambda b: last_tool_call_order.get(b.get("tool_use_id", ""), 0)
|
||||
)
|
||||
sorted_idx = 0
|
||||
new_blocks = []
|
||||
for block in msg["content_blocks"]:
|
||||
if block.get("type") == "tool_result":
|
||||
new_blocks.append(sorted_blocks[sorted_idx])
|
||||
sorted_idx += 1
|
||||
else:
|
||||
new_blocks.append(block)
|
||||
msg["content_blocks"] = new_blocks
|
||||
|
||||
return messages
|
||||
|
||||
|
||||
# ============================================================
|
||||
# Main Encoding Function
|
||||
# ============================================================
|
||||
|
||||
def encode_messages(
|
||||
messages: List[Dict[str, Any]],
|
||||
thinking_mode: str,
|
||||
context: Optional[List[Dict[str, Any]]] = None,
|
||||
drop_thinking: bool = True,
|
||||
add_default_bos_token: bool = True,
|
||||
reasoning_effort: Optional[str] = None,
|
||||
) -> str:
|
||||
"""
|
||||
Encode a list of messages into the DeepSeek-V4 prompt format.
|
||||
|
||||
This is the main entry point for encoding conversations. It handles:
|
||||
- BOS token insertion
|
||||
- Thinking mode with optional reasoning content dropping
|
||||
- Tool message merging into user messages
|
||||
- Multi-turn conversation context
|
||||
|
||||
Args:
|
||||
messages: List of message dicts to encode.
|
||||
thinking_mode: Either "chat" or "thinking".
|
||||
context: Optional preceding context messages (already encoded prefix).
|
||||
drop_thinking: If True, drop reasoning from earlier assistant turns
|
||||
(only keep reasoning for messages after the last user message).
|
||||
add_default_bos_token: Whether to prepend BOS token at conversation start.
|
||||
reasoning_effort: Optional reasoning effort level ("max", "high", or None).
|
||||
|
||||
Returns:
|
||||
The encoded prompt string.
|
||||
"""
|
||||
context = context if context else []
|
||||
|
||||
# Preprocess: merge tool messages and sort tool results
|
||||
messages = merge_tool_messages(messages)
|
||||
messages = sort_tool_results_by_call_order(context + messages)[len(context):]
|
||||
if context:
|
||||
context = merge_tool_messages(context)
|
||||
context = sort_tool_results_by_call_order(context)
|
||||
|
||||
full_messages = context + messages
|
||||
|
||||
prompt = bos_token if add_default_bos_token and len(context) == 0 else ""
|
||||
|
||||
# Resolve drop_thinking: if any message has tools defined, don't drop thinking
|
||||
effective_drop_thinking = drop_thinking
|
||||
if any(m.get("tools") for m in full_messages):
|
||||
effective_drop_thinking = False
|
||||
|
||||
if thinking_mode == "thinking" and effective_drop_thinking:
|
||||
full_messages = _drop_thinking_messages(full_messages)
|
||||
# After dropping, recalculate how many messages to render
|
||||
# (context may have shrunk too)
|
||||
num_to_render = len(full_messages) - len(_drop_thinking_messages(context))
|
||||
context_len = len(full_messages) - num_to_render
|
||||
else:
|
||||
num_to_render = len(messages)
|
||||
context_len = len(context)
|
||||
|
||||
for idx in range(num_to_render):
|
||||
prompt += render_message(
|
||||
idx + context_len,
|
||||
full_messages,
|
||||
thinking_mode=thinking_mode,
|
||||
drop_thinking=effective_drop_thinking,
|
||||
reasoning_effort=reasoning_effort,
|
||||
)
|
||||
|
||||
return prompt
|
||||
|
||||
|
||||
def _drop_thinking_messages(messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Drop reasoning and non-essential messages before the last user message.
|
||||
|
||||
Behavior:
|
||||
- Messages with role in ["user", "system", "tool", "latest_reminder"] are always kept.
|
||||
- Messages at or after the last user index are always kept.
|
||||
- Assistant messages before the last user get reasoning removed.
|
||||
- Developer messages before the last user are dropped entirely.
|
||||
"""
|
||||
last_user_idx = find_last_user_index(messages)
|
||||
result = []
|
||||
keep_roles = {"user", "system", "tool", "latest_reminder", "direct_search_results"}
|
||||
|
||||
for idx, msg in enumerate(messages):
|
||||
role = msg.get("role")
|
||||
if role in keep_roles or idx >= last_user_idx:
|
||||
result.append(msg)
|
||||
elif role == "assistant":
|
||||
msg = copy.copy(msg)
|
||||
msg.pop("reasoning", None)
|
||||
result.append(msg)
|
||||
# developer and other roles before last_user_idx are dropped
|
||||
|
||||
return result
|
||||
|
||||
|
||||
# ============================================================
|
||||
# Parsing (Decoding model output)
|
||||
# ============================================================
|
||||
|
||||
def _read_until_stop(index: int, text: str, stop: List[str]) -> Tuple[int, str, Optional[str]]:
|
||||
"""
|
||||
Read text from index until one of the stop strings is found.
|
||||
|
||||
Returns:
|
||||
Tuple of (new_index, content_before_stop, matched_stop_string_or_None).
|
||||
"""
|
||||
min_pos = len(text)
|
||||
matched_stop = None
|
||||
|
||||
for s in stop:
|
||||
pos = text.find(s, index)
|
||||
if pos != -1 and pos < min_pos:
|
||||
min_pos = pos
|
||||
matched_stop = s
|
||||
|
||||
if matched_stop:
|
||||
content = text[index:min_pos]
|
||||
return min_pos + len(matched_stop), content, matched_stop
|
||||
else:
|
||||
content = text[index:]
|
||||
return len(text), content, None
|
||||
|
||||
|
||||
def parse_tool_calls(index: int, text: str) -> Tuple[int, Optional[str], List[Dict[str, str]]]:
|
||||
"""
|
||||
Parse DSML tool calls from text starting at the given index.
|
||||
|
||||
Args:
|
||||
index: Starting position in text.
|
||||
text: The full text to parse.
|
||||
|
||||
Returns:
|
||||
Tuple of (new_index, last_stop_token, list_of_tool_call_dicts).
|
||||
Each tool call dict has "name" and "arguments" keys.
|
||||
"""
|
||||
tool_calls: List[Dict[str, Any]] = []
|
||||
stop_token = None
|
||||
tool_calls_end_token = f"</{dsml_token}{tool_calls_block_name}>"
|
||||
|
||||
while index < len(text):
|
||||
index, content_before, stop_token = _read_until_stop(index, text, [f"<{dsml_token}invoke", tool_calls_end_token])
|
||||
if content_before != ">\n":
|
||||
raise ValueError(f"Tool call format error: expected '>\\n' but got '{content_before}'")
|
||||
|
||||
if stop_token == tool_calls_end_token:
|
||||
break
|
||||
|
||||
if stop_token is None:
|
||||
raise ValueError("Missing special token in tool calls")
|
||||
|
||||
index, tool_name_content, stop_token = _read_until_stop(index, text, [f"<{dsml_token}parameter", f"</{dsml_token}invoke"])
|
||||
|
||||
p_tool_name = re.findall(r'^\s*name="(.*?)">\n$', tool_name_content, flags=re.DOTALL)
|
||||
if len(p_tool_name) != 1:
|
||||
raise ValueError(f"Tool name format error: '{tool_name_content}'")
|
||||
tool_name = p_tool_name[0]
|
||||
|
||||
tool_args: Dict[str, Tuple[str, str]] = {}
|
||||
while stop_token == f"<{dsml_token}parameter":
|
||||
index, param_content, stop_token = _read_until_stop(index, text, [f"/{dsml_token}parameter"])
|
||||
|
||||
param_kv = re.findall(r'^ name="(.*?)" string="(true|false)">(.*?)<$', param_content, flags=re.DOTALL)
|
||||
if len(param_kv) != 1:
|
||||
raise ValueError(f"Parameter format error: '{param_content}'")
|
||||
param_name, string, param_value = param_kv[0]
|
||||
|
||||
if param_name in tool_args:
|
||||
raise ValueError(f"Duplicate parameter name: '{param_name}'")
|
||||
tool_args[param_name] = (param_value, string)
|
||||
|
||||
index, content, stop_token = _read_until_stop(index, text, [f"<{dsml_token}parameter", f"</{dsml_token}invoke"])
|
||||
if content != ">\n":
|
||||
raise ValueError(f"Parameter format error: expected '>\\n' but got '{content}'")
|
||||
|
||||
tool_call = decode_dsml_to_arguments(tool_name=tool_name, tool_args=tool_args)
|
||||
tool_calls.append(tool_call)
|
||||
|
||||
return index, stop_token, tool_calls
|
||||
|
||||
|
||||
def parse_message_from_completion_text(text: str, thinking_mode: str) -> Dict[str, Any]:
|
||||
"""
|
||||
Parse a model completion text into a structured assistant message.
|
||||
|
||||
This function takes the raw text output from the model (a single assistant turn)
|
||||
and extracts:
|
||||
- reasoning (thinking block)
|
||||
- content (summary/response)
|
||||
- tool_calls (if any)
|
||||
|
||||
NOTE: This function is designed to parse only correctly formatted strings and
|
||||
will raise ValueError for malformed output.
|
||||
|
||||
Args:
|
||||
text: The raw completion text (including EOS token).
|
||||
thinking_mode: Either "chat" or "thinking".
|
||||
|
||||
Returns:
|
||||
Dict with keys: "role", "content", "reasoning", "tool_calls".
|
||||
tool_calls are in OpenAI format.
|
||||
"""
|
||||
summary_content, reasoning = "", ""
|
||||
tool_calls: List[Dict[str, str]] = []
|
||||
index, stop_token = 0, None
|
||||
tool_calls_start_token = f"\n\n<{dsml_token}{tool_calls_block_name}"
|
||||
|
||||
is_thinking = thinking_mode == "thinking"
|
||||
is_tool_calling = False
|
||||
|
||||
if is_thinking:
|
||||
index, content_delta, stop_token = _read_until_stop(index, text, [thinking_end_token, tool_calls_start_token])
|
||||
reasoning = content_delta
|
||||
if stop_token != thinking_end_token:
|
||||
raise ValueError("Invalid thinking format: missing </think>")
|
||||
|
||||
index, content_delta, stop_token = _read_until_stop(index, text, [eos_token, tool_calls_start_token])
|
||||
summary_content = content_delta
|
||||
if stop_token == tool_calls_start_token:
|
||||
is_tool_calling = True
|
||||
else:
|
||||
if stop_token != eos_token:
|
||||
raise ValueError("Invalid format: missing EOS token")
|
||||
|
||||
if is_tool_calling:
|
||||
index, stop_token, tool_calls = parse_tool_calls(index, text)
|
||||
|
||||
index, tool_ends_text, stop_token = _read_until_stop(index, text, [eos_token])
|
||||
if tool_ends_text:
|
||||
raise ValueError("Unexpected content after tool calls")
|
||||
|
||||
if len(text) != index or stop_token not in [eos_token, None]:
|
||||
raise ValueError("Unexpected content at end")
|
||||
|
||||
for sp_token in [bos_token, eos_token, thinking_start_token, thinking_end_token, dsml_token]:
|
||||
if sp_token in summary_content or sp_token in reasoning:
|
||||
raise ValueError(f"Unexpected special token '{sp_token}' in content")
|
||||
|
||||
return {
|
||||
"role": "assistant",
|
||||
"content": summary_content,
|
||||
"reasoning": reasoning,
|
||||
"tool_calls": tool_calls_to_openai_format(tool_calls)
|
||||
}
|
||||
|
||||
# fmt: on
|
||||
49
reference/README.md
Normal file
49
reference/README.md
Normal file
@@ -0,0 +1,49 @@
|
||||
# Reference Implementations
|
||||
|
||||
This directory contains **read-only** reference implementations from official sources.
|
||||
Do not modify these files — they exist to cross-check our production pipeline.
|
||||
|
||||
## Directory Layout
|
||||
|
||||
```
|
||||
reference/
|
||||
├── vllm/ # vLLM project reference (Apache-2.0)
|
||||
│ ├── tokenizers/
|
||||
│ │ ├── deepseek_v4.py # Tokenizer wrapper — apply_chat_template for DSV4
|
||||
│ │ └── deepseek_v4_encoding.py # Official prompt encoder (canonical source)
|
||||
│ ├── reasoning/
|
||||
│ │ ├── deepseek_v3_reasoning_parser.py # Thinking-mode dispatcher
|
||||
│ │ └── deepseek_r1_reasoning_parser.py # )/) reasoning token parser
|
||||
│ └── tool_parsers/
|
||||
│ ├── deepseekv4_tool_parser.py # DSML tool call parser (V4)
|
||||
│ └── deepseekv32_tool_parser.py # DSML tool call parser (V3.2 base)
|
||||
│
|
||||
└── official_inference/ # Original weight's reference inference code
|
||||
├── generate.py # Official generate loop + encode_messages usage
|
||||
├── model.py # BF16/FP8 model implementation
|
||||
├── kernel.py # Reference CUDA kernels
|
||||
├── convert.py # Weight conversion
|
||||
└── config.json # Model config (small variant)
|
||||
```
|
||||
|
||||
## Key Files for Our Pipeline
|
||||
|
||||
1. **`vllm/tokenizers/deepseek_v4_encoding.py`** — Canonical prompt encoder.
|
||||
Already copied to `encoding/deepseek_v4_encoding.py` in the repo root (our live import).
|
||||
If vLLM updates this file, diff and sync.
|
||||
|
||||
2. **`vllm/tokenizers/deepseek_v4.py`** — Shows how vLLM wraps the tokenizer
|
||||
to add `apply_chat_template` support. Key insight: it calls
|
||||
`encode_messages(messages, thinking_mode=..., ...)` then
|
||||
`tokenizer.encode(prompt_str, add_special_tokens=False)`.
|
||||
This is exactly what our single_shot does.
|
||||
|
||||
3. **`official_inference/generate.py`** — The original weight's inference entry point.
|
||||
Uses `tokenizer.encode(encode_messages(messages, thinking_mode="chat"))`
|
||||
(default `add_special_tokens=True`) and `parse_message_from_completion_text()`
|
||||
for output parsing.
|
||||
|
||||
4. **`vllm/reasoning/`** — How vLLM detects thinking mode boundaries
|
||||
(`)、` start, `)/)` end). Useful when we integrate streaming.
|
||||
|
||||
5. **`vllm/tool_parsers/`** — DSML tool call parsing for future tool-use support.
|
||||
1
reference/official_inference/README.md
Normal file
1
reference/official_inference/README.md
Normal file
@@ -0,0 +1 @@
|
||||
# THIS WAS FROM https://huggingface.co/deepseek-ai/DeepSeek-V4-Pro/tree/main/inference IT WAS USED TO REFERENCE HOW THE PARSERS, TOKENIZERS, AND TEMPLATING ARE HOOKED UP. IGNORE THE KERNEL AS OUR VERSION OF DSV4 IS OUR OWN NVFP4 QUANT
|
||||
1
reference/official_inference/__init__.py
Normal file
1
reference/official_inference/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
# Official inference reference — read only, do not modify
|
||||
35
reference/official_inference/config.json
Normal file
35
reference/official_inference/config.json
Normal file
@@ -0,0 +1,35 @@
|
||||
{
|
||||
"vocab_size": 129280,
|
||||
"dim": 7168,
|
||||
"moe_inter_dim": 3072,
|
||||
"n_layers": 61,
|
||||
"n_hash_layers": 3,
|
||||
"n_heads": 128,
|
||||
"n_routed_experts": 384,
|
||||
"n_shared_experts": 1,
|
||||
"n_activated_experts": 6,
|
||||
"score_func": "sqrtsoftplus",
|
||||
"route_scale": 2.5,
|
||||
"swiglu_limit": 10.0,
|
||||
"q_lora_rank": 1536,
|
||||
"head_dim": 512,
|
||||
"rope_head_dim": 64,
|
||||
"o_groups": 16,
|
||||
"o_lora_rank": 1024,
|
||||
"window_size": 128,
|
||||
"original_seq_len": 65536,
|
||||
"rope_theta": 10000,
|
||||
"rope_factor": 16,
|
||||
"beta_fast": 32,
|
||||
"beta_slow": 1,
|
||||
"index_n_heads": 64,
|
||||
"index_head_dim": 128,
|
||||
"index_topk": 1024,
|
||||
"hc_mult": 4,
|
||||
"hc_sinkhorn_iters": 20,
|
||||
"dtype": "fp8",
|
||||
"scale_fmt": "ue8m0",
|
||||
"expert_dtype": "fp4",
|
||||
"compress_rope_theta": 160000,
|
||||
"compress_ratios": [128, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 0]
|
||||
}
|
||||
168
reference/official_inference/convert.py
Normal file
168
reference/official_inference/convert.py
Normal file
@@ -0,0 +1,168 @@
|
||||
import os
|
||||
import shutil
|
||||
from argparse import ArgumentParser
|
||||
from glob import glob
|
||||
from tqdm import tqdm, trange
|
||||
|
||||
import torch
|
||||
from safetensors.torch import safe_open, save_file
|
||||
|
||||
|
||||
FP4_TABLE = torch.tensor([
|
||||
0.0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0,
|
||||
0.0, -0.5, -1.0, -1.5, -2.0, -3.0, -4.0, -6.0
|
||||
], dtype=torch.float32)
|
||||
|
||||
|
||||
def cast_e2m1fn_to_e4m3fn(x: torch.Tensor, scale: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Casts a tensor from e2m1fn to e4m3fn losslessly.
|
||||
"""
|
||||
assert x.dtype == torch.int8
|
||||
assert x.ndim == 2
|
||||
out_dim, in_dim = x.size()
|
||||
in_dim *= 2
|
||||
fp8_block_size = 128
|
||||
fp4_block_size = 32
|
||||
assert in_dim % fp8_block_size == 0 and out_dim % fp8_block_size == 0
|
||||
assert scale.size(0) == out_dim and scale.size(1) == in_dim // fp4_block_size
|
||||
|
||||
x = x.view(torch.uint8)
|
||||
low = x & 0x0F
|
||||
high = (x >> 4) & 0x0F
|
||||
x = torch.stack([FP4_TABLE[low.long()], FP4_TABLE[high.long()]], dim=-1).flatten(2)
|
||||
|
||||
# max_fp4 (6.0) * MAX_OFFSET must fit in e4m3fn (max 448)
|
||||
# 6.0 * 2^6 = 384 < 448; 6.0 * 2^7 = 768 > 448; so MAX_OFFSET_BITS = 6
|
||||
MAX_OFFSET_BITS = 6
|
||||
|
||||
bOut = out_dim // fp8_block_size
|
||||
bIn = in_dim // fp8_block_size
|
||||
# bOut, bIn, 128, 128
|
||||
x = x.view(bOut, fp8_block_size, bIn, fp8_block_size).transpose(1, 2)
|
||||
# bOut, bIn, 128*4
|
||||
scale = scale.float().view(bOut, fp8_block_size, bIn, -1).transpose(1, 2).flatten(2)
|
||||
## bOut, bIn, 1
|
||||
scale_max_offset_bits = scale.amax(dim=-1, keepdim=True) / (2**MAX_OFFSET_BITS)
|
||||
# bOut, bIn, 128*4
|
||||
offset = scale / scale_max_offset_bits
|
||||
# bOut, bIn, 128, 128
|
||||
offset = offset.unflatten(-1, (fp8_block_size, -1)).repeat_interleave(fp4_block_size, dim=-1)
|
||||
x = (x * offset).transpose(1, 2).reshape(out_dim, in_dim)
|
||||
return x.to(torch.float8_e4m3fn), scale_max_offset_bits.squeeze(-1).to(torch.float8_e8m0fnu)
|
||||
|
||||
|
||||
mapping = {
|
||||
"embed_tokens": ("embed", 0),
|
||||
"input_layernorm": ("attn_norm", None),
|
||||
"post_attention_layernorm": ("ffn_norm", None),
|
||||
"q_proj": ("wq", 0),
|
||||
"q_a_proj": ("wq_a", None),
|
||||
"q_a_layernorm": ("q_norm", None),
|
||||
"q_b_proj": ("wq_b", 0),
|
||||
"kv_a_proj_with_mqa": ("wkv_a", None),
|
||||
"kv_a_layernorm": ("kv_norm", None),
|
||||
"kv_b_proj": ("wkv_b", 0),
|
||||
"o_proj": ("wo", 1),
|
||||
"gate_proj": ("w1", 0),
|
||||
"down_proj": ("w2", 1),
|
||||
"up_proj": ("w3", 0),
|
||||
"lm_head": ("head", 0),
|
||||
|
||||
"embed": ("embed", 0),
|
||||
"wq_b": ("wq_b", 0),
|
||||
"wo_a": ("wo_a", 0),
|
||||
"wo_b": ("wo_b", 1),
|
||||
"head": ("head", 0),
|
||||
"attn_sink": ("attn_sink", 0),
|
||||
"weights_proj": ("weights_proj", 0),
|
||||
}
|
||||
|
||||
|
||||
def main(hf_ckpt_path, save_path, n_experts, mp, expert_dtype):
|
||||
"""
|
||||
Converts and saves model checkpoint files into a specified format.
|
||||
|
||||
Args:
|
||||
hf_ckpt_path (str): Path to the directory containing the input checkpoint files.
|
||||
save_path (str): Path to the directory where the converted checkpoint files will be saved.
|
||||
n_experts (int): Total number of experts in the model.
|
||||
mp (int): Model parallelism factor.
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
torch.set_num_threads(8)
|
||||
n_local_experts = n_experts // mp
|
||||
state_dicts = [{} for _ in range(mp)]
|
||||
|
||||
for file_path in tqdm(glob(os.path.join(hf_ckpt_path, "*.safetensors"))):
|
||||
with safe_open(file_path, framework="pt", device="cpu") as f:
|
||||
for name in f.keys():
|
||||
param: torch.Tensor = f.get_tensor(name)
|
||||
if name.startswith("model."):
|
||||
name = name[len("model."):]
|
||||
if name.startswith("mtp.") and ("emb" in name or name.endswith("head.weight")):
|
||||
continue
|
||||
name = name.replace("self_attn", "attn")
|
||||
name = name.replace("mlp", "ffn")
|
||||
name = name.replace("weight_scale_inv", "scale")
|
||||
name = name.replace("e_score_correction_bias", "bias")
|
||||
if any(x in name for x in ["hc", "attn_sink", "tie2eid", "ape"]): # without .weight
|
||||
key = name.split(".")[-1]
|
||||
else:
|
||||
key = name.split(".")[-2]
|
||||
if key in mapping:
|
||||
new_key, dim = mapping[key]
|
||||
else:
|
||||
new_key, dim = key, None
|
||||
name = name.replace(key, new_key)
|
||||
for i in range(mp):
|
||||
new_param = param
|
||||
if "experts" in name and "shared_experts" not in name:
|
||||
idx = int(name.split(".")[-3])
|
||||
if idx < i * n_local_experts or idx >= (i + 1) * n_local_experts:
|
||||
continue
|
||||
elif dim is not None:
|
||||
assert param.size(dim) % mp == 0, f"Dimension {dim} must be divisible by {mp}"
|
||||
shard_size = param.size(dim) // mp
|
||||
new_param = param.narrow(dim, i * shard_size, shard_size).contiguous()
|
||||
state_dicts[i][name] = new_param
|
||||
|
||||
os.makedirs(save_path, exist_ok=True)
|
||||
|
||||
for i in trange(mp):
|
||||
names = list(state_dicts[i].keys())
|
||||
for name in names:
|
||||
if name.endswith("wo_a.weight"):
|
||||
weight = state_dicts[i][name]
|
||||
scale = state_dicts[i].pop(name.replace("weight", "scale"))
|
||||
weight = weight.unflatten(0, (-1, 128)).unflatten(-1, (-1, 128)).float() * scale[:, None, :, None].float()
|
||||
state_dicts[i][name] = weight.flatten(2, 3).flatten(0, 1).bfloat16()
|
||||
elif "experts" in name and state_dicts[i][name].dtype == torch.int8:
|
||||
if expert_dtype == "fp8":
|
||||
scale_name = name.replace("weight", "scale")
|
||||
weight = state_dicts[i].pop(name)
|
||||
scale = state_dicts[i].pop(scale_name)
|
||||
state_dicts[i][name], state_dicts[i][scale_name] = cast_e2m1fn_to_e4m3fn(weight, scale)
|
||||
else:
|
||||
state_dicts[i][name] = state_dicts[i][name].view(torch.float4_e2m1fn_x2)
|
||||
save_file(state_dicts[i], os.path.join(save_path, f"model{i}-mp{mp}.safetensors"))
|
||||
|
||||
for file in ["tokenizer.json", "tokenizer_config.json"]:
|
||||
old_file_path = os.path.join(hf_ckpt_path, file)
|
||||
new_file_path = os.path.join(save_path, file)
|
||||
if os.path.exists(old_file_path):
|
||||
shutil.copyfile(old_file_path, new_file_path)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = ArgumentParser()
|
||||
parser.add_argument("--hf-ckpt-path", type=str, required=True)
|
||||
parser.add_argument("--save-path", type=str, required=True)
|
||||
parser.add_argument("--n-experts", type=int, required=True)
|
||||
parser.add_argument("--model-parallel", type=int, required=True)
|
||||
parser.add_argument("--expert-dtype", type=str, choices=["fp8", "fp4"], required=False, default=None)
|
||||
args = parser.parse_args()
|
||||
assert args.n_experts % args.model_parallel == 0, "Number of experts must be divisible by model parallelism"
|
||||
main(args.hf_ckpt_path, args.save_path, args.n_experts, args.model_parallel, args.expert_dtype)
|
||||
155
reference/official_inference/generate.py
Normal file
155
reference/official_inference/generate.py
Normal file
@@ -0,0 +1,155 @@
|
||||
import os
|
||||
import json
|
||||
import sys
|
||||
from argparse import ArgumentParser
|
||||
from typing import List
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from transformers import AutoTokenizer
|
||||
from safetensors.torch import load_model
|
||||
|
||||
from model import Transformer, ModelArgs
|
||||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
encoding_dir = os.path.join(current_dir, '../encoding')
|
||||
sys.path.insert(0, os.path.abspath(encoding_dir))
|
||||
from encoding_dsv4 import encode_messages, parse_message_from_completion_text
|
||||
|
||||
|
||||
def sample(logits, temperature: float = 1.0):
|
||||
"""Gumbel-max trick: equivalent to multinomial sampling but faster on GPU,
|
||||
since it avoids the GPU-to-CPU sync in torch.multinomial."""
|
||||
logits = logits / max(temperature, 1e-5)
|
||||
probs = torch.softmax(logits, dim=-1, dtype=torch.float32)
|
||||
return probs.div_(torch.empty_like(probs).exponential_(1)).argmax(dim=-1)
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def generate(
|
||||
model: Transformer,
|
||||
prompt_tokens: List[List[int]],
|
||||
max_new_tokens: int,
|
||||
eos_id: int,
|
||||
temperature: float = 1.0
|
||||
) -> List[List[int]]:
|
||||
"""Batch generation with left-padded prompts.
|
||||
|
||||
The first forward pass processes [min_prompt_len:] tokens (prefill phase).
|
||||
Subsequent passes generate one token at a time (decode phase). For positions
|
||||
still within a prompt, the ground-truth token overrides the model's prediction.
|
||||
"""
|
||||
prompt_lens = [len(t) for t in prompt_tokens]
|
||||
assert max(prompt_lens) <= model.max_seq_len, f"Prompt length exceeds model maximum sequence length (max_seq_len={model.max_seq_len})"
|
||||
total_len = min(model.max_seq_len, max_new_tokens + max(prompt_lens))
|
||||
tokens = torch.full((len(prompt_tokens), total_len), -1, dtype=torch.long)
|
||||
for i, t in enumerate(prompt_tokens):
|
||||
tokens[i, :len(t)] = torch.tensor(t, dtype=torch.long)
|
||||
prev_pos = 0
|
||||
finished = torch.tensor([False] * len(prompt_tokens))
|
||||
prompt_mask = tokens != -1
|
||||
for cur_pos in range(min(prompt_lens), total_len):
|
||||
logits = model.forward(tokens[:, prev_pos:cur_pos], prev_pos)
|
||||
if temperature > 0:
|
||||
next_token = sample(logits, temperature)
|
||||
else:
|
||||
next_token = logits.argmax(dim=-1)
|
||||
next_token = torch.where(prompt_mask[:, cur_pos], tokens[:, cur_pos], next_token)
|
||||
tokens[:, cur_pos] = next_token
|
||||
finished |= torch.logical_and(~prompt_mask[:, cur_pos], next_token == eos_id)
|
||||
prev_pos = cur_pos
|
||||
if finished.all():
|
||||
break
|
||||
completion_tokens = []
|
||||
for i, toks in enumerate(tokens.tolist()):
|
||||
toks = toks[prompt_lens[i]:prompt_lens[i]+max_new_tokens]
|
||||
if eos_id in toks:
|
||||
toks = toks[:toks.index(eos_id)]
|
||||
toks.append(eos_id)
|
||||
completion_tokens.append(toks)
|
||||
return completion_tokens
|
||||
|
||||
|
||||
def main(
|
||||
ckpt_path: str,
|
||||
config: str,
|
||||
input_file: str = "",
|
||||
interactive: bool = True,
|
||||
max_new_tokens: int = 100,
|
||||
temperature: float = 1.0,
|
||||
) -> None:
|
||||
world_size = int(os.getenv("WORLD_SIZE", "1"))
|
||||
rank = int(os.getenv("RANK", "0"))
|
||||
local_rank = int(os.getenv("LOCAL_RANK", "0"))
|
||||
if world_size > 1:
|
||||
dist.init_process_group("nccl")
|
||||
global print
|
||||
if rank != 0:
|
||||
print = lambda *_, **__: None
|
||||
torch.cuda.set_device(local_rank)
|
||||
torch.cuda.memory._set_allocator_settings("expandable_segments:True")
|
||||
torch.set_default_dtype(torch.bfloat16)
|
||||
torch.set_num_threads(8)
|
||||
torch.manual_seed(33377335)
|
||||
with open(config) as f:
|
||||
args = ModelArgs(**json.load(f))
|
||||
if interactive:
|
||||
args.max_batch_size = 1
|
||||
print(args)
|
||||
with torch.device("cuda"):
|
||||
model = Transformer(args)
|
||||
tokenizer = AutoTokenizer.from_pretrained(ckpt_path)
|
||||
print("load model")
|
||||
load_model(model, os.path.join(ckpt_path, f"model{rank}-mp{world_size}.safetensors"), strict=False)
|
||||
torch.set_default_device("cuda")
|
||||
print("I'm DeepSeek 👋")
|
||||
|
||||
if interactive:
|
||||
messages = []
|
||||
while True:
|
||||
if world_size == 1:
|
||||
prompt = input(">>> ")
|
||||
elif rank == 0:
|
||||
prompt = input(">>> ")
|
||||
objects = [prompt]
|
||||
dist.broadcast_object_list(objects, 0)
|
||||
else:
|
||||
objects = [None]
|
||||
dist.broadcast_object_list(objects, 0)
|
||||
prompt = objects[0]
|
||||
if prompt == "/exit":
|
||||
break
|
||||
elif prompt == "/clear":
|
||||
messages.clear()
|
||||
continue
|
||||
messages.append({"role": "user", "content": prompt})
|
||||
prompt_tokens = tokenizer.encode(encode_messages(messages, thinking_mode="chat"))
|
||||
completion_tokens = generate(model, [prompt_tokens], max_new_tokens, tokenizer.eos_token_id, temperature)
|
||||
completion = tokenizer.decode(completion_tokens[0])
|
||||
print(completion)
|
||||
messages.append(parse_message_from_completion_text(completion, thinking_mode="chat"))
|
||||
else:
|
||||
with open(input_file) as f:
|
||||
prompts = f.read().split("\n\n")
|
||||
prompt_tokens = [tokenizer.encode(encode_messages([{"role": "user", "content": prompt}], thinking_mode="chat")) for prompt in prompts]
|
||||
completion_tokens = generate(model, prompt_tokens, max_new_tokens, tokenizer.eos_token_id, temperature)
|
||||
completions = tokenizer.batch_decode(completion_tokens)
|
||||
for prompt, completion in zip(prompts, completions):
|
||||
print("Prompt:", prompt)
|
||||
print("Completion:", completion)
|
||||
print()
|
||||
|
||||
if world_size > 1:
|
||||
dist.destroy_process_group()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = ArgumentParser()
|
||||
parser.add_argument("--ckpt-path", type=str, required=True)
|
||||
parser.add_argument("--config", type=str, required=True)
|
||||
parser.add_argument("--input-file", type=str, default="")
|
||||
parser.add_argument("--interactive", action="store_true")
|
||||
parser.add_argument("--max-new-tokens", type=int, default=300)
|
||||
parser.add_argument("--temperature", type=float, default=0.6)
|
||||
args = parser.parse_args()
|
||||
assert args.input_file or args.interactive, "Either input-file or interactive mode must be specified"
|
||||
main(args.ckpt_path, args.config, args.input_file, args.interactive, args.max_new_tokens, args.temperature)
|
||||
536
reference/official_inference/kernel.py
Normal file
536
reference/official_inference/kernel.py
Normal file
@@ -0,0 +1,536 @@
|
||||
import torch
|
||||
import tilelang
|
||||
import tilelang.language as T
|
||||
from typing import Tuple, Optional
|
||||
|
||||
|
||||
tilelang.set_log_level("WARNING")
|
||||
|
||||
pass_configs = {
|
||||
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
|
||||
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
|
||||
}
|
||||
|
||||
FP8 = "float8_e4m3"
|
||||
FP4 = "float4_e2m1fn"
|
||||
FE8M0 = "float8_e8m0fnu"
|
||||
BF16 = "bfloat16"
|
||||
FP32 = "float32"
|
||||
INT32 = "int32"
|
||||
|
||||
|
||||
def fast_log2_ceil(x):
|
||||
"""Compute ceil(log2(x)) via IEEE 754 bit manipulation. Avoids slow log/ceil intrinsics."""
|
||||
bits_x = T.reinterpret("uint32", x)
|
||||
exp_x = (bits_x >> 23) & 0xFF
|
||||
man_bits = bits_x & ((1 << 23) - 1)
|
||||
return T.Cast("int32", exp_x - 127 + T.if_then_else(man_bits != 0, 1, 0))
|
||||
|
||||
|
||||
def fast_pow2(x):
|
||||
"""Compute 2^x for integer x via IEEE 754 bit manipulation."""
|
||||
bits_x = (x + 127) << 23
|
||||
return T.reinterpret("float32", bits_x)
|
||||
|
||||
|
||||
def fast_round_scale(amax, fp8_max_inv):
|
||||
return fast_pow2(fast_log2_ceil(amax * fp8_max_inv))
|
||||
|
||||
|
||||
@tilelang.jit(pass_configs=pass_configs)
|
||||
def act_quant_kernel(
|
||||
N, block_size=128, in_dtype=BF16, out_dtype=FP8, scale_dtype=FP32,
|
||||
round_scale=False, inplace=False
|
||||
):
|
||||
"""Block-wise FP8 quantization. inplace=True does fused quant+dequant back to BF16."""
|
||||
M = T.symbolic("M")
|
||||
fp8_min = -448.0
|
||||
fp8_max = 448.0
|
||||
fp8_max_inv = 1 / fp8_max
|
||||
num_stages = 0 if round_scale or inplace else 2
|
||||
blk_m = 32
|
||||
group_size = block_size
|
||||
# Internal computation in FP32; scale_dtype controls output storage format.
|
||||
compute_dtype = FP32
|
||||
out_dtype = in_dtype if inplace else out_dtype
|
||||
|
||||
@T.prim_func
|
||||
def act_quant_kernel_(
|
||||
X: T.Tensor[(M, N), in_dtype],
|
||||
Y: T.Tensor[(M, N), out_dtype],
|
||||
S: T.Tensor[(M, T.ceildiv(N, group_size)), scale_dtype],
|
||||
):
|
||||
with T.Kernel(T.ceildiv(M, blk_m), T.ceildiv(N, group_size), threads=128) as (
|
||||
pid_m,
|
||||
pid_n,
|
||||
):
|
||||
x_shared = T.alloc_shared((blk_m, group_size), in_dtype)
|
||||
x_local = T.alloc_fragment((blk_m, group_size), in_dtype)
|
||||
amax_local = T.alloc_fragment((blk_m,), compute_dtype)
|
||||
s_local = T.alloc_fragment((blk_m,), compute_dtype)
|
||||
y_local = T.alloc_fragment((blk_m, group_size), out_dtype)
|
||||
y_shared = T.alloc_shared((blk_m, group_size), out_dtype)
|
||||
|
||||
for _ in T.Pipelined(1, num_stages=num_stages):
|
||||
T.copy(X[pid_m * blk_m, pid_n * group_size], x_shared)
|
||||
T.copy(x_shared, x_local)
|
||||
T.reduce_absmax(x_local, amax_local, dim=1)
|
||||
for i in T.Parallel(blk_m):
|
||||
amax_local[i] = T.max(amax_local[i], 1e-4)
|
||||
if round_scale:
|
||||
s_local[i] = fast_round_scale(amax_local[i], fp8_max_inv)
|
||||
else:
|
||||
s_local[i] = amax_local[i] * fp8_max_inv
|
||||
if inplace:
|
||||
for i, j in T.Parallel(blk_m, group_size):
|
||||
y_local[i, j] = T.Cast(
|
||||
out_dtype,
|
||||
T.Cast(compute_dtype, T.Cast(out_dtype, T.clamp(
|
||||
x_local[i, j] / s_local[i], fp8_min, fp8_max
|
||||
))) * s_local[i],
|
||||
)
|
||||
else:
|
||||
for i, j in T.Parallel(blk_m, group_size):
|
||||
y_local[i, j] = T.clamp(
|
||||
x_local[i, j] / s_local[i], fp8_min, fp8_max
|
||||
)
|
||||
for i in T.Parallel(blk_m):
|
||||
S[pid_m * blk_m + i, pid_n] = T.Cast(scale_dtype, s_local[i])
|
||||
T.copy(y_local, y_shared)
|
||||
T.copy(y_shared, Y[pid_m * blk_m, pid_n * group_size])
|
||||
|
||||
return act_quant_kernel_
|
||||
|
||||
|
||||
def act_quant(
|
||||
x: torch.Tensor, block_size: int = 128, scale_fmt: Optional[str] = None,
|
||||
scale_dtype: torch.dtype = torch.float32, inplace: bool = False,
|
||||
) -> torch.Tensor:
|
||||
"""Block-wise FP8 quantization. inplace=True does fused quant+dequant back to BF16.
|
||||
When scale_fmt is set, scales are rounded to power-of-2 (MXFP)."""
|
||||
N = x.size(-1)
|
||||
assert N % block_size == 0
|
||||
tl_dtype = FE8M0 if scale_dtype == torch.float8_e8m0fnu else FP32
|
||||
z = x.contiguous()
|
||||
y = torch.empty_like(z) if inplace else torch.empty_like(z, dtype=torch.float8_e4m3fn)
|
||||
s = z.new_empty(*z.size()[:-1], N // block_size, dtype=scale_dtype)
|
||||
kernel = act_quant_kernel(
|
||||
N, block_size, scale_dtype=tl_dtype,
|
||||
round_scale=scale_fmt is not None, inplace=inplace,
|
||||
)
|
||||
kernel(z.view(-1, N), y.view(-1, N), s.view(-1, N // block_size))
|
||||
if inplace:
|
||||
x.copy_(y)
|
||||
return x
|
||||
return y, s
|
||||
|
||||
|
||||
@tilelang.jit(pass_configs=pass_configs)
|
||||
def fp4_quant_kernel(
|
||||
N, block_size=32, in_dtype=BF16, scale_dtype=FE8M0, inplace=False
|
||||
):
|
||||
"""Block-wise FP4 quantization. Power-of-2 scale via bit ops. inplace=True does fused quant+dequant."""
|
||||
M = T.symbolic("M")
|
||||
fp4_max = 6.0
|
||||
fp4_max_inv = 1.0 / fp4_max
|
||||
blk_m = 32
|
||||
group_size = block_size
|
||||
compute_dtype = FP32
|
||||
out_dtype = in_dtype if inplace else FP4
|
||||
|
||||
@T.prim_func
|
||||
def fp4_quant_kernel_(
|
||||
X: T.Tensor[(M, N), in_dtype],
|
||||
Y: T.Tensor[(M, N), out_dtype],
|
||||
S: T.Tensor[(M, T.ceildiv(N, group_size)), scale_dtype],
|
||||
):
|
||||
with T.Kernel(T.ceildiv(M, blk_m), T.ceildiv(N, group_size), threads=128) as (
|
||||
pid_m,
|
||||
pid_n,
|
||||
):
|
||||
x_shared = T.alloc_shared((blk_m, group_size), in_dtype)
|
||||
x_local = T.alloc_fragment((blk_m, group_size), in_dtype)
|
||||
amax_local = T.alloc_fragment((blk_m,), compute_dtype)
|
||||
s_local = T.alloc_fragment((blk_m,), compute_dtype)
|
||||
y_local = T.alloc_fragment((blk_m, group_size), out_dtype)
|
||||
y_shared = T.alloc_shared((blk_m, group_size), out_dtype)
|
||||
|
||||
for _ in T.Pipelined(1, num_stages=2):
|
||||
T.copy(X[pid_m * blk_m, pid_n * group_size], x_shared)
|
||||
T.copy(x_shared, x_local)
|
||||
T.reduce_absmax(x_local, amax_local, dim=1)
|
||||
for i in T.Parallel(blk_m):
|
||||
amax_local[i] = T.max(amax_local[i], 6 * (2**-126))
|
||||
s_local[i] = fast_round_scale(amax_local[i], fp4_max_inv)
|
||||
if inplace:
|
||||
for i, j in T.Parallel(blk_m, group_size):
|
||||
y_local[i, j] = T.Cast(
|
||||
out_dtype,
|
||||
T.Cast(compute_dtype, T.Cast(FP4, T.clamp(
|
||||
x_local[i, j] / s_local[i], -fp4_max, fp4_max
|
||||
))) * s_local[i],
|
||||
)
|
||||
else:
|
||||
for i, j in T.Parallel(blk_m, group_size):
|
||||
y_local[i, j] = T.clamp(
|
||||
x_local[i, j] / s_local[i], -fp4_max, fp4_max
|
||||
)
|
||||
for i in T.Parallel(blk_m):
|
||||
S[pid_m * blk_m + i, pid_n] = T.Cast(scale_dtype, s_local[i])
|
||||
T.copy(y_local, y_shared)
|
||||
T.copy(y_shared, Y[pid_m * blk_m, pid_n * group_size])
|
||||
|
||||
return fp4_quant_kernel_
|
||||
|
||||
|
||||
def fp4_act_quant(
|
||||
x: torch.Tensor, block_size: int = 32, inplace: bool = False,
|
||||
) -> torch.Tensor:
|
||||
"""Block-wise FP4 quantization. inplace=True does fused quant+dequant back to BF16."""
|
||||
N = x.size(-1)
|
||||
assert N % block_size == 0
|
||||
z = x.contiguous()
|
||||
y = torch.empty_like(z) if inplace else z.new_empty(*z.shape[:-1], N // 2, dtype=torch.float4_e2m1fn_x2)
|
||||
s = z.new_empty(*z.size()[:-1], N // block_size, dtype=torch.float8_e8m0fnu)
|
||||
kernel = fp4_quant_kernel(N, block_size, inplace=inplace)
|
||||
kernel(z.view(-1, N), y.view(-1, y.size(-1)), s.view(-1, N // block_size))
|
||||
if inplace:
|
||||
x.copy_(y)
|
||||
return x
|
||||
return y, s
|
||||
|
||||
|
||||
@tilelang.jit(pass_configs=pass_configs)
|
||||
def fp8_gemm_kernel(N, K, out_dtype=BF16, accum_dtype=FP32, scale_dtype=FP32):
|
||||
assert out_dtype in [BF16, FP32]
|
||||
|
||||
M = T.symbolic("M")
|
||||
group_size = 128
|
||||
block_M = 32
|
||||
block_N = 128
|
||||
block_K = 128
|
||||
|
||||
@T.prim_func
|
||||
def fp8_gemm_kernel_(
|
||||
A: T.Tensor[(M, K), FP8],
|
||||
B: T.Tensor[(N, K), FP8],
|
||||
C: T.Tensor[(M, N), out_dtype],
|
||||
scales_a: T.Tensor[(M, T.ceildiv(K, group_size)), scale_dtype],
|
||||
scales_b: T.Tensor[(T.ceildiv(N, group_size), T.ceildiv(K, group_size)), scale_dtype],
|
||||
):
|
||||
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (
|
||||
bx,
|
||||
by,
|
||||
):
|
||||
A_shared = T.alloc_shared((block_M, block_K), FP8)
|
||||
B_shared = T.alloc_shared((block_N, block_K), FP8)
|
||||
C_shared = T.alloc_shared((block_M, block_N), out_dtype)
|
||||
Scale_C_shared = T.alloc_shared((block_M), FP32)
|
||||
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
|
||||
C_local_accum = T.alloc_fragment((block_M, block_N), accum_dtype)
|
||||
|
||||
# Improve L2 Cache
|
||||
T.use_swizzle(panel_size=10)
|
||||
T.clear(C_local)
|
||||
T.clear(C_local_accum)
|
||||
|
||||
K_iters = T.ceildiv(K, block_K)
|
||||
for k in T.Pipelined(K_iters, num_stages=4):
|
||||
T.copy(A[by * block_M, k * block_K], A_shared)
|
||||
T.copy(B[bx * block_N, k * block_K], B_shared)
|
||||
# Cast scales to FP32 for computation; scales_b has one value per block_N group
|
||||
Scale_B = T.Cast(FP32, scales_b[bx * block_N // group_size, k])
|
||||
for i in T.Parallel(block_M):
|
||||
Scale_C_shared[i] = T.Cast(FP32, scales_a[by * block_M + i, k]) * Scale_B
|
||||
|
||||
T.gemm(A_shared, B_shared, C_local, transpose_B=True)
|
||||
# Separate accumulator for scale-corrected results (2x accumulation precision)
|
||||
for i, j in T.Parallel(block_M, block_N):
|
||||
C_local_accum[i, j] += C_local[i, j] * Scale_C_shared[i]
|
||||
T.clear(C_local)
|
||||
T.copy(C_local_accum, C_shared)
|
||||
T.copy(C_shared, C[by * block_M, bx * block_N])
|
||||
|
||||
return fp8_gemm_kernel_
|
||||
|
||||
|
||||
def fp8_gemm(
|
||||
a: torch.Tensor, a_s: torch.Tensor, b: torch.Tensor, b_s: torch.Tensor,
|
||||
scale_dtype: torch.dtype = torch.float32,
|
||||
) -> torch.Tensor:
|
||||
"""C[M,N] = A[M,K] @ B[N,K]^T with per-128 block FP8 scaling on both A and B."""
|
||||
assert a.is_contiguous() and b.is_contiguous(), "Input tensors must be contiguous"
|
||||
assert a_s.is_contiguous() and b_s.is_contiguous(), (
|
||||
"Scaling factor tensors must be contiguous"
|
||||
)
|
||||
tl_dtype = FE8M0 if scale_dtype == torch.float8_e8m0fnu else FP32
|
||||
K = a.size(-1)
|
||||
M = a.numel() // K
|
||||
N = b.size(0)
|
||||
c = a.new_empty(*a.size()[:-1], N, dtype=torch.get_default_dtype())
|
||||
kernel = fp8_gemm_kernel(N, K, scale_dtype=tl_dtype)
|
||||
kernel(a.view(M, K), b, c.view(M, N), a_s.view(M, -1), b_s)
|
||||
return c
|
||||
|
||||
|
||||
@tilelang.jit(pass_configs=pass_configs)
|
||||
def sparse_attn_kernel(h: int, d: int, scale=None):
|
||||
"""Sparse multi-head attention via index gathering + online softmax (FlashAttention-style).
|
||||
For each (batch, seq_pos), gathers top-k KV positions by index, computes attention
|
||||
with numerically stable running max/sum, and includes a learnable attn_sink bias."""
|
||||
b = T.symbolic("b")
|
||||
m = T.symbolic("m")
|
||||
n = T.symbolic("n")
|
||||
topk = T.symbolic("topk")
|
||||
if scale is None:
|
||||
scale = (1.0 / d) ** 0.5
|
||||
|
||||
num_stages = 2
|
||||
threads = 256
|
||||
block = 64
|
||||
num_blocks = tilelang.cdiv(topk, block)
|
||||
|
||||
@T.prim_func
|
||||
def sparse_attn_kernel_(
|
||||
q: T.Tensor[(b, m, h, d), BF16],
|
||||
kv: T.Tensor[(b, n, d), BF16],
|
||||
o: T.Tensor[(b, m, h, d), BF16],
|
||||
attn_sink: T.Tensor[(h,), FP32],
|
||||
topk_idxs: T.Tensor[(b, m, topk), INT32],
|
||||
):
|
||||
with T.Kernel(m, b, threads=threads) as (bx, by):
|
||||
q_shared = T.alloc_shared((h, d), BF16)
|
||||
kv_shared = T.alloc_shared((block, d), BF16)
|
||||
o_shared = T.alloc_shared((h, d), BF16)
|
||||
acc_s_cast = T.alloc_shared((h, block), BF16)
|
||||
|
||||
idxs = T.alloc_fragment(block, INT32)
|
||||
acc_s = T.alloc_fragment((h, block), FP32)
|
||||
acc_o = T.alloc_fragment((h, d), FP32)
|
||||
scores_max = T.alloc_fragment(h, FP32)
|
||||
scores_max_prev = T.alloc_fragment(h, FP32)
|
||||
scores_scale = T.alloc_fragment(h, FP32)
|
||||
scores_sum = T.alloc_fragment(h, FP32)
|
||||
sum_exp = T.alloc_fragment(h, FP32)
|
||||
|
||||
T.clear(acc_o)
|
||||
T.clear(sum_exp)
|
||||
T.fill(scores_max, -T.infinity(FP32))
|
||||
T.copy(q[by, bx, :, :], q_shared)
|
||||
|
||||
for t in T.Pipelined(num_blocks, num_stages=num_stages):
|
||||
for i in T.Parallel(block):
|
||||
idxs[i] = T.if_then_else(t * block + i < topk, topk_idxs[by, bx, t * block + i], -1)
|
||||
for i, j in T.Parallel(block, d):
|
||||
kv_shared[i, j] = T.if_then_else(idxs[i] != -1, kv[by, idxs[i], j], 0)
|
||||
for i, j in T.Parallel(h, block):
|
||||
acc_s[i, j] = T.if_then_else(idxs[j] != -1, 0, -T.infinity(FP32))
|
||||
T.gemm(q_shared, kv_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
|
||||
for i, j in T.Parallel(h, block):
|
||||
acc_s[i, j] *= scale
|
||||
T.copy(scores_max, scores_max_prev)
|
||||
T.reduce_max(acc_s, scores_max, dim=1, clear=False)
|
||||
for i in T.Parallel(h):
|
||||
scores_scale[i] = T.exp(scores_max_prev[i] - scores_max[i])
|
||||
for i, j in T.Parallel(h, block):
|
||||
acc_s[i, j] = T.exp(acc_s[i, j] - scores_max[i])
|
||||
T.reduce_sum(acc_s, scores_sum, dim=1)
|
||||
for i in T.Parallel(h):
|
||||
sum_exp[i] = sum_exp[i] * scores_scale[i] + scores_sum[i]
|
||||
T.copy(acc_s, acc_s_cast)
|
||||
for i, j in T.Parallel(h, d):
|
||||
acc_o[i, j] *= scores_scale[i]
|
||||
T.gemm(acc_s_cast, kv_shared, acc_o, policy=T.GemmWarpPolicy.FullRow)
|
||||
|
||||
for i in T.Parallel(h):
|
||||
sum_exp[i] += T.exp(attn_sink[i] - scores_max[i])
|
||||
for i, j in T.Parallel(h, d):
|
||||
acc_o[i, j] /= sum_exp[i]
|
||||
T.copy(acc_o, o_shared)
|
||||
T.copy(o_shared, o[by, bx, :, :])
|
||||
|
||||
return sparse_attn_kernel_
|
||||
|
||||
|
||||
def sparse_attn(
|
||||
q: torch.Tensor, kv: torch.Tensor, attn_sink: torch.Tensor, topk_idxs: torch.Tensor, softmax_scale: float
|
||||
) -> torch.Tensor:
|
||||
b, s, h, d = q.size()
|
||||
# Pad heads to 16 for kernel efficiency (stripped after)
|
||||
if h < 16:
|
||||
q = torch.cat([q, q.new_zeros(b, s, 16 - h, d)], dim=2)
|
||||
attn_sink = torch.cat([attn_sink, attn_sink.new_zeros(16 - h)])
|
||||
o = torch.empty_like(q)
|
||||
kernel = sparse_attn_kernel(q.size(2), d, softmax_scale)
|
||||
kernel(q, kv, o, attn_sink, topk_idxs)
|
||||
if h < 16:
|
||||
o = o.narrow(2, 0, h).contiguous()
|
||||
return o
|
||||
|
||||
|
||||
@tilelang.jit(pass_configs=pass_configs)
|
||||
def hc_split_sinkhorn_kernel(hc: int, sinkhorn_iters: int, eps: float):
|
||||
n = T.symbolic("n")
|
||||
mix_hc = (2 + hc) * hc
|
||||
threads = 64
|
||||
|
||||
@T.prim_func
|
||||
def hc_split_sinkhorn_kernel_(
|
||||
mixes: T.Tensor[(n, mix_hc), FP32],
|
||||
hc_scale: T.Tensor[(3,), FP32],
|
||||
hc_base: T.Tensor[(mix_hc,), FP32],
|
||||
pre: T.Tensor[(n, hc), FP32],
|
||||
post: T.Tensor[(n, hc), FP32],
|
||||
comb: T.Tensor[(n, hc, hc), FP32],
|
||||
):
|
||||
with T.Kernel(n, threads=threads) as i:
|
||||
mixes_shared = T.alloc_shared(mix_hc, FP32)
|
||||
comb_frag = T.alloc_fragment((hc, hc), FP32)
|
||||
T.copy(mixes[i, :], mixes_shared)
|
||||
|
||||
for j in T.Parallel(hc):
|
||||
pre[i, j] = T.sigmoid(mixes_shared[j] * hc_scale[0] + hc_base[j]) + eps
|
||||
for j in T.Parallel(hc):
|
||||
post[i, j] = 2 * T.sigmoid(mixes_shared[j + hc] * hc_scale[1] + hc_base[j + hc])
|
||||
for j, k in T.Parallel(hc, hc):
|
||||
comb_frag[j, k] = mixes_shared[j * hc + k + hc * 2] * hc_scale[2] + hc_base[j * hc + k + hc * 2]
|
||||
|
||||
row_sum = T.alloc_fragment(hc, FP32)
|
||||
col_sum = T.alloc_fragment(hc, FP32)
|
||||
|
||||
# comb = comb.softmax(-1) + eps
|
||||
row_max = T.alloc_fragment(hc, FP32)
|
||||
T.reduce_max(comb_frag, row_max, dim=1)
|
||||
for j, k in T.Parallel(hc, hc):
|
||||
comb_frag[j, k] = T.exp(comb_frag[j, k] - row_max[j])
|
||||
T.reduce_sum(comb_frag, row_sum, dim=1)
|
||||
for j, k in T.Parallel(hc, hc):
|
||||
comb_frag[j, k] = comb_frag[j, k] / row_sum[j] + eps
|
||||
|
||||
# comb = comb / (comb.sum(-2) + eps)
|
||||
T.reduce_sum(comb_frag, col_sum, dim=0)
|
||||
for j, k in T.Parallel(hc, hc):
|
||||
comb_frag[j, k] = comb_frag[j, k] / (col_sum[k] + eps)
|
||||
|
||||
for _ in T.serial(sinkhorn_iters - 1):
|
||||
# comb = comb / (comb.sum(-1) + eps)
|
||||
T.reduce_sum(comb_frag, row_sum, dim=1)
|
||||
for j, k in T.Parallel(hc, hc):
|
||||
comb_frag[j, k] = comb_frag[j, k] / (row_sum[j] + eps)
|
||||
# comb = comb / (comb.sum(-2) + eps)
|
||||
T.reduce_sum(comb_frag, col_sum, dim=0)
|
||||
for j, k in T.Parallel(hc, hc):
|
||||
comb_frag[j, k] = comb_frag[j, k] / (col_sum[k] + eps)
|
||||
|
||||
T.copy(comb_frag, comb[i, :, :])
|
||||
|
||||
return hc_split_sinkhorn_kernel_
|
||||
|
||||
|
||||
def hc_split_sinkhorn(mixes: torch.Tensor, hc_scale: torch.Tensor, hc_base: torch.Tensor, hc_mult: int = 4, sinkhorn_iters: int = 20, eps: float = 1e-6):
|
||||
b, s, _ = mixes.size()
|
||||
pre = mixes.new_empty(b, s, hc_mult)
|
||||
post = mixes.new_empty(b, s, hc_mult)
|
||||
comb = mixes.new_empty(b, s, hc_mult, hc_mult)
|
||||
kernel = hc_split_sinkhorn_kernel(hc_mult, sinkhorn_iters, eps)
|
||||
kernel(mixes.view(-1, (2 + hc_mult) * hc_mult), hc_scale, hc_base,
|
||||
pre.view(-1, hc_mult), post.view(-1, hc_mult), comb.view(-1, hc_mult, hc_mult))
|
||||
return pre, post, comb
|
||||
|
||||
|
||||
@tilelang.jit(pass_configs=pass_configs)
|
||||
def fp4_gemm_kernel(N, K, out_dtype=BF16, accum_dtype=FP32, scale_dtype=FP32):
|
||||
"""FP8 act x FP4 weight GEMM kernel.
|
||||
|
||||
C[M, N] = A_fp8[M, K] @ B_fp4[N, K]^T
|
||||
|
||||
Act: 1x128 quant on K (reduce dim), FP8 with configurable scale dtype
|
||||
Weight: 1x32 quant on K (reduce dim), FP4 with E8M0 scale
|
||||
|
||||
B is stored as [N, K//2] in float4_e2m1fn_x2, logical [N, K] in fp4.
|
||||
The FP4 values are packed along the K (last) dimension.
|
||||
|
||||
Strategy: load FP4 sub-blocks of size [block_N, sub_K] (sub_K=32),
|
||||
cast FP4 to FP8 via float, then do FP8xFP8 GEMM.
|
||||
Apply act scale (per 128 on K) and weight scale (per 32 on K) to the accumulator.
|
||||
"""
|
||||
M = T.symbolic("M")
|
||||
act_group_size = 128
|
||||
weight_group_size = 32
|
||||
block_M = 32
|
||||
block_N = 128
|
||||
block_K = 32 # matches weight_group_size for simple scale handling
|
||||
n_sub = act_group_size // block_K # 4 sub-blocks per act scale group
|
||||
|
||||
@T.prim_func
|
||||
def fp4_gemm_kernel_(
|
||||
A: T.Tensor[(M, K), FP8],
|
||||
B: T.Tensor[(N, K), FP4],
|
||||
C: T.Tensor[(M, N), out_dtype],
|
||||
scales_a: T.Tensor[(M, T.ceildiv(K, act_group_size)), scale_dtype],
|
||||
scales_b: T.Tensor[(N, T.ceildiv(K, weight_group_size)), scale_dtype],
|
||||
):
|
||||
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (
|
||||
bx,
|
||||
by,
|
||||
):
|
||||
A_shared = T.alloc_shared((block_M, block_K), FP8)
|
||||
B_fp4_shared = T.alloc_shared((block_N, block_K), FP4)
|
||||
B_shared = T.alloc_shared((block_N, block_K), FP8)
|
||||
C_shared = T.alloc_shared((block_M, block_N), out_dtype)
|
||||
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
|
||||
C_local_accum = T.alloc_fragment((block_M, block_N), accum_dtype)
|
||||
scale_a_frag = T.alloc_fragment((block_M,), FP32)
|
||||
scale_b_frag = T.alloc_fragment((block_N,), FP32)
|
||||
|
||||
T.use_swizzle(panel_size=10)
|
||||
T.clear(C_local)
|
||||
T.clear(C_local_accum)
|
||||
|
||||
K_iters = T.ceildiv(K, block_K)
|
||||
for k in T.Pipelined(K_iters, num_stages=2):
|
||||
T.copy(A[by * block_M, k * block_K], A_shared)
|
||||
T.copy(B[bx * block_N, k * block_K], B_fp4_shared)
|
||||
# FP4->FP8 cast must go through FP32 to avoid ambiguous C++ overload
|
||||
for i, j in T.Parallel(block_N, block_K):
|
||||
B_shared[i, j] = T.Cast(FP8, T.Cast(FP32, B_fp4_shared[i, j]))
|
||||
|
||||
# Weight scale: per 32 on K, indexed by k (each k is one block_K=32)
|
||||
for i in T.Parallel(block_N):
|
||||
scale_b_frag[i] = T.Cast(FP32, scales_b[bx * block_N + i, k])
|
||||
|
||||
# Act scale: per 128 on K, indexed by k // 4
|
||||
for i in T.Parallel(block_M):
|
||||
scale_a_frag[i] = T.Cast(FP32, scales_a[by * block_M + i, k // n_sub])
|
||||
|
||||
T.gemm(A_shared, B_shared, C_local, transpose_B=True)
|
||||
|
||||
for i, j in T.Parallel(block_M, block_N):
|
||||
C_local_accum[i, j] += C_local[i, j] * scale_a_frag[i] * scale_b_frag[j]
|
||||
T.clear(C_local)
|
||||
|
||||
T.copy(C_local_accum, C_shared)
|
||||
T.copy(C_shared, C[by * block_M, bx * block_N])
|
||||
|
||||
return fp4_gemm_kernel_
|
||||
|
||||
|
||||
def fp4_gemm(
|
||||
a: torch.Tensor, a_s: torch.Tensor, b: torch.Tensor, b_s: torch.Tensor,
|
||||
scale_dtype: torch.dtype = torch.float32,
|
||||
) -> torch.Tensor:
|
||||
"""C[M,N] = A_fp8[M,K] @ B_fp4[N,K]^T.
|
||||
A has per-128 act scale; B has per-32 E8M0 weight scale.
|
||||
B is stored as [N, K//2] in float4_e2m1fn_x2 (2 FP4 values per byte, packed along K)."""
|
||||
assert a.is_contiguous() and b.is_contiguous(), "Input tensors must be contiguous"
|
||||
assert a_s.is_contiguous() and b_s.is_contiguous(), (
|
||||
"Scaling factor tensors must be contiguous"
|
||||
)
|
||||
tl_dtype = FE8M0 if scale_dtype == torch.float8_e8m0fnu else FP32
|
||||
K = a.size(-1)
|
||||
M = a.numel() // K
|
||||
N = b.size(0)
|
||||
c = a.new_empty(*a.size()[:-1], N, dtype=torch.get_default_dtype())
|
||||
kernel = fp4_gemm_kernel(N, K, scale_dtype=tl_dtype)
|
||||
kernel(a.view(M, K), b, c.view(M, N), a_s.view(M, -1), b_s)
|
||||
return c
|
||||
827
reference/official_inference/model.py
Normal file
827
reference/official_inference/model.py
Normal file
@@ -0,0 +1,827 @@
|
||||
import math
|
||||
from dataclasses import dataclass
|
||||
from typing import Tuple, Optional, Literal
|
||||
from functools import lru_cache
|
||||
from contextlib import contextmanager
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
import torch.nn.functional as F
|
||||
import torch.distributed as dist
|
||||
|
||||
from kernel import act_quant, fp4_act_quant, fp8_gemm, fp4_gemm, sparse_attn, hc_split_sinkhorn
|
||||
|
||||
|
||||
world_size = 1
|
||||
rank = 0
|
||||
block_size = 128
|
||||
fp4_block_size = 32
|
||||
default_dtype = torch.bfloat16
|
||||
scale_fmt = None
|
||||
scale_dtype = torch.float32
|
||||
|
||||
|
||||
@contextmanager
|
||||
def set_dtype(dtype):
|
||||
"""Temporarily override torch default dtype, restoring it on exit (even if an exception occurs)."""
|
||||
prev = torch.get_default_dtype()
|
||||
torch.set_default_dtype(dtype)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
torch.set_default_dtype(prev)
|
||||
|
||||
@dataclass
|
||||
class ModelArgs:
|
||||
"""Model hyperparameters. Field names match the config JSON keys."""
|
||||
max_batch_size: int = 4
|
||||
max_seq_len: int = 4096
|
||||
dtype: Literal["bf16", "fp8"] = "fp8"
|
||||
scale_fmt: Literal[None, "ue8m0"] = "ue8m0"
|
||||
expert_dtype: Literal[None, "fp4"] = None
|
||||
scale_dtype: Literal["fp32", "fp8"] = "fp8"
|
||||
vocab_size: int = 129280
|
||||
dim: int = 4096
|
||||
moe_inter_dim: int = 4096
|
||||
n_layers: int = 7
|
||||
n_hash_layers: int = 0
|
||||
n_mtp_layers: int = 1
|
||||
n_heads: int = 64
|
||||
# moe
|
||||
n_routed_experts: int = 8
|
||||
n_shared_experts: int = 1
|
||||
n_activated_experts: int = 2
|
||||
score_func: Literal["softmax", "sigmoid", "sqrtsoftplus"] = "sqrtsoftplus"
|
||||
route_scale: float = 1.
|
||||
swiglu_limit: float = 0.
|
||||
# mqa
|
||||
q_lora_rank: int = 1024
|
||||
head_dim: int = 512
|
||||
rope_head_dim: int = 64
|
||||
norm_eps: float = 1e-6
|
||||
o_groups: int = 8
|
||||
o_lora_rank: int = 1024
|
||||
window_size: int = 128
|
||||
compress_ratios: Tuple[int] = (0, 0, 4, 128, 4, 128, 4, 0)
|
||||
# yarn
|
||||
compress_rope_theta: float = 40000.0
|
||||
original_seq_len: int = 0
|
||||
rope_theta: float = 10000.0
|
||||
rope_factor: float = 40
|
||||
beta_fast: int = 32
|
||||
beta_slow: int = 1
|
||||
# index
|
||||
index_n_heads: int = 64
|
||||
index_head_dim: int = 128
|
||||
index_topk: int = 512
|
||||
# hc
|
||||
hc_mult: int = 4
|
||||
hc_sinkhorn_iters: int = 20
|
||||
hc_eps: float = 1e-6
|
||||
|
||||
|
||||
class ParallelEmbedding(nn.Module):
|
||||
"""Embedding sharded along the vocab dimension. Each rank holds vocab_size // world_size rows.
|
||||
Out-of-range indices are zero-masked before all_reduce to combine partial embeddings."""
|
||||
def __init__(self, vocab_size: int, dim: int):
|
||||
super().__init__()
|
||||
self.vocab_size = vocab_size
|
||||
self.dim = dim
|
||||
assert vocab_size % world_size == 0, f"Vocabulary size must be divisible by world size (world_size={world_size})"
|
||||
self.part_vocab_size = (vocab_size // world_size)
|
||||
self.vocab_start_idx = rank * self.part_vocab_size
|
||||
self.vocab_end_idx = self.vocab_start_idx + self.part_vocab_size
|
||||
self.weight = nn.Parameter(torch.empty(self.part_vocab_size, self.dim))
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
if world_size > 1:
|
||||
mask = (x < self.vocab_start_idx) | (x >= self.vocab_end_idx)
|
||||
x = x - self.vocab_start_idx
|
||||
x[mask] = 0
|
||||
y = F.embedding(x, self.weight)
|
||||
if world_size > 1:
|
||||
y[mask] = 0
|
||||
dist.all_reduce(y)
|
||||
return y
|
||||
|
||||
|
||||
def linear(x: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
"""Dispatches to fp4_gemm / fp8_gemm / F.linear based on weight dtype.
|
||||
For quantized weights, x is first quantized to FP8 via act_quant."""
|
||||
assert bias is None
|
||||
|
||||
if weight.dtype == torch.float4_e2m1fn_x2:
|
||||
x, s = act_quant(x, block_size, scale_fmt, scale_dtype)
|
||||
return fp4_gemm(x, s, weight, weight.scale, scale_dtype)
|
||||
elif weight.dtype == torch.float8_e4m3fn:
|
||||
x, s = act_quant(x, block_size, scale_fmt, scale_dtype)
|
||||
return fp8_gemm(x, s, weight, weight.scale, scale_dtype)
|
||||
else:
|
||||
return F.linear(x, weight)
|
||||
|
||||
|
||||
class Linear(nn.Module):
|
||||
"""Linear layer supporting BF16, FP8, and FP4 weight formats with per-block scaling."""
|
||||
|
||||
def __init__(self, in_features: int, out_features: int, bias: bool = False, dtype = None):
|
||||
super().__init__()
|
||||
self.in_features = in_features
|
||||
self.out_features = out_features
|
||||
dtype = dtype or default_dtype
|
||||
if dtype == torch.float4_e2m1fn_x2:
|
||||
# FP4: weight is [out, in//2] in float4_e2m1fn_x2, logically [out, in] in fp4
|
||||
# Scale is [out, in//32] in float8_e8m0fnu (1 scale per 32 fp4 elements along K)
|
||||
self.weight = nn.Parameter(torch.empty(out_features, in_features // 2, dtype=torch.float4_e2m1fn_x2))
|
||||
scale_out_features = out_features
|
||||
scale_in_features = in_features // fp4_block_size
|
||||
self.weight.scale = self.scale = nn.Parameter(torch.empty(scale_out_features, scale_in_features, dtype=torch.float8_e8m0fnu))
|
||||
elif dtype == torch.float8_e4m3fn:
|
||||
self.weight = nn.Parameter(torch.empty(out_features, in_features, dtype=dtype))
|
||||
scale_out_features = (out_features + block_size - 1) // block_size
|
||||
scale_in_features = (in_features + block_size - 1) // block_size
|
||||
self.weight.scale = self.scale = nn.Parameter(torch.empty(scale_out_features, scale_in_features, dtype=torch.float8_e8m0fnu))
|
||||
else:
|
||||
self.weight = nn.Parameter(torch.empty(out_features, in_features, dtype=dtype))
|
||||
self.register_parameter("scale", None)
|
||||
if bias:
|
||||
self.bias = nn.Parameter(torch.empty(out_features))
|
||||
else:
|
||||
self.register_parameter("bias", None)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return linear(x, self.weight, self.bias)
|
||||
|
||||
|
||||
class ColumnParallelLinear(Linear):
|
||||
"""Shards output dim across TP ranks. No all-reduce needed on output."""
|
||||
def __init__(self, in_features: int, out_features: int, bias: bool = False, dtype = None):
|
||||
assert out_features % world_size == 0, f"Output features must be divisible by world size (world_size={world_size})"
|
||||
self.part_out_features = out_features // world_size
|
||||
super().__init__(in_features, self.part_out_features, bias, dtype)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return linear(x, self.weight, self.bias)
|
||||
|
||||
|
||||
class RowParallelLinear(Linear):
|
||||
"""Shards input dim across TP ranks. All-reduce on output to sum partial results."""
|
||||
def __init__(self, in_features: int, out_features: int, bias: bool = False, dtype = None):
|
||||
assert in_features % world_size == 0, f"Input features must be divisible by world size (world_size={world_size})"
|
||||
self.part_in_features = in_features // world_size
|
||||
super().__init__(self.part_in_features, out_features, bias, dtype)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
y = linear(x, self.weight, None)
|
||||
if world_size > 1:
|
||||
y = y.float()
|
||||
dist.all_reduce(y)
|
||||
if self.bias is not None:
|
||||
y += self.bias
|
||||
return y.type_as(x)
|
||||
|
||||
|
||||
class RMSNorm(nn.Module):
|
||||
def __init__(self, dim: int, eps: float = 1e-6):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.eps = eps
|
||||
# rmsnorm in the checkpoint is stored in bf16, while the parameter here is stored in fp32 for convenient.
|
||||
self.weight = nn.Parameter(torch.ones(dim, dtype=torch.float32))
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
dtype = x.dtype
|
||||
x = x.float()
|
||||
var = x.square().mean(-1, keepdim=True)
|
||||
x = x * torch.rsqrt(var + self.eps)
|
||||
return (self.weight * x).to(dtype)
|
||||
|
||||
|
||||
@lru_cache(2)
|
||||
def precompute_freqs_cis(dim, seqlen, original_seq_len, base, factor, beta_fast, beta_slow) -> torch.Tensor:
|
||||
"""Precomputes complex exponentials for rotary embeddings with YaRN scaling.
|
||||
When original_seq_len > 0, applies frequency interpolation with a smooth
|
||||
linear ramp between beta_fast and beta_slow correction ranges."""
|
||||
|
||||
def find_correction_dim(num_rotations, dim, base, max_seq_len):
|
||||
return dim * math.log(max_seq_len / (num_rotations * 2 * math.pi)) / (2 * math.log(base))
|
||||
|
||||
def find_correction_range(low_rot, high_rot, dim, base, max_seq_len):
|
||||
low = math.floor(find_correction_dim(low_rot, dim, base, max_seq_len))
|
||||
high = math.ceil(find_correction_dim(high_rot, dim, base, max_seq_len))
|
||||
return max(low, 0), min(high, dim-1)
|
||||
|
||||
def linear_ramp_factor(min, max, dim):
|
||||
if min == max:
|
||||
max += 0.001
|
||||
linear_func = (torch.arange(dim, dtype=torch.float32) - min) / (max - min)
|
||||
ramp_func = torch.clamp(linear_func, 0, 1)
|
||||
return ramp_func
|
||||
|
||||
freqs = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim))
|
||||
if original_seq_len > 0:
|
||||
low, high = find_correction_range(beta_fast, beta_slow, dim, base, original_seq_len)
|
||||
smooth = 1 - linear_ramp_factor(low, high, dim // 2)
|
||||
freqs = freqs / factor * (1 - smooth) + freqs * smooth
|
||||
|
||||
t = torch.arange(seqlen)
|
||||
freqs = torch.outer(t, freqs)
|
||||
freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
|
||||
return freqs_cis
|
||||
|
||||
|
||||
def apply_rotary_emb(x: torch.Tensor, freqs_cis: torch.Tensor, inverse: bool = False) -> torch.Tensor:
|
||||
"""Applies rotary positional embeddings in-place. Uses conjugate for inverse (de-rotation)."""
|
||||
y = x
|
||||
x = torch.view_as_complex(x.float().unflatten(-1, (-1, 2)))
|
||||
if inverse:
|
||||
freqs_cis = freqs_cis.conj()
|
||||
if x.ndim == 3:
|
||||
freqs_cis = freqs_cis.view(1, x.size(1), x.size(-1))
|
||||
else:
|
||||
freqs_cis = freqs_cis.view(1, x.size(1), 1, x.size(-1))
|
||||
x = torch.view_as_real(x * freqs_cis).flatten(-2)
|
||||
y.copy_(x)
|
||||
return y
|
||||
|
||||
|
||||
def rotate_activation(x: torch.Tensor) -> torch.Tensor:
|
||||
"""Applies randomized Hadamard rotation to spread information across dims before FP8 quant."""
|
||||
assert x.dtype == torch.bfloat16
|
||||
from fast_hadamard_transform import hadamard_transform
|
||||
return hadamard_transform(x, scale=x.size(-1) ** -0.5)
|
||||
|
||||
|
||||
@lru_cache(1)
|
||||
def get_window_topk_idxs(window_size: int, bsz: int, seqlen: int, start_pos: int):
|
||||
if start_pos >= window_size - 1:
|
||||
start_pos %= window_size
|
||||
matrix = torch.cat([torch.arange(start_pos + 1, window_size), torch.arange(0, start_pos + 1)], dim=0)
|
||||
elif start_pos > 0:
|
||||
matrix = F.pad(torch.arange(start_pos + 1), (0, window_size - start_pos - 1), value=-1)
|
||||
else:
|
||||
base = torch.arange(seqlen).unsqueeze(1)
|
||||
matrix = (base - window_size + 1).clamp(0) + torch.arange(min(seqlen, window_size))
|
||||
matrix = torch.where(matrix > base, -1, matrix)
|
||||
return matrix.unsqueeze(0).expand(bsz, -1, -1)
|
||||
|
||||
|
||||
@lru_cache(2)
|
||||
def get_compress_topk_idxs(ratio: int, bsz: int, seqlen: int, start_pos: int, offset: int):
|
||||
if start_pos > 0:
|
||||
matrix = torch.arange(0, (start_pos + 1) // ratio) + offset
|
||||
else:
|
||||
matrix = torch.arange(seqlen // ratio).repeat(seqlen, 1)
|
||||
mask = matrix >= torch.arange(1, seqlen + 1).unsqueeze(1) // ratio
|
||||
matrix = torch.where(mask, -1, matrix + offset)
|
||||
return matrix.unsqueeze(0).expand(bsz, -1, -1)
|
||||
|
||||
|
||||
class Compressor(nn.Module):
|
||||
"""Compresses KV cache via learned gated pooling over `compress_ratio` consecutive tokens.
|
||||
When overlap=True (ratio==4), uses overlapping windows for smoother compression boundaries."""
|
||||
|
||||
def __init__(self, args: ModelArgs, compress_ratio: int = 4, head_dim: int = 512, rotate: bool = False):
|
||||
super().__init__()
|
||||
self.dim = args.dim
|
||||
self.head_dim = head_dim
|
||||
self.rope_head_dim = args.rope_head_dim
|
||||
self.nope_head_dim = head_dim - args.rope_head_dim
|
||||
self.compress_ratio = compress_ratio
|
||||
self.overlap = compress_ratio == 4
|
||||
self.rotate = rotate
|
||||
coff = 1 + self.overlap
|
||||
|
||||
self.ape = nn.Parameter(torch.empty(compress_ratio, coff * self.head_dim, dtype=torch.float32))
|
||||
# wkv and wgate in the checkpoint is stored in bf16, while the parameter here is stored in fp32 for convenient.
|
||||
# When overlap, the first half of dims is for overlapping compression, second half for normal.
|
||||
self.wkv = Linear(self.dim, coff * self.head_dim, dtype=torch.float32)
|
||||
self.wgate = Linear(self.dim, coff * self.head_dim, dtype=torch.float32)
|
||||
self.norm = RMSNorm(self.head_dim, args.norm_eps)
|
||||
self.kv_cache: torch.Tensor = None # assigned lazily from Attention.kv_cache
|
||||
# State buffers for decode-phase incremental compression.
|
||||
# With overlap: state[:, :ratio] = overlapping window, state[:, ratio:] = current window.
|
||||
self.register_buffer("kv_state", torch.zeros(args.max_batch_size, coff * compress_ratio, coff * self.head_dim, dtype=torch.float32), persistent=False)
|
||||
self.register_buffer("score_state", torch.full((args.max_batch_size, coff * compress_ratio, coff * self.head_dim), float("-inf"), dtype=torch.float32), persistent=False)
|
||||
self.freqs_cis: torch.Tensor = None
|
||||
|
||||
def overlap_transform(self, tensor: torch.Tensor, value=0):
|
||||
# tensor: [b,s,r,2d]
|
||||
b, s, _, _ = tensor.size()
|
||||
ratio, d = self.compress_ratio, self.head_dim
|
||||
new_tensor = tensor.new_full((b, s, 2 * ratio, d), value)
|
||||
new_tensor[:, :, ratio:] = tensor[:, :, :, d:]
|
||||
new_tensor[:, 1:, :ratio] = tensor[:, :-1, :, :d]
|
||||
return new_tensor
|
||||
|
||||
def forward(self, x: torch.Tensor, start_pos: int):
|
||||
assert self.kv_cache is not None
|
||||
bsz, seqlen, _ = x.size()
|
||||
ratio, overlap, d, rd = self.compress_ratio, self.overlap, self.head_dim, self.rope_head_dim
|
||||
dtype = x.dtype
|
||||
# compression need fp32
|
||||
x = x.float()
|
||||
kv = self.wkv(x)
|
||||
score = self.wgate(x)
|
||||
if start_pos == 0:
|
||||
should_compress = seqlen >= ratio
|
||||
remainder = seqlen % ratio
|
||||
cutoff = seqlen - remainder
|
||||
offset = ratio if overlap else 0
|
||||
if overlap and cutoff >= ratio:
|
||||
self.kv_state[:bsz, :ratio] = kv[:, cutoff-ratio : cutoff]
|
||||
self.score_state[:bsz, :ratio] = score[:, cutoff-ratio : cutoff] + self.ape
|
||||
if remainder > 0:
|
||||
kv, self.kv_state[:bsz, offset : offset+remainder] = kv.split([cutoff, remainder], dim=1)
|
||||
self.score_state[:bsz, offset : offset+remainder] = score[:, cutoff:] + self.ape[:remainder]
|
||||
score = score[:, :cutoff]
|
||||
kv = kv.unflatten(1, (-1, ratio))
|
||||
score = score.unflatten(1, (-1, ratio)) + self.ape
|
||||
if overlap:
|
||||
kv = self.overlap_transform(kv, 0)
|
||||
score = self.overlap_transform(score, float("-inf"))
|
||||
kv = (kv * score.softmax(dim=2)).sum(dim=2)
|
||||
else:
|
||||
should_compress = (start_pos + 1) % self.compress_ratio == 0
|
||||
score += self.ape[start_pos % ratio]
|
||||
if overlap:
|
||||
self.kv_state[:bsz, ratio + start_pos % ratio] = kv.squeeze(1)
|
||||
self.score_state[:bsz, ratio + start_pos % ratio] = score.squeeze(1)
|
||||
if should_compress:
|
||||
kv_state = torch.cat([self.kv_state[:bsz, :ratio, :d], self.kv_state[:bsz, ratio:, d:]], dim=1)
|
||||
score_state = torch.cat([self.score_state[:bsz, :ratio, :d], self.score_state[:bsz, ratio:, d:]], dim=1)
|
||||
kv = (kv_state * score_state.softmax(dim=1)).sum(dim=1, keepdim=True)
|
||||
self.kv_state[:bsz, :ratio] = self.kv_state[:bsz, ratio:]
|
||||
self.score_state[:bsz, :ratio] = self.score_state[:bsz, ratio:]
|
||||
else:
|
||||
self.kv_state[:bsz, start_pos % ratio] = kv.squeeze(1)
|
||||
self.score_state[:bsz, start_pos % ratio] = score.squeeze(1)
|
||||
if should_compress:
|
||||
kv = (self.kv_state[:bsz] * self.score_state[:bsz].softmax(dim=1)).sum(dim=1, keepdim=True)
|
||||
if not should_compress:
|
||||
return
|
||||
kv = self.norm(kv.to(dtype))
|
||||
if start_pos == 0:
|
||||
freqs_cis = self.freqs_cis[:cutoff:ratio]
|
||||
else:
|
||||
freqs_cis = self.freqs_cis[start_pos + 1 - self.compress_ratio].unsqueeze(0)
|
||||
apply_rotary_emb(kv[..., -rd:], freqs_cis)
|
||||
if self.rotate:
|
||||
kv = rotate_activation(kv)
|
||||
fp4_act_quant(kv, fp4_block_size, True)
|
||||
else:
|
||||
act_quant(kv[..., :-rd], 64, scale_fmt, scale_dtype, True)
|
||||
if start_pos == 0:
|
||||
self.kv_cache[:bsz, :seqlen // ratio] = kv
|
||||
else:
|
||||
self.kv_cache[:bsz, start_pos // ratio] = kv.squeeze(1)
|
||||
return kv
|
||||
|
||||
|
||||
class Indexer(torch.nn.Module):
|
||||
"""Selects top-k compressed KV positions for sparse attention via learned scoring.
|
||||
Has its own Compressor (with Hadamard rotation) to build compressed KV for scoring."""
|
||||
|
||||
def __init__(self, args: ModelArgs, compress_ratio: int = 4):
|
||||
super().__init__()
|
||||
self.dim = args.dim
|
||||
self.n_heads = args.index_n_heads
|
||||
self.n_local_heads = args.index_n_heads // world_size
|
||||
self.head_dim = args.index_head_dim
|
||||
self.rope_head_dim = args.rope_head_dim
|
||||
self.index_topk = args.index_topk
|
||||
self.q_lora_rank = args.q_lora_rank
|
||||
self.wq_b = ColumnParallelLinear(self.q_lora_rank, self.n_heads * self.head_dim)
|
||||
self.weights_proj = ColumnParallelLinear(self.dim, self.n_heads, dtype=torch.bfloat16)
|
||||
self.softmax_scale = self.head_dim ** -0.5
|
||||
self.compress_ratio = compress_ratio
|
||||
|
||||
self.compressor = Compressor(args, compress_ratio, self.head_dim, True)
|
||||
self.register_buffer("kv_cache", torch.zeros(args.max_batch_size, args.max_seq_len // compress_ratio, self.head_dim), persistent=False)
|
||||
self.freqs_cis = None
|
||||
|
||||
def forward(self, x: torch.Tensor, qr: torch.Tensor, start_pos: int, offset: int):
|
||||
bsz, seqlen, _ = x.size()
|
||||
freqs_cis = self.freqs_cis[start_pos:start_pos+seqlen]
|
||||
ratio = self.compress_ratio
|
||||
rd = self.rope_head_dim
|
||||
end_pos = start_pos + seqlen
|
||||
if self.compressor.kv_cache is None:
|
||||
self.compressor.kv_cache = self.kv_cache
|
||||
self.compressor.freqs_cis = self.freqs_cis
|
||||
q = self.wq_b(qr)
|
||||
q = q.unflatten(-1, (self.n_local_heads, self.head_dim))
|
||||
apply_rotary_emb(q[..., -rd:], freqs_cis)
|
||||
q = rotate_activation(q)
|
||||
# use fp4 simulation for q and kv in indexer
|
||||
fp4_act_quant(q, fp4_block_size, True)
|
||||
self.compressor(x, start_pos)
|
||||
weights = self.weights_proj(x) * (self.softmax_scale * self.n_heads ** -0.5)
|
||||
# We performed QAT here, kv could also use fp8 format, though current implementation uses bf16
|
||||
index_score = torch.einsum("bshd,btd->bsht", q, self.kv_cache[:bsz, :end_pos // ratio])
|
||||
index_score = (index_score.relu_() * weights.unsqueeze(-1)).sum(dim=2)
|
||||
if world_size > 1:
|
||||
dist.all_reduce(index_score)
|
||||
if start_pos == 0:
|
||||
mask = torch.arange(seqlen // ratio).repeat(seqlen, 1) >= torch.arange(1, seqlen + 1).unsqueeze(1) // ratio
|
||||
index_score += torch.where(mask, float("-inf"), 0)
|
||||
topk_idxs = index_score.topk(min(self.index_topk, end_pos // ratio), dim=-1)[1]
|
||||
if start_pos == 0:
|
||||
mask = topk_idxs >= torch.arange(1, seqlen + 1).unsqueeze(1) // ratio
|
||||
topk_idxs = torch.where(mask, -1, topk_idxs + offset)
|
||||
else:
|
||||
topk_idxs += offset
|
||||
return topk_idxs
|
||||
|
||||
|
||||
class Attention(nn.Module):
|
||||
"""Multi-head Latent Attention (MLA) with sliding window + optional KV compression.
|
||||
Uses low-rank Q projection (wq_a -> q_norm -> wq_b) and grouped low-rank O projection."""
|
||||
def __init__(self, layer_id: int, args: ModelArgs):
|
||||
super().__init__()
|
||||
self.layer_id = layer_id
|
||||
self.dim = args.dim
|
||||
self.n_heads = args.n_heads
|
||||
self.n_local_heads = args.n_heads // world_size
|
||||
self.q_lora_rank = args.q_lora_rank
|
||||
self.o_lora_rank = args.o_lora_rank
|
||||
self.head_dim = args.head_dim
|
||||
self.rope_head_dim = args.rope_head_dim
|
||||
self.nope_head_dim = args.head_dim - args.rope_head_dim
|
||||
self.n_groups = args.o_groups
|
||||
self.n_local_groups = self.n_groups // world_size
|
||||
self.window_size = args.window_size
|
||||
self.compress_ratio = args.compress_ratios[layer_id]
|
||||
self.eps = args.norm_eps
|
||||
|
||||
self.attn_sink = nn.Parameter(torch.empty(self.n_local_heads, dtype=torch.float32))
|
||||
self.wq_a = Linear(self.dim, self.q_lora_rank)
|
||||
self.q_norm = RMSNorm(self.q_lora_rank, self.eps)
|
||||
self.wq_b = ColumnParallelLinear(self.q_lora_rank, self.n_heads * self.head_dim)
|
||||
self.wkv = Linear(self.dim, self.head_dim)
|
||||
self.kv_norm = RMSNorm(self.head_dim, self.eps)
|
||||
self.wo_a = ColumnParallelLinear(self.n_heads * self.head_dim // self.n_groups, self.n_groups * args.o_lora_rank, dtype=torch.bfloat16)
|
||||
self.wo_b = RowParallelLinear(self.n_groups * args.o_lora_rank, self.dim)
|
||||
self.softmax_scale = self.head_dim ** -0.5
|
||||
|
||||
if self.compress_ratio:
|
||||
self.compressor = Compressor(args, self.compress_ratio, self.head_dim)
|
||||
if self.compress_ratio == 4:
|
||||
self.indexer = Indexer(args, self.compress_ratio)
|
||||
else:
|
||||
self.indexer = None
|
||||
|
||||
kv_cache_size = args.window_size + (args.max_seq_len // self.compress_ratio if self.compress_ratio else 0)
|
||||
self.register_buffer("kv_cache", torch.zeros(args.max_batch_size, kv_cache_size, self.head_dim), persistent=False)
|
||||
if self.compress_ratio:
|
||||
original_seq_len, rope_theta = args.original_seq_len, args.compress_rope_theta
|
||||
else:
|
||||
# disable YaRN and use base rope_theta in pure sliding-window attention
|
||||
original_seq_len, rope_theta = 0, args.rope_theta
|
||||
freqs_cis = precompute_freqs_cis(self.rope_head_dim, args.max_seq_len, original_seq_len,
|
||||
rope_theta, args.rope_factor, args.beta_fast, args.beta_slow)
|
||||
self.register_buffer("freqs_cis", freqs_cis, persistent=False)
|
||||
|
||||
def forward(self, x: torch.Tensor, start_pos: int):
|
||||
bsz, seqlen, _ = x.size()
|
||||
freqs_cis = self.freqs_cis[start_pos:start_pos+seqlen]
|
||||
win = self.window_size
|
||||
ratio = self.compress_ratio
|
||||
rd = self.rope_head_dim
|
||||
if self.compress_ratio and self.compressor.kv_cache is None:
|
||||
self.compressor.kv_cache = self.kv_cache[:, win:]
|
||||
self.compressor.freqs_cis = self.freqs_cis
|
||||
if self.indexer is not None:
|
||||
self.indexer.freqs_cis = self.freqs_cis
|
||||
# q
|
||||
qr = q = self.q_norm(self.wq_a(x))
|
||||
q = self.wq_b(q).unflatten(-1, (self.n_local_heads, self.head_dim))
|
||||
q *= torch.rsqrt(q.square().mean(-1, keepdim=True) + self.eps)
|
||||
apply_rotary_emb(q[..., -rd:], freqs_cis)
|
||||
|
||||
# win kv & topk_idxs
|
||||
kv = self.wkv(x)
|
||||
kv = self.kv_norm(kv)
|
||||
apply_rotary_emb(kv[..., -rd:], freqs_cis)
|
||||
# FP8-simulate non-rope dims to match QAT; rope dims stay bf16 for positional precision
|
||||
act_quant(kv[..., :-rd], 64, scale_fmt, scale_dtype, True)
|
||||
topk_idxs = get_window_topk_idxs(win, bsz, seqlen, start_pos)
|
||||
if self.compress_ratio:
|
||||
offset = kv.size(1) if start_pos == 0 else win
|
||||
if self.indexer is not None:
|
||||
compress_topk_idxs = self.indexer(x, qr, start_pos, offset)
|
||||
else:
|
||||
compress_topk_idxs = get_compress_topk_idxs(ratio, bsz, seqlen, start_pos, offset)
|
||||
topk_idxs = torch.cat([topk_idxs, compress_topk_idxs], dim=-1)
|
||||
topk_idxs = topk_idxs.int()
|
||||
|
||||
# compress kv & attn
|
||||
if start_pos == 0:
|
||||
if seqlen <= win:
|
||||
self.kv_cache[:bsz, :seqlen] = kv
|
||||
else:
|
||||
cutoff = seqlen % win
|
||||
self.kv_cache[:bsz, cutoff: win], self.kv_cache[:bsz, :cutoff] = kv[:, -win:].split([win - cutoff, cutoff], dim=1)
|
||||
if self.compress_ratio:
|
||||
if (kv_compress := self.compressor(x, start_pos)) is not None:
|
||||
kv = torch.cat([kv, kv_compress], dim=1)
|
||||
# We performed QAT here, kv could also use fp8 format, though current implementation uses bf16
|
||||
o = sparse_attn(q, kv, self.attn_sink, topk_idxs, self.softmax_scale)
|
||||
else:
|
||||
self.kv_cache[:bsz, start_pos % win] = kv.squeeze(1)
|
||||
if self.compress_ratio:
|
||||
self.compressor(x, start_pos)
|
||||
o = sparse_attn(q, self.kv_cache[:bsz], self.attn_sink, topk_idxs, self.softmax_scale)
|
||||
apply_rotary_emb(o[..., -rd:], freqs_cis, True)
|
||||
|
||||
# o
|
||||
o = o.view(bsz, seqlen, self.n_local_groups, -1)
|
||||
wo_a = self.wo_a.weight.view(self.n_local_groups, self.o_lora_rank, -1)
|
||||
# NOTE: wo_a is FP8 in checkpoint; could do FP8 einsum here for better perf,
|
||||
# but using BF16 for simplicity.
|
||||
o = torch.einsum("bsgd,grd->bsgr", o, wo_a)
|
||||
x = self.wo_b(o.flatten(2))
|
||||
return x
|
||||
|
||||
|
||||
class Gate(nn.Module):
|
||||
"""MoE gating: computes expert routing scores and selects top-k experts.
|
||||
Supports hash-based routing (first n_hash_layers) where expert indices are
|
||||
predetermined per token ID, and score-based routing (remaining layers)."""
|
||||
def __init__(self, layer_id: int, args: ModelArgs):
|
||||
super().__init__()
|
||||
self.dim = args.dim
|
||||
self.topk = args.n_activated_experts
|
||||
self.score_func = args.score_func
|
||||
self.route_scale = args.route_scale
|
||||
self.hash = layer_id < args.n_hash_layers
|
||||
self.weight = nn.Parameter(torch.empty(args.n_routed_experts, args.dim))
|
||||
if self.hash:
|
||||
self.tid2eid = nn.Parameter(torch.empty(args.vocab_size, args.n_activated_experts, dtype=torch.int32), requires_grad=False)
|
||||
self.bias = None
|
||||
else:
|
||||
self.bias = nn.Parameter(torch.empty(args.n_routed_experts, dtype=torch.float32))
|
||||
|
||||
def forward(self, x: torch.Tensor, input_ids: Optional[torch.Tensor] = None) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
scores = linear(x.float(), self.weight.float())
|
||||
if self.score_func == "softmax":
|
||||
scores = scores.softmax(dim=-1)
|
||||
elif self.score_func == "sigmoid":
|
||||
scores = scores.sigmoid()
|
||||
else:
|
||||
scores = F.softplus(scores).sqrt()
|
||||
original_scores = scores
|
||||
# Bias shifts scores for expert selection (topk) but does not affect routing weights.
|
||||
if self.bias is not None:
|
||||
scores = scores + self.bias
|
||||
if self.hash:
|
||||
indices = self.tid2eid[input_ids]
|
||||
else:
|
||||
indices = scores.topk(self.topk, dim=-1)[1]
|
||||
weights = original_scores.gather(1, indices)
|
||||
if self.score_func != "softmax":
|
||||
weights /= weights.sum(dim=-1, keepdim=True)
|
||||
weights *= self.route_scale
|
||||
return weights, indices
|
||||
|
||||
|
||||
class Expert(nn.Module):
|
||||
"""Single MoE expert: SwiGLU FFN (w1, w2, w3). Computation in float32 for stability."""
|
||||
def __init__(self, dim: int, inter_dim: int, dtype=None, swiglu_limit=0):
|
||||
super().__init__()
|
||||
self.w1 = Linear(dim, inter_dim, dtype=dtype)
|
||||
self.w2 = Linear(inter_dim, dim, dtype=dtype)
|
||||
self.w3 = Linear(dim, inter_dim, dtype=dtype)
|
||||
self.swiglu_limit = swiglu_limit
|
||||
|
||||
def forward(self, x: torch.Tensor, weights: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
dtype = x.dtype
|
||||
gate = self.w1(x).float()
|
||||
up = self.w3(x).float()
|
||||
if self.swiglu_limit > 0:
|
||||
up = torch.clamp(up, min=-self.swiglu_limit, max=self.swiglu_limit)
|
||||
gate = torch.clamp(gate, max=self.swiglu_limit)
|
||||
x = F.silu(gate) * up
|
||||
if weights is not None:
|
||||
x = weights * x
|
||||
return self.w2(x.to(dtype))
|
||||
|
||||
|
||||
class MoE(nn.Module):
|
||||
"""Mixture-of-Experts: gate routes each token to top-k routed experts + 1 shared expert.
|
||||
Experts are sharded across TP ranks; each rank handles n_routed_experts // world_size experts."""
|
||||
def __init__(self, layer_id: int, args: ModelArgs):
|
||||
super().__init__()
|
||||
self.layer_id = layer_id
|
||||
self.dim = args.dim
|
||||
assert args.n_routed_experts % world_size == 0, f"Number of experts must be divisible by world size (world_size={world_size})"
|
||||
self.n_routed_experts = args.n_routed_experts
|
||||
self.n_local_experts = args.n_routed_experts // world_size
|
||||
self.n_activated_experts = args.n_activated_experts
|
||||
self.experts_start_idx = rank * self.n_local_experts
|
||||
self.experts_end_idx = self.experts_start_idx + self.n_local_experts
|
||||
self.gate = Gate(layer_id, args)
|
||||
expert_dtype = torch.float4_e2m1fn_x2 if args.expert_dtype == "fp4" else None
|
||||
self.experts = nn.ModuleList([Expert(args.dim, args.moe_inter_dim, dtype=expert_dtype, swiglu_limit=args.swiglu_limit) if self.experts_start_idx <= i < self.experts_end_idx else None
|
||||
for i in range(self.n_routed_experts)])
|
||||
assert args.n_shared_experts == 1
|
||||
self.shared_experts = Expert(args.dim, args.moe_inter_dim, swiglu_limit=args.swiglu_limit)
|
||||
|
||||
def forward(self, x: torch.Tensor, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
shape = x.size()
|
||||
x = x.view(-1, self.dim)
|
||||
weights, indices = self.gate(x, input_ids.flatten())
|
||||
y = torch.zeros_like(x, dtype=torch.float32)
|
||||
counts = torch.bincount(indices.flatten(), minlength=self.n_routed_experts).tolist()
|
||||
for i in range(self.experts_start_idx, self.experts_end_idx):
|
||||
if counts[i] == 0:
|
||||
continue
|
||||
expert = self.experts[i]
|
||||
idx, top = torch.where(indices == i)
|
||||
y[idx] += expert(x[idx], weights[idx, top, None])
|
||||
if world_size > 1:
|
||||
dist.all_reduce(y)
|
||||
y += self.shared_experts(x)
|
||||
return y.type_as(x).view(shape)
|
||||
|
||||
|
||||
class Block(nn.Module):
|
||||
"""Transformer block with Hyper-Connections (HC) mixing.
|
||||
Instead of a simple residual, HC maintains `hc_mult` copies of the hidden state.
|
||||
hc_pre: reduces hc copies -> 1 via learned weighted sum (pre-weights from Sinkhorn).
|
||||
hc_post: expands 1 -> hc copies via learned post-weights + combination matrix."""
|
||||
def __init__(self, layer_id: int, args: ModelArgs):
|
||||
super().__init__()
|
||||
self.layer_id = layer_id
|
||||
self.norm_eps = args.norm_eps
|
||||
self.attn = Attention(layer_id, args)
|
||||
self.ffn = MoE(layer_id, args)
|
||||
self.attn_norm = RMSNorm(args.dim, self.norm_eps)
|
||||
self.ffn_norm = RMSNorm(args.dim, self.norm_eps)
|
||||
self.hc_mult = hc_mult = args.hc_mult
|
||||
self.hc_sinkhorn_iters = args.hc_sinkhorn_iters
|
||||
self.hc_eps = args.hc_eps
|
||||
mix_hc = (2 + hc_mult) * hc_mult
|
||||
hc_dim = hc_mult * args.dim
|
||||
with set_dtype(torch.float32):
|
||||
self.hc_attn_fn = nn.Parameter(torch.empty(mix_hc, hc_dim))
|
||||
self.hc_ffn_fn = nn.Parameter(torch.empty(mix_hc, hc_dim))
|
||||
self.hc_attn_base = nn.Parameter(torch.empty(mix_hc))
|
||||
self.hc_ffn_base = nn.Parameter(torch.empty(mix_hc))
|
||||
self.hc_attn_scale = nn.Parameter(torch.empty(3))
|
||||
self.hc_ffn_scale = nn.Parameter(torch.empty(3))
|
||||
|
||||
def hc_pre(self, x: torch.Tensor, hc_fn: torch.Tensor, hc_scale: torch.Tensor, hc_base: torch.Tensor):
|
||||
# x: [b,s,hc,d], hc_fn: [mix_hc,hc*d], hc_scale: [3], hc_base: [mix_hc], y: [b,s,hc,d]
|
||||
shape, dtype = x.size(), x.dtype
|
||||
x = x.flatten(2).float()
|
||||
rsqrt = torch.rsqrt(x.square().mean(-1, keepdim=True) + self.norm_eps)
|
||||
mixes = F.linear(x, hc_fn) * rsqrt
|
||||
pre, post, comb = hc_split_sinkhorn(mixes, hc_scale, hc_base, self.hc_mult, self.hc_sinkhorn_iters, self.hc_eps)
|
||||
y = torch.sum(pre.unsqueeze(-1) * x.view(shape), dim=2)
|
||||
return y.to(dtype), post, comb
|
||||
|
||||
def hc_post(self, x: torch.Tensor, residual: torch.Tensor, post: torch.Tensor, comb: torch.Tensor):
|
||||
# x: [b,s,d], residual: [b,s,hc,d], post: [b,s,hc], comb: [b,s,hc,hc], y: [b,s,hc,d]
|
||||
y = post.unsqueeze(-1) * x.unsqueeze(-2) + torch.sum(comb.unsqueeze(-1) * residual.unsqueeze(-2), dim=2)
|
||||
return y.type_as(x)
|
||||
|
||||
def forward(self, x: torch.Tensor, start_pos: int, input_ids: Optional[torch.Tensor]) -> torch.Tensor:
|
||||
residual = x
|
||||
x, post, comb = self.hc_pre(x, self.hc_attn_fn, self.hc_attn_scale, self.hc_attn_base)
|
||||
x = self.attn_norm(x)
|
||||
x = self.attn(x, start_pos)
|
||||
x = self.hc_post(x, residual, post, comb)
|
||||
|
||||
residual = x
|
||||
x, post, comb = self.hc_pre(x, self.hc_ffn_fn, self.hc_ffn_scale, self.hc_ffn_base)
|
||||
x = self.ffn_norm(x)
|
||||
x = self.ffn(x, input_ids)
|
||||
x = self.hc_post(x, residual, post, comb)
|
||||
return x
|
||||
|
||||
|
||||
class ParallelHead(nn.Module):
|
||||
|
||||
def __init__(self, vocab_size: int, dim: int, norm_eps: float = 1e-6, hc_eps: float = 1e-6):
|
||||
super().__init__()
|
||||
self.vocab_size = vocab_size
|
||||
self.dim = dim
|
||||
self.norm_eps = norm_eps
|
||||
self.hc_eps = hc_eps
|
||||
self.part_vocab_size = (vocab_size // world_size)
|
||||
# lm_head in the checkpoint is stored in bf16, while the parameter here is stored in fp32 for easier computation of logits later.
|
||||
self.weight = nn.Parameter(torch.empty(self.part_vocab_size, self.dim, dtype=torch.float32))
|
||||
|
||||
def get_logits(self, x):
|
||||
return F.linear(x[:, -1].float(), self.weight)
|
||||
|
||||
def forward(self, x: torch.Tensor, hc_fn: torch.Tensor, hc_scale: torch.Tensor, hc_base: torch.Tensor, norm: RMSNorm):
|
||||
# x: [b,s,hc,d]
|
||||
x = self.hc_head(x, hc_fn, hc_scale, hc_base)
|
||||
logits = self.get_logits(norm(x))
|
||||
if world_size > 1:
|
||||
all_logits = [torch.empty_like(logits) for _ in range(world_size)]
|
||||
dist.all_gather(all_logits, logits)
|
||||
logits = torch.cat(all_logits, dim=-1)
|
||||
return logits
|
||||
|
||||
def hc_head(self, x: torch.Tensor, hc_fn: torch.Tensor, hc_scale: torch.Tensor, hc_base: torch.Tensor):
|
||||
shape, dtype = x.size(), x.dtype
|
||||
x = x.flatten(2).float()
|
||||
rsqrt = torch.rsqrt(x.square().mean(-1, keepdim=True) + self.norm_eps)
|
||||
mixes = F.linear(x, hc_fn) * rsqrt
|
||||
pre = torch.sigmoid(mixes * hc_scale + hc_base) + self.hc_eps
|
||||
y = torch.sum(pre.unsqueeze(-1) * x.view(shape), dim=2)
|
||||
return y.to(dtype)
|
||||
|
||||
|
||||
class MTPBlock(Block):
|
||||
|
||||
def __init__(self, layer_id: int, args: ModelArgs):
|
||||
super().__init__(layer_id, args)
|
||||
self.e_proj = Linear(args.dim, args.dim)
|
||||
self.h_proj = Linear(args.dim, args.dim)
|
||||
self.enorm = RMSNorm(args.dim, args.norm_eps)
|
||||
self.hnorm = RMSNorm(args.dim, args.norm_eps)
|
||||
self.norm = RMSNorm(args.dim, args.norm_eps)
|
||||
self.hc_mult = hc_mult = args.hc_mult
|
||||
hc_dim = hc_mult * args.dim
|
||||
with set_dtype(torch.float32):
|
||||
self.hc_head_fn = nn.Parameter(torch.empty(hc_mult, hc_dim))
|
||||
self.hc_head_base = nn.Parameter(torch.empty(hc_mult))
|
||||
self.hc_head_scale = nn.Parameter(torch.empty(1))
|
||||
self.embed: ParallelEmbedding = None
|
||||
self.head: ParallelHead = None
|
||||
|
||||
@torch.inference_mode()
|
||||
def forward(self, x: torch.Tensor, start_pos: int, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
# x: [b,s,hc,d]
|
||||
assert self.embed is not None and self.head is not None
|
||||
e = self.embed(input_ids)
|
||||
e = self.enorm(e)
|
||||
x = self.hnorm(x)
|
||||
x = self.e_proj(e).unsqueeze(2) + self.h_proj(x)
|
||||
x = super().forward(x, start_pos, input_ids)
|
||||
logits = self.head(x, self.hc_head_fn, self.hc_head_scale, self.hc_head_base, self.norm)
|
||||
return logits
|
||||
|
||||
|
||||
class Transformer(nn.Module):
|
||||
"""Full DeepSeek-V4 model: embed -> HC-expand -> N blocks -> HC-head -> logits.
|
||||
Sets global state (world_size, rank, default_dtype, scale_fmt, scale_dtype) in __init__."""
|
||||
def __init__(self, args: ModelArgs):
|
||||
global world_size, rank, default_dtype, scale_fmt, scale_dtype
|
||||
world_size = dist.get_world_size() if dist.is_initialized() else 1
|
||||
rank = dist.get_rank() if dist.is_initialized() else 0
|
||||
default_dtype = torch.float8_e4m3fn if args.dtype == "fp8" else torch.bfloat16
|
||||
scale_fmt = "ue8m0" if args.scale_dtype == "fp8" else args.scale_fmt
|
||||
scale_dtype = torch.float8_e8m0fnu if args.scale_dtype == "fp8" else torch.float32
|
||||
super().__init__()
|
||||
self.max_seq_len = args.max_seq_len
|
||||
self.norm_eps = args.norm_eps
|
||||
self.hc_eps = args.hc_eps
|
||||
self.embed = ParallelEmbedding(args.vocab_size, args.dim)
|
||||
self.layers = torch.nn.ModuleList()
|
||||
for layer_id in range(args.n_layers):
|
||||
self.layers.append(Block(layer_id, args))
|
||||
self.norm = RMSNorm(args.dim, self.norm_eps)
|
||||
self.head = ParallelHead(args.vocab_size, args.dim, self.norm_eps, self.hc_eps)
|
||||
self.mtp = torch.nn.ModuleList()
|
||||
for layer_id in range(args.n_mtp_layers):
|
||||
self.mtp.append(MTPBlock(args.n_layers + layer_id, args))
|
||||
self.mtp[-1].embed = self.embed
|
||||
self.mtp[-1].head = self.head
|
||||
self.hc_mult = hc_mult = args.hc_mult
|
||||
hc_dim = hc_mult * args.dim
|
||||
with set_dtype(torch.float32):
|
||||
self.hc_head_fn = nn.Parameter(torch.empty(hc_mult, hc_dim))
|
||||
self.hc_head_base = nn.Parameter(torch.empty(hc_mult))
|
||||
self.hc_head_scale = nn.Parameter(torch.empty(1))
|
||||
|
||||
@torch.inference_mode()
|
||||
def forward(self, input_ids: torch.Tensor, start_pos: int = 0):
|
||||
h = self.embed(input_ids)
|
||||
# Expand to hc_mult copies for Hyper-Connections
|
||||
h = h.unsqueeze(2).repeat(1, 1, self.hc_mult, 1)
|
||||
for layer in self.layers:
|
||||
h = layer(h, start_pos, input_ids)
|
||||
logits = self.head(h, self.hc_head_fn, self.hc_head_scale, self.hc_head_base, self.norm)
|
||||
return logits
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
torch.set_default_dtype(torch.bfloat16)
|
||||
torch.set_default_device("cuda")
|
||||
torch.manual_seed(0)
|
||||
args = ModelArgs(n_hash_layers=0)
|
||||
x = torch.randint(0, args.vocab_size, (2, 128))
|
||||
model = Transformer(args)
|
||||
|
||||
print(model(x).size())
|
||||
for i in range(128, 150):
|
||||
print(i, model(x[:, 0:1], i).size())
|
||||
|
||||
h = torch.randn(2, 128, args.hc_mult, args.dim)
|
||||
mtp = model.mtp[0]
|
||||
print(mtp(h, 0, x).size())
|
||||
print(mtp(h[:, 0:1], 1, x[:, 0:1]).size())
|
||||
1
reference/vllm/__init__.py
Normal file
1
reference/vllm/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
# vLLM reference — read only, do not modify
|
||||
67
reference/vllm/reasoning/deepseek_r1_reasoning_parser.py
Normal file
67
reference/vllm/reasoning/deepseek_r1_reasoning_parser.py
Normal file
@@ -0,0 +1,67 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from collections.abc import Sequence
|
||||
|
||||
from vllm.entrypoints.openai.engine.protocol import DeltaMessage
|
||||
from vllm.reasoning.basic_parsers import BaseThinkingReasoningParser
|
||||
|
||||
|
||||
class DeepSeekR1ReasoningParser(BaseThinkingReasoningParser):
|
||||
"""
|
||||
Reasoning parser for DeepSeek R1 model.
|
||||
|
||||
The DeepSeek R1 model uses <think>...</think> tokens to denote reasoning
|
||||
text. This parser extracts the reasoning content from the model output.
|
||||
"""
|
||||
|
||||
@property
|
||||
def start_token(self) -> str:
|
||||
"""The token that starts reasoning content."""
|
||||
return "<think>"
|
||||
|
||||
@property
|
||||
def end_token(self) -> str:
|
||||
"""The token that ends reasoning content."""
|
||||
return "</think>"
|
||||
|
||||
def extract_reasoning_streaming(
|
||||
self,
|
||||
previous_text: str,
|
||||
current_text: str,
|
||||
delta_text: str,
|
||||
previous_token_ids: Sequence[int],
|
||||
current_token_ids: Sequence[int],
|
||||
delta_token_ids: Sequence[int],
|
||||
) -> DeltaMessage | None:
|
||||
ret = super().extract_reasoning_streaming(
|
||||
previous_text,
|
||||
current_text,
|
||||
delta_text,
|
||||
previous_token_ids,
|
||||
current_token_ids,
|
||||
delta_token_ids,
|
||||
)
|
||||
if (
|
||||
ret is not None
|
||||
and self.start_token_id not in previous_token_ids
|
||||
and self.start_token_id not in delta_token_ids
|
||||
):
|
||||
if self.end_token_id in delta_token_ids:
|
||||
# end token in delta with more tokens,
|
||||
# extract reasoning content and content
|
||||
end_index = delta_text.find(self.end_token)
|
||||
reasoning = delta_text[:end_index]
|
||||
content = delta_text[end_index + len(self.end_token) :]
|
||||
return DeltaMessage(
|
||||
reasoning=reasoning,
|
||||
content=content if content else None,
|
||||
)
|
||||
elif self.end_token_id in previous_token_ids:
|
||||
# end token in previous, thinking content ends
|
||||
return DeltaMessage(content=delta_text)
|
||||
else:
|
||||
# no end token in previous or delta, reasoning content continues
|
||||
return DeltaMessage(reasoning=delta_text)
|
||||
|
||||
return ret
|
||||
99
reference/vllm/reasoning/deepseek_v3_reasoning_parser.py
Normal file
99
reference/vllm/reasoning/deepseek_v3_reasoning_parser.py
Normal file
@@ -0,0 +1,99 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from collections.abc import Iterable, Sequence
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from transformers import PreTrainedTokenizerBase
|
||||
|
||||
from vllm.logger import init_logger
|
||||
from vllm.reasoning import ReasoningParser
|
||||
from vllm.reasoning.deepseek_r1_reasoning_parser import DeepSeekR1ReasoningParser
|
||||
|
||||
from .identity_reasoning_parser import IdentityReasoningParser
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.entrypoints.openai.chat_completion.protocol import ChatCompletionRequest
|
||||
from vllm.entrypoints.openai.engine.protocol import DeltaMessage
|
||||
from vllm.entrypoints.openai.responses.protocol import ResponsesRequest
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class DeepSeekV3ReasoningParser(ReasoningParser):
|
||||
"""
|
||||
V3 parser that delegates to either DeepSeekR1ReasoningParser or
|
||||
IdentityReasoningParser based on `thinking` and `separate_reasoning`.
|
||||
"""
|
||||
|
||||
def __init__(self, tokenizer: PreTrainedTokenizerBase, *args, **kwargs):
|
||||
super().__init__(tokenizer, *args, **kwargs)
|
||||
|
||||
chat_kwargs = kwargs.get("chat_template_kwargs", {}) or {}
|
||||
thinking = bool(chat_kwargs.get("thinking", False))
|
||||
enable_thinking = bool(chat_kwargs.get("enable_thinking", False))
|
||||
thinking = thinking or enable_thinking
|
||||
|
||||
self._parser: ReasoningParser
|
||||
if thinking:
|
||||
self._parser = DeepSeekR1ReasoningParser(tokenizer, *args, **kwargs)
|
||||
else:
|
||||
self._parser = IdentityReasoningParser(tokenizer, *args, **kwargs)
|
||||
|
||||
@property
|
||||
def reasoning_start_str(self) -> str | None:
|
||||
return self._parser.reasoning_start_str
|
||||
|
||||
@property
|
||||
def reasoning_end_str(self) -> str | None:
|
||||
return self._parser.reasoning_end_str
|
||||
|
||||
def is_reasoning_end(self, input_ids: Sequence[int]) -> bool:
|
||||
return self._parser.is_reasoning_end(input_ids)
|
||||
|
||||
def is_reasoning_end_streaming(
|
||||
self, input_ids: Sequence[int], delta_ids: Iterable[int]
|
||||
) -> bool:
|
||||
return self._parser.is_reasoning_end_streaming(input_ids, delta_ids)
|
||||
|
||||
def extract_content_ids(self, input_ids: list[int]) -> list[int]:
|
||||
return self._parser.extract_content_ids(input_ids)
|
||||
|
||||
def extract_reasoning(
|
||||
self, model_output: str, request: "ChatCompletionRequest | ResponsesRequest"
|
||||
) -> tuple[str | None, str | None]:
|
||||
return self._parser.extract_reasoning(model_output, request)
|
||||
|
||||
def extract_reasoning_streaming(
|
||||
self,
|
||||
previous_text: str,
|
||||
current_text: str,
|
||||
delta_text: str,
|
||||
previous_token_ids: Sequence[int],
|
||||
current_token_ids: Sequence[int],
|
||||
delta_token_ids: Sequence[int],
|
||||
) -> "DeltaMessage | None":
|
||||
return self._parser.extract_reasoning_streaming(
|
||||
previous_text,
|
||||
current_text,
|
||||
delta_text,
|
||||
previous_token_ids,
|
||||
current_token_ids,
|
||||
delta_token_ids,
|
||||
)
|
||||
|
||||
|
||||
class DeepSeekV3ReasoningWithThinkingParser(DeepSeekV3ReasoningParser):
|
||||
"""
|
||||
DeepSeekV3ReasoningParser that defaults to thinking mode.
|
||||
"""
|
||||
|
||||
def __init__(self, tokenizer: PreTrainedTokenizerBase, *args, **kwargs):
|
||||
chat_kwargs = kwargs.get("chat_template_kwargs", {}) or {}
|
||||
thinking = chat_kwargs.get("thinking", None)
|
||||
enable_thinking = chat_kwargs.get("enable_thinking", None)
|
||||
if thinking is None and enable_thinking is None:
|
||||
chat_kwargs["thinking"] = True
|
||||
chat_kwargs["enable_thinking"] = True
|
||||
kwargs["chat_template_kwargs"] = chat_kwargs
|
||||
super().__init__(tokenizer, *args, **kwargs)
|
||||
96
reference/vllm/tokenizers/deepseek_v4.py
Normal file
96
reference/vllm/tokenizers/deepseek_v4.py
Normal file
@@ -0,0 +1,96 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import copy
|
||||
from typing import Any
|
||||
|
||||
from transformers import PreTrainedTokenizerFast
|
||||
|
||||
from vllm.entrypoints.chat_utils import ChatCompletionMessageParam
|
||||
|
||||
from .deepseek_v4_encoding import encode_messages
|
||||
from .hf import HfTokenizer, get_cached_tokenizer
|
||||
from .protocol import TokenizerLike
|
||||
|
||||
|
||||
def get_deepseek_v4_tokenizer(tokenizer: HfTokenizer) -> HfTokenizer:
|
||||
"""
|
||||
Wraps a tokenizer to use the custom DeepSeek V4 chat template encoding.
|
||||
"""
|
||||
dsv4_tokenizer = copy.copy(tokenizer)
|
||||
|
||||
added_vocab = tokenizer.get_added_vocab()
|
||||
added_vocab_size = len(added_vocab)
|
||||
tokenizer_vocab_size = tokenizer.vocab_size
|
||||
|
||||
class _DeepseekV4Tokenizer(tokenizer.__class__): # type: ignore
|
||||
def apply_chat_template(
|
||||
self,
|
||||
messages: list["ChatCompletionMessageParam"],
|
||||
tools: list[dict[str, Any]] | None = None,
|
||||
**kwargs,
|
||||
) -> str | list[int]:
|
||||
thinking = kwargs.get("thinking", False)
|
||||
enable_thinking = kwargs.get("enable_thinking", False)
|
||||
thinking = thinking or enable_thinking
|
||||
thinking_mode = "thinking" if thinking else "chat"
|
||||
|
||||
conversation = kwargs.get("conversation", messages)
|
||||
messages = conversation.copy()
|
||||
if tools is not None and len(tools) > 0:
|
||||
messages.insert(0, {"role": "system"})
|
||||
messages[0]["tools"] = tools # type: ignore[typeddict-unknown-key]
|
||||
|
||||
reasoning_effort = kwargs.get("reasoning_effort")
|
||||
if not isinstance(reasoning_effort, str):
|
||||
reasoning_effort = None
|
||||
elif reasoning_effort == "none":
|
||||
thinking_mode = "chat"
|
||||
reasoning_effort = None
|
||||
elif reasoning_effort in ("max", "xhigh"):
|
||||
reasoning_effort = "max"
|
||||
else:
|
||||
reasoning_effort = "high"
|
||||
|
||||
encode_config = dict(
|
||||
thinking_mode=thinking_mode,
|
||||
drop_thinking=kwargs.get("drop_thinking", True),
|
||||
reasoning_effort=reasoning_effort,
|
||||
)
|
||||
|
||||
prompt_str = encode_messages(messages, **encode_config) # type: ignore
|
||||
|
||||
if kwargs.get("tokenize", True):
|
||||
tokenizer_kwargs = {
|
||||
k: kwargs[k] for k in ("truncation", "max_length") if k in kwargs
|
||||
}
|
||||
return self.encode(
|
||||
prompt_str,
|
||||
add_special_tokens=False,
|
||||
**tokenizer_kwargs,
|
||||
)
|
||||
|
||||
return prompt_str
|
||||
|
||||
def num_special_tokens_to_add(self) -> int:
|
||||
return len(self.encode(""))
|
||||
|
||||
def __len__(self) -> int:
|
||||
return tokenizer_vocab_size + added_vocab_size
|
||||
|
||||
def get_added_vocab(self) -> dict[str, int]:
|
||||
return added_vocab.copy()
|
||||
|
||||
def __reduce__(self):
|
||||
return get_deepseek_v4_tokenizer, (tokenizer,)
|
||||
|
||||
_DeepseekV4Tokenizer.__name__ = f"DSV4{tokenizer.__class__.__name__}"
|
||||
|
||||
dsv4_tokenizer.__class__ = _DeepseekV4Tokenizer
|
||||
return dsv4_tokenizer
|
||||
|
||||
|
||||
class DeepseekV4Tokenizer(TokenizerLike):
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs) -> HfTokenizer:
|
||||
tokenizer = PreTrainedTokenizerFast.from_pretrained(*args, **kwargs)
|
||||
return get_cached_tokenizer(get_deepseek_v4_tokenizer(tokenizer))
|
||||
757
reference/vllm/tokenizers/deepseek_v4_encoding.py
Normal file
757
reference/vllm/tokenizers/deepseek_v4_encoding.py
Normal file
@@ -0,0 +1,757 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
# ruff: noqa
|
||||
# fmt: off
|
||||
|
||||
"""
|
||||
DeepSeek-V4 Encoding
|
||||
|
||||
A self-contained implementation for encoding/decoding DeepSeek-V4 chat messages
|
||||
with tool calling, thinking mode, and quick instruction task support.
|
||||
"""
|
||||
|
||||
from typing import Any, Dict, List, Union, Optional, Tuple
|
||||
import copy
|
||||
import json
|
||||
|
||||
import regex as re
|
||||
|
||||
# ============================================================
|
||||
# Special Tokens
|
||||
# ============================================================
|
||||
|
||||
bos_token: str = "<|begin▁of▁sentence|>"
|
||||
eos_token: str = "<|end▁of▁sentence|>"
|
||||
thinking_start_token: str = "<think>"
|
||||
thinking_end_token: str = "</think>"
|
||||
dsml_token: str = "|DSML|"
|
||||
|
||||
USER_SP_TOKEN = "<|User|>"
|
||||
ASSISTANT_SP_TOKEN = "<|Assistant|>"
|
||||
LATEST_REMINDER_SP_TOKEN = "<|latest_reminder|>"
|
||||
|
||||
# Task special tokens for internal classification tasks
|
||||
DS_TASK_SP_TOKENS = {
|
||||
"action": "<|action|>",
|
||||
"query": "<|query|>",
|
||||
"authority": "<|authority|>",
|
||||
"domain": "<|domain|>",
|
||||
"title": "<|title|>",
|
||||
"read_url": "<|read_url|>",
|
||||
}
|
||||
VALID_TASKS = set(DS_TASK_SP_TOKENS.keys())
|
||||
|
||||
# ============================================================
|
||||
# Templates
|
||||
# ============================================================
|
||||
|
||||
system_msg_template: str = "{content}"
|
||||
user_msg_template: str = "{content}"
|
||||
latest_reminder_msg_template: str = "{content}"
|
||||
assistant_msg_template: str = "{reasoning}{content}{tool_calls}" + eos_token
|
||||
assistant_msg_wo_eos_template: str = "{reasoning}{content}{tool_calls}"
|
||||
thinking_template: str = "{reasoning}"
|
||||
|
||||
response_format_template: str = (
|
||||
"## Response Format:\n\nYou MUST strictly adhere to the following schema to reply:\n{schema}"
|
||||
)
|
||||
tool_call_template: str = (
|
||||
"<{dsml_token}invoke name=\"{name}\">\n{arguments}\n</{dsml_token}invoke>"
|
||||
)
|
||||
tool_calls_template = (
|
||||
"<{dsml_token}{tc_block_name}>\n{tool_calls}\n</{dsml_token}{tc_block_name}>"
|
||||
)
|
||||
tool_calls_block_name: str = "tool_calls"
|
||||
|
||||
tool_output_template: str = (
|
||||
"<tool_result>{content}</tool_result>"
|
||||
)
|
||||
|
||||
REASONING_EFFORT_MAX = (
|
||||
"Reasoning Effort: Absolute maximum with no shortcuts permitted.\n"
|
||||
"You MUST be very thorough in your thinking and comprehensively decompose the problem to resolve the root cause, rigorously stress-testing your logic against all potential paths, edge cases, and adversarial scenarios.\n"
|
||||
"Explicitly write out your entire deliberation process, documenting every intermediate step, considered alternative, and rejected hypothesis to ensure absolutely no assumption is left unchecked.\n\n"
|
||||
)
|
||||
|
||||
TOOLS_TEMPLATE = """## Tools
|
||||
|
||||
You have access to a set of tools to help answer the user's question. You can invoke tools by writing a "<{dsml_token}tool_calls>" block like the following:
|
||||
|
||||
<{dsml_token}tool_calls>
|
||||
<{dsml_token}invoke name="$TOOL_NAME">
|
||||
<{dsml_token}parameter name="$PARAMETER_NAME" string="true|false">$PARAMETER_VALUE</{dsml_token}parameter>
|
||||
...
|
||||
</{dsml_token}invoke>
|
||||
<{dsml_token}invoke name="$TOOL_NAME2">
|
||||
...
|
||||
</{dsml_token}invoke>
|
||||
</{dsml_token}tool_calls>
|
||||
|
||||
String parameters should be specified as is and set `string="true"`. For all other types (numbers, booleans, arrays, objects), pass the value in JSON format and set `string="false"`.
|
||||
|
||||
If thinking_mode is enabled (triggered by {thinking_start_token}), you MUST output your complete reasoning inside {thinking_start_token}...{thinking_end_token} BEFORE any tool calls or final response.
|
||||
|
||||
Otherwise, output directly after {thinking_end_token} with tool calls or final response.
|
||||
|
||||
### Available Tool Schemas
|
||||
|
||||
{tool_schemas}
|
||||
|
||||
You MUST strictly follow the above defined tool name and parameter schemas to invoke tool calls.
|
||||
"""
|
||||
|
||||
# ============================================================
|
||||
# Utility Functions
|
||||
# ============================================================
|
||||
|
||||
def to_json(value: Any) -> str:
|
||||
"""Serialize a value to JSON string."""
|
||||
try:
|
||||
return json.dumps(value, ensure_ascii=False)
|
||||
except Exception:
|
||||
return json.dumps(value, ensure_ascii=True)
|
||||
|
||||
|
||||
def tools_from_openai_format(tools):
|
||||
"""Extract function definitions from OpenAI-format tool list."""
|
||||
return [tool["function"] for tool in tools]
|
||||
|
||||
|
||||
def tool_calls_from_openai_format(tool_calls):
|
||||
"""Convert OpenAI-format tool calls to internal format."""
|
||||
return [
|
||||
{
|
||||
"name": tool_call["function"]["name"],
|
||||
"arguments": tool_call["function"]["arguments"],
|
||||
}
|
||||
for tool_call in tool_calls
|
||||
]
|
||||
|
||||
|
||||
def tool_calls_to_openai_format(tool_calls):
|
||||
"""Convert internal tool calls to OpenAI format."""
|
||||
return [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": tool_call["name"],
|
||||
"arguments": tool_call["arguments"],
|
||||
}
|
||||
}
|
||||
for tool_call in tool_calls
|
||||
]
|
||||
|
||||
|
||||
def encode_arguments_to_dsml(tool_call: Dict[str, Any]) -> str:
|
||||
"""
|
||||
Encode tool call arguments into DSML parameter format.
|
||||
|
||||
Args:
|
||||
tool_call: Dict with "name" and "arguments" keys.
|
||||
|
||||
Returns:
|
||||
DSML-formatted parameter string.
|
||||
"""
|
||||
p_dsml_template = '<{dsml_token}parameter name="{key}" string="{is_str}">{value}</{dsml_token}parameter>'
|
||||
P_dsml_strs = []
|
||||
|
||||
if isinstance(tool_call["arguments"], str):
|
||||
arguments = json.loads(tool_call["arguments"])
|
||||
else:
|
||||
arguments = tool_call["arguments"]
|
||||
|
||||
for k, v in arguments.items():
|
||||
p_dsml_str = p_dsml_template.format(
|
||||
dsml_token=dsml_token,
|
||||
key=k,
|
||||
is_str="true" if isinstance(v, str) else "false",
|
||||
value=v if isinstance(v, str) else to_json(v),
|
||||
)
|
||||
P_dsml_strs.append(p_dsml_str)
|
||||
|
||||
return "\n".join(P_dsml_strs)
|
||||
|
||||
|
||||
def decode_dsml_to_arguments(tool_name: str, tool_args: Dict[str, Tuple[str, str]]) -> Dict[str, str]:
|
||||
"""
|
||||
Decode DSML parameters back to a tool call dict.
|
||||
|
||||
Args:
|
||||
tool_name: Name of the tool.
|
||||
tool_args: Dict mapping param_name -> (value, is_string_flag).
|
||||
|
||||
Returns:
|
||||
Dict with "name" and "arguments" (JSON string) keys.
|
||||
"""
|
||||
def _decode_value(key: str, value: str, string: str):
|
||||
if string == "true":
|
||||
value = to_json(value)
|
||||
return f"{to_json(key)}: {value}"
|
||||
|
||||
tool_args_json = "{" + ", ".join([_decode_value(k, v, string=is_str) for k, (v, is_str) in tool_args.items()]) + "}"
|
||||
return dict(name=tool_name, arguments=tool_args_json)
|
||||
|
||||
|
||||
def render_tools(tools: List[Dict[str, Union[str, Dict[str, Any]]]]) -> str:
|
||||
"""
|
||||
Render tool schemas into the system prompt format.
|
||||
|
||||
Args:
|
||||
tools: List of tool schema dicts (each with name, description, parameters).
|
||||
|
||||
Returns:
|
||||
Formatted tools section string.
|
||||
"""
|
||||
tools_json = [to_json(t) for t in tools]
|
||||
|
||||
return TOOLS_TEMPLATE.format(
|
||||
tool_schemas="\n".join(tools_json),
|
||||
dsml_token=dsml_token,
|
||||
thinking_start_token=thinking_start_token,
|
||||
thinking_end_token=thinking_end_token,
|
||||
)
|
||||
|
||||
|
||||
def find_last_user_index(messages: List[Dict[str, Any]]) -> int:
|
||||
"""Find the index of the last user/developer message."""
|
||||
last_user_index = -1
|
||||
for idx in range(len(messages) - 1, -1, -1):
|
||||
if messages[idx].get("role") in ["user", "developer"]:
|
||||
last_user_index = idx
|
||||
break
|
||||
return last_user_index
|
||||
|
||||
|
||||
# ============================================================
|
||||
# Message Rendering
|
||||
# ============================================================
|
||||
|
||||
def render_message(index: int, messages: List[Dict[str, Any]], thinking_mode: str, drop_thinking: bool = True, reasoning_effort: Optional[str] = None) -> str:
|
||||
"""
|
||||
Render a single message at the given index into its encoded string form.
|
||||
|
||||
This is the core function that converts each message in the conversation
|
||||
into the DeepSeek-V4 format.
|
||||
|
||||
Args:
|
||||
index: Index of the message to render.
|
||||
messages: Full list of messages in the conversation.
|
||||
thinking_mode: Either "chat" or "thinking".
|
||||
drop_thinking: Whether to drop reasoning content from earlier turns.
|
||||
reasoning_effort: Optional reasoning effort level ("max", "high", or None).
|
||||
|
||||
Returns:
|
||||
Encoded string for this message.
|
||||
"""
|
||||
assert 0 <= index < len(messages)
|
||||
assert thinking_mode in ["chat", "thinking"], f"Invalid thinking_mode `{thinking_mode}`"
|
||||
|
||||
prompt = ""
|
||||
msg = messages[index]
|
||||
last_user_idx = find_last_user_index(messages)
|
||||
|
||||
role = msg.get("role")
|
||||
content = msg.get("content")
|
||||
tools = msg.get("tools")
|
||||
response_format = msg.get("response_format")
|
||||
tool_calls = msg.get("tool_calls")
|
||||
reasoning = msg.get("reasoning")
|
||||
wo_eos = msg.get("wo_eos", False)
|
||||
|
||||
if tools:
|
||||
tools = tools_from_openai_format(tools)
|
||||
if tool_calls:
|
||||
tool_calls = tool_calls_from_openai_format(tool_calls)
|
||||
|
||||
# Reasoning effort prefix (only at index 0 in thinking mode with max effort)
|
||||
assert reasoning_effort in ['max', None, 'high'], f"Invalid reasoning effort: {reasoning_effort}"
|
||||
if index == 0 and thinking_mode == "thinking" and reasoning_effort == 'max':
|
||||
prompt += REASONING_EFFORT_MAX
|
||||
|
||||
if role == "system":
|
||||
prompt += system_msg_template.format(content=content or "")
|
||||
if tools:
|
||||
prompt += "\n\n" + render_tools(tools)
|
||||
if response_format:
|
||||
prompt += "\n\n" + response_format_template.format(schema=to_json(response_format))
|
||||
|
||||
elif role == "developer":
|
||||
assert content, f"Invalid message for role `{role}`: {msg}"
|
||||
|
||||
content_developer = USER_SP_TOKEN
|
||||
content_developer += content
|
||||
|
||||
if tools:
|
||||
content_developer += "\n\n" + render_tools(tools)
|
||||
if response_format:
|
||||
content_developer += "\n\n" + response_format_template.format(schema=to_json(response_format))
|
||||
|
||||
prompt += user_msg_template.format(content=content_developer)
|
||||
|
||||
elif role == "user":
|
||||
prompt += USER_SP_TOKEN
|
||||
|
||||
# Handle content blocks (tool results mixed with text)
|
||||
content_blocks = msg.get("content_blocks")
|
||||
if content_blocks:
|
||||
parts = []
|
||||
for block in content_blocks:
|
||||
block_type = block.get("type")
|
||||
if block_type == "text":
|
||||
parts.append(block.get("text", ""))
|
||||
elif block_type == "tool_result":
|
||||
tool_content = block.get("content", "")
|
||||
if isinstance(tool_content, list):
|
||||
text_parts = []
|
||||
for b in tool_content:
|
||||
if b.get("type") == "text":
|
||||
text_parts.append(b.get("text", ""))
|
||||
else:
|
||||
text_parts.append(f"[Unsupported {b.get('type')}]")
|
||||
tool_content = "\n\n".join(text_parts)
|
||||
parts.append(tool_output_template.format(content=tool_content))
|
||||
else:
|
||||
parts.append(f"[Unsupported {block_type}]")
|
||||
prompt += "\n\n".join(parts)
|
||||
else:
|
||||
prompt += content or ""
|
||||
|
||||
elif role == "latest_reminder":
|
||||
prompt += LATEST_REMINDER_SP_TOKEN + latest_reminder_msg_template.format(content=content)
|
||||
|
||||
elif role == "tool":
|
||||
raise NotImplementedError("deepseek_v4 merges tool messages into user; please preprocess with merge_tool_messages()")
|
||||
|
||||
elif role == "assistant":
|
||||
thinking_part = ""
|
||||
tc_content = ""
|
||||
|
||||
if tool_calls:
|
||||
tc_list = [
|
||||
tool_call_template.format(
|
||||
dsml_token=dsml_token,
|
||||
name=tc.get("name"),
|
||||
arguments=encode_arguments_to_dsml(tc)
|
||||
)
|
||||
for tc in tool_calls
|
||||
]
|
||||
tc_content += '\n\n' + tool_calls_template.format(
|
||||
dsml_token=dsml_token,
|
||||
tool_calls="\n".join(tc_list),
|
||||
tc_block_name=tool_calls_block_name,
|
||||
)
|
||||
|
||||
summary_content = content or ""
|
||||
reasoning = reasoning or ""
|
||||
|
||||
# Check if previous message has a task - if so, this is a task output (no thinking)
|
||||
prev_has_task = index - 1 >= 0 and messages[index - 1].get("task") is not None
|
||||
|
||||
if thinking_mode == "thinking" and not prev_has_task:
|
||||
if not drop_thinking or index > last_user_idx:
|
||||
thinking_part = thinking_template.format(reasoning=reasoning) + thinking_end_token
|
||||
else:
|
||||
thinking_part = ""
|
||||
|
||||
if wo_eos:
|
||||
prompt += assistant_msg_wo_eos_template.format(
|
||||
reasoning=thinking_part,
|
||||
content=summary_content,
|
||||
tool_calls=tc_content,
|
||||
)
|
||||
else:
|
||||
prompt += assistant_msg_template.format(
|
||||
reasoning=thinking_part,
|
||||
content=summary_content,
|
||||
tool_calls=tc_content,
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError(f"Unknown role: {role}")
|
||||
|
||||
# Append transition tokens based on what follows
|
||||
if index + 1 < len(messages) and messages[index + 1].get("role") not in ["assistant", "latest_reminder"]:
|
||||
return prompt
|
||||
|
||||
task = messages[index].get("task")
|
||||
if task is not None:
|
||||
# Task special token for internal classification tasks
|
||||
assert task in VALID_TASKS, f"Invalid task: '{task}'. Valid tasks are: {list(VALID_TASKS)}"
|
||||
task_sp_token = DS_TASK_SP_TOKENS[task]
|
||||
|
||||
if task != "action":
|
||||
# Non-action tasks: append task sp token directly after the message
|
||||
prompt += task_sp_token
|
||||
else:
|
||||
# Action task: append Assistant + thinking token + action sp token
|
||||
prompt += ASSISTANT_SP_TOKEN
|
||||
prompt += thinking_end_token if thinking_mode != "thinking" else thinking_start_token
|
||||
prompt += task_sp_token
|
||||
|
||||
elif messages[index].get("role") in ["user", "developer"]:
|
||||
# Normal generation: append Assistant + thinking token
|
||||
prompt += ASSISTANT_SP_TOKEN
|
||||
if not drop_thinking and thinking_mode == "thinking":
|
||||
prompt += thinking_start_token
|
||||
elif drop_thinking and thinking_mode == "thinking" and index >= last_user_idx:
|
||||
prompt += thinking_start_token
|
||||
else:
|
||||
prompt += thinking_end_token
|
||||
|
||||
return prompt
|
||||
|
||||
|
||||
# ============================================================
|
||||
# Preprocessing
|
||||
# ============================================================
|
||||
|
||||
def merge_tool_messages(messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Merge tool messages into the preceding user message using content_blocks format.
|
||||
|
||||
DeepSeek-V4 does not have a standalone "tool" role; instead, tool results
|
||||
are encoded as <tool_result> blocks within user messages.
|
||||
|
||||
This function converts a standard OpenAI-format conversation (with separate
|
||||
"tool" role messages) into V4 format where tool results are merged into
|
||||
user messages.
|
||||
|
||||
Args:
|
||||
messages: List of message dicts in OpenAI format.
|
||||
|
||||
Returns:
|
||||
Processed message list with tool messages merged into user messages.
|
||||
"""
|
||||
merged: List[Dict[str, Any]] = []
|
||||
|
||||
for msg in messages:
|
||||
msg = copy.deepcopy(msg)
|
||||
role = msg.get("role")
|
||||
|
||||
if role == "tool":
|
||||
# Convert tool message to a user message with tool_result block
|
||||
tool_block = {
|
||||
"type": "tool_result",
|
||||
"tool_use_id": msg.get("tool_call_id", ""),
|
||||
"content": msg.get("content", ""),
|
||||
}
|
||||
# Merge into previous message if it's already a user (merged tool)
|
||||
if merged and merged[-1].get("role") == "user" and "content_blocks" in merged[-1]:
|
||||
merged[-1]["content_blocks"].append(tool_block)
|
||||
else:
|
||||
merged.append({
|
||||
"role": "user",
|
||||
"content_blocks": [tool_block],
|
||||
})
|
||||
elif role == "user":
|
||||
text_block = {"type": "text", "text": msg.get("content", "")}
|
||||
if merged and merged[-1].get("role") == "user" and "content_blocks" in merged[-1] and merged[-1].get("task") is None:
|
||||
merged[-1]["content_blocks"].append(text_block)
|
||||
else:
|
||||
new_msg = {
|
||||
"role": "user",
|
||||
"content": msg.get("content", ""),
|
||||
"content_blocks": [text_block],
|
||||
}
|
||||
# Preserve extra fields (task, wo_eos, mask, etc.)
|
||||
for key in ("task", "wo_eos", "mask"):
|
||||
if key in msg:
|
||||
new_msg[key] = msg[key]
|
||||
merged.append(new_msg)
|
||||
else:
|
||||
merged.append(msg)
|
||||
|
||||
return merged
|
||||
|
||||
|
||||
def sort_tool_results_by_call_order(messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Sort tool_result blocks within user messages by the order of tool_calls
|
||||
in the preceding assistant message.
|
||||
|
||||
Args:
|
||||
messages: Preprocessed message list (after merge_tool_messages).
|
||||
|
||||
Returns:
|
||||
Message list with sorted tool result blocks.
|
||||
"""
|
||||
last_tool_call_order: Dict[str, int] = {}
|
||||
|
||||
for msg in messages:
|
||||
role = msg.get("role")
|
||||
if role == "assistant" and msg.get("tool_calls"):
|
||||
last_tool_call_order = {}
|
||||
for idx, tc in enumerate(msg["tool_calls"]):
|
||||
tc_id = tc.get("id") or tc.get("function", {}).get("id", "")
|
||||
if tc_id:
|
||||
last_tool_call_order[tc_id] = idx
|
||||
|
||||
elif role == "user" and msg.get("content_blocks"):
|
||||
tool_blocks = [b for b in msg["content_blocks"] if b.get("type") == "tool_result"]
|
||||
if len(tool_blocks) > 1 and last_tool_call_order:
|
||||
sorted_blocks = sorted(
|
||||
tool_blocks,
|
||||
key=lambda b: last_tool_call_order.get(b.get("tool_use_id", ""), 0)
|
||||
)
|
||||
sorted_idx = 0
|
||||
new_blocks = []
|
||||
for block in msg["content_blocks"]:
|
||||
if block.get("type") == "tool_result":
|
||||
new_blocks.append(sorted_blocks[sorted_idx])
|
||||
sorted_idx += 1
|
||||
else:
|
||||
new_blocks.append(block)
|
||||
msg["content_blocks"] = new_blocks
|
||||
|
||||
return messages
|
||||
|
||||
|
||||
# ============================================================
|
||||
# Main Encoding Function
|
||||
# ============================================================
|
||||
|
||||
def encode_messages(
|
||||
messages: List[Dict[str, Any]],
|
||||
thinking_mode: str,
|
||||
context: Optional[List[Dict[str, Any]]] = None,
|
||||
drop_thinking: bool = True,
|
||||
add_default_bos_token: bool = True,
|
||||
reasoning_effort: Optional[str] = None,
|
||||
) -> str:
|
||||
"""
|
||||
Encode a list of messages into the DeepSeek-V4 prompt format.
|
||||
|
||||
This is the main entry point for encoding conversations. It handles:
|
||||
- BOS token insertion
|
||||
- Thinking mode with optional reasoning content dropping
|
||||
- Tool message merging into user messages
|
||||
- Multi-turn conversation context
|
||||
|
||||
Args:
|
||||
messages: List of message dicts to encode.
|
||||
thinking_mode: Either "chat" or "thinking".
|
||||
context: Optional preceding context messages (already encoded prefix).
|
||||
drop_thinking: If True, drop reasoning from earlier assistant turns
|
||||
(only keep reasoning for messages after the last user message).
|
||||
add_default_bos_token: Whether to prepend BOS token at conversation start.
|
||||
reasoning_effort: Optional reasoning effort level ("max", "high", or None).
|
||||
|
||||
Returns:
|
||||
The encoded prompt string.
|
||||
"""
|
||||
context = context if context else []
|
||||
|
||||
# Preprocess: merge tool messages and sort tool results
|
||||
messages = merge_tool_messages(messages)
|
||||
messages = sort_tool_results_by_call_order(context + messages)[len(context):]
|
||||
if context:
|
||||
context = merge_tool_messages(context)
|
||||
context = sort_tool_results_by_call_order(context)
|
||||
|
||||
full_messages = context + messages
|
||||
|
||||
prompt = bos_token if add_default_bos_token and len(context) == 0 else ""
|
||||
|
||||
# Resolve drop_thinking: if any message has tools defined, don't drop thinking
|
||||
effective_drop_thinking = drop_thinking
|
||||
if any(m.get("tools") for m in full_messages):
|
||||
effective_drop_thinking = False
|
||||
|
||||
if thinking_mode == "thinking" and effective_drop_thinking:
|
||||
full_messages = _drop_thinking_messages(full_messages)
|
||||
# After dropping, recalculate how many messages to render
|
||||
# (context may have shrunk too)
|
||||
num_to_render = len(full_messages) - len(_drop_thinking_messages(context))
|
||||
context_len = len(full_messages) - num_to_render
|
||||
else:
|
||||
num_to_render = len(messages)
|
||||
context_len = len(context)
|
||||
|
||||
for idx in range(num_to_render):
|
||||
prompt += render_message(
|
||||
idx + context_len,
|
||||
full_messages,
|
||||
thinking_mode=thinking_mode,
|
||||
drop_thinking=effective_drop_thinking,
|
||||
reasoning_effort=reasoning_effort,
|
||||
)
|
||||
|
||||
return prompt
|
||||
|
||||
|
||||
def _drop_thinking_messages(messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Drop reasoning and non-essential messages before the last user message.
|
||||
|
||||
Behavior:
|
||||
- Messages with role in ["user", "system", "tool", "latest_reminder"] are always kept.
|
||||
- Messages at or after the last user index are always kept.
|
||||
- Assistant messages before the last user get reasoning removed.
|
||||
- Developer messages before the last user are dropped entirely.
|
||||
"""
|
||||
last_user_idx = find_last_user_index(messages)
|
||||
result = []
|
||||
keep_roles = {"user", "system", "tool", "latest_reminder", "direct_search_results"}
|
||||
|
||||
for idx, msg in enumerate(messages):
|
||||
role = msg.get("role")
|
||||
if role in keep_roles or idx >= last_user_idx:
|
||||
result.append(msg)
|
||||
elif role == "assistant":
|
||||
msg = copy.copy(msg)
|
||||
msg.pop("reasoning", None)
|
||||
result.append(msg)
|
||||
# developer and other roles before last_user_idx are dropped
|
||||
|
||||
return result
|
||||
|
||||
|
||||
# ============================================================
|
||||
# Parsing (Decoding model output)
|
||||
# ============================================================
|
||||
|
||||
def _read_until_stop(index: int, text: str, stop: List[str]) -> Tuple[int, str, Optional[str]]:
|
||||
"""
|
||||
Read text from index until one of the stop strings is found.
|
||||
|
||||
Returns:
|
||||
Tuple of (new_index, content_before_stop, matched_stop_string_or_None).
|
||||
"""
|
||||
min_pos = len(text)
|
||||
matched_stop = None
|
||||
|
||||
for s in stop:
|
||||
pos = text.find(s, index)
|
||||
if pos != -1 and pos < min_pos:
|
||||
min_pos = pos
|
||||
matched_stop = s
|
||||
|
||||
if matched_stop:
|
||||
content = text[index:min_pos]
|
||||
return min_pos + len(matched_stop), content, matched_stop
|
||||
else:
|
||||
content = text[index:]
|
||||
return len(text), content, None
|
||||
|
||||
|
||||
def parse_tool_calls(index: int, text: str) -> Tuple[int, Optional[str], List[Dict[str, str]]]:
|
||||
"""
|
||||
Parse DSML tool calls from text starting at the given index.
|
||||
|
||||
Args:
|
||||
index: Starting position in text.
|
||||
text: The full text to parse.
|
||||
|
||||
Returns:
|
||||
Tuple of (new_index, last_stop_token, list_of_tool_call_dicts).
|
||||
Each tool call dict has "name" and "arguments" keys.
|
||||
"""
|
||||
tool_calls: List[Dict[str, Any]] = []
|
||||
stop_token = None
|
||||
tool_calls_end_token = f"</{dsml_token}{tool_calls_block_name}>"
|
||||
|
||||
while index < len(text):
|
||||
index, content_before, stop_token = _read_until_stop(index, text, [f"<{dsml_token}invoke", tool_calls_end_token])
|
||||
if content_before != ">\n":
|
||||
raise ValueError(f"Tool call format error: expected '>\\n' but got '{content_before}'")
|
||||
|
||||
if stop_token == tool_calls_end_token:
|
||||
break
|
||||
|
||||
if stop_token is None:
|
||||
raise ValueError("Missing special token in tool calls")
|
||||
|
||||
index, tool_name_content, stop_token = _read_until_stop(index, text, [f"<{dsml_token}parameter", f"</{dsml_token}invoke"])
|
||||
|
||||
p_tool_name = re.findall(r'^\s*name="(.*?)">\n$', tool_name_content, flags=re.DOTALL)
|
||||
if len(p_tool_name) != 1:
|
||||
raise ValueError(f"Tool name format error: '{tool_name_content}'")
|
||||
tool_name = p_tool_name[0]
|
||||
|
||||
tool_args: Dict[str, Tuple[str, str]] = {}
|
||||
while stop_token == f"<{dsml_token}parameter":
|
||||
index, param_content, stop_token = _read_until_stop(index, text, [f"/{dsml_token}parameter"])
|
||||
|
||||
param_kv = re.findall(r'^ name="(.*?)" string="(true|false)">(.*?)<$', param_content, flags=re.DOTALL)
|
||||
if len(param_kv) != 1:
|
||||
raise ValueError(f"Parameter format error: '{param_content}'")
|
||||
param_name, string, param_value = param_kv[0]
|
||||
|
||||
if param_name in tool_args:
|
||||
raise ValueError(f"Duplicate parameter name: '{param_name}'")
|
||||
tool_args[param_name] = (param_value, string)
|
||||
|
||||
index, content, stop_token = _read_until_stop(index, text, [f"<{dsml_token}parameter", f"</{dsml_token}invoke"])
|
||||
if content != ">\n":
|
||||
raise ValueError(f"Parameter format error: expected '>\\n' but got '{content}'")
|
||||
|
||||
tool_call = decode_dsml_to_arguments(tool_name=tool_name, tool_args=tool_args)
|
||||
tool_calls.append(tool_call)
|
||||
|
||||
return index, stop_token, tool_calls
|
||||
|
||||
|
||||
def parse_message_from_completion_text(text: str, thinking_mode: str) -> Dict[str, Any]:
|
||||
"""
|
||||
Parse a model completion text into a structured assistant message.
|
||||
|
||||
This function takes the raw text output from the model (a single assistant turn)
|
||||
and extracts:
|
||||
- reasoning (thinking block)
|
||||
- content (summary/response)
|
||||
- tool_calls (if any)
|
||||
|
||||
NOTE: This function is designed to parse only correctly formatted strings and
|
||||
will raise ValueError for malformed output.
|
||||
|
||||
Args:
|
||||
text: The raw completion text (including EOS token).
|
||||
thinking_mode: Either "chat" or "thinking".
|
||||
|
||||
Returns:
|
||||
Dict with keys: "role", "content", "reasoning", "tool_calls".
|
||||
tool_calls are in OpenAI format.
|
||||
"""
|
||||
summary_content, reasoning = "", ""
|
||||
tool_calls: List[Dict[str, str]] = []
|
||||
index, stop_token = 0, None
|
||||
tool_calls_start_token = f"\n\n<{dsml_token}{tool_calls_block_name}"
|
||||
|
||||
is_thinking = thinking_mode == "thinking"
|
||||
is_tool_calling = False
|
||||
|
||||
if is_thinking:
|
||||
index, content_delta, stop_token = _read_until_stop(index, text, [thinking_end_token, tool_calls_start_token])
|
||||
reasoning = content_delta
|
||||
if stop_token != thinking_end_token:
|
||||
raise ValueError("Invalid thinking format: missing </think>")
|
||||
|
||||
index, content_delta, stop_token = _read_until_stop(index, text, [eos_token, tool_calls_start_token])
|
||||
summary_content = content_delta
|
||||
if stop_token == tool_calls_start_token:
|
||||
is_tool_calling = True
|
||||
else:
|
||||
if stop_token != eos_token:
|
||||
raise ValueError("Invalid format: missing EOS token")
|
||||
|
||||
if is_tool_calling:
|
||||
index, stop_token, tool_calls = parse_tool_calls(index, text)
|
||||
|
||||
index, tool_ends_text, stop_token = _read_until_stop(index, text, [eos_token])
|
||||
if tool_ends_text:
|
||||
raise ValueError("Unexpected content after tool calls")
|
||||
|
||||
if len(text) != index or stop_token not in [eos_token, None]:
|
||||
raise ValueError("Unexpected content at end")
|
||||
|
||||
for sp_token in [bos_token, eos_token, thinking_start_token, thinking_end_token, dsml_token]:
|
||||
if sp_token in summary_content or sp_token in reasoning:
|
||||
raise ValueError(f"Unexpected special token '{sp_token}' in content")
|
||||
|
||||
return {
|
||||
"role": "assistant",
|
||||
"content": summary_content,
|
||||
"reasoning": reasoning,
|
||||
"tool_calls": tool_calls_to_openai_format(tool_calls)
|
||||
}
|
||||
|
||||
# fmt: on
|
||||
322
reference/vllm/tool_parsers/deepseekv32_tool_parser.py
Normal file
322
reference/vllm/tool_parsers/deepseekv32_tool_parser.py
Normal file
@@ -0,0 +1,322 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import json
|
||||
import uuid
|
||||
from collections.abc import Sequence
|
||||
from typing import Any
|
||||
|
||||
import regex as re
|
||||
|
||||
from vllm.entrypoints.openai.chat_completion.protocol import (
|
||||
ChatCompletionRequest,
|
||||
)
|
||||
from vllm.entrypoints.openai.engine.protocol import (
|
||||
DeltaFunctionCall,
|
||||
DeltaMessage,
|
||||
DeltaToolCall,
|
||||
ExtractedToolCallInformation,
|
||||
FunctionCall,
|
||||
ToolCall,
|
||||
)
|
||||
from vllm.entrypoints.openai.responses.protocol import ResponsesRequest
|
||||
from vllm.logger import init_logger
|
||||
from vllm.tokenizers import TokenizerLike
|
||||
from vllm.tool_parsers.abstract_tool_parser import (
|
||||
Tool,
|
||||
ToolParser,
|
||||
)
|
||||
from vllm.tool_parsers.utils import partial_tag_overlap
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class DeepSeekV32ToolParser(ToolParser):
|
||||
"""
|
||||
example tool call content:
|
||||
<|DSML|function_calls>
|
||||
<|DSML|invoke name="get_weather">
|
||||
<|DSML|parameter name="location" string="true">杭州</|DSML|parameter>
|
||||
<|DSML|parameter name="date" string="true">2024-01-16</|DSML|parameter>
|
||||
</|DSML|invoke>
|
||||
<|DSML|invoke name="get_weather">
|
||||
<|DSML|parameter name="location" string="true">北京</|DSML|parameter>
|
||||
<|DSML|parameter name="date" string="true">2024-01-16</|DSML|parameter>
|
||||
</|DSML|invoke>
|
||||
</|DSML|function_calls>
|
||||
"""
|
||||
|
||||
tool_call_start_token: str = "<|DSML|function_calls>"
|
||||
tool_call_end_token: str = "</|DSML|function_calls>"
|
||||
|
||||
def __init__(self, tokenizer: TokenizerLike, tools: list[Tool] | None = None):
|
||||
super().__init__(tokenizer, tools)
|
||||
|
||||
self.prev_tool_call_arr: list[dict] = []
|
||||
|
||||
# Streaming state
|
||||
self.current_tool_index: int = 0
|
||||
self._sent_content_idx: int = 0
|
||||
|
||||
# Regex patterns for complete parsing
|
||||
self.tool_call_complete_regex = re.compile(
|
||||
re.escape(self.tool_call_start_token)
|
||||
+ r"(.*?)"
|
||||
+ re.escape(self.tool_call_end_token),
|
||||
re.DOTALL,
|
||||
)
|
||||
self.invoke_complete_regex = re.compile(
|
||||
r'<|DSML|invoke\s+name="([^"]+)"\s*>(.*?)</|DSML|invoke>', re.DOTALL
|
||||
)
|
||||
self.parameter_complete_regex = re.compile(
|
||||
r'<|DSML|parameter\s+name="([^"]+)"\s+string="(?:true|false)"\s*>(.*?)</|DSML|parameter>',
|
||||
re.DOTALL,
|
||||
)
|
||||
|
||||
if not self.model_tokenizer:
|
||||
raise ValueError(
|
||||
"The model tokenizer must be passed to the ToolParser "
|
||||
"constructor during construction."
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
"vLLM Successfully import tool parser %s !", self.__class__.__name__
|
||||
)
|
||||
|
||||
def adjust_request(
|
||||
self, request: ChatCompletionRequest | ResponsesRequest
|
||||
) -> ChatCompletionRequest | ResponsesRequest:
|
||||
request = super().adjust_request(request)
|
||||
if request.tools and request.tool_choice != "none":
|
||||
# Ensure tool call tokens
|
||||
# (e.g. <|DSML|function_calls>, </|DSML|function_calls>)
|
||||
# are not skippedduring decoding.
|
||||
# Even though they are not marked as special tokens,
|
||||
# setting skip_special_tokens=False ensures proper handling in
|
||||
# transformers 5.x where decoding behavior may have changed.
|
||||
request.skip_special_tokens = False
|
||||
return request
|
||||
|
||||
def _generate_tool_call_id(self) -> str:
|
||||
"""Generate a unique tool call ID."""
|
||||
return f"call_{uuid.uuid4().hex[:24]}"
|
||||
|
||||
def _parse_invoke_params(self, invoke_str: str) -> dict:
|
||||
param_dict = dict()
|
||||
for param_name, param_val in self.parameter_complete_regex.findall(invoke_str):
|
||||
param_dict[param_name] = param_val
|
||||
return param_dict
|
||||
|
||||
def _convert_param_value_checked(self, value: str, param_type: str) -> Any:
|
||||
"""Convert parameter value to the correct type."""
|
||||
if value.lower() == "null":
|
||||
return None
|
||||
|
||||
param_type = param_type.lower()
|
||||
if param_type in ["string", "str", "text"]:
|
||||
return value
|
||||
elif param_type in ["integer", "int"]:
|
||||
return int(value)
|
||||
elif param_type in ["number", "float"]:
|
||||
val = float(value)
|
||||
return val if val != int(val) else int(val)
|
||||
elif param_type in ["boolean", "bool"]:
|
||||
value = value.strip()
|
||||
if value.lower() not in ["false", "0", "true", "1"]:
|
||||
raise ValueError("Invalid boolean value")
|
||||
return value.lower() in ["true", "1"]
|
||||
elif param_type in ["object", "array"]:
|
||||
return json.loads(value)
|
||||
else:
|
||||
return json.loads(value)
|
||||
|
||||
def _convert_param_value(self, value: str, param_type: str | list[str]) -> Any:
|
||||
"""Convert parameter value to the correct type."""
|
||||
if not isinstance(param_type, list):
|
||||
param_type = [param_type]
|
||||
for current_type in param_type:
|
||||
try:
|
||||
return self._convert_param_value_checked(value, current_type)
|
||||
except Exception:
|
||||
continue
|
||||
# return value as fallback
|
||||
return value
|
||||
|
||||
def _convert_params_with_schema(
|
||||
self,
|
||||
function_name: str,
|
||||
param_dict: dict[str, str],
|
||||
) -> dict[str, Any]:
|
||||
"""Convert raw string param values using the tool schema types."""
|
||||
param_config: dict = {}
|
||||
if self.tools:
|
||||
for tool in self.tools:
|
||||
if (
|
||||
hasattr(tool, "function")
|
||||
and tool.function.name == function_name
|
||||
and hasattr(tool.function, "parameters")
|
||||
):
|
||||
schema = tool.function.parameters
|
||||
if isinstance(schema, dict) and "properties" in schema:
|
||||
param_config = schema["properties"]
|
||||
break
|
||||
|
||||
converted: dict[str, Any] = {}
|
||||
for name, value in param_dict.items():
|
||||
param_type = "string"
|
||||
if name in param_config and isinstance(param_config[name], dict):
|
||||
param_type = param_config[name].get("type", "string")
|
||||
converted[name] = self._convert_param_value(value, param_type)
|
||||
return converted
|
||||
|
||||
def extract_tool_calls(
|
||||
self,
|
||||
model_output: str,
|
||||
request: ChatCompletionRequest,
|
||||
) -> ExtractedToolCallInformation:
|
||||
"""Extract tool calls from complete model output (non-streaming)."""
|
||||
# Quick check
|
||||
if self.tool_call_start_token not in model_output:
|
||||
return ExtractedToolCallInformation(
|
||||
tools_called=False, tool_calls=[], content=model_output
|
||||
)
|
||||
|
||||
try:
|
||||
tool_calls = []
|
||||
|
||||
# Find all complete tool_call blocks
|
||||
for tool_call_match in self.tool_call_complete_regex.findall(model_output):
|
||||
# Find all invokes within this tool_call
|
||||
for invoke_name, invoke_content in self.invoke_complete_regex.findall(
|
||||
tool_call_match
|
||||
):
|
||||
param_dict = self._parse_invoke_params(invoke_content)
|
||||
params = self._convert_params_with_schema(invoke_name, param_dict)
|
||||
tool_calls.append(
|
||||
ToolCall(
|
||||
type="function",
|
||||
function=FunctionCall(
|
||||
name=invoke_name,
|
||||
arguments=json.dumps(params, ensure_ascii=False),
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
if not tool_calls:
|
||||
return ExtractedToolCallInformation(
|
||||
tools_called=False, tool_calls=[], content=model_output
|
||||
)
|
||||
|
||||
# Extract content before first tool call
|
||||
first_tool_idx = model_output.find(self.tool_call_start_token)
|
||||
content = model_output[:first_tool_idx] if first_tool_idx > 0 else None
|
||||
|
||||
return ExtractedToolCallInformation(
|
||||
tools_called=True, tool_calls=tool_calls, content=content
|
||||
)
|
||||
|
||||
except Exception:
|
||||
logger.exception("Error extracting tool calls")
|
||||
return ExtractedToolCallInformation(
|
||||
tools_called=False, tool_calls=[], content=model_output
|
||||
)
|
||||
|
||||
def _reset_streaming_state(self):
|
||||
"""Reset all streaming state."""
|
||||
self.current_tool_index = 0
|
||||
self._sent_content_idx = 0
|
||||
self.prev_tool_call_arr.clear()
|
||||
self.streamed_args_for_tool.clear()
|
||||
|
||||
def _extract_delta_tool_calls(
|
||||
self,
|
||||
current_text: str,
|
||||
request: ChatCompletionRequest | None,
|
||||
) -> list[DeltaToolCall]:
|
||||
"""Extract DeltaToolCalls from newly completed <invoke> blocks.
|
||||
|
||||
Tracks progress via ``current_tool_index`` so each block is
|
||||
extracted exactly once across successive streaming calls.
|
||||
"""
|
||||
complete_invokes = self.invoke_complete_regex.findall(current_text)
|
||||
delta_tool_calls: list[DeltaToolCall] = []
|
||||
|
||||
while len(complete_invokes) > self.current_tool_index:
|
||||
invoke_name, invoke_body = complete_invokes[self.current_tool_index]
|
||||
param_dict = self._parse_invoke_params(invoke_body)
|
||||
|
||||
converted = self._convert_params_with_schema(invoke_name, param_dict)
|
||||
args_json = json.dumps(converted, ensure_ascii=False)
|
||||
idx = self.current_tool_index
|
||||
self.current_tool_index += 1
|
||||
|
||||
self.prev_tool_call_arr.append(
|
||||
{"name": invoke_name, "arguments": converted}
|
||||
)
|
||||
self.streamed_args_for_tool.append(args_json)
|
||||
|
||||
delta_tool_calls.append(
|
||||
DeltaToolCall(
|
||||
index=idx,
|
||||
id=self._generate_tool_call_id(),
|
||||
function=DeltaFunctionCall(
|
||||
name=invoke_name,
|
||||
arguments=args_json,
|
||||
),
|
||||
type="function",
|
||||
)
|
||||
)
|
||||
|
||||
return delta_tool_calls
|
||||
|
||||
def _extract_content(self, current_text: str) -> str | None:
|
||||
"""Return unsent non-tool-call text, or None.
|
||||
|
||||
Holds back any suffix that could be a partial start marker
|
||||
so that split markers are never leaked as content.
|
||||
"""
|
||||
if self.tool_call_start_token not in current_text:
|
||||
overlap = partial_tag_overlap(current_text, self.tool_call_start_token)
|
||||
sendable_idx = len(current_text) - overlap
|
||||
else:
|
||||
sendable_idx = current_text.index(self.tool_call_start_token)
|
||||
|
||||
if sendable_idx > self._sent_content_idx:
|
||||
content = current_text[self._sent_content_idx : sendable_idx]
|
||||
self._sent_content_idx = sendable_idx
|
||||
return content
|
||||
return None
|
||||
|
||||
def extract_tool_calls_streaming(
|
||||
self,
|
||||
previous_text: str,
|
||||
current_text: str,
|
||||
delta_text: str,
|
||||
previous_token_ids: Sequence[int], # pylint: disable=unused-argument
|
||||
current_token_ids: Sequence[int], # pylint: disable=unused-argument
|
||||
delta_token_ids: Sequence[int],
|
||||
request: ChatCompletionRequest,
|
||||
) -> DeltaMessage | None:
|
||||
"""Extract tool calls from streaming model output.
|
||||
|
||||
Uses a buffer-until-complete-invoke strategy: tokens are buffered
|
||||
until a complete invoke block is available, then parsed and emitted
|
||||
in one shot.
|
||||
"""
|
||||
|
||||
# First chunk of a new stream — reset state from prior request.
|
||||
if not previous_text:
|
||||
self._reset_streaming_state()
|
||||
|
||||
content = self._extract_content(current_text)
|
||||
delta_tool_calls = self._extract_delta_tool_calls(current_text, request)
|
||||
|
||||
if delta_tool_calls or content:
|
||||
return DeltaMessage(content=content, tool_calls=delta_tool_calls)
|
||||
|
||||
# Empty delta with token ids means EOS or closing tag; return
|
||||
# non-None so the serving framework can finalize finish_reason.
|
||||
if not delta_text and delta_token_ids and self.prev_tool_call_arr:
|
||||
return DeltaMessage(content="")
|
||||
|
||||
return None
|
||||
31
reference/vllm/tool_parsers/deepseekv4_tool_parser.py
Normal file
31
reference/vllm/tool_parsers/deepseekv4_tool_parser.py
Normal file
@@ -0,0 +1,31 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from vllm.entrypoints.openai.chat_completion.protocol import (
|
||||
ChatCompletionRequest,
|
||||
)
|
||||
from vllm.tool_parsers.deepseekv32_tool_parser import DeepSeekV32ToolParser
|
||||
from vllm.tool_parsers.structural_tag_registry import (
|
||||
get_enable_structured_outputs_in_reasoning,
|
||||
get_model_structural_tag,
|
||||
)
|
||||
|
||||
|
||||
class DeepSeekV4ToolParser(DeepSeekV32ToolParser):
|
||||
"""
|
||||
DeepSeek V4 DSML tool parser.
|
||||
|
||||
V4 keeps the V3.2 DSML invoke/parameter grammar, but wraps tool calls in
|
||||
``<|DSML|tool_calls>`` instead of ``<|DSML|function_calls>``.
|
||||
"""
|
||||
|
||||
tool_call_start_token: str = "<|DSML|tool_calls>"
|
||||
tool_call_end_token: str = "</|DSML|tool_calls>"
|
||||
|
||||
def get_structural_tag(self, request: ChatCompletionRequest):
|
||||
return get_model_structural_tag(
|
||||
model="deepseek_v4",
|
||||
tools=request.tools,
|
||||
tool_choice=request.tool_choice,
|
||||
reasoning=get_enable_structured_outputs_in_reasoning(),
|
||||
)
|
||||
@@ -24,6 +24,8 @@ def parse_args():
|
||||
p.add_argument('--top-k', type=int, default=50, help='Top-k filtering (0=disabled)')
|
||||
p.add_argument('--top-p', type=float, default=0.95, help='Top-p (nucleus) filtering (1.0=disabled)')
|
||||
p.add_argument('--prompt', type=str, default=None)
|
||||
p.add_argument('--thinking-mode', choices=['thinking', 'chat'], default='thinking',
|
||||
help='Thinking mode: "thinking" = model reasons first, "chat" = model generates directly')
|
||||
p.add_argument('--seed', type=int, default=42)
|
||||
p.add_argument('--verbose', type=int, default=1)
|
||||
p.add_argument('--prefill-only', action='store_true')
|
||||
@@ -45,8 +47,16 @@ PROMPT = _args.prompt or "The capital of France is"
|
||||
NUM_GPUS = _args.num_gpus
|
||||
SEED = _args.seed
|
||||
VERBOSE = _args.verbose
|
||||
THINK_START, THINK_END = 128821, 128822
|
||||
USER_TOKEN, ASSISTANT_TOKEN = 128803, 128804
|
||||
# Special token IDs — derived from official encoding module strings + tokenizer.
|
||||
# Do NOT hardcode these; the encoding module defines the canonical token strings.
|
||||
from encoding.deepseek_v4_encoding import (
|
||||
thinking_start_token as _THINK_START_STR,
|
||||
thinking_end_token as _THINK_END_STR,
|
||||
USER_SP_TOKEN as _USER_STR,
|
||||
ASSISTANT_SP_TOKEN as _ASSISTANT_STR,
|
||||
eos_token as _EOS_STR,
|
||||
bos_token as _BOS_STR,
|
||||
)
|
||||
FP4_LUT = torch.tensor([0., 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0])
|
||||
|
||||
# =====================================================================
|
||||
@@ -157,7 +167,7 @@ class CUDAGraphDecoder:
|
||||
def pre_allocate(self, cfg, attn_mhcs, ffn_mhcs, attn_norms, ffn_norms,
|
||||
kv_caches, compressors, indexers, moe_runners, se_runners,
|
||||
routers, prod_lins, layer_w, rope_caches, hc_head,
|
||||
final_norm_w, lm_head_lin):
|
||||
final_norm_w, lm_head_lin, comp_rope_caches=None):
|
||||
"""Pre-allocate all I/O buffers with fixed addresses."""
|
||||
for li in range(self.n_layers):
|
||||
dev = self.devices[li % self.num_gpus]
|
||||
@@ -169,7 +179,7 @@ class CUDAGraphDecoder:
|
||||
def capture(self, cfg, attn_mhcs, ffn_mhcs, attn_norms, ffn_norms,
|
||||
kv_caches, compressors, indexers, moe_runners, se_runners,
|
||||
routers, prod_lins, layer_w, rope_caches, hc_head,
|
||||
final_norm_w, lm_head_lin, positions, token_id):
|
||||
final_norm_w, lm_head_lin, positions, token_id, comp_rope_caches=None):
|
||||
"""Capture CUDA graphs for all layers + lm_head.
|
||||
|
||||
Must be called after one warmup step so that:
|
||||
@@ -198,7 +208,9 @@ class CUDAGraphDecoder:
|
||||
compressors.get(li), indexers.get(li),
|
||||
moe_runners.get(li), se_runners.get(li), routers.get(li),
|
||||
prod_lin=prod_lins.get(li),
|
||||
_use_fused_rmsnorm_quantize=True
|
||||
_use_fused_rmsnorm_quantize=True,
|
||||
comp_rope_cos=comp_rope_caches[gpu][0] if comp_rope_caches else None,
|
||||
comp_rope_sin=comp_rope_caches[gpu][1] if comp_rope_caches else None,
|
||||
)
|
||||
# Copy output to fixed buffer
|
||||
self.x_out_bufs[li].copy_(X_out)
|
||||
@@ -364,8 +376,10 @@ class Compressor:
|
||||
n_comp = compressed.shape[0]
|
||||
|
||||
# Vectorized position computation — no Python loop, no .item()
|
||||
# Block-aligned: use FIRST position of each block (vLLM cross-check confirmed)
|
||||
# Wrong: ((bi+1)*r - 1) uses LAST position → off by r-1 (3 for CSA, 127 for HCA)
|
||||
bi = torch.arange(n_comp, device=dev)
|
||||
pos_idx = ((bi + 1) * r - 1).clamp(max=positions.numel() - 1)
|
||||
pos_idx = (bi * r).clamp(max=positions.numel() - 1)
|
||||
comp_pos = positions[pos_idx]
|
||||
|
||||
# Return FP32 compressed output — caller handles RoPE + NVFP4 quantize
|
||||
@@ -750,23 +764,32 @@ def _run_production_fmha(q_heads, all_kv, n_h, hd, T, seq_len, scale, dev, li, w
|
||||
|
||||
def _run_production_fmha_mixed(q_heads, kv_nope_fp8, kv_nope_scale, kv_rope_bf16,
|
||||
n_h, hd, T, seq_len, scale, dev, li, w, pfx, rope_dim):
|
||||
"""B1 storage-native mixed FP8/BF16 decode FMHA. No BF16 KV staging."""
|
||||
if T != 1:
|
||||
raise RuntimeError(f"B1 mixed FP8 FMHA is decode-only (T==1); got T={T}")
|
||||
from dsv4.kernels.attention.production import dsv4_attention_mixed_fp8_decode
|
||||
q = q_heads.permute(1, 0, 2).contiguous() # (n_h, 1, hd)
|
||||
"""B1 storage-native mixed FP8/BF16 FMHA. Supports decode (T=1) and prefill (T>1)."""
|
||||
from dsv4.kernels.attention.production import dsv4_attention_mixed_fp8_decode, dsv4_attention_mixed_fp8_prefill
|
||||
q = q_heads.permute(1, 0, 2).contiguous() # (n_h, T, hd)
|
||||
sinks = w.get(f"{pfx}.sinks"); sink_bias = None
|
||||
if sinks is not None:
|
||||
sink_bias = sinks.to(device=dev).float().reshape(n_h)
|
||||
attn_out = dsv4_attention_mixed_fp8_decode(
|
||||
q=q,
|
||||
k_nope_fp8=kv_nope_fp8,
|
||||
k_nope_scale=kv_nope_scale,
|
||||
k_rope_bf16=kv_rope_bf16,
|
||||
scale=scale,
|
||||
sink_bias=sink_bias,
|
||||
rope_dim=rope_dim,
|
||||
)
|
||||
if T == 1:
|
||||
attn_out = dsv4_attention_mixed_fp8_decode(
|
||||
q=q,
|
||||
k_nope_fp8=kv_nope_fp8,
|
||||
k_nope_scale=kv_nope_scale,
|
||||
k_rope_bf16=kv_rope_bf16,
|
||||
scale=scale,
|
||||
sink_bias=sink_bias,
|
||||
rope_dim=rope_dim,
|
||||
)
|
||||
else:
|
||||
attn_out = dsv4_attention_mixed_fp8_prefill(
|
||||
q=q,
|
||||
k_nope_fp8=kv_nope_fp8,
|
||||
k_nope_scale=kv_nope_scale,
|
||||
k_rope_bf16=kv_rope_bf16,
|
||||
scale=scale,
|
||||
sink_bias=sink_bias,
|
||||
rope_dim=rope_dim,
|
||||
)
|
||||
return attn_out.permute(1, 0, 2) # (T, n_h, hd)
|
||||
|
||||
# =====================================================================
|
||||
@@ -775,7 +798,8 @@ def _run_production_fmha_mixed(q_heads, kv_nope_fp8, kv_nope_scale, kv_rope_bf16
|
||||
def forward_attention(x_normed, w, li, cfg, rope_cos, rope_sin,
|
||||
kv_cache, positions, compressor, indexer, prod_lin,
|
||||
x_quant=None,
|
||||
_profile_detail=False, _profile_times=None):
|
||||
_profile_detail=False, _profile_times=None,
|
||||
comp_rope_cos=None, comp_rope_sin=None):
|
||||
dev = x_normed.device; T = x_normed.shape[0]
|
||||
n_h = cfg["num_attention_heads"]; hd = cfg["head_dim"]; rd = cfg.get("qk_rope_head_dim", 64)
|
||||
o_groups = cfg.get("o_groups", 16); o_rank = cfg.get("o_lora_rank", 1024)
|
||||
@@ -843,7 +867,10 @@ def forward_attention(x_normed, w, li, cfg, rope_cos, rope_sin,
|
||||
rope_bf16 = comp_kv_fp32[:, nope_dim:].bfloat16().contiguous() # (n_comp, 64) BF16
|
||||
# Apply RoPE on BF16 rope dims (existing BF16 RoPE kernel)
|
||||
rope_3d = rope_bf16.unsqueeze(1) # (n_comp, 1, 64)
|
||||
rope_3d = _apply_rope(rope_3d, comp_pos, rope_cos, rope_sin, rd)
|
||||
# Use compress_rope_theta cache for compressed entries if available
|
||||
crc = comp_rope_cos if comp_rope_cos is not None else rope_cos
|
||||
crs = comp_rope_sin if comp_rope_sin is not None else rope_sin
|
||||
rope_3d = _apply_rope(rope_3d, comp_pos, crc, crs, rd)
|
||||
rope_bf16 = rope_3d.squeeze(1) # (n_comp, 64) BF16
|
||||
# Quantize non-RoPE part FP32 → FP8_E4M3
|
||||
nope_fp8, nope_scale = kv_mod.quantize_fp8_e4m3_from_fp32(nope_fp32)
|
||||
@@ -885,9 +912,10 @@ def forward_attention(x_normed, w, li, cfg, rope_cos, rope_sin,
|
||||
# 6. Production FMHA — B1 mixed FP8/BF16 decode path.
|
||||
_pt('fmha_start')
|
||||
if li == 0:
|
||||
print(f" L0 B1 verify: kv_nope_fp8 dtype={kv_nope_fp8.dtype} shape={tuple(kv_nope_fp8.shape)} "
|
||||
f"kv_nope_scale dtype={kv_nope_scale.dtype} shape={tuple(kv_nope_scale.shape)} "
|
||||
f"kv_rope_bf16 dtype={kv_rope_bf16.dtype} shape={tuple(kv_rope_bf16.shape)}", flush=True)
|
||||
if VERBOSE >= 2:
|
||||
print(f" L0 B1 verify: kv_nope_fp8 dtype={kv_nope_fp8.dtype} shape={tuple(kv_nope_fp8.shape)} "
|
||||
f"kv_nope_scale dtype={kv_nope_scale.dtype} shape={tuple(kv_nope_scale.shape)} "
|
||||
f"kv_rope_bf16 dtype={kv_rope_bf16.dtype} shape={tuple(kv_rope_bf16.shape)}", flush=True)
|
||||
assert kv_nope_fp8.dtype in (torch.uint8, torch.float8_e4m3fn), f"kv_nope_fp8 wrong dtype: {kv_nope_fp8.dtype}"
|
||||
assert kv_nope_scale.dtype == torch.float32, f"kv_nope_scale wrong dtype: {kv_nope_scale.dtype}"
|
||||
assert kv_rope_bf16.dtype == torch.bfloat16, f"kv_rope_bf16 wrong dtype: {kv_rope_bf16.dtype}"
|
||||
@@ -979,6 +1007,7 @@ def forward_layer(X_l, w, li, cfg, rope_cos, rope_sin,
|
||||
moe_runner=None, se_runner=None, router=None,
|
||||
prod_lin=None, _profile_detail=False, _profile_times=None,
|
||||
_use_fused_rmsnorm_quantize=True,
|
||||
comp_rope_cos=None, comp_rope_sin=None,
|
||||
):
|
||||
"""Forward one transformer layer.
|
||||
"""
|
||||
@@ -1011,7 +1040,8 @@ def forward_layer(X_l, w, li, cfg, rope_cos, rope_sin,
|
||||
F_attn, _ = forward_attention(x_normed, w, li, cfg, rope_cos, rope_sin,
|
||||
kv_cache, positions, compressor, indexer, prod_lin,
|
||||
x_quant=x_quant_attn,
|
||||
_profile_detail=_profile_detail, _profile_times=_profile_times)
|
||||
_profile_detail=_profile_detail, _profile_times=_profile_times,
|
||||
comp_rope_cos=comp_rope_cos, comp_rope_sin=comp_rope_sin)
|
||||
if _profile_detail: torch.cuda.synchronize(); t_attn1 = time.perf_counter()
|
||||
X_mid = attn_mhc.post_block(X_l, F_attn, ctx_a)
|
||||
|
||||
@@ -1395,6 +1425,14 @@ def main():
|
||||
rtheta = cfg.get("rope_theta", 10000.); romax = rp.get("original_max_position_embeddings", 65536)
|
||||
rbfast, rbslow = rp.get("beta_fast", 32), rp.get("beta_slow", 1)
|
||||
rope_caches = {g: build_rope_cache(romax, rd, f"cuda:{g}", rtheta, rt, rf, romax, rbfast, rbslow) for g in range(NUM_GPUS)}
|
||||
# Compressed-entry RoPE uses separate theta (vLLM cross-check: compress_rope_theta)
|
||||
# If compress_rope_theta differs from rope_theta, compressed KV entries need their own cache
|
||||
comp_rtheta = cfg.get("compress_rope_theta", rtheta)
|
||||
if comp_rtheta != rtheta:
|
||||
comp_rope_caches = {g: build_rope_cache(romax, rd, f"cuda:{g}", comp_rtheta, rt, rf, romax, rbfast, rbslow) for g in range(NUM_GPUS)}
|
||||
print(f" Compressed RoPE theta: {comp_rtheta} (different from normal: {rtheta})")
|
||||
else:
|
||||
comp_rope_caches = rope_caches # Same theta, reuse normal cache
|
||||
|
||||
# KV caches, compressors, indexers
|
||||
kv_caches, compressors, indexers = {}, {}, {}
|
||||
@@ -1431,6 +1469,14 @@ def main():
|
||||
from transformers import AutoTokenizer
|
||||
tokenizer = AutoTokenizer.from_pretrained(CHECKPOINT_DIR)
|
||||
|
||||
# Derive special token IDs from official encoding strings + tokenizer.
|
||||
# This is the ONLY source of truth — never hardcode these IDs.
|
||||
THINK_START = tokenizer.convert_tokens_to_ids(_THINK_START_STR)
|
||||
THINK_END = tokenizer.convert_tokens_to_ids(_THINK_END_STR)
|
||||
USER_TOKEN = tokenizer.convert_tokens_to_ids(_USER_STR)
|
||||
ASSISTANT_TOKEN = tokenizer.convert_tokens_to_ids(_ASSISTANT_STR)
|
||||
bos = tokenizer.bos_token_id or 0
|
||||
|
||||
# A1: Build explicit stop set — DSV4 uses special turn-end tokens beyond eos
|
||||
STOP_IDS = set()
|
||||
eos_id = tokenizer.eos_token_id
|
||||
@@ -1446,36 +1492,50 @@ def main():
|
||||
print(f" Special tokens: {tokenizer.special_tokens_map}")
|
||||
print(f" THINK_START={THINK_START} THINK_END={THINK_END} USER={USER_TOKEN} ASST={ASSISTANT_TOKEN}")
|
||||
|
||||
bos = tokenizer.bos_token_id or 0
|
||||
if _args.prefill_tokens:
|
||||
generated = [int(x) for x in _args.prefill_tokens.split(',')]
|
||||
else:
|
||||
input_ids = [bos, USER_TOKEN]
|
||||
input_ids += tokenizer.encode('\n\n' + PROMPT, add_special_tokens=False)
|
||||
input_ids.append(ASSISTANT_TOKEN)
|
||||
# DSV4 reasoning model: must prime with ◇ (think_start) after Assistant token.
|
||||
# Without this, the model is out-of-distribution — it expects to be inside a
|
||||
# thinking block but never received the think-start sentinel.
|
||||
# Symptom: degenerate output from step 0 (e.g. "France" instead of "Paris",
|
||||
# looping on newlines/repeated tokens). With ◇, the model generates thinking
|
||||
# content, emits ◇ (think_end), then produces the actual answer.
|
||||
input_ids.append(THINK_START)
|
||||
generated = input_ids
|
||||
# Official DeepSeek V4 encoding — canonical path, no hand-rolled alternatives.
|
||||
# Uses encoding/deepseek_v4_encoding.py (copied from vLLM tree) to build
|
||||
# the prompt. This is the ONLY way to construct prompts — the official
|
||||
# encoder handles BOS, User/Assistant tokens, thinking mode, and all
|
||||
# special token placement. It can't drift because it's the same code
|
||||
# the inference engines will use.
|
||||
from encoding.deepseek_v4_encoding import encode_messages
|
||||
messages = [{"role": "user", "content": PROMPT}]
|
||||
thinking_mode = _args.thinking_mode # 'thinking' or 'chat'
|
||||
encoded_str = encode_messages(messages, thinking_mode=thinking_mode)
|
||||
generated = tokenizer.encode(encoded_str, add_special_tokens=False)
|
||||
# Ensure BOS token is present at the start
|
||||
if generated[0] != bos:
|
||||
generated = [bos] + generated
|
||||
all_tokens = generated.copy()
|
||||
print(f"Input: {len(generated)} tokens")
|
||||
print(f"Input: {len(generated)} tokens (thinking_mode={_args.thinking_mode})")
|
||||
|
||||
# Prefill — one token at a time (decode-style; TODO: batched prefill)
|
||||
print(f"Prefilling {len(generated)} tokens...")
|
||||
# Pre-allocate prefill buffers — no per-step torch.tensor()
|
||||
pre_tid_buf = torch.zeros(1, dtype=torch.long, device='cuda:0')
|
||||
pre_tid32_buf = torch.zeros(1, dtype=torch.int32, device='cuda:0')
|
||||
pre_pos_buf = torch.zeros(1, dtype=torch.long, device='cuda:0')
|
||||
for pi, tid_val in enumerate(generated):
|
||||
# Batched prefill — process tokens in chunks of up to 128 (FMHA T≤128 constraint)
|
||||
PREFILL_CHUNK = 128 # max T per FMHA launch; split larger prefills into chunks
|
||||
n_prefill = len(generated)
|
||||
print(f"Batched prefill: {n_prefill} tokens, chunk_size={PREFILL_CHUNK}")
|
||||
prefill_ids = torch.tensor(generated, dtype=torch.long, device='cuda:0')
|
||||
prefill_ids32 = prefill_ids.to(torch.int32)
|
||||
all_positions = torch.arange(n_prefill, dtype=torch.long, device='cuda:0')
|
||||
|
||||
# Process chunks: each chunk goes through ALL 61 layers before the next chunk.
|
||||
# This ensures KV cache is populated correctly for each layer.
|
||||
chunk_starts = list(range(0, n_prefill, PREFILL_CHUNK))
|
||||
X = None # will be set by first chunk's embedding
|
||||
for ci, cs in enumerate(chunk_starts):
|
||||
ce = min(cs + PREFILL_CHUNK, n_prefill)
|
||||
chunk_len = ce - cs
|
||||
t1 = time.time()
|
||||
pre_tid_buf[0] = tid_val
|
||||
pre_tid32_buf[0] = tid_val
|
||||
pre_pos_buf[0] = pi
|
||||
X = mHCLayer.init_state(embed(pre_tid_buf))
|
||||
|
||||
# Embed chunk tokens: (chunk_len, d)
|
||||
chunk_ids = prefill_ids[cs:ce]
|
||||
chunk_ids32 = prefill_ids32[cs:ce]
|
||||
chunk_positions = all_positions[cs:ce]
|
||||
chunk_embed = embed(chunk_ids) # (chunk_len, d) BF16
|
||||
X = mHCLayer.init_state(chunk_embed) # (chunk_len, n_hc, d) BF16
|
||||
|
||||
for li in range(n_layers):
|
||||
gpu = li % NUM_GPUS
|
||||
if X.device != torch.device(f"cuda:{gpu}"): X = X.to(f"cuda:{gpu}")
|
||||
@@ -1484,23 +1544,23 @@ def main():
|
||||
X = forward_layer(X, layer_w[li], li, cfg, *rope_caches[gpu],
|
||||
attn_mhcs.get(li), ffn_mhcs.get(li),
|
||||
attn_norms.get(li), ffn_norms.get(li),
|
||||
kv_caches[li], pre_pos_buf, pre_tid32_buf,
|
||||
kv_caches[li], chunk_positions, chunk_ids32,
|
||||
compressors.get(li), indexers.get(li),
|
||||
moe_runners.get(li), se_runners.get(li), routers.get(li),
|
||||
prod_lin=prod_lins.get(li),
|
||||
_use_fused_rmsnorm_quantize=not _args.no_fused_rmsnorm,
|
||||
comp_rope_cos=comp_rope_caches[gpu][0], comp_rope_sin=comp_rope_caches[gpu][1],
|
||||
)
|
||||
except Exception as e:
|
||||
torch.cuda.synchronize()
|
||||
err = torch.cuda.current_stream(gpu).query()
|
||||
print(f" CRASH at token {pi} layer {li} gpu {gpu}: {e}", flush=True)
|
||||
print(f" CRASH at chunk {ci} (tokens {cs}-{ce-1}) layer {li} gpu {gpu}: {e}", flush=True)
|
||||
raise
|
||||
if VERBOSE >= 2 and pi == 0 and li < 3:
|
||||
if VERBOSE >= 2 and ci == 0 and li < 3:
|
||||
torch.cuda.synchronize(gpu)
|
||||
print(f" Token {pi} L{li}: OK |X|={X.abs().max().item():.1f}", flush=True)
|
||||
print(f" Chunk {ci} L{li}: OK |X|={X.abs().max().item():.1f}", flush=True)
|
||||
X = X.to('cuda:0'); torch.cuda.set_device(0)
|
||||
if pi % 10 == 0: print(f" Token {pi}/{len(generated)}: {time.time()-t1:.2f}s", flush=True)
|
||||
print(f" Prefill done ({time.time()-t0:.1f}s)")
|
||||
print(f" Chunk {ci+1}/{len(chunk_starts)} tokens {cs}-{ce-1} ({chunk_len} tok): {time.time()-t1:.2f}s", flush=True)
|
||||
print(f" Batched prefill done ({time.time()-t0:.1f}s)")
|
||||
|
||||
if _args.prefill_only: print("Prefill-only mode, stopping."); return
|
||||
|
||||
@@ -1570,6 +1630,7 @@ def main():
|
||||
_profile_detail=(profile and step == 1),
|
||||
_profile_times=cuda_layer_events if (profile and step == 1) else None,
|
||||
_use_fused_rmsnorm_quantize=not _args.no_fused_rmsnorm,
|
||||
comp_rope_cos=comp_rope_caches[gpu][0], comp_rope_sin=comp_rope_caches[gpu][1],
|
||||
)
|
||||
X = X.to('cuda:0'); torch.cuda.set_device(0)
|
||||
t_layers = time.perf_counter()
|
||||
@@ -1703,12 +1764,35 @@ def main():
|
||||
print(f" L{li} {tag}: {dt_ms:.2f}ms")
|
||||
prev_t = t
|
||||
|
||||
out = tokenizer.decode(all_tokens, skip_special_tokens=True)
|
||||
print(f"\n{'='*70}")
|
||||
print(f"Input: '{PROMPT}'")
|
||||
print(f"Output: '{out}'")
|
||||
print(f"Total: {time.time()-t0:.1f}s")
|
||||
print(f"{'='*70}")
|
||||
out_raw = tokenizer.decode(all_tokens, skip_special_tokens=False)
|
||||
# Use official DSV4 parser for structured output
|
||||
try:
|
||||
from encoding.deepseek_v4_encoding import parse_message_from_completion_text
|
||||
# Find the assistant portion — after the last ASSISTANT token
|
||||
assistant_start = out_raw.find(_ASSISTANT_STR)
|
||||
if assistant_start >= 0:
|
||||
assistant_text = out_raw[assistant_start + len(_ASSISTANT_STR):]
|
||||
else:
|
||||
assistant_text = out_raw
|
||||
parsed = parse_message_from_completion_text(assistant_text, thinking_mode=_args.thinking_mode)
|
||||
reasoning = parsed.get('reasoning', '')
|
||||
content = parsed.get('content', '')
|
||||
print(f"\n{'='*70}")
|
||||
print(f"Input: '{PROMPT}'")
|
||||
if reasoning:
|
||||
print(f"Reasoning: {reasoning[:500]}{'...' if len(reasoning) > 500 else ''}")
|
||||
print(f"Content: {content}")
|
||||
print(f"Total: {time.time()-t0:.1f}s")
|
||||
print(f"{'='*70}")
|
||||
except Exception as e:
|
||||
# Fallback: raw decode (shouldn't happen with correct output)
|
||||
out = tokenizer.decode(all_tokens, skip_special_tokens=True)
|
||||
print(f"\n{'='*70}")
|
||||
print(f"Input: '{PROMPT}'")
|
||||
print(f"Output (raw): '{out}'")
|
||||
print(f"Parse error: {e}")
|
||||
print(f"Total: {time.time()-t0:.1f}s")
|
||||
print(f"{'='*70}")
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
202
tests/unit/test_b1_mixed_fp8_prefill.py
Normal file
202
tests/unit/test_b1_mixed_fp8_prefill.py
Normal file
@@ -0,0 +1,202 @@
|
||||
#!/usr/bin/env python3
|
||||
"""B1 mixed FP8/BF16 prefill FMHA — unit test.
|
||||
|
||||
Tests the T>1 prefill kernel at production values:
|
||||
HD=512, NOPE=448, ROPE=64, H=128, T=1..64, N=128..2048.
|
||||
|
||||
1. T=1 prefill vs decode kernel (should be identical)
|
||||
2. T>1 prefill vs PyTorch SDPA reference
|
||||
3. T>1 with attention sinks
|
||||
4. Large N (production context lengths)
|
||||
5. Multi-batch
|
||||
|
||||
No model weights needed — uses synthetic random data.
|
||||
"""
|
||||
import sys, math
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
def quantize_fp8_e4m3(x_fp32):
|
||||
amax = x_fp32.abs().amax(dim=-1, keepdim=True).clamp(min=1e-12)
|
||||
scale = amax / 448.0
|
||||
fp8 = (x_fp32 / scale).clamp(-448, 448).to(torch.float8_e4m3fn)
|
||||
return fp8.view(torch.uint8), scale.squeeze(-1)
|
||||
|
||||
|
||||
def cosine(a, b):
|
||||
return F.cosine_similarity(a.flatten().float(), b.flatten().float(), dim=0).item()
|
||||
|
||||
|
||||
def main():
|
||||
HD = 512; NOPE = 448; ROPE = 64; H = 128
|
||||
scale = 1.0 / math.sqrt(HD)
|
||||
|
||||
print("=" * 70)
|
||||
print("B1 MIXED FP8 PREFILL FMHA — UNIT TEST")
|
||||
print(f"Production values: HD={HD}, NOPE={NOPE}, ROPE={ROPE}, H={H}")
|
||||
print("=" * 70)
|
||||
|
||||
results = {}
|
||||
|
||||
# ---- Test 1: T=1 prefill vs decode kernel ----
|
||||
print("\n" + "=" * 70)
|
||||
print("TEST 1: T=1 prefill vs T=1 decode (should be identical)")
|
||||
print("=" * 70)
|
||||
try:
|
||||
from dsv4.kernels.attention.fmha_mixed_fp8_op import fmha_mixed_fp8_decode_raw
|
||||
from dsv4.kernels.attention.fmha_mixed_fp8_prefill_op import fmha_mixed_fp8_prefill_raw
|
||||
|
||||
torch.manual_seed(42)
|
||||
B = 1; T = 1; N = 256
|
||||
q_fp32 = torch.randn(B, H, T, HD, dtype=torch.float32) * 0.5
|
||||
k_fp32 = torch.randn(N, HD, dtype=torch.float32) * 0.5
|
||||
q_bf16 = q_fp32.bfloat16().cuda()
|
||||
k_nope_fp8, k_nope_scale = quantize_fp8_e4m3(k_fp32[:, :NOPE])
|
||||
k_rope_bf16 = k_fp32[:, NOPE:].bfloat16()
|
||||
k_nope_fp8 = k_nope_fp8.cuda()
|
||||
k_nope_scale = k_nope_scale.cuda()
|
||||
k_rope_bf16 = k_rope_bf16.cuda()
|
||||
|
||||
o_decode, _ = fmha_mixed_fp8_decode_raw(
|
||||
q_bf16, k_nope_fp8, k_nope_scale, k_rope_bf16, scale, rope_dim=ROPE)
|
||||
o_prefill, _ = fmha_mixed_fp8_prefill_raw(
|
||||
q_bf16, k_nope_fp8, k_nope_scale, k_rope_bf16, scale, rope_dim=ROPE)
|
||||
|
||||
cos_val = cosine(o_decode, o_prefill)
|
||||
print(f" T=1 decode vs prefill: cos={cos_val:.8f}")
|
||||
assert cos_val >= 0.999, f"T=1 decode vs prefill cos={cos_val:.6f} < 0.999"
|
||||
results["1_t1_vs_decode"] = True
|
||||
print(" PASS")
|
||||
except Exception as e:
|
||||
print(f" FAIL: {e}")
|
||||
results["1_t1_vs_decode"] = False
|
||||
|
||||
# ---- Test 2: T>1 prefill vs SDPA reference ----
|
||||
print("\n" + "=" * 70)
|
||||
print("TEST 2: T>1 prefill vs PyTorch SDPA")
|
||||
print("=" * 70)
|
||||
all_pass = True
|
||||
for T in [1, 2, 4, 8, 16, 32]:
|
||||
for N in [128, 512]:
|
||||
print(f"\n T={T} N={N}")
|
||||
try:
|
||||
from dsv4.kernels.attention.fmha_mixed_fp8_prefill_op import fmha_mixed_fp8_prefill_raw
|
||||
|
||||
torch.manual_seed(42)
|
||||
q_fp32 = torch.randn(1, H, T, HD, dtype=torch.float32) * 0.5
|
||||
k_fp32 = torch.randn(N, HD, dtype=torch.float32) * 0.5
|
||||
q_bf16 = q_fp32.bfloat16().cuda()
|
||||
k_nope_fp8, k_nope_scale = quantize_fp8_e4m3(k_fp32[:, :NOPE])
|
||||
k_rope_bf16 = k_fp32[:, NOPE:].bfloat16()
|
||||
k_nope_fp8 = k_nope_fp8.cuda()
|
||||
k_nope_scale = k_nope_scale.cuda()
|
||||
k_rope_bf16 = k_rope_bf16.cuda()
|
||||
|
||||
o_prefill, lse = fmha_mixed_fp8_prefill_raw(
|
||||
q_bf16, k_nope_fp8, k_nope_scale, k_rope_bf16, scale, rope_dim=ROPE)
|
||||
|
||||
# Reference: dequantize, run SDPA per query position
|
||||
nope_dequant = k_nope_fp8.view(torch.float8_e4m3fn).cpu().float() * k_nope_scale.cpu().unsqueeze(-1).float()
|
||||
k_full = torch.cat([nope_dequant, k_fp32[:, NOPE:]], dim=-1).bfloat16().cuda()
|
||||
k_4d = k_full.unsqueeze(0).unsqueeze(0).expand(1, 1, -1, -1)
|
||||
v_4d = k_4d.clone()
|
||||
o_ref = F.scaled_dot_product_attention(q_bf16, k_4d, v_4d, scale=scale)
|
||||
|
||||
cos_val = cosine(o_prefill, o_ref)
|
||||
print(f" cos={cos_val:.6f} |prod|={o_prefill.float().abs().max().item():.4f} "
|
||||
f"|ref|={o_ref.float().abs().max().item():.4f}")
|
||||
if cos_val < 0.999:
|
||||
all_pass = False
|
||||
print(f" FAIL")
|
||||
else:
|
||||
print(f" PASS")
|
||||
except Exception as e:
|
||||
print(f" ERROR: {e}")
|
||||
all_pass = False
|
||||
results["2_t1_vs_sdpa"] = all_pass
|
||||
|
||||
# ---- Test 3: T>1 with attention sinks ----
|
||||
print("\n" + "=" * 70)
|
||||
print("TEST 3: T>1 with attention sinks")
|
||||
print("=" * 70)
|
||||
try:
|
||||
from dsv4.kernels.attention.fmha_mixed_fp8_prefill_op import fmha_mixed_fp8_prefill_raw
|
||||
T = 4; N = 256
|
||||
torch.manual_seed(42)
|
||||
q_fp32 = torch.randn(1, H, T, HD, dtype=torch.float32) * 0.5
|
||||
k_fp32 = torch.randn(N, HD, dtype=torch.float32) * 0.5
|
||||
q_bf16 = q_fp32.bfloat16().cuda()
|
||||
k_nope_fp8, k_nope_scale = quantize_fp8_e4m3(k_fp32[:, :NOPE])
|
||||
k_rope_bf16 = k_fp32[:, NOPE:].bfloat16()
|
||||
k_nope_fp8 = k_nope_fp8.cuda(); k_nope_scale = k_nope_scale.cuda(); k_rope_bf16 = k_rope_bf16.cuda()
|
||||
sink_bias = torch.randn(H, dtype=torch.float32) * 2.0
|
||||
|
||||
o_with, _ = fmha_mixed_fp8_prefill_raw(
|
||||
q_bf16, k_nope_fp8, k_nope_scale, k_rope_bf16, scale,
|
||||
attn_sink=sink_bias, rope_dim=ROPE)
|
||||
o_no, _ = fmha_mixed_fp8_prefill_raw(
|
||||
q_bf16, k_nope_fp8, k_nope_scale, k_rope_bf16, scale, rope_dim=ROPE)
|
||||
diff = (o_with - o_no).float().abs().max().item()
|
||||
print(f" Max diff with/without sink: {diff:.6f}")
|
||||
assert diff > 1e-4, "Sink bias has no effect"
|
||||
results["3_sinks"] = True
|
||||
print(" PASS")
|
||||
except Exception as e:
|
||||
print(f" FAIL: {e}")
|
||||
results["3_sinks"] = False
|
||||
|
||||
# ---- Test 4: Large N ----
|
||||
print("\n" + "=" * 70)
|
||||
print("TEST 4: Large N (production context)")
|
||||
print("=" * 70)
|
||||
all_pass = True
|
||||
for N in [1024, 2048, 4096]:
|
||||
for T in [4, 16]:
|
||||
print(f"\n T={T} N={N}")
|
||||
try:
|
||||
from dsv4.kernels.attention.fmha_mixed_fp8_prefill_op import fmha_mixed_fp8_prefill_raw
|
||||
torch.manual_seed(42)
|
||||
q_fp32 = torch.randn(1, H, T, HD, dtype=torch.float32) * 0.5
|
||||
k_fp32 = torch.randn(N, HD, dtype=torch.float32) * 0.5
|
||||
q_bf16 = q_fp32.bfloat16().cuda()
|
||||
k_nope_fp8, k_nope_scale = quantize_fp8_e4m3(k_fp32[:, :NOPE])
|
||||
k_rope_bf16 = k_fp32[:, NOPE:].bfloat16()
|
||||
k_nope_fp8 = k_nope_fp8.cuda(); k_nope_scale = k_nope_scale.cuda(); k_rope_bf16 = k_rope_bf16.cuda()
|
||||
|
||||
o_prefill, lse = fmha_mixed_fp8_prefill_raw(
|
||||
q_bf16, k_nope_fp8, k_nope_scale, k_rope_bf16, scale, rope_dim=ROPE)
|
||||
|
||||
nope_dequant = k_nope_fp8.view(torch.float8_e4m3fn).cpu().float() * k_nope_scale.cpu().unsqueeze(-1).float()
|
||||
k_full = torch.cat([nope_dequant, k_fp32[:, NOPE:]], dim=-1).bfloat16().cuda()
|
||||
k_4d = k_full.unsqueeze(0).unsqueeze(0).expand(1, 1, -1, -1)
|
||||
v_4d = k_4d.clone()
|
||||
o_ref = F.scaled_dot_product_attention(q_bf16, k_4d, v_4d, scale=scale)
|
||||
|
||||
cos_val = cosine(o_prefill, o_ref)
|
||||
print(f" cos={cos_val:.6f}")
|
||||
if cos_val < 0.999:
|
||||
all_pass = False
|
||||
print(f" FAIL")
|
||||
else:
|
||||
print(f" PASS")
|
||||
except Exception as e:
|
||||
print(f" ERROR: {e}")
|
||||
all_pass = False
|
||||
results["4_large_n"] = all_pass
|
||||
|
||||
# ---- Summary ----
|
||||
print("\n" + "=" * 70)
|
||||
print("SUMMARY")
|
||||
print("=" * 70)
|
||||
all_ok = True
|
||||
for name, passed in results.items():
|
||||
status = "PASS" if passed else "FAIL"
|
||||
if not passed: all_ok = False
|
||||
print(f" {name}: {status}")
|
||||
print()
|
||||
sys.exit(0 if all_ok else 1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
572
tests/unit/test_decode_fmha_layer.py
Normal file
572
tests/unit/test_decode_fmha_layer.py
Normal file
@@ -0,0 +1,572 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Production FMHA layer comparison test — DECODE phase.
|
||||
|
||||
The key difference from test_production_fmha_layer.py:
|
||||
- That test checks FMHA cos during PREFILL (or with random Q after prefill)
|
||||
- This test checks FMHA cos during the FIRST DECODE STEP
|
||||
|
||||
Why this matters:
|
||||
During decode, the KV cache has compressed entries (CSA/HCA) + SWA window.
|
||||
The CSA path uses indexer top-k to select which compressed entries to attend to.
|
||||
The HCA path gathers ALL compressed entries. The SWA-only path has no compression.
|
||||
If the per-layer cos is 0.999993 during prefill but drops during decode,
|
||||
the bug is in the decode-time KV gathering or compressed/SWA parity.
|
||||
|
||||
Strategy:
|
||||
1. Run full production pipeline (single_shot_inference.py forward_layer)
|
||||
for ALL prefill tokens through layers 0-4, populating KV caches.
|
||||
2. Run the FIRST decode token through forward_layer, but capture the
|
||||
production FMHA inputs (q_heads, gathered KV) at each layer.
|
||||
3. For each layer, ALSO run reference FMHA (dequantize KV to BF16, PyTorch SDPA)
|
||||
on the SAME gathered KV that the production kernel saw.
|
||||
4. Compare raw FMHA output (before inverse RoPE, before output projection).
|
||||
|
||||
Production values: HD=512, NOPE=448, ROPE=64, H=128, 61 layers, 8 GPUs.
|
||||
"""
|
||||
import os, sys, json, math, time
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
CHECKPOINT_DIR = os.environ.get(
|
||||
"CHECKPOINT_DIR", "/root/nvidia-meeting/DeepSeek-V4-Pro-NVFP4")
|
||||
NUM_GPUS = int(os.environ.get("NUM_GPUS", "8"))
|
||||
DEVICE = "cuda:0"
|
||||
# How many layers to test (first N layers)
|
||||
TEST_LAYERS = int(os.environ.get("TEST_LAYERS", "5"))
|
||||
|
||||
|
||||
def cosine(a, b):
|
||||
return F.cosine_similarity(a.flatten().float(), b.flatten().float(), dim=0).item()
|
||||
|
||||
|
||||
def main():
|
||||
torch.manual_seed(42)
|
||||
print("=" * 70)
|
||||
print("DECODE FMHA LAYER COMPARISON TEST")
|
||||
print("Tests FMHA accuracy during DECODE (not prefill)")
|
||||
print("=" * 70)
|
||||
|
||||
with open(os.path.join(CHECKPOINT_DIR, "config.json")) as f:
|
||||
cfg = json.load(f)
|
||||
n_layers = cfg["num_hidden_layers"]
|
||||
H = cfg["hidden_size"]
|
||||
hd = cfg["head_dim"]
|
||||
n_h = cfg["num_attention_heads"]
|
||||
rd = cfg.get("qk_rope_head_dim", 64)
|
||||
nope_dim = hd - rd
|
||||
cr = cfg.get("compress_ratios", [128] * n_layers)
|
||||
print(f"Model: {n_layers} layers, {n_h} heads, hd={hd}, rope_dim={rd}, nope_dim={nope_dim}")
|
||||
print(f"Compress ratios (first {TEST_LAYERS}): {cr[:TEST_LAYERS]}")
|
||||
|
||||
# Import from single_shot_inference.py
|
||||
from single_shot_inference import (
|
||||
load_all_weights, make_nvfp4_linear, get_nvfp4_weight,
|
||||
rmsnorm, unweighted_rmsnorm, _apply_rope, build_rope_cache,
|
||||
KVCache, Compressor, Indexer, forward_layer, moe_forward,
|
||||
_load_moe_weights_stacked, _load_shared_expert_weights,
|
||||
_cache_layer_weights_no_experts,
|
||||
)
|
||||
from dsv4.layers.mhc import mHCLayer, mHCContext
|
||||
from dsv4.layers.router import Router
|
||||
from dsv4.layers.moe import Nvfp4MoE
|
||||
from dsv4.layers.shared_expert import Nvfp4SharedExpert
|
||||
from dsv4.layers.grouped_linear import Nvfp4GroupedLinear
|
||||
from dsv4.layers.linear import Nvfp4Linear
|
||||
from dsv4.ops.quantize import (
|
||||
rmsnorm_quantize_nvfp4, mhc_rmsnorm_quantize_nvfp4, dequantize_nvfp4,
|
||||
quantize_to_nvfp4,
|
||||
)
|
||||
|
||||
print("Loading weights...")
|
||||
all_w = load_all_weights(CHECKPOINT_DIR)
|
||||
|
||||
o_groups = cfg.get("o_groups", 16)
|
||||
o_rank = cfg.get("o_lora_rank", 1024)
|
||||
n_ih = cfg.get("index_n_heads", 64)
|
||||
ihd = cfg.get("index_head_dim", 128)
|
||||
itk = cfg.get("index_topk", 1024)
|
||||
|
||||
rope_caches = {g: build_rope_cache(65536, rd, f"cuda:{g}", 10000., "yarn", 16., 4096, 32, 1)
|
||||
for g in range(NUM_GPUS)}
|
||||
|
||||
# Build all production components
|
||||
prod_lins, attn_mhcs, ffn_mhcs = {}, {}, {}
|
||||
attn_norms, ffn_norms = {}, {}
|
||||
compressors, indexers, kv_caches = {}, {}, {}
|
||||
routers, moe_runners, se_runners = {}, {}, {}
|
||||
|
||||
for li in range(TEST_LAYERS):
|
||||
gpu = li % NUM_GPUS
|
||||
dev = f"cuda:{gpu}"
|
||||
torch.cuda.set_device(gpu)
|
||||
pfx = f"model.layers.{li}.self_attn"
|
||||
mlp_pfx = f"model.layers.{li}.mlp"
|
||||
ratio = cr[li] if li < len(cr) else 128
|
||||
|
||||
# Attention linears
|
||||
pl = {}
|
||||
pl['q_a'] = make_nvfp4_linear(H, 1536, dev, all_w, pfx, 'q_a_proj')
|
||||
pl['q_b'] = make_nvfp4_linear(1536, H * hd, dev, all_w, pfx, 'q_b_proj')
|
||||
pl['kv'] = make_nvfp4_linear(H, hd, dev, all_w, pfx, 'kv_proj')
|
||||
hpg = n_h // o_groups
|
||||
wo_a = Nvfp4GroupedLinear(n_local_groups=o_groups, heads_per_group=hpg,
|
||||
head_dim=hd, o_lora_rank=o_rank, max_num_tokens=8192, device=dev)
|
||||
oa_w, oa_ws, oa_ws2, oa_isc = get_nvfp4_weight(all_w, pfx, 'o_a_proj')
|
||||
if oa_w is not None and oa_ws is not None:
|
||||
wo_a.load_nvfp4_weight(oa_w.to(dev), oa_ws.to(dev),
|
||||
oa_ws2.to(dev) if oa_ws2 is not None else None,
|
||||
oa_isc.to(dev) if oa_isc is not None else None)
|
||||
else:
|
||||
oa_bf = all_w.get(f"{pfx}.o_a_proj.weight")
|
||||
if oa_bf is not None:
|
||||
wo_a.set_bf16_weight(oa_bf.bfloat16().to(dev))
|
||||
pl['o_a'] = wo_a; wo_a._use_runtime_gsa = True
|
||||
pl['o_b'] = make_nvfp4_linear(o_groups * o_rank, H, dev, all_w, pfx, 'o_b_proj')
|
||||
prod_lins[li] = pl
|
||||
|
||||
# mHC
|
||||
for tag, blocks, fn_s, base_s, scale_s in [
|
||||
("attn", attn_mhcs, f"model.layers.{li}.attn_hc.fn",
|
||||
f"model.layers.{li}.attn_hc.base", f"model.layers.{li}.attn_hc.scale"),
|
||||
("ffn", ffn_mhcs, f"model.layers.{li}.ffn_hc.fn",
|
||||
f"model.layers.{li}.ffn_hc.base", f"model.layers.{li}.ffn_hc.scale"),
|
||||
]:
|
||||
fn, base, scale = all_w.get(fn_s), all_w.get(base_s), all_w.get(scale_s)
|
||||
if fn is not None and base is not None and scale is not None:
|
||||
m = mHCLayer(hidden_dim=H, n_hc=4, t_max_sinkhorn=20, device=dev)
|
||||
n = 4
|
||||
m.load_weights(
|
||||
W_pre=fn[0:n].to(dev, torch.float32), W_post=fn[n:2*n].to(dev, torch.float32),
|
||||
W_comb=fn[2*n:].to(dev, torch.float32),
|
||||
S_pre=base[0:n].reshape(1, n).to(dev, torch.float32),
|
||||
S_post=base[n:2*n].reshape(n, 1).to(dev, torch.float32),
|
||||
S_comb=base[2*n:].reshape(n, n).to(dev, torch.float32),
|
||||
alpha_pre=scale[0].item(), alpha_post=scale[1].item(), alpha_comb=scale[2].item())
|
||||
blocks[li] = m
|
||||
|
||||
an_k = f"model.layers.{li}.input_layernorm.weight"
|
||||
if an_k in all_w: attn_norms[li] = all_w[an_k].to(dev, torch.float32)
|
||||
fn_k = f"model.layers.{li}.post_attention_layernorm.weight"
|
||||
if fn_k in all_w: ffn_norms[li] = all_w[fn_k].to(dev, torch.float32)
|
||||
|
||||
max_comp = (8192 + ratio - 1) // ratio if ratio > 0 else 0
|
||||
kv_caches[li] = KVCache(hd, cfg.get("sliding_window", 128), max_comp=max_comp,
|
||||
device=dev, indexer_key_dim=ihd, compress_ratio=ratio, indexer_top_k=itk, rope_dim=rd)
|
||||
if ratio > 0: compressors[li] = Compressor(ratio, hd, H, dev)
|
||||
if ratio == 4: indexers[li] = Indexer(n_ih, ihd, itk, dev)
|
||||
|
||||
# Router
|
||||
is_hash = (li < cfg.get("num_hash_layers", 3)) and (f"{mlp_pfx}.gate.tid2eid" in all_w)
|
||||
router = Router(hidden_size=H, num_experts=cfg["n_routed_experts"],
|
||||
top_k=cfg.get("num_experts_per_tok", 6),
|
||||
routed_scaling_factor=cfg.get("routed_scaling_factor", 2.5),
|
||||
mode="hash" if is_hash else "dense",
|
||||
vocab_size=cfg.get("vocab_size", 128000) if is_hash else None, device=dev)
|
||||
if is_hash:
|
||||
router.load_weights(hash_lut=all_w[f"{mlp_pfx}.gate.tid2eid"].to(dev, torch.int32))
|
||||
else:
|
||||
eb = all_w.get(f"{mlp_pfx}.gate.e_score_correction_bias")
|
||||
gate_w, gate_ws, gate_ws2, gate_isc = get_nvfp4_weight(all_w, mlp_pfx, 'gate')
|
||||
E = cfg["n_routed_experts"]
|
||||
if gate_w is not None and gate_ws is not None:
|
||||
gate_lin = Nvfp4Linear(in_features=H, out_features=E, device=dev)
|
||||
gate_lin.fp4 = [gate_w.to(dev).view(torch.float4_e2m1fn_x2) if gate_w.dtype == torch.uint8 else gate_w.to(dev)]
|
||||
gate_lin.sf = [gate_ws.to(dev)]
|
||||
ws2_v = gate_ws2.float().item() if gate_ws2 is not None else 1.0
|
||||
isc_v = gate_isc.float().item() if gate_isc is not None else 1.0/(6.0*448.0)
|
||||
gate_lin.gs = [1.0]
|
||||
gate_lin.ws2 = [torch.tensor([ws2_v], device=dev, dtype=torch.float32)]
|
||||
gate_lin._activation_global_scale = isc_v
|
||||
gate_lin._use_runtime_gsa = True
|
||||
gate_lin.finalize_weights()
|
||||
router.load_nvfp4_gate(gate_lin)
|
||||
router.load_weights(e_bias=eb.to(dev, torch.float32))
|
||||
else:
|
||||
gw = all_w.get(f"{mlp_pfx}.gate.weight")
|
||||
if gw is not None:
|
||||
g_bf16 = gw if gw.shape == (E, H) else gw.T.contiguous()
|
||||
g_bf16 = g_bf16.bfloat16().to(dev)
|
||||
g_fp4, g_sf, g_gs = quantize_to_nvfp4(g_bf16)
|
||||
gate_lin = Nvfp4Linear(in_features=H, out_features=E, device=dev)
|
||||
gate_lin.fp4 = [g_fp4]
|
||||
gate_lin.sf = [g_sf]
|
||||
gate_lin.gs = [g_gs]
|
||||
gate_lin.ws2 = [torch.tensor([g_gs], device=dev, dtype=torch.float32)]
|
||||
gate_lin._activation_global_scale = 1.0 / (6.0 * 448.0)
|
||||
gate_lin._use_runtime_gsa = True
|
||||
gate_lin.finalize_weights()
|
||||
router.load_nvfp4_gate(gate_lin)
|
||||
router.load_weights(e_bias=eb.to(dev, torch.float32))
|
||||
router.finalize_weights(); routers[li] = router
|
||||
|
||||
moe = Nvfp4MoE(num_experts=cfg["n_routed_experts"], hidden_size=H,
|
||||
intermediate_size=cfg.get("moe_intermediate_size", 3072),
|
||||
top_k=cfg.get("num_experts_per_tok", 6), device=dev)
|
||||
moe.set_swiglu_limit(cfg.get("swiglu_limit", 10.0)); moe.set_fused_swiglu(True)
|
||||
_load_moe_weights_stacked(all_w, li, mlp_pfx, dev, moe, cfg)
|
||||
moe._ensure_stacked(); moe._use_runtime_gsa = True; moe_runners[li] = moe
|
||||
|
||||
se = Nvfp4SharedExpert(hidden_size=H, intermediate_size=cfg.get("moe_intermediate_size", 3072),
|
||||
device=dev, swiglu_limit=cfg.get("swiglu_limit", 10.0))
|
||||
se.set_fused_swiglu(True)
|
||||
_load_shared_expert_weights(all_w, li, mlp_pfx, dev, se, cfg)
|
||||
se._ensure_initialized(); se._use_runtime_gsa = True; se_runners[li] = se
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
for li in range(TEST_LAYERS):
|
||||
pfx = f"model.layers.{li}.self_attn.compressor"
|
||||
dev = f"cuda:{li % NUM_GPUS}"
|
||||
if li in compressors: compressors[li].load(all_w, pfx, dev=dev)
|
||||
if li in indexers: indexers[li].load(all_w, f"{pfx}.indexer", dev=dev)
|
||||
print("Components built")
|
||||
|
||||
# Embedding + tokenizer
|
||||
from transformers import AutoTokenizer
|
||||
tokenizer = AutoTokenizer.from_pretrained(CHECKPOINT_DIR)
|
||||
bos = tokenizer.bos_token_id or 0
|
||||
USER_TOKEN, ASSISTANT_TOKEN, THINK_START = 128803, 128804, 128821
|
||||
input_ids = [bos, USER_TOKEN]
|
||||
input_ids += tokenizer.encode('\n\nThe capital of France is', add_special_tokens=False)
|
||||
input_ids.append(ASSISTANT_TOKEN)
|
||||
input_ids.append(THINK_START)
|
||||
print(f"Input: {len(input_ids)} tokens: {input_ids}")
|
||||
|
||||
torch.cuda.set_device(0)
|
||||
embed_w = all_w.get("model.embed_tokens.weight")
|
||||
embed = torch.nn.Embedding.from_pretrained(embed_w.bfloat16().to(DEVICE))
|
||||
devs_list = [f"cuda:{g}" for g in range(NUM_GPUS)]
|
||||
layer_w = _cache_layer_weights_no_experts(all_w, TEST_LAYERS, devs_list)
|
||||
del all_w; import gc; gc.collect()
|
||||
for g in range(NUM_GPUS): torch.cuda.set_device(g); torch.cuda.empty_cache()
|
||||
torch.cuda.set_device(0)
|
||||
|
||||
# ================================================================
|
||||
# PHASE 1: Run full production pipeline to populate KV caches
|
||||
# ================================================================
|
||||
print(f"\n{'='*70}")
|
||||
print("PHASE 1: Populating KV caches (prefill)")
|
||||
print(f"{'='*70}")
|
||||
for pi, tid_val in enumerate(input_ids):
|
||||
t1 = time.time()
|
||||
tid = torch.tensor([tid_val], dtype=torch.long, device=DEVICE)
|
||||
pos = torch.tensor([pi], dtype=torch.long, device=DEVICE)
|
||||
tid32 = torch.tensor([tid_val], dtype=torch.int32, device=DEVICE)
|
||||
|
||||
X = mHCLayer.init_state(embed(tid))
|
||||
for li in range(TEST_LAYERS):
|
||||
gpu = li % NUM_GPUS
|
||||
if X.device != torch.device(f"cuda:{gpu}"): X = X.to(f"cuda:{gpu}")
|
||||
torch.cuda.set_device(gpu)
|
||||
X = forward_layer(X, layer_w[li], li, cfg, *rope_caches[gpu],
|
||||
attn_mhcs.get(li), ffn_mhcs.get(li), attn_norms.get(li), ffn_norms.get(li),
|
||||
kv_caches[li], pos, tid32, compressors.get(li), indexers.get(li),
|
||||
moe_runners.get(li), se_runners.get(li), routers.get(li),
|
||||
prod_lin=prod_lins.get(li), _use_fused_rmsnorm_quantize=True)
|
||||
if pi % 5 == 0:
|
||||
print(f" Token {pi}/{len(input_ids)}: {time.time()-t1:.2f}s", flush=True)
|
||||
|
||||
# Print KV cache state after prefill
|
||||
print(f"\nKV cache state after prefill ({len(input_ids)} tokens):")
|
||||
for li in range(TEST_LAYERS):
|
||||
kc = kv_caches[li]
|
||||
ratio = cr[li] if li < len(cr) else 128
|
||||
print(f" L{li} (ratio={ratio}): n_comp={kc.n_comp} swa_len={kc.swa_len} "
|
||||
f"total_KV={kc.n_comp + kc.swa_len}")
|
||||
|
||||
# ================================================================
|
||||
# PHASE 2: Run ONE decode step, capturing FMHA inputs/outputs
|
||||
# ================================================================
|
||||
print(f"\n{'='*70}")
|
||||
print("PHASE 2: Decode FMHA comparison per layer")
|
||||
print(f"{'='*70}")
|
||||
|
||||
# Use a real next token — the model's own greedy output would require
|
||||
# a full forward pass to get logits. Instead, use a reasonable continuation
|
||||
# token. For "The capital of France is" → the space token or a letter.
|
||||
# Actually, we need to run the FULL decode forward pass (all layers) to get
|
||||
# the actual Q at each layer. So we'll intercept inside forward_attention.
|
||||
#
|
||||
# Approach: duplicate the forward_attention logic, capturing FMHA inputs
|
||||
# at each layer, then compare against reference SDPA.
|
||||
|
||||
# First, we need the hidden state X at the decode position.
|
||||
# We'll re-run the decode step manually, layer by layer, capturing
|
||||
# the production FMHA inputs and comparing against reference.
|
||||
|
||||
# Decode token: use the actual next position
|
||||
decode_pos = len(input_ids)
|
||||
# Use a common token — the " " (space) token
|
||||
decode_tid = tokenizer.encode(" the", add_special_tokens=False)
|
||||
if len(decode_tid) > 0:
|
||||
decode_tid = decode_tid[0]
|
||||
else:
|
||||
decode_tid = tokenizer.convert_tokens_to_ids(" ")
|
||||
print(f"Decode token: id={decode_tid} pos={decode_pos}")
|
||||
|
||||
# Get initial hidden state from embedding
|
||||
dec_tid = torch.tensor([decode_tid], dtype=torch.long, device=DEVICE)
|
||||
dec_tid32 = torch.tensor([decode_tid], dtype=torch.int32, device=DEVICE)
|
||||
dec_pos = torch.tensor([decode_pos], dtype=torch.long, device=DEVICE)
|
||||
|
||||
X = mHCLayer.init_state(embed(dec_tid))
|
||||
print(f"Initial X: shape={tuple(X.shape)} |X|={X.abs().max().item():.4f}")
|
||||
|
||||
results = {}
|
||||
|
||||
for li in range(TEST_LAYERS):
|
||||
gpu = li % NUM_GPUS
|
||||
dev = f"cuda:{gpu}"
|
||||
torch.cuda.set_device(gpu)
|
||||
if X.device != torch.device(f"cuda:{gpu}"): X = X.to(dev)
|
||||
|
||||
ratio = cr[li] if li < len(cr) else 128
|
||||
kc = kv_caches[li]
|
||||
pfx = f"model.layers.{li}.self_attn"
|
||||
scale = 1.0 / math.sqrt(hd)
|
||||
|
||||
# ---- mHC pre_block + rmsnorm (same as forward_layer) ----
|
||||
attn_mhc = attn_mhcs.get(li)
|
||||
ffn_mhc = ffn_mhcs.get(li)
|
||||
attn_norm_w = attn_norms.get(li)
|
||||
ffn_norm_w = ffn_norms.get(li)
|
||||
|
||||
A_l_a, B_l_a, C_l_a = attn_mhc._dynamic_params(X)
|
||||
ctx_a = mHCContext(B_l=B_l_a, C_l=C_l_a)
|
||||
|
||||
# Fused mHC + rmsnorm + NVFP4 quantize (production path)
|
||||
x_quant_attn = mhc_rmsnorm_quantize_nvfp4(
|
||||
X, A_l_a, attn_norm_w.to(dev, torch.float32))
|
||||
x_normed = dequantize_nvfp4(x_quant_attn.x_fp4, x_quant_attn.x_sf, x_quant_attn.gsa)
|
||||
|
||||
# ---- Manually replicate forward_attention to capture FMHA inputs ----
|
||||
T = x_normed.shape[0]
|
||||
pl = prod_lins[li]
|
||||
|
||||
# 1. Q projection
|
||||
q_a = pl['q_a'].run_from_quantized(x_quant_attn)
|
||||
q_norm_w = layer_w[li].get(f"{pfx}.q_a_norm.weight")
|
||||
if q_norm_w is not None:
|
||||
q_a_quant = rmsnorm_quantize_nvfp4(q_a, q_norm_w.to(dev, torch.float32))
|
||||
q_a = dequantize_nvfp4(q_a_quant.x_fp4, q_a_quant.x_sf, q_a_quant.gsa)
|
||||
q = pl['q_b'].run_from_quantized(q_a_quant)
|
||||
else:
|
||||
q = pl['q_b'](q_a)
|
||||
q = unweighted_rmsnorm(q).bfloat16()
|
||||
q_heads = q.reshape(T, n_h, hd)
|
||||
q_heads = _apply_rope(q_heads, dec_pos.to(dev), *rope_caches[gpu][:2], rd)
|
||||
|
||||
# 2. KV projection + cache
|
||||
kv = pl['kv'].run_from_quantized(x_quant_attn)
|
||||
kv_norm_w = layer_w[li].get(f"{pfx}.kv_norm.weight")
|
||||
if kv_norm_w is not None: kv = rmsnorm(kv, kv_norm_w.to(dev, torch.float32))
|
||||
kv_3d = kv.reshape(T, 1, hd)
|
||||
kv_3d = _apply_rope(kv_3d, dec_pos.to(dev), *rope_caches[gpu][:2], rd)
|
||||
kv_roped = kv_3d.reshape(T, hd)
|
||||
kc.append_swa(kv_roped, dec_pos.to(dev))
|
||||
|
||||
# 3. Compressor → compressed KV
|
||||
compressor = compressors.get(li)
|
||||
indexer = indexers.get(li)
|
||||
comp_pos, block_bias = None, None
|
||||
if compressor is not None and compressor.ratio > 0:
|
||||
comp_kv_fp32, comp_pos, block_bias = compressor.forward(x_normed, dec_pos.to(dev))
|
||||
if comp_kv_fp32 is not None:
|
||||
from dsv4.kernels.cuda.loader import get_cuda_module
|
||||
kv_mod = get_cuda_module("kv_quantize", ["kv_quantize.cu"])
|
||||
nope_fp32 = comp_kv_fp32[:, :nope_dim].contiguous()
|
||||
rope_bf16 = comp_kv_fp32[:, nope_dim:].bfloat16().contiguous()
|
||||
rope_3d = rope_bf16.unsqueeze(1)
|
||||
rope_3d = _apply_rope(rope_3d, comp_pos, *rope_caches[gpu][:2], rd)
|
||||
rope_bf16 = rope_3d.squeeze(1)
|
||||
nope_fp8, nope_scale = kv_mod.quantize_fp8_e4m3_from_fp32(nope_fp32)
|
||||
kc.set_compressed_mixed(nope_fp8, nope_scale, rope_bf16, comp_pos)
|
||||
if compressor.is_csa and indexer is not None and indexer.compressor is not None:
|
||||
comp_idx_kv, _, _ = indexer.compressor.forward(x_normed, dec_pos.to(dev))
|
||||
kc.set_indexer_keys_fp8(comp_idx_kv)
|
||||
|
||||
# 4. Indexer top-k (CSA layers)
|
||||
topk_idx = None
|
||||
if indexer is not None and ratio == 4:
|
||||
topk_idx = indexer.forward(q_a, x_normed, kc, dec_pos.to(dev), layer_idx=li)
|
||||
if topk_idx is not None:
|
||||
print(f" L{li} CSA: indexer topk shape={tuple(topk_idx.shape)} "
|
||||
f"range=[{topk_idx.min().item()}, {topk_idx.max().item()}] "
|
||||
f"n_comp={kc.n_comp}", flush=True)
|
||||
|
||||
# 5. Gather KV — same logic as forward_attention
|
||||
swa_kv, _swa_pos = kc.get_swa()
|
||||
swa_len = swa_kv.shape[0]
|
||||
|
||||
if kc.n_comp > 0:
|
||||
if ratio == 4:
|
||||
# CSA: gather top-k compressed rows
|
||||
assert topk_idx is not None, f"CSA layer {li}: indexer returned no top-k"
|
||||
tk = topk_idx[0].clamp(0, kc.n_comp - 1).int()
|
||||
kv_nope_fp8, kv_nope_scale, kv_rope_bf16 = kc.gather_mixed_selective(tk)
|
||||
gather_mode = f"CSA top-k ({tk.numel()} comp + {swa_len} SWA)"
|
||||
elif ratio > 4:
|
||||
kv_nope_fp8, kv_nope_scale, kv_rope_bf16 = kc.gather_mixed_all()
|
||||
gather_mode = f"HCA all ({kc.n_comp} comp + {swa_len} SWA)"
|
||||
else:
|
||||
kv_nope_fp8, kv_nope_scale, kv_rope_bf16 = kc.gather_mixed_swa_only()
|
||||
gather_mode = f"SWA-only ({swa_len} SWA)"
|
||||
else:
|
||||
kv_nope_fp8, kv_nope_scale, kv_rope_bf16 = kc.gather_mixed_swa_only()
|
||||
gather_mode = f"SWA-only ({swa_len} SWA)"
|
||||
|
||||
seq_len = kv_nope_scale.shape[0]
|
||||
if seq_len == 0:
|
||||
print(f" L{li}: SKIPPED (seq_len=0)")
|
||||
continue
|
||||
|
||||
print(f" L{li}: {gather_mode} → seq_len={seq_len}", flush=True)
|
||||
|
||||
# 6. Run production mixed FP8 FMHA
|
||||
from dsv4.kernels.attention.fmha_mixed_fp8_op import fmha_mixed_fp8_decode_raw
|
||||
q_4d = q_heads.permute(1, 0, 2).unsqueeze(0).contiguous() # (1, n_h, T, hd)
|
||||
sinks = layer_w[li].get(f"{pfx}.sinks")
|
||||
sink_bias = None
|
||||
if sinks is not None:
|
||||
sink_bias = sinks.to(device=dev).float().reshape(n_h)
|
||||
|
||||
try:
|
||||
o_prod_4d, lse_prod = fmha_mixed_fp8_decode_raw(
|
||||
q_4d, kv_nope_fp8, kv_nope_scale, kv_rope_bf16,
|
||||
scale, attn_sink=sink_bias, rope_dim=rd)
|
||||
except Exception as e:
|
||||
print(f" L{li}: PROD FMHA FAILED: {e}")
|
||||
results[li] = {'cos': -1.0, 'error': str(e)}
|
||||
continue
|
||||
o_prod = o_prod_4d.squeeze(0) # (n_h, T, hd)
|
||||
|
||||
# 7. Reference: dequantize mixed KV to BF16, run reference with sink bias
|
||||
nope_dequant = kv_nope_fp8.view(torch.float8_e4m3fn).float() * kv_nope_scale.unsqueeze(-1).float()
|
||||
kv_full = torch.cat([nope_dequant.bfloat16(), kv_rope_bf16], dim=-1) # (N, hd)
|
||||
k_4d = kv_full.unsqueeze(0).unsqueeze(0).expand(1, 1, -1, -1) # (1, 1, N, hd)
|
||||
v_4d = k_4d.clone()
|
||||
if sink_bias is not None:
|
||||
# DSV4 sink is denominator-only: O = sum(P*V) / (sum(P) + exp(sb))
|
||||
# where P = softmax(QK*scale). The sink has NO V contribution.
|
||||
# Reference: compute O_no_sink, then scale by correction factor.
|
||||
q_ref = q_4d.float() # (1, H, T, hd)
|
||||
k_ref = k_4d.float().expand(1, n_h, -1, -1) # (1, H, N, hd)
|
||||
v_ref = v_4d.float().expand(1, n_h, -1, -1) # (1, H, N, hd)
|
||||
scores = torch.matmul(q_ref, k_ref.transpose(-2, -1)) * scale # (1, H, T, N)
|
||||
# O_no_sink = softmax(scores) @ V
|
||||
O_no_sink = F.softmax(scores, dim=-1) @ v_ref # (1, H, T, hd)
|
||||
# Correction: O_with_sink = O_no_sink * Z / (Z + exp(sb))
|
||||
# Z = sum(exp(scores - max)) per head, but more conveniently:
|
||||
# Z / (Z + exp(sb)) = 1 / (1 + exp(sb) / Z) = 1 / (1 + exp(sb - log(Z)))
|
||||
# log(Z) = logsumexp(scores)
|
||||
lse = torch.logsumexp(scores, dim=-1, keepdim=True) # (1, H, T, 1)
|
||||
# sb shape: (n_h,) → (1, n_h, 1, 1)
|
||||
sb_4d = sink_bias.reshape(1, n_h, 1, 1)
|
||||
correction = 1.0 / (1.0 + torch.exp(sb_4d - lse))
|
||||
o_ref_4d = (O_no_sink * correction).bfloat16() # (1, H, T, hd)
|
||||
else:
|
||||
o_ref_4d = F.scaled_dot_product_attention(q_4d, k_4d, v_4d, scale=scale) # (1, H, T, hd)
|
||||
o_ref = o_ref_4d.squeeze(0) # (n_h, T, hd)
|
||||
|
||||
# 8. Compare
|
||||
cos_val = cosine(o_prod, o_ref)
|
||||
mag_prod = o_prod.float().abs().max().item()
|
||||
mag_ref = o_ref.float().abs().max().item()
|
||||
|
||||
# Per-head cosine AND magnitude ratio
|
||||
o_prod_h = o_prod.float().squeeze(1) # (n_h, hd)
|
||||
o_ref_h = o_ref.float().squeeze(1)
|
||||
per_head_cos = F.cosine_similarity(o_prod_h, o_ref_h, dim=-1)
|
||||
per_head_mag_prod = o_prod_h.abs().max(dim=-1).values # (n_h,)
|
||||
per_head_mag_ref = o_ref_h.abs().max(dim=-1).values # (n_h,)
|
||||
per_head_mag_ratio = (per_head_mag_prod / (per_head_mag_ref + 1e-8)) # (n_h,)
|
||||
min_head = per_head_cos.min().item()
|
||||
mean_head = per_head_cos.mean().item()
|
||||
worst_heads = per_head_cos.argsort()[:5]
|
||||
# Find heads with worst magnitude ratio
|
||||
worst_mag = per_head_mag_ratio.sub(1.0).abs().argsort(descending=True)[:5]
|
||||
|
||||
results[li] = {
|
||||
'cos': cos_val, 'mag_prod': mag_prod, 'mag_ref': mag_ref,
|
||||
'seq_len': seq_len, 'ratio': ratio, 'gather_mode': gather_mode,
|
||||
'n_comp': kc.n_comp, 'swa_len': swa_len,
|
||||
'min_head_cos': min_head, 'mean_head_cos': mean_head,
|
||||
}
|
||||
|
||||
status = "PASS" if cos_val >= 0.999 else "FAIL"
|
||||
print(f" L{li}: {status} cos={cos_val:.6f} min_head={min_head:.6f} mean_head={mean_head:.6f} "
|
||||
f"|prod|={mag_prod:.4f} |ref|={mag_ref:.4f} seq={seq_len} {gather_mode}", flush=True)
|
||||
if cos_val < 0.999:
|
||||
cos_list = [f'{c:.4f}' for c in per_head_cos[worst_heads].tolist()]
|
||||
mag_list = [f'{r:.4f}' for r in per_head_mag_ratio[worst_mag].tolist()]
|
||||
print(f" Worst heads (cos): {worst_heads.tolist()} cos={cos_list}")
|
||||
print(f" Worst heads (mag): {worst_mag.tolist()} ratio={mag_list}")
|
||||
print(f" Mag ratio range: [{per_head_mag_ratio.min().item():.4f}, {per_head_mag_ratio.max().item():.4f}]")
|
||||
|
||||
# ---- Continue through the rest of the layer (so subsequent layers get correct X) ----
|
||||
# Apply inverse RoPE to production output
|
||||
attn_out = o_prod.permute(1, 0, 2) # (T, n_h, hd)
|
||||
attn_out = _apply_rope(attn_out, dec_pos.to(dev), *rope_caches[gpu][:2], rd, inverse=True)
|
||||
|
||||
# Output projection
|
||||
wo_a_lin = pl.get('o_a')
|
||||
if wo_a_lin is not None:
|
||||
g_3d = wo_a_lin.run(attn_out)
|
||||
g_flat = g_3d.reshape(T, -1)
|
||||
F_attn = pl['o_b'](g_flat)
|
||||
else:
|
||||
hpg_fb = n_h // o_groups; gid_fb = hpg_fb * hd
|
||||
oa_full = layer_w[li].get(f"{pfx}.o_a_proj.weight")
|
||||
if oa_full is not None:
|
||||
oa_bf = oa_full.bfloat16().to(dev); a_flat = attn_out.reshape(T, n_h * hd)
|
||||
a_grp = a_flat.reshape(T, o_groups, gid_fb); oa_3d = oa_bf.reshape(o_groups, o_rank, gid_fb)
|
||||
g_out = torch.bmm(a_grp.permute(1, 0, 2), oa_3d.transpose(1, 2))
|
||||
g_flat = g_out.permute(1, 0, 2).reshape(T, o_groups * o_rank)
|
||||
F_attn = pl['o_b'](g_flat)
|
||||
else:
|
||||
F_attn = torch.zeros(T, H, dtype=torch.bfloat16, device=dev)
|
||||
|
||||
# mHC post_block
|
||||
X_mid = attn_mhc.post_block(X, F_attn, ctx_a)
|
||||
|
||||
# FFN mHC + MoE
|
||||
A_l_f, B_l_f, C_l_f = ffn_mhc._dynamic_params(X_mid)
|
||||
ctx_f = mHCContext(B_l=B_l_f, C_l=C_l_f)
|
||||
x_quant_ffn = mhc_rmsnorm_quantize_nvfp4(
|
||||
X_mid, A_l_f, ffn_norm_w.to(dev, torch.float32))
|
||||
x_ffn = dequantize_nvfp4(x_quant_ffn.x_fp4, x_quant_ffn.x_sf, x_quant_ffn.gsa)
|
||||
F_ffn = moe_forward(x_ffn, li, moe_runners.get(li), se_runners.get(li),
|
||||
routers.get(li), dec_tid32.to(dev))
|
||||
X = ffn_mhc.post_block(X_mid, F_ffn, ctx_f)
|
||||
|
||||
# ================================================================
|
||||
# Summary
|
||||
# ================================================================
|
||||
print(f"\n{'='*70}")
|
||||
print("DECODE FMHA COMPARISON SUMMARY")
|
||||
print(f"{'='*70}")
|
||||
all_pass = True
|
||||
for li in sorted(results.keys()):
|
||||
r = results[li]
|
||||
c = r.get('cos', -1.0)
|
||||
status = "PASS" if c >= 0.999 else "FAIL"
|
||||
if c < 0.999: all_pass = False
|
||||
print(f" L{li}: {status} cos={c:.6f} seq={r.get('seq_len','?')} "
|
||||
f"mode={r.get('gather_mode','?')} "
|
||||
f"n_comp={r.get('n_comp','?')} swa={r.get('swa_len','?')}")
|
||||
|
||||
print()
|
||||
if all_pass:
|
||||
print("ALL DECODE LAYERS PASSED (cos >= 0.999)")
|
||||
else:
|
||||
print("SOME DECODE LAYERS FAILED — investigate KV gathering or compressed/SWA parity")
|
||||
print()
|
||||
print("If prefill cos was 0.999993 but decode cos < 0.999:")
|
||||
print(" → Bug is in decode-time KV gathering or compressed/SWA parity")
|
||||
print(" → Check: gather_mixed_selective (CSA), gather_mixed_all (HCA)")
|
||||
print(" → Check: SWA positions vs compressed positions (causality)")
|
||||
print(" → Check: indexer top-k indices validity")
|
||||
return 0 if all_pass else 1
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(main())
|
||||
168
tests/unit/test_degeneration_1_chat_template.py
Normal file
168
tests/unit/test_degeneration_1_chat_template.py
Normal file
@@ -0,0 +1,168 @@
|
||||
#!/usr/bin/env python3
|
||||
"""DEGENERATION TEST 1 v2 — Chat-template token-ID diff using official encoding.
|
||||
|
||||
Uses the official DeepSeek V4 encoding reference from encoding/encoding_dsv4.py
|
||||
to build the canonical prompt, then diffs against our hand-rolled construction.
|
||||
|
||||
Official format (from DeepSeek-V4-Pro/encoding/README.md):
|
||||
Thinking mode: <BOS>{system}<|User|>{msg}<|Assistant|>ately{reasoning}heroically{response}<EOS>
|
||||
Chat mode: <BOS>{system}<|User|>{msg}<|Assistant|>heroically{response}<EOS>
|
||||
|
||||
Key differences from our hand-rolled:
|
||||
1. No \n\n between User token and content
|
||||
2. System prompt goes directly after BOS (no User token for system)
|
||||
"""
|
||||
import os, sys
|
||||
|
||||
CHECKPOINT_DIR = os.environ.get("CHECKPOINT_DIR", "/root/nvidia-meeting/DeepSeek-V4-Pro-NVFP4")
|
||||
PROMPT = os.environ.get("TEST_PROMPT", "The capital of France is")
|
||||
|
||||
THINK_START, THINK_END = 128821, 128822
|
||||
USER_TOKEN, ASSISTANT_TOKEN = 128803, 128804
|
||||
|
||||
def main():
|
||||
from transformers import AutoTokenizer
|
||||
print("=" * 70)
|
||||
print("DEGENERATION TEST 1 v2 — Official encoding diff")
|
||||
print("=" * 70)
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(CHECKPOINT_DIR)
|
||||
bos = tokenizer.bos_token_id or 0
|
||||
|
||||
# === 1. Hand-rolled (current single_shot_inference.py) ===
|
||||
input_ids = [bos, USER_TOKEN]
|
||||
input_ids += tokenizer.encode('\n\n' + PROMPT, add_special_tokens=False)
|
||||
input_ids.append(ASSISTANT_TOKEN)
|
||||
input_ids.append(THINK_START)
|
||||
|
||||
print(f"\n1. HAND-ROLLED ({len(input_ids)} tokens):")
|
||||
for i, tid in enumerate(input_ids):
|
||||
print(f" [{i:3d}] id={tid:>7d} {repr(tokenizer.decode([tid]))}")
|
||||
print(f" Full: {repr(tokenizer.decode(input_ids))}")
|
||||
|
||||
# === 2. Official encoding (thinking mode, no system prompt) ===
|
||||
# Format: <BOS><|User|>{msg}<|Assistant|>ately
|
||||
# NO \n\n between User token and message
|
||||
canonical_thinking = [bos, USER_TOKEN]
|
||||
canonical_thinking += tokenizer.encode(PROMPT, add_special_tokens=False)
|
||||
canonical_thinking.append(ASSISTANT_TOKEN)
|
||||
canonical_thinking.append(THINK_START)
|
||||
|
||||
print(f"\n2. OFFICIAL (thinking, no \\n\\n) ({len(canonical_thinking)} tokens):")
|
||||
for i, tid in enumerate(canonical_thinking):
|
||||
print(f" [{i:3d}] id={tid:>7d} {repr(tokenizer.decode([tid]))}")
|
||||
print(f" Full: {repr(tokenizer.decode(canonical_thinking))}")
|
||||
|
||||
# === 3. Official encoding (chat mode — THINK_END closes thinking) ===
|
||||
canonical_chat = [bos, USER_TOKEN]
|
||||
canonical_chat += tokenizer.encode(PROMPT, add_special_tokens=False)
|
||||
canonical_chat.append(ASSISTANT_TOKEN)
|
||||
canonical_chat.append(THINK_END)
|
||||
|
||||
print(f"\n3. OFFICIAL (chat mode, THINK_END) ({len(canonical_chat)} tokens):")
|
||||
for i, tid in enumerate(canonical_chat):
|
||||
print(f" [{i:3d}] id={tid:>7d} {repr(tokenizer.decode([tid]))}")
|
||||
print(f" Full: {repr(tokenizer.decode(canonical_chat))}")
|
||||
|
||||
# === 4. Official encoding with system prompt ===
|
||||
# Format: <BOS>{system}<|User|>{msg}<|Assistant|>ately
|
||||
system_prompt = "You are a helpful assistant."
|
||||
canonical_sys = tokenizer.encode(system_prompt, add_special_tokens=False)
|
||||
canonical_sys_thinking = [bos] + canonical_sys + [USER_TOKEN]
|
||||
canonical_sys_thinking += tokenizer.encode(PROMPT, add_special_tokens=False)
|
||||
canonical_sys_thinking.append(ASSISTANT_TOKEN)
|
||||
canonical_sys_thinking.append(THINK_START)
|
||||
|
||||
print(f"\n4. OFFICIAL (thinking + system prompt) ({len(canonical_sys_thinking)} tokens):")
|
||||
for i, tid in enumerate(canonical_sys_thinking):
|
||||
print(f" [{i:3d}] id={tid:>7d} {repr(tokenizer.decode([tid]))}")
|
||||
print(f" Full: {repr(tokenizer.decode(canonical_sys_thinking))}")
|
||||
|
||||
# === 5. Diff ===
|
||||
print(f"\n{'='*70}")
|
||||
print("DIFF: hand-rolled vs official (thinking, no \\n\\n)")
|
||||
print(f"{'='*70}")
|
||||
if input_ids == canonical_thinking:
|
||||
print(" IDENTICAL")
|
||||
else:
|
||||
print(f" DIFFERENT: hand_rolled={len(input_ids)} tokens, canonical={len(canonical_thinking)} tokens")
|
||||
min_len = min(len(input_ids), len(canonical_thinking))
|
||||
for i in range(min_len):
|
||||
if input_ids[i] != canonical_thinking[i]:
|
||||
print(f" First diff at position {i}:")
|
||||
print(f" hand_rolled[{i}] = {input_ids[i]} ({repr(tokenizer.decode([input_ids[i]]))})")
|
||||
print(f" canonical[{i}] = {canonical_thinking[i]} ({repr(tokenizer.decode([canonical_thinking[i]]))})")
|
||||
# Show context
|
||||
for j in range(max(0,i-2), min(len(input_ids), i+3)):
|
||||
hr = input_ids[j] if j < len(input_ids) else "—"
|
||||
cn = canonical_thinking[j] if j < len(canonical_thinking) else "—"
|
||||
mark = " <<<" if j == i else ""
|
||||
print(f" [{j}] hand={hr} canon={cn}{mark}")
|
||||
break
|
||||
else:
|
||||
if len(input_ids) != len(canonical_thinking):
|
||||
print(f" Same prefix but different lengths: {len(input_ids)} vs {len(canonical_thinking)}")
|
||||
longer = input_ids if len(input_ids) > len(canonical_thinking) else canonical_thinking
|
||||
shorter_len = min(len(input_ids), len(canonical_thinking))
|
||||
label = "hand_rolled" if len(input_ids) > len(canonical_thinking) else "canonical"
|
||||
for j in range(shorter_len, len(longer)):
|
||||
print(f" Extra in {label}: [{j}] = {longer[j]} ({repr(tokenizer.decode([longer[j]]))})")
|
||||
|
||||
# === 6. The key question: does the \n\n matter? ===
|
||||
# Check what token 271 decodes to (it's our \n\n)
|
||||
print(f"\n{'='*70}")
|
||||
print("ANALYSIS")
|
||||
print(f"{'='*70}")
|
||||
# The only difference should be the \n\n token (id 271) in hand-rolled
|
||||
# Check if tokenizer encodes PROMPT differently with/without \n\n prefix
|
||||
enc_with_prefix = tokenizer.encode('\n\n' + PROMPT, add_special_tokens=False)
|
||||
enc_no_prefix = tokenizer.encode(PROMPT, add_special_tokens=False)
|
||||
print(f" encode('\\n\\n' + PROMPT) = {enc_with_prefix} ({len(enc_with_prefix)} tokens)")
|
||||
print(f" encode(PROMPT) = {enc_no_prefix} ({len(enc_no_prefix)} tokens)")
|
||||
if len(enc_with_prefix) > len(enc_no_prefix):
|
||||
diff_tokens = enc_with_prefix[:len(enc_with_prefix) - len(enc_no_prefix)]
|
||||
print(f" Extra tokens from \\n\\n: {diff_tokens}")
|
||||
for t in diff_tokens:
|
||||
print(f" id={t}: {repr(tokenizer.decode([t]))}")
|
||||
# Check if the remaining tokens match
|
||||
if enc_with_prefix[len(diff_tokens):] == enc_no_prefix:
|
||||
print(f" Remaining tokens MATCH — \\n\\n only adds prefix tokens")
|
||||
else:
|
||||
print(f" WARNING: remaining tokens DIFFER — \\n\\n changes tokenization!")
|
||||
print(f" with prefix tail: {enc_with_prefix[len(diff_tokens):]}")
|
||||
print(f" without prefix: {enc_no_prefix}")
|
||||
|
||||
# === 7. What does SGLang use? ===
|
||||
# From the SGLang docs: --reasoning-parser deepseek-v4 and SGLANG_DEFAULT_THINKING=1
|
||||
# This should use the same encoding. Let's check the raw tokenizer.json for special tokens
|
||||
print(f"\n{'='*70}")
|
||||
print("SPECIAL TOKEN INVENTORY")
|
||||
print(f"{'='*70}")
|
||||
if hasattr(tokenizer, 'added_tokens_decoder'):
|
||||
for tid_str, tok in sorted(tokenizer.added_tokens_decoder.items(), key=lambda x: int(x[0])):
|
||||
tid = int(tid_str)
|
||||
s = str(tok)
|
||||
if tid >= 128000 or any(x in s.lower() for x in ['think', 'user', 'assistant', 'end', 'sentence', 'dsml']):
|
||||
print(f" id={tid:>7d}: {repr(s)}")
|
||||
if hasattr(tokenizer, 'special_tokens_map'):
|
||||
for k, v in tokenizer.special_tokens_map.items():
|
||||
tid = tokenizer.convert_tokens_to_ids(v) if isinstance(v, str) else '—'
|
||||
print(f" special: {k} = {repr(v)} (id={tid})")
|
||||
|
||||
# === 8. Verdict ===
|
||||
print(f"\n{'='*70}")
|
||||
print("VERDICT")
|
||||
print(f"{'='*70}")
|
||||
if input_ids == canonical_thinking:
|
||||
print(" Hand-rolled matches official thinking-mode encoding.")
|
||||
print(" Prompt is CORRECT per the official spec.")
|
||||
print(" Degeneration is NOT caused by prompt format → look at Test 2.")
|
||||
else:
|
||||
print(" Hand-rolled DIFFERS from official encoding!")
|
||||
print(" This is likely contributing to degenerate output.")
|
||||
print(" FIX: Use canonical_thinking encoding in single_shot_inference.py.")
|
||||
print(f" Also try: canonical_chat (THINK_END after Assistant) for non-reasoning mode.")
|
||||
print(f" Also try: canonical_sys_thinking (with system prompt).")
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
397
tests/unit/test_degeneration_2_mhc_falsify.py
Normal file
397
tests/unit/test_degeneration_2_mhc_falsify.py
Normal file
@@ -0,0 +1,397 @@
|
||||
#!/usr/bin/env python3
|
||||
"""DEGENERATION TEST 2 — Falsify the mHC "root cause".
|
||||
|
||||
Claim: "|X|=860 compresses the logit range so the model can't distinguish tokens."
|
||||
Test: RMSNorm is scale-invariant, so |X|=860 and |X|=8 should give the same logits.
|
||||
If they differ, the final norm is missing/broken, NOT mHC.
|
||||
|
||||
This test runs single_shot_inference.py with a monkey-patch that intercepts
|
||||
the final-layer residual and does the scale-invariance comparison.
|
||||
"""
|
||||
import os, sys, time
|
||||
|
||||
CHECKPOINT_DIR = os.environ.get("CHECKPOINT_DIR", "/root/nvidia-meeting/DeepSeek-V4-Pro-NVFP4")
|
||||
|
||||
def main():
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
# We'll import single_shot and monkey-patch the decode loop to capture X
|
||||
# after all layers and before hc_head/final_norm/lm_head.
|
||||
# Then we do the scale-invariance test on the captured X.
|
||||
|
||||
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
|
||||
# Load everything through single_shot's infrastructure
|
||||
# Strategy: import single_shot, call its setup functions, then do our own decode
|
||||
# with interception at the hc_head point.
|
||||
|
||||
import json
|
||||
from pathlib import Path
|
||||
from single_shot_inference import (
|
||||
load_all_weights, build_rope_cache, rmsnorm, unweighted_rmsnorm,
|
||||
HcHead, KVCache, Compressor, Indexer,
|
||||
make_nvfp4_linear, get_nvfp4_weight,
|
||||
forward_layer, moe_forward,
|
||||
_cache_layer_weights_no_experts,
|
||||
_load_moe_weights_stacked, _load_shared_expert_weights,
|
||||
FP4_LUT, HC_EPS, THINK_START, THINK_END, USER_TOKEN, ASSISTANT_TOKEN,
|
||||
kill_stale_gpu_processes,
|
||||
)
|
||||
from dsv4.layers.mhc import mHCLayer
|
||||
from dsv4.layers.router import Router
|
||||
from dsv4.layers.linear import Nvfp4Linear
|
||||
from dsv4.layers.grouped_linear import Nvfp4GroupedLinear
|
||||
from dsv4.layers.moe import Nvfp4MoE
|
||||
from dsv4.layers.shared_expert import Nvfp4SharedExpert
|
||||
from dsv4.ops.quantize import quantize_weight_to_nvfp4, quantize_to_nvfp4
|
||||
|
||||
NUM_GPUS = 8
|
||||
PROMPT = "The capital of France is"
|
||||
HIDDEN = 7168
|
||||
|
||||
print("=" * 70)
|
||||
print("DEGENERATION TEST 2 — Falsify mHC residual growth root cause")
|
||||
print("=" * 70)
|
||||
|
||||
t0 = time.time(); torch.manual_seed(42)
|
||||
|
||||
with open(os.path.join(CHECKPOINT_DIR, "config.json")) as f:
|
||||
cfg = json.load(f)
|
||||
n_layers = cfg["num_hidden_layers"]; H = cfg["hidden_size"]
|
||||
hd = cfg["head_dim"]; n_h = cfg["num_attention_heads"]
|
||||
rd = cfg.get("qk_rope_head_dim", 64)
|
||||
cr = cfg.get("compress_ratios", [128] * n_layers)
|
||||
print(f"Model: {n_layers} layers, {n_h} heads, hd={hd}")
|
||||
|
||||
# Load weights
|
||||
print(f"\nLoading weights..."); all_w = load_all_weights(CHECKPOINT_DIR)
|
||||
kill_stale_gpu_processes()
|
||||
for g in range(NUM_GPUS): torch.cuda.set_device(g); torch.cuda.empty_cache()
|
||||
torch.cuda.set_device(0)
|
||||
|
||||
# Build mHC + norms
|
||||
attn_mhcs, ffn_mhcs, attn_norms, ffn_norms = {}, {}, {}, {}
|
||||
for li in range(n_layers):
|
||||
dev = f"cuda:{li % NUM_GPUS}"
|
||||
for tag, blocks, fn_s, base_s, scale_s in [
|
||||
("attn", attn_mhcs, f"model.layers.{li}.attn_hc.fn", f"model.layers.{li}.attn_hc.base", f"model.layers.{li}.attn_hc.scale"),
|
||||
("ffn", ffn_mhcs, f"model.layers.{li}.ffn_hc.fn", f"model.layers.{li}.ffn_hc.base", f"model.layers.{li}.ffn_hc.scale"),
|
||||
]:
|
||||
fn, base, scale = all_w.get(fn_s), all_w.get(base_s), all_w.get(scale_s)
|
||||
if fn is not None and base is not None and scale is not None:
|
||||
m = mHCLayer(hidden_dim=H, n_hc=4, t_max_sinkhorn=20, device=dev)
|
||||
n = 4
|
||||
m.load_weights(
|
||||
W_pre=fn[0:n].to(dev, torch.float32), W_post=fn[n:2*n].to(dev, torch.float32),
|
||||
W_comb=fn[2*n:].to(dev, torch.float32),
|
||||
S_pre=base[0:n].reshape(1, n).to(dev, torch.float32),
|
||||
S_post=base[n:2*n].reshape(n, 1).to(dev, torch.float32),
|
||||
S_comb=base[2*n:].reshape(n, n).to(dev, torch.float32),
|
||||
alpha_pre=scale[0].item(), alpha_post=scale[1].item(), alpha_comb=scale[2].item(),
|
||||
)
|
||||
blocks[li] = m
|
||||
an_k = f"model.layers.{li}.input_layernorm.weight"
|
||||
if an_k in all_w: attn_norms[li] = all_w[an_k].to(dev, torch.float32)
|
||||
fn_k = f"model.layers.{li}.post_attention_layernorm.weight"
|
||||
if fn_k in all_w: ffn_norms[li] = all_w[fn_k].to(dev, torch.float32)
|
||||
|
||||
# Attention linears
|
||||
prod_lins = {}
|
||||
for li in range(n_layers):
|
||||
dev = f"cuda:{li % NUM_GPUS}"; pfx = f"model.layers.{li}.self_attn"
|
||||
torch.cuda.set_device(li % NUM_GPUS)
|
||||
pl = {}
|
||||
pl['q_a'] = make_nvfp4_linear(7168, 1536, dev, all_w, pfx, 'q_a_proj')
|
||||
pl['q_b'] = make_nvfp4_linear(1536, 65536, dev, all_w, pfx, 'q_b_proj')
|
||||
pl['kv'] = make_nvfp4_linear(7168, 512, dev, all_w, pfx, 'kv_proj')
|
||||
n_local_groups = cfg.get('o_groups', 16)
|
||||
heads_per_group = n_h // n_local_groups
|
||||
o_rank_val = cfg.get('o_lora_rank', 1024)
|
||||
wo_a = Nvfp4GroupedLinear(n_local_groups=n_local_groups, heads_per_group=heads_per_group,
|
||||
head_dim=hd, o_lora_rank=o_rank_val, max_num_tokens=8192, device=dev)
|
||||
oa_w_nvfp4, oa_ws, oa_ws2, oa_isc = get_nvfp4_weight(all_w, pfx, 'o_a_proj')
|
||||
if oa_w_nvfp4 is not None and oa_ws is not None:
|
||||
wo_a.load_nvfp4_weight(oa_w_nvfp4.to(dev), oa_ws.to(dev),
|
||||
oa_ws2.to(dev) if oa_ws2 is not None else None,
|
||||
oa_isc.to(dev) if oa_isc is not None else None)
|
||||
else:
|
||||
oa_bf = all_w.get(f"{pfx}.o_a_proj.weight")
|
||||
if oa_bf is not None: wo_a.set_bf16_weight(oa_bf.bfloat16().to(dev))
|
||||
pl['o_a'] = wo_a; wo_a._use_runtime_gsa = True
|
||||
pl['o_b'] = make_nvfp4_linear(16384, 7168, dev, all_w, pfx, 'o_b_proj')
|
||||
prod_lins[li] = pl
|
||||
|
||||
# Routers, MoE, shared experts
|
||||
routers, moe_runners, se_runners = {}, {}, {}
|
||||
for li in range(n_layers):
|
||||
dev = f"cuda:{li % NUM_GPUS}"; pfx = f"model.layers.{li}.mlp"
|
||||
torch.cuda.set_device(li % NUM_GPUS)
|
||||
is_hash = (li < cfg.get("num_hash_layers", 3)) and (f"{pfx}.gate.tid2eid" in all_w)
|
||||
router = Router(hidden_size=H, num_experts=cfg["n_routed_experts"],
|
||||
top_k=cfg.get("num_experts_per_tok", 6),
|
||||
routed_scaling_factor=cfg.get("routed_scaling_factor", 2.5),
|
||||
mode="hash" if is_hash else "dense",
|
||||
vocab_size=cfg.get("vocab_size", 128000) if is_hash else None, device=dev)
|
||||
if is_hash:
|
||||
router.load_weights(hash_lut=all_w[f"{pfx}.gate.tid2eid"].to(dev, torch.int32))
|
||||
else:
|
||||
eb = all_w.get(f"{pfx}.gate.e_score_correction_bias")
|
||||
gate_w, gate_ws, gate_ws2, gate_isc = get_nvfp4_weight(all_w, pfx, 'gate')
|
||||
E = cfg["n_routed_experts"]
|
||||
if gate_w is not None and gate_ws is not None:
|
||||
gate_lin = Nvfp4Linear(in_features=H, out_features=E, device=dev)
|
||||
gate_w_view = gate_w.to(dev).view(torch.float4_e2m1fn_x2) if gate_w.dtype == torch.uint8 else gate_w.to(dev)
|
||||
gate_lin.fp4 = [gate_w_view]; gate_lin.sf = [gate_ws.to(dev)]
|
||||
ws2_v = gate_ws2.float().item() if gate_ws2 is not None else 1.0
|
||||
isc_v = gate_isc.float().item() if gate_isc is not None else 1.0/(6.0*448.0)
|
||||
gate_lin.gs = [1.0]; gate_lin.ws2 = [torch.tensor([ws2_v], device=dev, dtype=torch.float32)]
|
||||
gate_lin._activation_global_scale = isc_v; gate_lin._use_runtime_gsa = True
|
||||
gate_lin.finalize_weights(); router.load_nvfp4_gate(gate_lin)
|
||||
router.load_weights(e_bias=eb.to(dev, torch.float32))
|
||||
else:
|
||||
gw = all_w.get(f"{pfx}.gate.weight")
|
||||
if gw is not None:
|
||||
g_bf16 = gw if gw.shape == (E, H) else gw.T.contiguous()
|
||||
g_bf16 = g_bf16.bfloat16().to(dev)
|
||||
g_fp4, g_sf, g_gs = quantize_to_nvfp4(g_bf16)
|
||||
gate_lin = Nvfp4Linear(in_features=H, out_features=E, device=dev)
|
||||
gate_lin.fp4 = [g_fp4]; gate_lin.sf = [g_sf]; gate_lin.gs = [g_gs]
|
||||
gate_lin.ws2 = [torch.tensor([g_gs], device=dev, dtype=torch.float32)]
|
||||
gate_lin._activation_global_scale = 1.0 / (6.0 * 448.0); gate_lin._use_runtime_gsa = True
|
||||
gate_lin.finalize_weights(); router.load_nvfp4_gate(gate_lin)
|
||||
router.load_weights(e_bias=eb.to(dev, torch.float32))
|
||||
router.finalize_weights(); routers[li] = router
|
||||
|
||||
moe = Nvfp4MoE(num_experts=cfg["n_routed_experts"], hidden_size=H,
|
||||
intermediate_size=cfg.get("moe_intermediate_size", 3072),
|
||||
top_k=cfg.get("num_experts_per_tok", 6), device=dev)
|
||||
moe.set_swiglu_limit(cfg.get("swiglu_limit", 10.0)); moe.set_fused_swiglu(True)
|
||||
_load_moe_weights_stacked(all_w, li, pfx, dev, moe, cfg)
|
||||
moe._ensure_stacked(); moe._use_runtime_gsa = True; moe_runners[li] = moe
|
||||
|
||||
se = Nvfp4SharedExpert(hidden_size=H, intermediate_size=cfg.get("moe_intermediate_size", 3072),
|
||||
device=dev, swiglu_limit=cfg.get("swiglu_limit", 10.0))
|
||||
se.set_fused_swiglu(True)
|
||||
_load_shared_expert_weights(all_w, li, pfx, dev, se, cfg)
|
||||
se._ensure_initialized(); se._use_runtime_gsa = True; se_runners[li] = se
|
||||
if (li+1) % 10 == 0: print(f" Built {li+1}/{n_layers} MoE layers")
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
# Global weights
|
||||
torch.cuda.set_device(0)
|
||||
embed_w = all_w.get("model.embed_tokens.weight")
|
||||
embed = torch.nn.Embedding.from_pretrained(embed_w.bfloat16().to('cuda:0'))
|
||||
lm_w_raw = all_w.get("lm_head.weight", embed_w).bfloat16().to('cuda:0')
|
||||
lm_head_lin = Nvfp4Linear(lm_w_raw.shape[1], lm_w_raw.shape[0], max_num_tokens=8192, device='cuda:0')
|
||||
lm_fp4, lm_sf, lm_gs = quantize_weight_to_nvfp4(lm_w_raw.T.contiguous())
|
||||
lm_head_lin.fp4 = [lm_fp4.permute(1, 0).contiguous()]
|
||||
lm_head_lin.sf = [lm_sf.permute(1, 0).contiguous()]
|
||||
lm_head_lin.gs = [lm_gs]; lm_head_lin.ws2 = [None]
|
||||
lm_head_lin._activation_global_scale = 1.0 / (6.0 * 448.0)
|
||||
lm_head_lin._use_runtime_gsa = True; lm_head_lin.finalize_weights()
|
||||
|
||||
final_norm_w = all_w.get("model.norm.weight")
|
||||
if final_norm_w is not None: final_norm_w = final_norm_w.to('cuda:0', torch.float32)
|
||||
|
||||
hc_head = HcHead(H, 4, 'cuda:0')
|
||||
hc_fn = all_w.get("model.hc_head.hc_fn"); hc_base = all_w.get("model.hc_head.hc_base")
|
||||
hc_scale = all_w.get("model.hc_head.hc_scale")
|
||||
if hc_fn is not None and hc_base is not None: hc_head.load(hc_fn, hc_base, hc_scale)
|
||||
|
||||
rp = cfg.get("rope_scaling", cfg.get("rope_parameters", {}))
|
||||
rt = rp.get("type", rp.get("rope_type", "yarn")); rf = rp.get("factor", 16.0)
|
||||
rtheta = cfg.get("rope_theta", 10000.)
|
||||
romax = rp.get("original_max_position_embeddings", 65536)
|
||||
rbfast, rbslow = rp.get("beta_fast", 32), rp.get("beta_slow", 1)
|
||||
rope_caches = {g: build_rope_cache(romax, rd, f"cuda:{g}", rtheta, rt, rf, romax, rbfast, rbslow)
|
||||
for g in range(NUM_GPUS)}
|
||||
|
||||
kv_caches, compressors, indexers = {}, {}, {}
|
||||
n_ih = cfg.get("index_n_heads", 64); ihd = cfg.get("index_head_dim", 128)
|
||||
itk = cfg.get("index_topk", 1024)
|
||||
for li in range(n_layers):
|
||||
dev = f"cuda:{li % NUM_GPUS}"; ratio = cr[li] if li < len(cr) else 128
|
||||
max_comp = (8192 + ratio - 1) // ratio if ratio > 0 else 0
|
||||
kv_caches[li] = KVCache(hd, cfg.get("sliding_window", 128), max_comp=max_comp, device=dev,
|
||||
indexer_key_dim=ihd, compress_ratio=ratio, indexer_top_k=itk, rope_dim=rd)
|
||||
if ratio > 0: compressors[li] = Compressor(ratio, hd, H, dev)
|
||||
if ratio == 4: indexers[li] = Indexer(n_ih, ihd, itk, dev)
|
||||
|
||||
devs = [f"cuda:{g}" for g in range(NUM_GPUS)]
|
||||
layer_w = _cache_layer_weights_no_experts(all_w, n_layers, devs)
|
||||
del all_w; import gc; gc.collect()
|
||||
for g in range(NUM_GPUS): torch.cuda.set_device(g); torch.cuda.empty_cache()
|
||||
torch.cuda.set_device(0)
|
||||
|
||||
for li in range(n_layers):
|
||||
pfx = f"model.layers.{li}.self_attn.compressor"
|
||||
if li in compressors: compressors[li].load(layer_w[li], pfx, dev=f"cuda:{li % NUM_GPUS}")
|
||||
if li in indexers: indexers[li].load(layer_w[li], f"{pfx}.indexer", dev=f"cuda:{li % NUM_GPUS}")
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(CHECKPOINT_DIR)
|
||||
bos = tokenizer.bos_token_id or 0
|
||||
# FIXED: no \n\n (official DSV4 encoding spec)
|
||||
input_ids = [bos, USER_TOKEN]
|
||||
input_ids += tokenizer.encode(PROMPT, add_special_tokens=False)
|
||||
input_ids.append(ASSISTANT_TOKEN)
|
||||
input_ids.append(THINK_START)
|
||||
|
||||
print(f"\nPrefill + 1 decode step...")
|
||||
PREFILL_CHUNK = 128
|
||||
n_prefill = len(input_ids)
|
||||
prefill_ids = torch.tensor(input_ids, dtype=torch.long, device='cuda:0')
|
||||
prefill_ids32 = prefill_ids.to(torch.int32)
|
||||
all_positions = torch.arange(n_prefill, dtype=torch.long, device='cuda:0')
|
||||
|
||||
chunk_starts = list(range(0, n_prefill, PREFILL_CHUNK))
|
||||
X = None
|
||||
for ci, cs in enumerate(chunk_starts):
|
||||
ce = min(cs + PREFILL_CHUNK, n_prefill)
|
||||
chunk_ids = prefill_ids[cs:ce]
|
||||
chunk_ids32 = prefill_ids32[cs:ce]
|
||||
chunk_positions = all_positions[cs:ce]
|
||||
X = mHCLayer.init_state(embed(chunk_ids))
|
||||
for li in range(n_layers):
|
||||
gpu = li % NUM_GPUS
|
||||
if X.device != torch.device(f"cuda:{gpu}"): X = X.to(f"cuda:{gpu}")
|
||||
torch.cuda.set_device(gpu)
|
||||
X = forward_layer(X, layer_w[li], li, cfg, *rope_caches[gpu],
|
||||
attn_mhcs.get(li), ffn_mhcs.get(li),
|
||||
attn_norms.get(li), ffn_norms.get(li),
|
||||
kv_caches[li], chunk_positions, chunk_ids32,
|
||||
compressors.get(li), indexers.get(li),
|
||||
moe_runners.get(li), se_runners.get(li), routers.get(li),
|
||||
prod_lin=prod_lins.get(li))
|
||||
X = X.to('cuda:0'); torch.cuda.set_device(0)
|
||||
print(f" Chunk {ci+1}/{len(chunk_starts)}: OK |X|={X.abs().max().item():.1f}", flush=True)
|
||||
|
||||
# Decode step 1
|
||||
dec_tid = torch.tensor([input_ids[-1]], dtype=torch.long, device='cuda:0')
|
||||
dec_tid32 = dec_tid.to(torch.int32)
|
||||
dec_pos = torch.tensor([n_prefill - 1], dtype=torch.long, device='cuda:0')
|
||||
|
||||
X = mHCLayer.init_state(embed(dec_tid))
|
||||
for li in range(n_layers):
|
||||
gpu = li % NUM_GPUS
|
||||
if X.device != torch.device(f"cuda:{gpu}"): X = X.to(f"cuda:{gpu}")
|
||||
torch.cuda.set_device(gpu)
|
||||
X = forward_layer(X, layer_w[li], li, cfg, *rope_caches[gpu],
|
||||
attn_mhcs.get(li), ffn_mhcs.get(li),
|
||||
attn_norms.get(li), ffn_norms.get(li),
|
||||
kv_caches[li], dec_pos, dec_tid32,
|
||||
compressors.get(li), indexers.get(li),
|
||||
moe_runners.get(li), se_runners.get(li), routers.get(li),
|
||||
prod_lin=prod_lins.get(li))
|
||||
X = X.to('cuda:0'); torch.cuda.set_device(0)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
# ================================================================
|
||||
# TEST 2: Falsification
|
||||
# ================================================================
|
||||
print(f"\n{'='*70}")
|
||||
print("TEST 2 — Falsify mHC residual growth root cause")
|
||||
print(f"{'='*70}")
|
||||
|
||||
# Step 1: Confirm final norm exists
|
||||
print(f"\n1. FINAL NORM CHECK:")
|
||||
print(f" final_norm_w exists: {final_norm_w is not None}")
|
||||
if final_norm_w is not None:
|
||||
print(f" final_norm_w shape: {final_norm_w.shape}, dtype: {final_norm_w.dtype}")
|
||||
print(f" final_norm_w range: [{final_norm_w.min().item():.6f}, {final_norm_w.max().item():.6f}]")
|
||||
else:
|
||||
print(f" *** CRITICAL: final_norm_w is MISSING! ***")
|
||||
|
||||
# Step 2: Residual inspection
|
||||
print(f"\n2. RESIDUAL INSPECTION:")
|
||||
X_max = X.abs().max().item()
|
||||
print(f" |X| (final layer residual) = {X_max:.4f}")
|
||||
print(f" X shape: {X.shape}, dtype: {X.dtype}")
|
||||
|
||||
# Step 3: Trace full path X → hc_head → final_norm → lm_head → logits
|
||||
x_out = hc_head.forward(X) if hc_head is not None else X[:, 0, :]
|
||||
print(f" |x_out| (after hc_head) = {x_out.abs().max().item():.4f}")
|
||||
|
||||
if final_norm_w is not None:
|
||||
x_normed = rmsnorm(x_out, final_norm_w)
|
||||
print(f" |x_normed| (after final_norm) = {x_normed.abs().max().item():.4f}")
|
||||
# Verify scale invariance of RMSNorm alone
|
||||
x_out_tiny = x_out / 100.0
|
||||
x_normed_tiny = rmsnorm(x_out_tiny, final_norm_w)
|
||||
cos_norm = F.cosine_similarity(x_normed.flatten().float(), x_normed_tiny.flatten().float(), dim=0).item()
|
||||
print(f" RMSNorm scale invariance: cos(x_normed, x_normed_tiny) = {cos_norm:.8f}")
|
||||
else:
|
||||
x_normed = x_out
|
||||
print(f" *** NO FINAL NORM — logits will be magnitude-dependent! ***")
|
||||
|
||||
# Step 4: FALSIFICATION — logits with X vs X/100
|
||||
print(f"\n3. FALSIFICATION: logits with |X|={X_max:.1f} vs |X/100|={X_max/100:.2f}")
|
||||
|
||||
# Path A: X as-is
|
||||
x_out_A = hc_head.forward(X) if hc_head is not None else X[:, 0, :]
|
||||
if final_norm_w is not None: x_out_A = rmsnorm(x_out_A, final_norm_w)
|
||||
logits_A = lm_head_lin(x_out_A)
|
||||
|
||||
# Path B: X scaled down by 100
|
||||
X_scaled = X / 100.0
|
||||
x_out_B = hc_head.forward(X_scaled) if hc_head is not None else X_scaled[:, 0, :]
|
||||
if final_norm_w is not None: x_out_B = rmsnorm(x_out_B, final_norm_w)
|
||||
logits_B = lm_head_lin(x_out_B)
|
||||
|
||||
torch.cuda.synchronize()
|
||||
logits_A_f = logits_A.float(); logits_B_f = logits_B.float()
|
||||
argmax_A = logits_A_f.argmax().item(); argmax_B = logits_B_f.argmax().item()
|
||||
cos_AB = F.cosine_similarity(logits_A_f.flatten(), logits_B_f.flatten(), dim=0).item()
|
||||
top5_A_vals, top5_A_ids = logits_A_f.topk(5)
|
||||
top5_B_vals, top5_B_ids = logits_B_f.topk(5)
|
||||
top5_A_ids = top5_A_ids.flatten(); top5_A_vals = top5_A_vals.flatten()
|
||||
top5_B_ids = top5_B_ids.flatten(); top5_B_vals = top5_B_vals.flatten()
|
||||
|
||||
print(f"\n logits_A (|X|={X_max:.1f}):")
|
||||
print(f" range: [{logits_A_f.min().item():.2f}, {logits_A_f.max().item():.2f}]")
|
||||
print(f" argmax: {argmax_A} ('{tokenizer.decode([argmax_A])}')")
|
||||
print(f" top-5: {[(tokenizer.decode([tid.item()]), f'{val.item():.2f}') for tid, val in zip(top5_A_ids, top5_A_vals)]}")
|
||||
|
||||
print(f"\n logits_B (|X/100|={X_max/100:.2f}):")
|
||||
print(f" range: [{logits_B_f.min().item():.2f}, {logits_B_f.max().item():.2f}]")
|
||||
print(f" argmax: {argmax_B} ('{tokenizer.decode([argmax_B])}')")
|
||||
print(f" top-5: {[(tokenizer.decode([tid.item()]), f'{val.item():.2f}') for tid, val in zip(top5_B_ids, top5_B_vals)]}")
|
||||
|
||||
print(f"\n cos(logits_A, logits_B) = {cos_AB:.8f}")
|
||||
print(f" argmax_A == argmax_B: {argmax_A == argmax_B}")
|
||||
|
||||
# Step 5: hc_head magnitude sensitivity
|
||||
print(f"\n4. HC_HEAD MAGNITUDE SENSITIVITY:")
|
||||
x_out_A_raw = hc_head.forward(X) if hc_head is not None else X[:, 0, :]
|
||||
x_out_B_raw = hc_head.forward(X / 100.0) if hc_head is not None else (X / 100.0)[:, 0, :]
|
||||
cos_hc = F.cosine_similarity(x_out_A_raw.flatten().float(), (x_out_B_raw * 100.0).flatten().float(), dim=0).item()
|
||||
print(f" cos(hc_head(X), hc_head(X/100)*100) = {cos_hc:.8f}")
|
||||
print(f" |hc_head(X)| = {x_out_A_raw.abs().max().item():.4f}")
|
||||
print(f" |hc_head(X/100)| = {x_out_B_raw.abs().max().item():.6f}")
|
||||
print(f" |hc_head(X/100)*100| = {(x_out_B_raw * 100.0).abs().max().item():.4f}")
|
||||
|
||||
# Step 6: Verdict
|
||||
print(f"\n{'='*70}")
|
||||
print("VERDICT:")
|
||||
print(f"{'='*70}")
|
||||
if final_norm_w is None:
|
||||
print(" *** CRITICAL: FINAL NORM IS MISSING! ***")
|
||||
print(" The model has no RMSNorm before the LM head.")
|
||||
print(" FIX: Apply the final norm before lm_head.")
|
||||
elif cos_AB >= 0.999:
|
||||
print(" mHC residual growth is EXONERATED.")
|
||||
print(f" cos(logits_A, logits_B) = {cos_AB:.8f} ≈ 1.0")
|
||||
print(f" argmax_A={argmax_A}, argmax_B={argmax_B}")
|
||||
print(" |X| magnitude does NOT affect logits (RMSNorm divides it out).")
|
||||
print(" The degeneration cause is elsewhere — likely the prompt format (Test 1).")
|
||||
elif argmax_A != argmax_B:
|
||||
print(" mHC residual growth IS magnitude-sensitive despite final norm.")
|
||||
print(f" argmax_A={argmax_A} ≠ argmax_B={argmax_B}, cos={cos_AB:.8f}")
|
||||
print(" Something downstream is magnitude-sensitive.")
|
||||
else:
|
||||
print(f" Inconclusive: argmax matches but cos={cos_AB:.8f} < 0.999")
|
||||
print(f"{'='*70}")
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
115
tests/unit/test_part_a_compressor_kv.py
Normal file
115
tests/unit/test_part_a_compressor_kv.py
Normal file
@@ -0,0 +1,115 @@
|
||||
#!/usr/bin/env python3
|
||||
"""PART A diagnostic: Compressor + FMHA at production scale."""
|
||||
import sys, math
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
def cosine(a, b):
|
||||
return F.cosine_similarity(a.flatten().float(), b.flatten().float(), dim=0).item()
|
||||
|
||||
def main():
|
||||
HD = 512; NOPE = 448; ROPE = 64; n_h = 128
|
||||
scale = 1.0 / math.sqrt(HD)
|
||||
device = "cuda:0"
|
||||
torch.manual_seed(42)
|
||||
|
||||
print("=" * 70)
|
||||
print("PART A: Compressor + FMHA at Production Scale")
|
||||
print("=" * 70)
|
||||
|
||||
all_pass = True
|
||||
|
||||
# ---- Test 1: CSA compression round-trip ----
|
||||
print("\n--- Test 1: CSA compression (ratio=4) ---")
|
||||
from dsv4.kernels.compressor.production_compress import csa_compress_production_fp32
|
||||
from dsv4.kernels.cuda.loader import get_cuda_module
|
||||
kv_mod = get_cuda_module("kv_quantize", ["kv_quantize.cu"])
|
||||
|
||||
for T in [4, 16, 32, 64]:
|
||||
m = 4; n_blocks = T // m; kv_dim = HD * 2
|
||||
kv_proj = torch.randn(T, kv_dim, dtype=torch.float32, device=device) * 0.3
|
||||
gate_proj = torch.randn(T, kv_dim, dtype=torch.float32, device=device) * 0.3
|
||||
compressed = csa_compress_production_fp32(kv_proj, gate_proj, None, None, m=4)
|
||||
if compressed.shape[0] == 0: print(f" T={T}: SKIP"); continue
|
||||
comp_kv = compressed[:, :HD]
|
||||
nope_fp32 = comp_kv[:, :NOPE].contiguous()
|
||||
rope_bf16 = comp_kv[:, NOPE:].bfloat16().contiguous()
|
||||
nope_fp8, nope_scale = kv_mod.quantize_fp8_e4m3_from_fp32(nope_fp32)
|
||||
nope_dequant = nope_fp8.view(torch.float8_e4m3fn).float() * nope_scale.unsqueeze(-1).float()
|
||||
comp_kv_rt = torch.cat([nope_dequant, rope_bf16.float()], dim=-1)
|
||||
cos = cosine(comp_kv, comp_kv_rt)
|
||||
ok = cos > 0.999
|
||||
if not ok: all_pass = False
|
||||
print(f" T={T}: n_blocks={n_blocks} cos={cos:.6f} {'PASS' if ok else 'FAIL'}")
|
||||
|
||||
# ---- Test 2: HCA compression round-trip ----
|
||||
print("\n--- Test 2: HCA compression (ratio=128) ---")
|
||||
from dsv4.kernels.compressor.production_compress import hca_compress_production_fp32
|
||||
|
||||
for T in [128, 256]:
|
||||
m = 128; n_blocks = T // m
|
||||
if n_blocks == 0: print(f" T={T}: SKIP"); continue
|
||||
kv_dim = HD * 2
|
||||
kv_proj = torch.randn(T, kv_dim, dtype=torch.float32, device=device) * 0.3
|
||||
gate_proj = torch.randn(T, kv_dim, dtype=torch.float32, device=device) * 0.3
|
||||
compressed = hca_compress_production_fp32(kv_proj, gate_proj, None, None, m=128)
|
||||
comp_kv = compressed[:, :HD]
|
||||
nope_fp32 = comp_kv[:, :NOPE].contiguous()
|
||||
rope_bf16 = comp_kv[:, NOPE:].bfloat16().contiguous()
|
||||
nope_fp8, nope_scale = kv_mod.quantize_fp8_e4m3_from_fp32(nope_fp32)
|
||||
nope_dequant = nope_fp8.view(torch.float8_e4m3fn).float() * nope_scale.unsqueeze(-1).float()
|
||||
comp_kv_rt = torch.cat([nope_dequant, rope_bf16.float()], dim=-1)
|
||||
cos = cosine(comp_kv, comp_kv_rt)
|
||||
ok = cos > 0.999
|
||||
if not ok: all_pass = False
|
||||
print(f" T={T}: n_blocks={n_blocks} cos={cos:.6f} {'PASS' if ok else 'FAIL'}")
|
||||
|
||||
# ---- Test 3: B1 FMHA decode vs SDPA (H=128, MQA) ----
|
||||
print("\n--- Test 3: B1 FMHA decode vs SDPA (H=128, MQA) ---")
|
||||
from dsv4.kernels.attention.production import dsv4_attention_mixed_fp8_decode
|
||||
|
||||
for N in [128, 512, 1024]:
|
||||
# Realistic FP8 quantized KV
|
||||
kv_nope_fp32 = torch.randn(N, NOPE, dtype=torch.float32, device=device) * 0.3
|
||||
kv_rope_bf16 = torch.randn(N, ROPE, dtype=torch.bfloat16, device=device) * 0.3
|
||||
amax = kv_nope_fp32.abs().amax(dim=-1, keepdim=True).clamp(min=1e-12)
|
||||
nope_scale = (amax / 448.0).squeeze(-1)
|
||||
nope_clamped = (kv_nope_fp32 / nope_scale.unsqueeze(-1)).clamp(-448, 448)
|
||||
kv_nope_fp8 = nope_clamped.to(torch.float8_e4m3fn).view(torch.uint8).contiguous()
|
||||
kv_nope_scale = nope_scale.contiguous()
|
||||
|
||||
q = torch.randn(n_h, 1, HD, dtype=torch.bfloat16, device=device) * 0.3
|
||||
|
||||
# Production FMHA (128 heads, each attends to the same KV)
|
||||
attn_out = dsv4_attention_mixed_fp8_decode(
|
||||
q=q, k_nope_fp8=kv_nope_fp8, k_nope_scale=kv_nope_scale,
|
||||
k_rope_bf16=kv_rope_bf16, scale=scale, rope_dim=ROPE)
|
||||
|
||||
# Reference: dequantize, run SDPA per-head (MQA: all Q heads share 1 KV head)
|
||||
nope_dequant = kv_nope_fp8.view(torch.float8_e4m3fn).float() * kv_nope_scale.unsqueeze(-1).float()
|
||||
k_full = torch.cat([nope_dequant.bfloat16(), kv_rope_bf16], dim=-1)
|
||||
# MQA reference: expand K/V for all Q heads
|
||||
k_expanded = k_full.unsqueeze(0).expand(n_h, -1, -1) # (n_h, N, HD)
|
||||
# SDPA per head
|
||||
o_ref = torch.zeros_like(attn_out)
|
||||
for h in range(n_h):
|
||||
q_h = q[h:h+1] # (1, 1, HD)
|
||||
k_h = k_full.unsqueeze(0).unsqueeze(0) # (1, 1, N, HD)
|
||||
v_h = k_h.clone()
|
||||
q_4d = q_h.unsqueeze(0) # (1, 1, 1, HD)
|
||||
o_h = F.scaled_dot_product_attention(q_4d, k_h, v_h, scale=scale)
|
||||
o_ref[h] = o_h.squeeze()
|
||||
|
||||
cos = cosine(attn_out, o_ref)
|
||||
ok = cos > 0.999
|
||||
if not ok: all_pass = False
|
||||
print(f" N={N}: cos={cos:.6f} {'PASS' if ok else 'FAIL'}")
|
||||
|
||||
# ---- Summary ----
|
||||
print("\n" + "=" * 70)
|
||||
print(f"OVERALL: {'PASS' if all_pass else 'FAIL'}")
|
||||
print("=" * 70)
|
||||
sys.exit(0 if all_pass else 1)
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
468
tests/unit/test_part_a_decode_diagnostics.py
Normal file
468
tests/unit/test_part_a_decode_diagnostics.py
Normal file
@@ -0,0 +1,468 @@
|
||||
#!/usr/bin/env python3
|
||||
"""PART A — Decode Diagnostics: Production pipeline per-layer diagnostics.
|
||||
|
||||
This test runs the FULL production pipeline (single_shot_inference.py forward_layer)
|
||||
for prefill tokens and the first decode step, printing per-layer diagnostics:
|
||||
- |X| per layer (mHC residual growth)
|
||||
- |F_attn| and |F_ffn| magnitudes
|
||||
- Compressed/SWA visible range diagnostics (causality, overlap)
|
||||
- KV cache state (n_comp, swa_len)
|
||||
|
||||
Production values: HD=512, NOPE=448, ROPE=64, H=128, 61 layers, 8 GPUs, 384 experts.
|
||||
"""
|
||||
import os, sys, json, math, time
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
CHECKPOINT_DIR = os.environ.get(
|
||||
"CHECKPOINT_DIR", "/root/nvidia-meeting/DeepSeek-V4-Pro-NVFP4")
|
||||
NUM_GPUS = int(os.environ.get("NUM_GPUS", "8"))
|
||||
DEVICE = "cuda:0"
|
||||
TEST_LAYERS = int(os.environ.get("TEST_LAYERS", "5"))
|
||||
|
||||
|
||||
def main():
|
||||
torch.manual_seed(42)
|
||||
print("=" * 70)
|
||||
print("PART A — DECODE DIAGNOSTICS (Production Pipeline)")
|
||||
print("=" * 70)
|
||||
|
||||
with open(os.path.join(CHECKPOINT_DIR, "config.json")) as f:
|
||||
cfg = json.load(f)
|
||||
n_layers = cfg["num_hidden_layers"]
|
||||
H = cfg["hidden_size"]
|
||||
hd = cfg["head_dim"]
|
||||
n_h = cfg["num_attention_heads"]
|
||||
rd = cfg.get("qk_rope_head_dim", 64)
|
||||
nope_dim = hd - rd
|
||||
cr = cfg.get("compress_ratios", [128] * n_layers)
|
||||
print(f"Model: {n_layers} layers, {n_h} heads, hd={hd}, rope_dim={rd}, nope_dim={nope_dim}")
|
||||
print(f"Compress ratios (first {TEST_LAYERS}): {cr[:TEST_LAYERS]}")
|
||||
|
||||
from single_shot_inference import (
|
||||
load_all_weights, make_nvfp4_linear, get_nvfp4_weight,
|
||||
rmsnorm, unweighted_rmsnorm, _apply_rope, build_rope_cache,
|
||||
KVCache, Compressor, Indexer, forward_layer, forward_attention, moe_forward,
|
||||
_load_moe_weights_stacked, _load_shared_expert_weights,
|
||||
_cache_layer_weights_no_experts,
|
||||
)
|
||||
from dsv4.layers.mhc import mHCLayer, mHCContext
|
||||
from dsv4.layers.router import Router
|
||||
from dsv4.layers.moe import Nvfp4MoE
|
||||
from dsv4.layers.shared_expert import Nvfp4SharedExpert
|
||||
from dsv4.layers.grouped_linear import Nvfp4GroupedLinear
|
||||
from dsv4.layers.linear import Nvfp4Linear
|
||||
from dsv4.ops.quantize import (
|
||||
rmsnorm_quantize_nvfp4, mhc_rmsnorm_quantize_nvfp4, dequantize_nvfp4,
|
||||
quantize_to_nvfp4,
|
||||
)
|
||||
|
||||
print("Loading weights...")
|
||||
all_w = load_all_weights(CHECKPOINT_DIR)
|
||||
|
||||
o_groups = cfg.get("o_groups", 16)
|
||||
o_rank = cfg.get("o_lora_rank", 1024)
|
||||
n_ih = cfg.get("index_n_heads", 64)
|
||||
ihd = cfg.get("index_head_dim", 128)
|
||||
itk = cfg.get("index_topk", 1024)
|
||||
|
||||
rope_caches = {g: build_rope_cache(65536, rd, f"cuda:{g}", 10000., "yarn", 16., 4096, 32, 1)
|
||||
for g in range(NUM_GPUS)}
|
||||
|
||||
# Build production components for TEST_LAYERS
|
||||
prod_lins, attn_mhcs, ffn_mhcs = {}, {}, {}
|
||||
attn_norms, ffn_norms = {}, {}
|
||||
compressors, indexers, kv_caches = {}, {}, {}
|
||||
routers, moe_runners, se_runners = {}, {}, {}
|
||||
|
||||
for li in range(TEST_LAYERS):
|
||||
gpu = li % NUM_GPUS
|
||||
dev = f"cuda:{gpu}"
|
||||
torch.cuda.set_device(gpu)
|
||||
pfx = f"model.layers.{li}.self_attn"
|
||||
mlp_pfx = f"model.layers.{li}.mlp"
|
||||
ratio = cr[li] if li < len(cr) else 128
|
||||
|
||||
pl = {}
|
||||
pl['q_a'] = make_nvfp4_linear(H, 1536, dev, all_w, pfx, 'q_a_proj')
|
||||
pl['q_b'] = make_nvfp4_linear(1536, H * hd, dev, all_w, pfx, 'q_b_proj')
|
||||
pl['kv'] = make_nvfp4_linear(H, hd, dev, all_w, pfx, 'kv_proj')
|
||||
hpg = n_h // o_groups
|
||||
wo_a = Nvfp4GroupedLinear(n_local_groups=o_groups, heads_per_group=hpg,
|
||||
head_dim=hd, o_lora_rank=o_rank, max_num_tokens=8192, device=dev)
|
||||
oa_w, oa_ws, oa_ws2, oa_isc = get_nvfp4_weight(all_w, pfx, 'o_a_proj')
|
||||
if oa_w is not None and oa_ws is not None:
|
||||
wo_a.load_nvfp4_weight(oa_w.to(dev), oa_ws.to(dev),
|
||||
oa_ws2.to(dev) if oa_ws2 is not None else None,
|
||||
oa_isc.to(dev) if oa_isc is not None else None)
|
||||
else:
|
||||
oa_bf = all_w.get(f"{pfx}.o_a_proj.weight")
|
||||
if oa_bf is not None:
|
||||
wo_a.set_bf16_weight(oa_bf.bfloat16().to(dev))
|
||||
pl['o_a'] = wo_a; wo_a._use_runtime_gsa = True
|
||||
pl['o_b'] = make_nvfp4_linear(o_groups * o_rank, H, dev, all_w, pfx, 'o_b_proj')
|
||||
prod_lins[li] = pl
|
||||
|
||||
for tag, blocks, fn_s, base_s, scale_s in [
|
||||
("attn", attn_mhcs, f"model.layers.{li}.attn_hc.fn",
|
||||
f"model.layers.{li}.attn_hc.base", f"model.layers.{li}.attn_hc.scale"),
|
||||
("ffn", ffn_mhcs, f"model.layers.{li}.ffn_hc.fn",
|
||||
f"model.layers.{li}.ffn_hc.base", f"model.layers.{li}.ffn_hc.scale"),
|
||||
]:
|
||||
fn, base, scale = all_w.get(fn_s), all_w.get(base_s), all_w.get(scale_s)
|
||||
if fn is not None and base is not None and scale is not None:
|
||||
m = mHCLayer(hidden_dim=H, n_hc=4, t_max_sinkhorn=20, device=dev)
|
||||
n = 4
|
||||
m.load_weights(
|
||||
W_pre=fn[0:n].to(dev, torch.float32), W_post=fn[n:2*n].to(dev, torch.float32),
|
||||
W_comb=fn[2*n:].to(dev, torch.float32),
|
||||
S_pre=base[0:n].reshape(1, n).to(dev, torch.float32),
|
||||
S_post=base[n:2*n].reshape(n, 1).to(dev, torch.float32),
|
||||
S_comb=base[2*n:].reshape(n, n).to(dev, torch.float32),
|
||||
alpha_pre=scale[0].item(), alpha_post=scale[1].item(), alpha_comb=scale[2].item())
|
||||
blocks[li] = m
|
||||
|
||||
an_k = f"model.layers.{li}.input_layernorm.weight"
|
||||
if an_k in all_w: attn_norms[li] = all_w[an_k].to(dev, torch.float32)
|
||||
fn_k = f"model.layers.{li}.post_attention_layernorm.weight"
|
||||
if fn_k in all_w: ffn_norms[li] = all_w[fn_k].to(dev, torch.float32)
|
||||
|
||||
max_comp = (8192 + ratio - 1) // ratio if ratio > 0 else 0
|
||||
kv_caches[li] = KVCache(hd, cfg.get("sliding_window", 128), max_comp=max_comp,
|
||||
device=dev, indexer_key_dim=ihd, compress_ratio=ratio, indexer_top_k=itk, rope_dim=rd)
|
||||
if ratio > 0: compressors[li] = Compressor(ratio, hd, H, dev)
|
||||
if ratio == 4: indexers[li] = Indexer(n_ih, ihd, itk, dev)
|
||||
|
||||
is_hash = (li < cfg.get("num_hash_layers", 3)) and (f"{mlp_pfx}.gate.tid2eid" in all_w)
|
||||
router = Router(hidden_size=H, num_experts=cfg["n_routed_experts"],
|
||||
top_k=cfg.get("num_experts_per_tok", 6),
|
||||
routed_scaling_factor=cfg.get("routed_scaling_factor", 2.5),
|
||||
mode="hash" if is_hash else "dense",
|
||||
vocab_size=cfg.get("vocab_size", 128000) if is_hash else None, device=dev)
|
||||
if is_hash:
|
||||
router.load_weights(hash_lut=all_w[f"{mlp_pfx}.gate.tid2eid"].to(dev, torch.int32))
|
||||
else:
|
||||
eb = all_w.get(f"{mlp_pfx}.gate.e_score_correction_bias")
|
||||
gate_w, gate_ws, gate_ws2, gate_isc = get_nvfp4_weight(all_w, mlp_pfx, 'gate')
|
||||
E = cfg["n_routed_experts"]
|
||||
if gate_w is not None and gate_ws is not None:
|
||||
gate_lin = Nvfp4Linear(in_features=H, out_features=E, device=dev)
|
||||
gate_lin.fp4 = [gate_w.to(dev).view(torch.float4_e2m1fn_x2) if gate_w.dtype == torch.uint8 else gate_w.to(dev)]
|
||||
gate_lin.sf = [gate_ws.to(dev)]
|
||||
ws2_v = gate_ws2.float().item() if gate_ws2 is not None else 1.0
|
||||
isc_v = gate_isc.float().item() if gate_isc is not None else 1.0/(6.0*448.0)
|
||||
gate_lin.gs = [1.0]
|
||||
gate_lin.ws2 = [torch.tensor([ws2_v], device=dev, dtype=torch.float32)]
|
||||
gate_lin._activation_global_scale = isc_v
|
||||
gate_lin._use_runtime_gsa = True
|
||||
gate_lin.finalize_weights()
|
||||
router.load_nvfp4_gate(gate_lin)
|
||||
router.load_weights(e_bias=eb.to(dev, torch.float32))
|
||||
else:
|
||||
gw = all_w.get(f"{mlp_pfx}.gate.weight")
|
||||
if gw is not None:
|
||||
g_bf16 = gw if gw.shape == (E, H) else gw.T.contiguous()
|
||||
g_bf16 = g_bf16.bfloat16().to(dev)
|
||||
g_fp4, g_sf, g_gs = quantize_to_nvfp4(g_bf16)
|
||||
gate_lin = Nvfp4Linear(in_features=H, out_features=E, device=dev)
|
||||
gate_lin.fp4 = [g_fp4]
|
||||
gate_lin.sf = [g_sf]
|
||||
gate_lin.gs = [g_gs]
|
||||
gate_lin.ws2 = [torch.tensor([g_gs], device=dev, dtype=torch.float32)]
|
||||
gate_lin._activation_global_scale = 1.0 / (6.0 * 448.0)
|
||||
gate_lin._use_runtime_gsa = True
|
||||
gate_lin.finalize_weights()
|
||||
router.load_nvfp4_gate(gate_lin)
|
||||
router.load_weights(e_bias=eb.to(dev, torch.float32))
|
||||
router.finalize_weights(); routers[li] = router
|
||||
|
||||
moe = Nvfp4MoE(num_experts=cfg["n_routed_experts"], hidden_size=H,
|
||||
intermediate_size=cfg.get("moe_intermediate_size", 3072),
|
||||
top_k=cfg.get("num_experts_per_tok", 6), device=dev)
|
||||
moe.set_swiglu_limit(cfg.get("swiglu_limit", 10.0)); moe.set_fused_swiglu(True)
|
||||
_load_moe_weights_stacked(all_w, li, mlp_pfx, dev, moe, cfg)
|
||||
moe._ensure_stacked(); moe._use_runtime_gsa = True; moe_runners[li] = moe
|
||||
|
||||
se = Nvfp4SharedExpert(hidden_size=H, intermediate_size=cfg.get("moe_intermediate_size", 3072),
|
||||
device=dev, swiglu_limit=cfg.get("swiglu_limit", 10.0))
|
||||
se.set_fused_swiglu(True)
|
||||
_load_shared_expert_weights(all_w, li, mlp_pfx, dev, se, cfg)
|
||||
se._ensure_initialized(); se._use_runtime_gsa = True; se_runners[li] = se
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
for li in range(TEST_LAYERS):
|
||||
pfx = f"model.layers.{li}.self_attn.compressor"
|
||||
dev = f"cuda:{li % NUM_GPUS}"
|
||||
if li in compressors: compressors[li].load(all_w, pfx, dev=dev)
|
||||
if li in indexers: indexers[li].load(all_w, f"{pfx}.indexer", dev=dev)
|
||||
# Verify compressor kv_norm_w loaded correctly
|
||||
for li in range(TEST_LAYERS):
|
||||
if li in compressors and compressors[li].kv_norm_w is not None:
|
||||
n = compressors[li].kv_norm_w
|
||||
print(f" L{li} compressor kv_norm_w: shape={tuple(n.shape)} |w|={n.abs().max().item():.4f}", flush=True)
|
||||
elif li in compressors:
|
||||
print(f" L{li} compressor kv_norm_w: MISSING!", flush=True)
|
||||
print("Production components built")
|
||||
|
||||
# Embedding + tokenizer
|
||||
from transformers import AutoTokenizer
|
||||
tokenizer = AutoTokenizer.from_pretrained(CHECKPOINT_DIR)
|
||||
bos = tokenizer.bos_token_id or 0
|
||||
USER_TOKEN, ASSISTANT_TOKEN, THINK_START = 128803, 128804, 128821
|
||||
input_ids = [bos, USER_TOKEN]
|
||||
input_ids += tokenizer.encode('\n\nThe capital of France is', add_special_tokens=False)
|
||||
input_ids.append(ASSISTANT_TOKEN)
|
||||
input_ids.append(THINK_START)
|
||||
print(f"Input: {len(input_ids)} tokens")
|
||||
|
||||
torch.cuda.set_device(0)
|
||||
embed_w = all_w.get("model.embed_tokens.weight")
|
||||
prod_embed = torch.nn.Embedding.from_pretrained(embed_w.bfloat16().to(DEVICE))
|
||||
devs_list = [f"cuda:{g}" for g in range(NUM_GPUS)]
|
||||
layer_w = _cache_layer_weights_no_experts(all_w, TEST_LAYERS, devs_list)
|
||||
del all_w; import gc; gc.collect()
|
||||
for g in range(NUM_GPUS): torch.cuda.set_device(g); torch.cuda.empty_cache()
|
||||
torch.cuda.set_device(0)
|
||||
|
||||
# ================================================================
|
||||
# PHASE 1: Prefill — production, with per-layer |X| tracking
|
||||
# ================================================================
|
||||
print(f"\n{'='*70}")
|
||||
print("PHASE 1: Prefill — PRODUCTION (per-layer |X| tracking)")
|
||||
print(f"{'='*70}")
|
||||
|
||||
print(f"\n {'tok':>3} {'L':>3} {'|X_in|':>12} {'|X_out|':>12} {'ratio':>5} {'n_comp':>6} {'swa':>4}")
|
||||
print(f" {'---':>3} {'---':>3} {'---':>12} {'---':>12} {'---':>5} {'---':>6} {'---':>4}")
|
||||
|
||||
for pi, tid_val in enumerate(input_ids):
|
||||
t1 = time.time()
|
||||
tid = torch.tensor([tid_val], dtype=torch.long, device=DEVICE)
|
||||
pos = torch.tensor([pi], dtype=torch.long, device=DEVICE)
|
||||
tid32 = torch.tensor([tid_val], dtype=torch.int32, device=DEVICE)
|
||||
|
||||
X = mHCLayer.init_state(prod_embed(tid))
|
||||
for li in range(TEST_LAYERS):
|
||||
gpu = li % NUM_GPUS
|
||||
dev = f"cuda:{gpu}"
|
||||
if X.device != torch.device(dev): X = X.to(dev)
|
||||
torch.cuda.set_device(gpu)
|
||||
|
||||
X_prev = X.clone() # Save for blowup diagnostics
|
||||
X_in_mag = X.abs().max().item()
|
||||
X = forward_layer(X, layer_w[li], li, cfg, *rope_caches[gpu],
|
||||
attn_mhcs.get(li), ffn_mhcs.get(li), attn_norms.get(li), ffn_norms.get(li),
|
||||
kv_caches[li], pos, tid32, compressors.get(li), indexers.get(li),
|
||||
moe_runners.get(li), se_runners.get(li), routers.get(li),
|
||||
prod_lin=prod_lins.get(li), _use_fused_rmsnorm_quantize=True)
|
||||
X_out_mag = X.abs().max().item() if X.device == torch.device(DEVICE) else X.to(DEVICE).abs().max().item()
|
||||
|
||||
kc = kv_caches[li]
|
||||
ratio = cr[li] if li < len(cr) else 128
|
||||
# Print per-token, per-layer for first 3 tokens, then only first and last layer
|
||||
if pi < 3 or pi == len(input_ids) - 1:
|
||||
print(f" {pi:>3} {li:>3} {X_in_mag:>12.2f} {X_out_mag:>12.2f} {ratio:>5} {kc.n_comp:>6} {kc.swa_len:>4}", flush=True)
|
||||
|
||||
# Early abort if |X| blows up — run detailed diagnostics on THIS layer
|
||||
if X_out_mag > 1e6:
|
||||
print(f" *** BLOWUP at token {pi} layer {li}: |X|={X_out_mag:.2e} ***", flush=True)
|
||||
print(f" Re-running layer {li} with detailed diagnostics...", flush=True)
|
||||
# Re-run the SAME input through forward_layer but capture intermediates
|
||||
X_diag = X_prev.clone() # X before this layer
|
||||
attn_mhc_d = attn_mhcs.get(li)
|
||||
ffn_mhc_d = ffn_mhcs.get(li)
|
||||
A_l_a, B_l_a, C_l_a = attn_mhc_d._dynamic_params(X_diag)
|
||||
ctx_a_d = mHCContext(B_l=B_l_a, C_l=C_l_a)
|
||||
x_quant_attn = mhc_rmsnorm_quantize_nvfp4(
|
||||
X_diag, A_l_a, attn_norms.get(li).to(dev, torch.float32))
|
||||
x_normed = dequantize_nvfp4(x_quant_attn.x_fp4, x_quant_attn.x_sf, x_quant_attn.gsa)
|
||||
print(f" |x_normed|={x_normed.abs().max().item():.2f} gsa={x_quant_attn.gsa}", flush=True)
|
||||
# Run compressor and print raw output
|
||||
comp_diag = compressors.get(li)
|
||||
if comp_diag is not None:
|
||||
comp_kv_d, comp_pos_d, _ = comp_diag.forward(x_normed, pos)
|
||||
if comp_kv_d is not None:
|
||||
print(f" Compressor output: |comp_kv|={comp_kv_d.abs().max().item():.2f} shape={tuple(comp_kv_d.shape)}", flush=True)
|
||||
else:
|
||||
print(f" Compressor output: None (n_complete=0)", flush=True)
|
||||
# Print KV cache state BEFORE calling forward_attention
|
||||
kc_diag = kv_caches[li]
|
||||
swa_kv_d, swa_pos_d = kc_diag.get_swa()
|
||||
print(f" KV: n_comp={kc_diag.n_comp} swa_len={swa_kv_d.shape[0]}", flush=True)
|
||||
# Gather KV and print
|
||||
ratio_diag = cr[li] if li < len(cr) else 128
|
||||
seq_len_d = 0
|
||||
if kc_diag.n_comp > 0:
|
||||
if ratio_diag == 4:
|
||||
# Need to compute indexer top-k first
|
||||
# Run Q projection to get q_a
|
||||
pl_diag = prod_lins.get(li)
|
||||
q_a_d = pl_diag['q_a'].run_from_quantized(x_quant_attn)
|
||||
q_norm_w_d = layer_w[li].get(f"model.layers.{li}.self_attn.q_a_norm.weight")
|
||||
if q_norm_w_d is not None:
|
||||
q_a_quant_d = rmsnorm_quantize_nvfp4(q_a_d, q_norm_w_d.to(dev, torch.float32))
|
||||
q_a_d = dequantize_nvfp4(q_a_quant_d.x_fp4, q_a_quant_d.x_sf, q_a_quant_d.gsa)
|
||||
topk_idx_d = None
|
||||
if indexers.get(li) is not None:
|
||||
topk_idx_d = indexers[li].forward(q_a_d, x_normed, kc_diag, pos, layer_idx=li)
|
||||
if topk_idx_d is not None:
|
||||
tk_d = topk_idx_d[0].clamp(0, kc_diag.n_comp - 1).int()
|
||||
kv_nope_fp8_d, kv_nope_scale_d, kv_rope_bf16_d = kc_diag.gather_mixed_selective(tk_d)
|
||||
print(f" CSA topk: {tk_d.tolist()[:10]}", flush=True)
|
||||
else:
|
||||
kv_nope_fp8_d, kv_nope_scale_d, kv_rope_bf16_d = kc_diag.gather_mixed_swa_only()
|
||||
elif ratio_diag > 4:
|
||||
kv_nope_fp8_d, kv_nope_scale_d, kv_rope_bf16_d = kc_diag.gather_mixed_all()
|
||||
else:
|
||||
kv_nope_fp8_d, kv_nope_scale_d, kv_rope_bf16_d = kc_diag.gather_mixed_swa_only()
|
||||
else:
|
||||
kv_nope_fp8_d, kv_nope_scale_d, kv_rope_bf16_d = kc_diag.gather_mixed_swa_only()
|
||||
seq_len_d = kv_nope_scale_d.shape[0]
|
||||
nope_max = kv_nope_fp8_d.view(torch.float8_e4m3fn).float().abs().max().item()
|
||||
scale_max = kv_nope_scale_d.abs().max().item()
|
||||
rope_max = kv_rope_bf16_d.float().abs().max().item()
|
||||
print(f" Gathered KV: seq_len={seq_len_d} |nope_fp8|={nope_max:.2f} |nope_scale|={scale_max:.6f} |rope_bf16|={rope_max:.2f}", flush=True)
|
||||
nope_dequant_max = (kv_nope_fp8_d.view(torch.float8_e4m3fn).float() * kv_nope_scale_d.unsqueeze(-1).float()).abs().max().item()
|
||||
print(f" |nope_dequant_max|={nope_dequant_max:.4f}", flush=True)
|
||||
# Now run FMHA
|
||||
F_attn_d, q_a_d = forward_attention(
|
||||
x_normed, layer_w[li], li, cfg, *rope_caches[gpu],
|
||||
kv_caches[li], pos, compressors.get(li), indexers.get(li), prod_lins.get(li),
|
||||
x_quant=x_quant_attn)
|
||||
print(f" |F_attn|={F_attn_d.abs().max().item():.2f}", flush=True)
|
||||
# Check if Q heads are reasonable
|
||||
q_heads_diag = pl_diag['q_b'].run_from_quantized(rmsnorm_quantize_nvfp4(q_a_d, layer_w[li].get(f"model.layers.{li}.self_attn.q_a_norm.weight").to(dev, torch.float32)))
|
||||
q_heads_diag = unweighted_rmsnorm(q_heads_diag).bfloat16()
|
||||
print(f" |Q_heads|={q_heads_diag.abs().max().item():.4f}", flush=True)
|
||||
X_mid_d = attn_mhc_d.post_block(X_diag, F_attn_d, ctx_a_d)
|
||||
print(f" |X_mid|={X_mid_d.abs().max().item():.2f} B_l_row=[{B_l_a.sum(-1).min().item():.4f},{B_l_a.sum(-1).max().item():.4f}] C_l=[{C_l_a.min().item():.4f},{C_l_a.max().item():.4f}]", flush=True)
|
||||
A_l_f, B_l_f, C_l_f = ffn_mhc_d._dynamic_params(X_mid_d)
|
||||
ctx_f_d = mHCContext(B_l=B_l_f, C_l=C_l_f)
|
||||
x_quant_ffn = mhc_rmsnorm_quantize_nvfp4(
|
||||
X_mid_d, A_l_f, ffn_norms.get(li).to(dev, torch.float32))
|
||||
x_ffn = dequantize_nvfp4(x_quant_ffn.x_fp4, x_quant_ffn.x_sf, x_quant_ffn.gsa)
|
||||
F_ffn_d = moe_forward(x_ffn, li, moe_runners.get(li), se_runners.get(li),
|
||||
routers.get(li), tid32.to(dev))
|
||||
print(f" |F_ffn|={F_ffn_d.abs().max().item():.2f}", flush=True)
|
||||
X_next_d = ffn_mhc_d.post_block(X_mid_d, F_ffn_d, ctx_f_d)
|
||||
print(f" |X_next|={X_next_d.abs().max().item():.2e}", flush=True)
|
||||
# Check per-component magnitudes
|
||||
BX = torch.bmm(ctx_a_d.B_l.transpose(-1, -2), X_diag.float())
|
||||
CF = ctx_a_d.C_l.unsqueeze(-1) * F_attn_d.unsqueeze(1)
|
||||
print(f" |B@X|={BX.abs().max().item():.2f} |C*F|={CF.abs().max().item():.2f}", flush=True)
|
||||
BX_f = torch.bmm(ctx_f_d.B_l.transpose(-1, -2), X_mid_d.float())
|
||||
CF_f = ctx_f_d.C_l.unsqueeze(-1) * F_ffn_d.unsqueeze(1)
|
||||
print(f" FFN: |B@X|={BX_f.abs().max().item():.2f} |C*F|={CF_f.abs().max().item():.2f}", flush=True)
|
||||
return 1
|
||||
|
||||
if pi % 5 == 0:
|
||||
print(f" Token {pi}/{len(input_ids)} done: {time.time()-t1:.2f}s |X|={X.to(DEVICE).abs().max().item():.2f}", flush=True)
|
||||
|
||||
# KV cache state
|
||||
print(f"\nProduction KV cache state after prefill ({len(input_ids)} tokens):")
|
||||
for li in range(TEST_LAYERS):
|
||||
kc = kv_caches[li]
|
||||
ratio = cr[li] if li < len(cr) else 128
|
||||
print(f" L{li} (ratio={ratio}): n_comp={kc.n_comp} swa_len={kc.swa_len} total_KV={kc.n_comp + kc.swa_len}")
|
||||
|
||||
# ================================================================
|
||||
# PHASE 2: Decode step — per-layer diagnostics
|
||||
# ================================================================
|
||||
print(f"\n{'='*70}")
|
||||
print("PHASE 2: Decode step — per-layer diagnostics")
|
||||
print(f"{'='*70}")
|
||||
|
||||
decode_pos = len(input_ids)
|
||||
decode_tid = tokenizer.encode(" the", add_special_tokens=False)
|
||||
decode_tid = decode_tid[0] if len(decode_tid) > 0 else 2
|
||||
|
||||
dec_tid = torch.tensor([decode_tid], dtype=torch.long, device=DEVICE)
|
||||
dec_tid32 = torch.tensor([decode_tid], dtype=torch.int32, device=DEVICE)
|
||||
dec_pos = torch.tensor([decode_pos], dtype=torch.long, device=DEVICE)
|
||||
|
||||
X = mHCLayer.init_state(prod_embed(dec_tid))
|
||||
print(f"\nInitial X: shape={tuple(X.shape)} |X|={X.abs().max().item():.6f}")
|
||||
|
||||
print(f"\n {'L':>3} {'ratio':>5} {'|X_in|':>12} {'|X_out|':>12} {'|F_attn|':>10} {'|F_ffn|':>10} {'n_comp':>6} {'swa':>4} {'mode':>8} {'leak':>5}")
|
||||
print(f" {'-'*3} {'-'*5} {'-'*12} {'-'*12} {'-'*10} {'-'*10} {'-'*6} {'-'*4} {'-'*8} {'-'*5}")
|
||||
|
||||
for li in range(TEST_LAYERS):
|
||||
gpu = li % NUM_GPUS
|
||||
dev = f"cuda:{gpu}"
|
||||
torch.cuda.set_device(gpu)
|
||||
if X.device != torch.device(dev): X = X.to(dev)
|
||||
|
||||
ratio = cr[li] if li < len(cr) else 128
|
||||
kc = kv_caches[li]
|
||||
X_in_mag = X.abs().max().item()
|
||||
|
||||
# Production forward — capture intermediates
|
||||
attn_mhc = attn_mhcs.get(li)
|
||||
ffn_mhc = ffn_mhcs.get(li)
|
||||
A_l_a, B_l_a, C_l_a = attn_mhc._dynamic_params(X)
|
||||
ctx_a = mHCContext(B_l=B_l_a, C_l=C_l_a)
|
||||
x_quant_attn = mhc_rmsnorm_quantize_nvfp4(
|
||||
X, A_l_a, attn_norms.get(li).to(dev, torch.float32))
|
||||
x_normed = dequantize_nvfp4(x_quant_attn.x_fp4, x_quant_attn.x_sf, x_quant_attn.gsa)
|
||||
|
||||
F_attn, q_a = forward_attention(
|
||||
x_normed, layer_w[li], li, cfg, *rope_caches[gpu],
|
||||
kc, dec_pos, compressors.get(li), indexers.get(li), prod_lins.get(li),
|
||||
x_quant=x_quant_attn)
|
||||
X_mid = attn_mhc.post_block(X, F_attn, ctx_a)
|
||||
|
||||
A_l_f, B_l_f, C_l_f = ffn_mhc._dynamic_params(X_mid)
|
||||
ctx_f = mHCContext(B_l=B_l_f, C_l=C_l_f)
|
||||
x_quant_ffn = mhc_rmsnorm_quantize_nvfp4(
|
||||
X_mid, A_l_f, ffn_norms.get(li).to(dev, torch.float32))
|
||||
x_ffn = dequantize_nvfp4(x_quant_ffn.x_fp4, x_quant_ffn.x_sf, x_quant_ffn.gsa)
|
||||
F_ffn = moe_forward(x_ffn, li, moe_runners.get(li), se_runners.get(li),
|
||||
routers.get(li), dec_tid32.to(dev))
|
||||
X_next = ffn_mhc.post_block(X_mid, F_ffn, ctx_f)
|
||||
|
||||
X_out_mag = X_next.to(DEVICE).abs().max().item()
|
||||
f_attn_mag = F_attn.to(DEVICE).abs().max().item()
|
||||
f_ffn_mag = F_ffn.to(DEVICE).abs().max().item()
|
||||
|
||||
swa_kv, swa_pos = kc.get_swa()
|
||||
swa_len = swa_kv.shape[0]
|
||||
n_comp = kc.n_comp
|
||||
mode = "CSA" if ratio == 4 else ("HCA" if ratio > 4 else "SWA")
|
||||
|
||||
# Causality check
|
||||
future_leak = False
|
||||
if n_comp > 0 and kc.comp_pos is not None and kc.comp_pos.numel() > 0:
|
||||
visible_comp_pos = kc.comp_pos[:n_comp]
|
||||
future_leak = (visible_comp_pos >= decode_pos).any().item()
|
||||
|
||||
print(f" {li:>3} {ratio:>5} {X_in_mag:>12.2f} {X_out_mag:>12.2f} "
|
||||
f"{f_attn_mag:>10.2f} {f_ffn_mag:>10.2f} {n_comp:>6} {swa_len:>4} {mode:>8} "
|
||||
f"{'YES!' if future_leak else 'no':>5}")
|
||||
|
||||
# mHC diagnostics
|
||||
B_a = B_l_a
|
||||
print(f" mHC: B_l row_sum=[{B_a.sum(-1).min().item():.4f},{B_a.sum(-1).max().item():.4f}] "
|
||||
f"col_sum=[{B_a.sum(-2).min().item():.4f},{B_a.sum(-2).max().item():.4f}] "
|
||||
f"A=[{A_l_a.min().item():.4f},{A_l_a.max().item():.4f}] "
|
||||
f"C=[{C_l_a.min().item():.4f},{C_l_a.max().item():.4f}]")
|
||||
|
||||
# CSA specifics
|
||||
if ratio == 4 and n_comp > 0:
|
||||
print(f" CSA: n_comp={n_comp} swa_len={swa_len} total_attend={n_comp + swa_len}")
|
||||
|
||||
X = X_next
|
||||
|
||||
# Summary
|
||||
print(f"\n{'='*70}")
|
||||
print("PART A SUMMARY")
|
||||
print(f"{'='*70}")
|
||||
print("Production pipeline diagnostics complete.")
|
||||
print("Check the |X| values above for:")
|
||||
print(" 1. Exponential growth (mHC residual blowup)")
|
||||
print(" 2. Sudden jumps (NVFP4 quantization error)")
|
||||
print(" 3. NaN/Inf (numerical instability)")
|
||||
print(" 4. future_leak=YES (causality violation in compressed KV)")
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(main())
|
||||
247
tests/unit/test_part_a_pipeline.py
Normal file
247
tests/unit/test_part_a_pipeline.py
Normal file
@@ -0,0 +1,247 @@
|
||||
#!/usr/bin/env python3
|
||||
"""PART A diagnostic: full forward_attention pipeline comparison.
|
||||
|
||||
Tests each stage of the production attention pipeline against a PyTorch
|
||||
reference for the first few layers. Identifies exactly where the pipeline
|
||||
diverges from the reference.
|
||||
|
||||
Stages tested per layer:
|
||||
1. Q projection (q_a → q_a_norm → q_b → q_b_norm)
|
||||
2. KV projection + RoPE
|
||||
3. KV cache append + compressor
|
||||
4. KV gathering (compressed + SWA)
|
||||
5. FMHA (production vs SDPA)
|
||||
6. Inverse RoPE
|
||||
7. Output projection (o_a + o_b)
|
||||
8. Full forward_attention output vs reference
|
||||
|
||||
Uses REAL model weights and production values.
|
||||
"""
|
||||
import sys, os, time, math
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
# ── Helpers ──────────────────────────────────────────────────────
|
||||
def cosine(a, b):
|
||||
a, b = a.flatten().float(), b.flatten().float()
|
||||
d = a @ b
|
||||
na, nb = a.norm(), b.norm()
|
||||
return (d / (na * nb + 1e-12)).item()
|
||||
|
||||
def rmsnorm(x, w, eps=1e-6):
|
||||
dtype = x.dtype
|
||||
x = x.float()
|
||||
rms = x.pow(2).mean(-1, keepdim=True).add(eps).rsqrt()
|
||||
return (x * rms).to(dtype) * w.to(dtype)
|
||||
|
||||
# ── Main ─────────────────────────────────────────────────────────
|
||||
def main():
|
||||
MODEL = "/root/nvidia-meeting/DeepSeek-V4-Pro-NVFP4"
|
||||
NUM_GPUS = 8
|
||||
MAX_LAYERS = 3 # Test first 3 layers
|
||||
|
||||
print("=" * 70)
|
||||
print("PART A DIAGNOSTIC: Full Attention Pipeline Comparison")
|
||||
print(f"Model: {MODEL}, Layers: {MAX_LAYERS}, GPUs: {NUM_GPUS}")
|
||||
print("=" * 70)
|
||||
|
||||
# ── Load model config ──
|
||||
import json
|
||||
with open(os.path.join(MODEL, "config.json")) as f:
|
||||
cfg = json.load(f)
|
||||
n_layers = cfg["num_hidden_layers"]
|
||||
n_h = cfg["num_attention_heads"]
|
||||
hd = cfg["head_dim"]
|
||||
hidden = cfg["hidden_size"]
|
||||
rd = cfg.get("qk_rope_head_dim", 64)
|
||||
nope_dim = hd - rd
|
||||
o_groups = cfg.get("o_groups", 16)
|
||||
o_rank = cfg.get("o_lora_rank", 1024)
|
||||
scale = 1.0 / math.sqrt(hd)
|
||||
|
||||
print(f"Config: {n_layers}L, {n_h}H, hd={hd}, rope={rd}, nope={nope_dim}")
|
||||
print(f" o_groups={o_groups}, o_rank={o_rank}, hidden={hidden}")
|
||||
|
||||
# ── Load tokenizer ──
|
||||
from transformers import AutoTokenizer
|
||||
tokenizer = AutoTokenizer.from_pretrained(MODEL, trust_remote_code=True)
|
||||
prompt = "The capital of France is"
|
||||
input_ids = tokenizer.encode(prompt, add_special_tokens=False)
|
||||
print(f"Prompt: '{prompt}' → {len(input_ids)} tokens: {input_ids}")
|
||||
|
||||
# ── Load RoPE caches ──
|
||||
from dsv4.ops.rope_cuda import build_rope_cache
|
||||
rope_caches = {}
|
||||
for gpu in range(NUM_GPUS):
|
||||
torch.cuda.set_device(gpu)
|
||||
rope_caches[gpu] = build_rope_cache(8192, hd, rd, device=f"cuda:{gpu}")
|
||||
|
||||
# ── Load weights and set up production layers ──
|
||||
from single_shot_inference import (
|
||||
load_layer_weights, setup_production_linear, setup_compressor,
|
||||
setup_indexer, KVCache, mHCLayer, rmsnorm as prod_rmsnorm,
|
||||
_apply_rope, forward_attention
|
||||
)
|
||||
|
||||
# ── Process prefill tokens one by one ──
|
||||
results = {}
|
||||
for li in range(MAX_LAYERS):
|
||||
gpu = li % NUM_GPUS
|
||||
torch.cuda.set_device(gpu)
|
||||
|
||||
# Load weights for this layer
|
||||
w, prod_lin, compressor, indexer = None, None, None, None
|
||||
try:
|
||||
w = load_layer_weights(MODEL, li, f"cuda:{gpu}")
|
||||
prod_lin = setup_production_linear(w, li, cfg, f"cuda:{gpu}")
|
||||
compressor = setup_compressor(w, li, cfg, f"cuda:{gpu}")
|
||||
if compressor is not None and compressor.ratio == 4:
|
||||
indexer = setup_indexer(w, li, cfg, f"cuda:{gpu}")
|
||||
except Exception as e:
|
||||
print(f" L{li}: Failed to load weights: {e}")
|
||||
continue
|
||||
|
||||
pfx = f"model.layers.{li}.self_attn"
|
||||
ratio = compressor.ratio if compressor is not None else 0
|
||||
layer_type = "SWA" if ratio == 0 else ("CSA" if ratio == 4 else "HCA")
|
||||
print(f"\nL{li} (gpu={gpu}, type={layer_type}, ratio={ratio})")
|
||||
|
||||
# Set up KV cache
|
||||
kv_cache = KVCache(li, cfg, f"cuda:{gpu}")
|
||||
mhc_attn = mHCLayer(li, "attn", cfg, f"cuda:{gpu}")
|
||||
|
||||
# Initialize mHC state
|
||||
embed_w = torch.load(os.path.join(MODEL, "model.embed_tokens.weight.pt"),
|
||||
map_location=f"cuda:{gpu}", weights_only=True).bfloat16()
|
||||
embed_w = embed_w.to(f"cuda:{gpu}")
|
||||
|
||||
# Process each prefill token
|
||||
X = None
|
||||
for pi, tid in enumerate(input_ids):
|
||||
tid_t = torch.tensor([tid], dtype=torch.long, device=f"cuda:{gpu}")
|
||||
pos = torch.tensor([pi], dtype=torch.long, device=f"cuda:{gpu}")
|
||||
|
||||
if pi == 0:
|
||||
X = mHCLayer.init_state(F.embedding(tid_t, embed_w))
|
||||
else:
|
||||
X = mHCLayer.init_state(F.embedding(tid_t, embed_w))
|
||||
|
||||
# Forward through attention for this layer
|
||||
X_normed = rmsnorm(X, w.get(f"model.layers.{li}.input_layernorm.weight").to(f"cuda:{gpu}", torch.float32))
|
||||
|
||||
if pi == 0:
|
||||
# First token: run forward_attention and capture intermediate values
|
||||
# We need to run the full pipeline and compare
|
||||
dev = f"cuda:{gpu}"
|
||||
T = 1
|
||||
|
||||
# 1. Q projections
|
||||
q_a = prod_lin['q_a'](X_normed)
|
||||
q_norm_w = w.get(f"{pfx}.q_a_norm.weight")
|
||||
q_a_norm = rmsnorm(q_a, q_norm_w.to(dev, torch.float32)) if q_norm_w is not None else q_a
|
||||
q = prod_lin['q_b'](q_a_norm)
|
||||
q = rmsnorm(q, w.get(f"{pfx}.q_b_norm.weight").to(dev, torch.float32)).bfloat16()
|
||||
q_heads = q.reshape(T, n_h, hd)
|
||||
q_heads = _apply_rope(q_heads, pos, *rope_caches[gpu], rd)
|
||||
|
||||
# 2. KV projection
|
||||
kv = prod_lin['kv'](X_normed)
|
||||
kv_norm_w = w.get(f"{pfx}.kv_norm.weight")
|
||||
if kv_norm_w is not None:
|
||||
kv = rmsnorm(kv, kv_norm_w.to(dev, torch.float32))
|
||||
kv_3d = kv.reshape(T, 1, hd)
|
||||
kv_3d = _apply_rope(kv_3d, pos, *rope_caches[gpu], rd)
|
||||
kv_roped = kv_3d.reshape(T, hd)
|
||||
kv_cache.append_swa(kv_roped, pos)
|
||||
|
||||
# 3. Compression (if applicable)
|
||||
comp_pos = None
|
||||
if compressor is not None and compressor.ratio > 0:
|
||||
comp_kv_fp32, comp_pos, _ = compressor.forward(X_normed, pos)
|
||||
if comp_kv_fp32 is not None:
|
||||
from dsv4.kernels.cuda.loader import get_cuda_module
|
||||
kv_mod = get_cuda_module("kv_quantize", ["kv_quantize.cu"])
|
||||
nope_fp32 = comp_kv_fp32[:, :nope_dim].contiguous()
|
||||
rope_bf16 = comp_kv_fp32[:, nope_dim:].bfloat16().contiguous()
|
||||
rope_3d = rope_bf16.unsqueeze(1)
|
||||
rope_3d = _apply_rope(rope_3d, comp_pos, *rope_caches[gpu], rd)
|
||||
rope_bf16 = rope_3d.squeeze(1)
|
||||
nope_fp8, nope_scale = kv_mod.quantize_fp8_e4m3_from_fp32(nope_fp32)
|
||||
kv_cache.set_compressed_mixed(nope_fp8, nope_scale, rope_bf16, comp_pos)
|
||||
if compressor.is_csa and indexer is not None:
|
||||
comp_idx_kv, _, _ = indexer.compressor.forward(X_normed, pos)
|
||||
kv_cache.set_indexer_keys_fp8(comp_idx_kv)
|
||||
|
||||
# 4. Indexer (CSA)
|
||||
topk_idx = None
|
||||
if indexer is not None and ratio == 4:
|
||||
topk_idx = indexer.forward(q_a, X_normed, kv_cache, pos, layer_idx=li)
|
||||
|
||||
# 5. Gather KV
|
||||
swa_kv, _swa_pos = kv_cache.get_swa()
|
||||
swa_len = swa_kv.shape[0]
|
||||
if kv_cache.n_comp > 0:
|
||||
if ratio == 4:
|
||||
tk = topk_idx[0].clamp(0, kv_cache.n_comp - 1).int()
|
||||
kv_nope_fp8, kv_nope_scale, kv_rope_bf16 = kv_cache.gather_mixed_selective(tk)
|
||||
elif ratio > 4:
|
||||
kv_nope_fp8, kv_nope_scale, kv_rope_bf16 = kv_cache.gather_mixed_all()
|
||||
else:
|
||||
kv_nope_fp8, kv_nope_scale, kv_rope_bf16 = kv_cache.gather_mixed_swa_only()
|
||||
else:
|
||||
kv_nope_fp8, kv_nope_scale, kv_rope_bf16 = kv_cache.gather_mixed_swa_only()
|
||||
seq_len = kv_nope_scale.shape[0]
|
||||
|
||||
print(f" Token 0: seq_len={seq_len} swa_len={swa_len} n_comp={kv_cache.n_comp}")
|
||||
print(f" kv_nope_fp8 shape={tuple(kv_nope_fp8.shape)} dtype={kv_nope_fp8.dtype}")
|
||||
print(f" kv_nope_scale shape={tuple(kv_nope_scale.shape)} dtype={kv_nope_scale.dtype}")
|
||||
print(f" kv_rope_bf16 shape={tuple(kv_rope_bf16.shape)} dtype={kv_rope_bf16.dtype}")
|
||||
else:
|
||||
# Non-first token: just run through and build KV cache
|
||||
dev = f"cuda:{gpu}"
|
||||
T = 1
|
||||
q_a = prod_lin['q_a'](X_normed)
|
||||
q_norm_w = w.get(f"{pfx}.q_a_norm.weight")
|
||||
q_a_norm = rmsnorm(q_a, q_norm_w.to(dev, torch.float32)) if q_norm_w is not None else q_a
|
||||
q = prod_lin['q_b'](q_a_norm)
|
||||
q = rmsnorm(q, w.get(f"{pfx}.q_b_norm.weight").to(dev, torch.float32)).bfloat16()
|
||||
q_heads = q.reshape(T, n_h, hd)
|
||||
q_heads = _apply_rope(q_heads, pos, *rope_caches[gpu], rd)
|
||||
|
||||
kv = prod_lin['kv'](X_normed)
|
||||
kv_norm_w = w.get(f"{pfx}.kv_norm.weight")
|
||||
if kv_norm_w is not None:
|
||||
kv = rmsnorm(kv, kv_norm_w.to(dev, torch.float32))
|
||||
kv_3d = kv.reshape(T, 1, hd)
|
||||
kv_3d = _apply_rope(kv_3d, pos, *rope_caches[gpu], rd)
|
||||
kv_roped = kv_3d.reshape(T, hd)
|
||||
kv_cache.append_swa(kv_roped, pos)
|
||||
|
||||
if compressor is not None and compressor.ratio > 0:
|
||||
comp_kv_fp32, comp_pos, _ = compressor.forward(X_normed, pos)
|
||||
if comp_kv_fp32 is not None:
|
||||
from dsv4.kernels.cuda.loader import get_cuda_module
|
||||
kv_mod = get_cuda_module("kv_quantize", ["kv_quantize.cu"])
|
||||
nope_fp32 = comp_kv_fp32[:, :nope_dim].contiguous()
|
||||
rope_bf16 = comp_kv_fp32[:, nope_dim:].bfloat16().contiguous()
|
||||
rope_3d = rope_bf16.unsqueeze(1)
|
||||
rope_3d = _apply_rope(rope_3d, comp_pos, *rope_caches[gpu], rd)
|
||||
rope_bf16 = rope_3d.squeeze(1)
|
||||
nope_fp8, nope_scale = kv_mod.quantize_fp8_e4m3_from_fp32(nope_fp32)
|
||||
kv_cache.set_compressed_mixed(nope_fp8, nope_scale, rope_bf16, comp_pos)
|
||||
if compressor.is_csa and indexer is not None:
|
||||
comp_idx_kv, _, _ = indexer.compressor.forward(X_normed, pos)
|
||||
kv_cache.set_indexer_keys_fp8(comp_idx_kv)
|
||||
|
||||
# mHC forward
|
||||
# (simplified — the real single_shot uses forward_layer which handles mHC)
|
||||
|
||||
# After all prefill tokens, check KV state
|
||||
print(f" L{li} after prefill: n_comp={kv_cache.n_comp} swa_len={kv_cache.get_swa()[0].shape[0]}")
|
||||
|
||||
print("\n" + "=" * 70)
|
||||
print("DONE")
|
||||
print("=" * 70)
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
83
tests/unit/test_prefill_debug.py
Normal file
83
tests/unit/test_prefill_debug.py
Normal file
@@ -0,0 +1,83 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Debug test: compare T=1 prefill vs T=1 decode, step by step.
|
||||
|
||||
Uses synthetic data. Prints per-step comparisons to identify
|
||||
where the prefill kernel diverges from the decode kernel.
|
||||
"""
|
||||
import math
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
HD = 512; NOPE = 448; ROPE = 64; H = 128
|
||||
B = 1; T = 1; N = 256
|
||||
scale = 1.0 / math.sqrt(HD)
|
||||
|
||||
def quantize_fp8_e4m3(x_fp32):
|
||||
amax = x_fp32.abs().amax(dim=-1, keepdim=True).clamp(min=1e-12)
|
||||
s = amax / 448.0
|
||||
fp8 = (x_fp32 / s).clamp(-448, 448).to(torch.float8_e4m3fn)
|
||||
return fp8.view(torch.uint8), s.squeeze(-1)
|
||||
|
||||
def cosine(a, b):
|
||||
return F.cosine_similarity(a.flatten().float(), b.flatten().float(), dim=0).item()
|
||||
|
||||
def main():
|
||||
torch.manual_seed(42)
|
||||
q_fp32 = torch.randn(B, H, T, HD, dtype=torch.float32) * 0.5
|
||||
k_fp32 = torch.randn(N, HD, dtype=torch.float32) * 0.5
|
||||
q_bf16 = q_fp32.bfloat16().cuda()
|
||||
k_nope_fp8, k_nope_scale = quantize_fp8_e4m3(k_fp32[:, :NOPE])
|
||||
k_rope_bf16 = k_fp32[:, NOPE:].bfloat16()
|
||||
k_nope_fp8 = k_nope_fp8.cuda()
|
||||
k_nope_scale = k_nope_scale.cuda()
|
||||
k_rope_bf16 = k_rope_bf16.cuda()
|
||||
|
||||
# Reference SDPA
|
||||
nope_dequant = k_nope_fp8.view(torch.float8_e4m3fn).cpu().float() * k_nope_scale.cpu().unsqueeze(-1).float()
|
||||
k_full = torch.cat([nope_dequant, k_fp32[:, NOPE:]], dim=-1).bfloat16().cuda()
|
||||
k_4d = k_full.unsqueeze(0).unsqueeze(0).expand(1, 1, -1, -1)
|
||||
v_4d = k_4d.clone()
|
||||
o_ref = F.scaled_dot_product_attention(q_bf16, k_4d, v_4d, scale=scale)
|
||||
print(f"Reference: |o|={o_ref.float().abs().max().item():.6f} mean={o_ref.float().mean().item():.6f}")
|
||||
|
||||
# Decode kernel
|
||||
from dsv4.kernels.attention.fmha_mixed_fp8_op import fmha_mixed_fp8_decode_raw
|
||||
o_decode, lse_decode = fmha_mixed_fp8_decode_raw(
|
||||
q_bf16, k_nope_fp8, k_nope_scale, k_rope_bf16, scale, rope_dim=ROPE)
|
||||
print(f"Decode: |o|={o_decode.float().abs().max().item():.6f} mean={o_decode.float().mean().item():.6f}")
|
||||
print(f"Decode vs Ref: cos={cosine(o_decode, o_ref):.6f}")
|
||||
|
||||
# Prefill kernel
|
||||
from dsv4.kernels.attention.fmha_mixed_fp8_prefill_op import fmha_mixed_fp8_prefill_raw
|
||||
o_prefill, lse_prefill = fmha_mixed_fp8_prefill_raw(
|
||||
q_bf16, k_nope_fp8, k_nope_scale, k_rope_bf16, scale, rope_dim=ROPE)
|
||||
print(f"Prefill: |o|={o_prefill.float().abs().max().item():.6f} mean={o_prefill.float().mean().item():.6f}")
|
||||
print(f"Prefill vs Ref: cos={cosine(o_prefill, o_ref):.6f}")
|
||||
print(f"Prefill vs Decode: cos={cosine(o_prefill, o_decode):.6f}")
|
||||
|
||||
# Check for NaN
|
||||
has_nan = torch.isnan(o_prefill).any().item()
|
||||
print(f"Prefill NaN: {has_nan}")
|
||||
|
||||
# Per-head cosine
|
||||
o_d_h = o_decode.float().squeeze(0).squeeze(1) # (H, HD)
|
||||
o_p_h = o_prefill.float().squeeze(0).squeeze(1)
|
||||
if o_d_h.dim() == 3: o_d_h = o_d_h.squeeze(0)
|
||||
if o_p_h.dim() == 3: o_p_h = o_p_h.squeeze(0)
|
||||
per_head_cos = F.cosine_similarity(o_d_h, o_p_h, dim=-1)
|
||||
print(f"Per-head cos: min={per_head_cos.min().item():.6f} mean={per_head_cos.mean().item():.6f} max={per_head_cos.max().item():.6f}")
|
||||
|
||||
# Value comparison for head 0
|
||||
if not has_nan:
|
||||
d0 = o_decode[0, 0, 0, :8].float()
|
||||
p0 = o_prefill[0, 0, 0, :8].float()
|
||||
r0 = o_ref[0, 0, 0, :8].float()
|
||||
print(f"Decode[0,0,0,:8]: {d0.tolist()}")
|
||||
print(f"Prefill[0,0,0,:8]: {p0.tolist()}")
|
||||
print(f"Ref[0,0,0,:8]: {r0.tolist()}")
|
||||
print(f"Ratio decode/ref: {(d0 / (r0 + 1e-10)).tolist()}")
|
||||
print(f"Ratio prefill/ref: {(p0 / (r0 + 1e-10)).tolist()}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
557
tests/unit/test_prefill_t2_debug.cu
Normal file
557
tests/unit/test_prefill_t2_debug.cu
Normal file
@@ -0,0 +1,557 @@
|
||||
/**
|
||||
* Debug test for B1 prefill kernel T>1 path.
|
||||
*
|
||||
* Tests T=2 N=128 step by step:
|
||||
* 1. Compute QK (noPE + RoPE) for 2 query rows
|
||||
* 2. Verify QK logits against CPU reference
|
||||
* 3. Compute softmax
|
||||
* 4. Compute PV and verify against CPU reference
|
||||
* 5. Full T=2 prefill vs CPU reference
|
||||
*/
|
||||
#include <cuda_runtime.h>
|
||||
#include <cuda_fp8.h>
|
||||
#include <cuda_fp8.hpp>
|
||||
#include <cstdio>
|
||||
#include <cmath>
|
||||
#include <cstdlib>
|
||||
#include <cassert>
|
||||
|
||||
// Include kernel headers
|
||||
#include "dsv4/kernels/attention/fmha_common.cuh"
|
||||
#include "dsv4/kernels/attention/fmha_umma_desc.cuh"
|
||||
#include "dsv4/kernels/attention/fmha_mixed_fp8_prefill.cuh"
|
||||
|
||||
using namespace dsv4::kernels::attention;
|
||||
|
||||
// ---- CPU reference functions ----
|
||||
|
||||
static void cpu_fp8_e4m3_quantize(const float* src, uint8_t* dst, float* scale,
|
||||
int rows, int cols) {
|
||||
for (int r = 0; r < rows; r++) {
|
||||
float amax = 0.0f;
|
||||
for (int c = 0; c < cols; c++) amax = fmaxf(amax, fabsf(src[r * cols + c]));
|
||||
float s = amax / 448.0f;
|
||||
if (s < 1e-12f) s = 1.0f;
|
||||
scale[r] = s;
|
||||
for (int c = 0; c < cols; c++) {
|
||||
float v = src[r * cols + c] / s;
|
||||
v = fmaxf(-448.0f, fminf(448.0f, v));
|
||||
__nv_fp8_e4m3 fp8; fp8.__x = 0;
|
||||
// Simplest quantize: round to FP8
|
||||
memcpy(&fp8, &v, 1); // This won't work, use proper conversion
|
||||
dst[r * cols + c] = 0; // placeholder
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
static float fp8_to_f32(uint8_t b) {
|
||||
__nv_fp8_e4m3 v; v.__x = b;
|
||||
return (float)v;
|
||||
}
|
||||
|
||||
static bf16_t f32_to_bf16_host(float f) {
|
||||
uint32_t u; memcpy(&u, &f, 4);
|
||||
uint16_t h = (u + 0x8000) >> 16;
|
||||
return h;
|
||||
}
|
||||
|
||||
static float bf16_to_f32_host(bf16_t h) {
|
||||
uint32_t u = (uint32_t)h << 16;
|
||||
float f; memcpy(&f, &u, 4);
|
||||
return f;
|
||||
}
|
||||
|
||||
// ---- Minimal T=2 kernel that prints intermediate values ----
|
||||
|
||||
__global__ void prefill_t2_debug_kernel(
|
||||
const uint8_t* __restrict__ q_nope_fp8,
|
||||
const float* __restrict__ q_nope_scale,
|
||||
const bf16_t* __restrict__ q_rope_bf16,
|
||||
const uint8_t* __restrict__ k_nope_fp8,
|
||||
const float* __restrict__ k_nope_scale,
|
||||
const bf16_t* __restrict__ k_rope_bf16,
|
||||
int T, int N, int HD, int NOPE, int ROPE,
|
||||
float scale)
|
||||
{
|
||||
// Only one CTA for debug
|
||||
if (blockIdx.x > 0 || blockIdx.y > 0 || blockIdx.z > 0) return;
|
||||
|
||||
constexpr int SK_TILE = 128;
|
||||
constexpr int MMA_K_F8 = 32;
|
||||
constexpr int MMA_K_F16 = 16;
|
||||
constexpr int NKT_NOPE = 448 / MMA_K_F8; // 14
|
||||
constexpr int NKT_ROPE = 64 / MMA_K_F16; // 4
|
||||
constexpr int N_SUB = 512 / 16; // 32
|
||||
constexpr int NKT_PV = SK_TILE / MMA_K_F16; // 8
|
||||
constexpr int TILE_F8 = 128 * MMA_K_F8; // 4096
|
||||
constexpr int TILE_F16 = 128 * MMA_K_F16; // 2048
|
||||
constexpr int V_SUB_SZ = 16 * MMA_K_F16; // 256
|
||||
constexpr int TMEM_COLS = 512;
|
||||
constexpr int T_ACT = 2;
|
||||
|
||||
const int tid = threadIdx.x;
|
||||
const int wid = tid >> 5;
|
||||
const int lane = tid & 31;
|
||||
const bool is_mma_warp = (wid == 4);
|
||||
|
||||
extern __shared__ __align__(128) char sbuf[];
|
||||
size_t off = 0;
|
||||
uint32_t* sTmemBase = (uint32_t*)(sbuf + off); off += 4;
|
||||
off = (off + 127) & ~(size_t)127;
|
||||
uint8_t* sQ8 = (uint8_t*)(sbuf + off); off += TILE_F8;
|
||||
off = (off + 127) & ~(size_t)127;
|
||||
uint8_t* sK8 = (uint8_t*)(sbuf + off); off += TILE_F8;
|
||||
off = (off + 127) & ~(size_t)127;
|
||||
bf16_t* sQ16 = (bf16_t*)(sbuf + off); off += TILE_F16 * sizeof(bf16_t);
|
||||
off = (off + 127) & ~(size_t)127;
|
||||
bf16_t* sK16 = (bf16_t*)(sbuf + off); off += TILE_F16 * sizeof(bf16_t);
|
||||
off = (off + 127) & ~(size_t)127;
|
||||
bf16_t* sPk = (bf16_t*)(sbuf + off); off += TILE_F16 * sizeof(bf16_t);
|
||||
off = (off + 127) & ~(size_t)127;
|
||||
bf16_t* sV = (bf16_t*)(sbuf + off); off += V_SUB_SZ * sizeof(bf16_t);
|
||||
off = (off + 127) & ~(size_t)127;
|
||||
float* sLogits = (float*)(sbuf + off); off += T_ACT * SK_TILE * sizeof(float);
|
||||
float* sP = (float*)(sbuf + off); off += T_ACT * SK_TILE * sizeof(float);
|
||||
float* sOacc = (float*)(sbuf + off); off += T_ACT * HD * sizeof(float);
|
||||
float* sRunningMax = (float*)(sbuf + off); off += T_ACT * sizeof(float);
|
||||
float* sRunningSum = (float*)(sbuf + off); off += T_ACT * sizeof(float);
|
||||
|
||||
// TMEM alloc
|
||||
if (is_mma_warp) tmem_alloc((uint32_t)__cvta_generic_to_shared(sTmemBase), TMEM_COLS);
|
||||
asm volatile("fence.proxy.async.shared::cta;" ::: "memory");
|
||||
__syncthreads();
|
||||
uint32_t tb = *sTmemBase;
|
||||
|
||||
const uint32_t idesc_f8_qk = make_idesc_f8_e4m3(128, 128);
|
||||
const uint32_t idesc_f16_qk = make_idesc(128, 128);
|
||||
const uint32_t idesc_pv = make_idesc(128, 16);
|
||||
|
||||
// Init accumulators
|
||||
for (int i = tid; i < T_ACT * HD; i += blockDim.x) sOacc[i] = 0.0f;
|
||||
for (int t = tid; t < T_ACT; t += blockDim.x) {
|
||||
sRunningMax[t] = -INFINITY;
|
||||
sRunningSum[t] = 0.0f;
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// Single KV tile (N=128)
|
||||
const int kv_len = min(SK_TILE, N);
|
||||
|
||||
// ---- QK noPE: FP8 ----
|
||||
for (int kt = 0; kt < NKT_NOPE; kt++) {
|
||||
for (int i = tid; i < TILE_F8; i += blockDim.x) { sQ8[i] = 0; sK8[i] = 0; }
|
||||
__syncthreads();
|
||||
for (int r = tid; r < T_ACT; r += blockDim.x) {
|
||||
for (int c = 0; c < MMA_K_F8; c++) {
|
||||
int d = kt * MMA_K_F8 + c;
|
||||
if (d < NOPE) sQ8[_pfill_cidx_f8(r, c)] = q_nope_fp8[r * NOPE + d];
|
||||
}
|
||||
}
|
||||
for (int i = tid; i < kv_len * MMA_K_F8; i += blockDim.x) {
|
||||
int r = i / MMA_K_F8, c = i % MMA_K_F8;
|
||||
int d = kt * MMA_K_F8 + c;
|
||||
if (d < NOPE) sK8[_pfill_cidx_f8(r, c)] = k_nope_fp8[r * NOPE + d];
|
||||
}
|
||||
__syncthreads();
|
||||
if (is_mma_warp && lane == 0) {
|
||||
uint64_t dq = make_umma_desc_kmajor_none((uint32_t)__cvta_generic_to_shared(sQ8), 128);
|
||||
uint64_t dk = make_umma_desc_kmajor_none((uint32_t)__cvta_generic_to_shared(sK8), 128);
|
||||
umma_ss_f8f6f4(tb, dq, dk, idesc_f8_qk, kt > 0);
|
||||
asm volatile("tcgen05.fence::after_thread_sync;" ::: "memory");
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
asm volatile("fence.sc.gpu;" ::: "memory");
|
||||
__syncthreads();
|
||||
|
||||
// Read QK noPE
|
||||
prefill_read_qk_rows<SK_TILE>(tb, sLogits, T_ACT, kv_len);
|
||||
__syncthreads();
|
||||
|
||||
// Print QK noPE logits for rows 0,1 (first 8 values)
|
||||
if (tid == 0) {
|
||||
printf("QK noPE (row 0, first 8): ");
|
||||
for (int c = 0; c < 8; c++) printf("%.4f ", sLogits[0 * SK_TILE + c]);
|
||||
printf("\n");
|
||||
printf("QK noPE (row 1, first 8): ");
|
||||
for (int c = 0; c < 8; c++) printf("%.4f ", sLogits[1 * SK_TILE + c]);
|
||||
printf("\n");
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// Apply scales
|
||||
for (int r = tid; r < T_ACT; r += blockDim.x) {
|
||||
float q_s = q_nope_scale[r];
|
||||
for (int c = 0; c < kv_len; c++) {
|
||||
sLogits[r * SK_TILE + c] *= q_s * k_nope_scale[c];
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
if (tid == 0) {
|
||||
printf("QK noPE scaled (row 0, first 8): ");
|
||||
for (int c = 0; c < 8; c++) printf("%.4f ", sLogits[0 * SK_TILE + c]);
|
||||
printf("\n");
|
||||
printf("QK noPE scaled (row 1, first 8): ");
|
||||
for (int c = 0; c < 8; c++) printf("%.4f ", sLogits[1 * SK_TILE + c]);
|
||||
printf("\n");
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// ---- QK RoPE: BF16 ----
|
||||
for (int kt = 0; kt < NKT_ROPE; kt++) {
|
||||
for (int i = tid; i < TILE_F16; i += blockDim.x) { sQ16[i] = 0; sK16[i] = 0; }
|
||||
__syncthreads();
|
||||
for (int r = tid; r < T_ACT; r += blockDim.x) {
|
||||
for (int c = 0; c < MMA_K_F16; c++) {
|
||||
int d = kt * MMA_K_F16 + c;
|
||||
if (d < ROPE) sQ16[_pfill_cidx_bf16_128(r, c)] = q_rope_bf16[r * ROPE + d];
|
||||
}
|
||||
}
|
||||
for (int i = tid; i < kv_len * MMA_K_F16; i += blockDim.x) {
|
||||
int r = i / MMA_K_F16, c = i % MMA_K_F16;
|
||||
int d = kt * MMA_K_F16 + c;
|
||||
if (d < ROPE) sK16[_pfill_cidx_bf16_128(r, c)] = k_rope_bf16[(int64_t)r * ROPE + d];
|
||||
}
|
||||
__syncthreads();
|
||||
if (is_mma_warp && lane == 0) {
|
||||
uint64_t dq = make_umma_desc_kmajor_none((uint32_t)__cvta_generic_to_shared(sQ16), 128);
|
||||
uint64_t dk = make_umma_desc_kmajor_none((uint32_t)__cvta_generic_to_shared(sK16), 128);
|
||||
umma_ss_f16(tb, dq, dk, idesc_f16_qk, kt > 0);
|
||||
asm volatile("tcgen05.fence::after_thread_sync;" ::: "memory");
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
asm volatile("fence.sc.gpu;" ::: "memory");
|
||||
__syncthreads();
|
||||
|
||||
// Add RoPE to noPE
|
||||
prefill_read_qk_rows<SK_TILE>(tb, sP, T_ACT, kv_len);
|
||||
__syncthreads();
|
||||
for (int i = tid; i < T_ACT * kv_len; i += blockDim.x) {
|
||||
sLogits[i] += sP[i];
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
if (tid == 0) {
|
||||
printf("QK total (row 0, first 8): ");
|
||||
for (int c = 0; c < 8; c++) printf("%.4f ", sLogits[0 * SK_TILE + c] * scale);
|
||||
printf("\n");
|
||||
printf("QK total (row 1, first 8): ");
|
||||
for (int c = 0; c < 8; c++) printf("%.4f ", sLogits[1 * SK_TILE + c] * scale);
|
||||
printf("\n");
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// ---- Softmax ----
|
||||
for (int r = tid; r < T_ACT; r += blockDim.x) {
|
||||
float tile_max = -INFINITY;
|
||||
for (int c = 0; c < kv_len; c++)
|
||||
tile_max = fmaxf(tile_max, sLogits[r * SK_TILE + c] * scale);
|
||||
|
||||
float tile_sum = 0.0f;
|
||||
for (int c = 0; c < kv_len; c++) {
|
||||
float pv = expf(sLogits[r * SK_TILE + c] * scale - tile_max);
|
||||
sP[r * SK_TILE + c] = pv;
|
||||
tile_sum += pv;
|
||||
}
|
||||
for (int c = kv_len; c < SK_TILE; c++) sP[r * SK_TILE + c] = 0.0f;
|
||||
|
||||
float old_max = sRunningMax[r];
|
||||
float new_max = fmaxf(old_max, tile_max);
|
||||
float rescale_old = (old_max > -INFINITY) ? expf(old_max - new_max) : 0.0f;
|
||||
for (int d = 0; d < HD; d++) sOacc[r * HD + d] *= rescale_old;
|
||||
float rescale_new = expf(tile_max - new_max);
|
||||
sRunningSum[r] = sRunningSum[r] * rescale_old + tile_sum * rescale_new;
|
||||
sRunningMax[r] = new_max;
|
||||
|
||||
sLogits[r * SK_TILE] = rescale_new;
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
if (tid == 0) {
|
||||
printf("Softmax P (row 0, first 8): ");
|
||||
for (int c = 0; c < 8; c++) printf("%.6f ", sP[0 * SK_TILE + c]);
|
||||
printf(" sum=%.6f\n", sRunningSum[0]);
|
||||
printf("Softmax P (row 1, first 8): ");
|
||||
for (int c = 0; c < 8; c++) printf("%.6f ", sP[1 * SK_TILE + c]);
|
||||
printf(" sum=%.6f\n", sRunningSum[1]);
|
||||
printf("Rescale: row0=%.6f row1=%.6f\n", sLogits[0 * SK_TILE], sLogits[1 * SK_TILE]);
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// ---- PV: per query row ----
|
||||
for (int qr = 0; qr < T_ACT; qr++) {
|
||||
float p_rescale = sLogits[qr * SK_TILE];
|
||||
|
||||
if (tid == 0) printf("PV for qr=%d: p_rescale=%.6f\n", qr, p_rescale);
|
||||
|
||||
for (int n_sub = 0; n_sub < N_SUB; n_sub++) {
|
||||
int d_base = n_sub * 16;
|
||||
for (int pv_kt = 0; pv_kt < NKT_PV; pv_kt++) {
|
||||
const int col_start = pv_kt * MMA_K_F16;
|
||||
for (int i = tid; i < TILE_F16; i += blockDim.x) sPk[i] = 0;
|
||||
for (int i = tid; i < V_SUB_SZ; i += blockDim.x) sV[i] = 0;
|
||||
__syncthreads();
|
||||
|
||||
for (int c = tid; c < MMA_K_F16; c += blockDim.x) {
|
||||
int gc = col_start + c;
|
||||
sPk[_pfill_cidx_bf16_128(qr, c)] = f32_to_bf16(sP[qr * SK_TILE + gc]);
|
||||
}
|
||||
|
||||
for (int i = tid; i < 16 * MMA_K_F16; i += blockDim.x) {
|
||||
int dd = i / MMA_K_F16, kk = i % MMA_K_F16;
|
||||
int row = col_start + kk;
|
||||
int g_row = row;
|
||||
int d = d_base + dd;
|
||||
bf16_t vbits = 0;
|
||||
if (row < kv_len) {
|
||||
if (d < NOPE) {
|
||||
uint8_t b = k_nope_fp8[(int64_t)g_row * NOPE + d];
|
||||
float v = _prefill_fp8_to_f32(b) * k_nope_scale[g_row];
|
||||
vbits = f32_to_bf16(v);
|
||||
} else {
|
||||
vbits = k_rope_bf16[(int64_t)g_row * ROPE + (d - NOPE)];
|
||||
}
|
||||
}
|
||||
sV[_pfill_cidx_bf16_16(dd, kk)] = vbits;
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
bool first = (pv_kt == 0);
|
||||
if (is_mma_warp && lane == 0) {
|
||||
uint64_t dp = make_umma_desc_kmajor_none((uint32_t)__cvta_generic_to_shared(sPk), 128);
|
||||
uint64_t dv = make_umma_desc_kmajor_none((uint32_t)__cvta_generic_to_shared(sV), 16);
|
||||
umma_ss_f16(tb + n_sub * 16, dp, dv, idesc_pv, !first);
|
||||
asm volatile("tcgen05.fence::after_thread_sync;" ::: "memory");
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
}
|
||||
|
||||
// Read PV result for row qr
|
||||
asm volatile("fence.sc.gpu;" ::: "memory");
|
||||
__syncthreads();
|
||||
prefill_read_pv_all_subs<512, 32>(tb, qr, sOacc, p_rescale);
|
||||
__syncthreads();
|
||||
|
||||
// Print first few accumulated values
|
||||
if (tid == 0 && qr == 0) {
|
||||
printf("sOacc qr=0 (first 8): ");
|
||||
for (int d = 0; d < 8; d++) printf("%.6f ", sOacc[0 * HD + d]);
|
||||
printf("\n");
|
||||
}
|
||||
if (tid == 0 && qr == 1) {
|
||||
printf("sOacc qr=1 (first 8): ");
|
||||
for (int d = 0; d < 8; d++) printf("%.6f ", sOacc[1 * HD + d]);
|
||||
printf("\n");
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
// Normalize and print final output
|
||||
if (tid == 0) {
|
||||
printf("sRunningSum: row0=%.6f row1=%.6f\n", sRunningSum[0], sRunningSum[1]);
|
||||
printf("sRunningMax: row0=%.6f row1=%.6f\n", sRunningMax[0], sRunningMax[1]);
|
||||
printf("Final output row0 (first 8): ");
|
||||
for (int d = 0; d < 8; d++) printf("%.6f ", sOacc[0 * HD + d] / sRunningSum[0]);
|
||||
printf("\n");
|
||||
printf("Final output row1 (first 8): ");
|
||||
for (int d = 0; d < 8; d++) printf("%.6f ", sOacc[1 * HD + d] / sRunningSum[1]);
|
||||
printf("\n");
|
||||
|
||||
// Check for NaN
|
||||
bool has_nan0 = false, has_nan1 = false;
|
||||
for (int d = 0; d < HD; d++) {
|
||||
if (isnan(sOacc[0 * HD + d])) has_nan0 = true;
|
||||
if (isnan(sOacc[1 * HD + d])) has_nan1 = true;
|
||||
}
|
||||
printf("NaN check: row0=%s row1=%s\n", has_nan0 ? "YES" : "no", has_nan1 ? "YES" : "no");
|
||||
}
|
||||
|
||||
if (is_mma_warp) tmem_dealloc(tb, TMEM_COLS);
|
||||
}
|
||||
|
||||
int main() {
|
||||
constexpr int T = 2;
|
||||
constexpr int N = 128;
|
||||
constexpr int HD = 512;
|
||||
constexpr int NOPE = 448;
|
||||
constexpr int ROPE = 64;
|
||||
const float scale = 1.0f / sqrtf((float)HD);
|
||||
|
||||
printf("=== Prefill T=2 Debug Test ===\n");
|
||||
printf("T=%d N=%d HD=%d NOPE=%d ROPE=%d scale=%.6f\n", T, N, HD, NOPE, ROPE, scale);
|
||||
|
||||
// Generate random data on CPU, then upload
|
||||
srand(42);
|
||||
|
||||
// Q: (T, HD) FP32 → quantize noPE to FP8, keep RoPE as BF16
|
||||
float* h_q = (float*)malloc(T * HD * sizeof(float));
|
||||
for (int i = 0; i < T * HD; i++) h_q[i] = (float)rand() / RAND_MAX * 0.5f - 0.25f;
|
||||
|
||||
// K: (N, HD) FP32 → quantize noPE to FP8, keep RoPE as BF16
|
||||
float* h_k = (float*)malloc(N * HD * sizeof(float));
|
||||
for (int i = 0; i < N * HD; i++) h_k[i] = (float)rand() / RAND_MAX * 0.5f - 0.25f;
|
||||
|
||||
// Q noPE FP8 quantization (per-row scale)
|
||||
uint8_t* h_q_nope_fp8 = (uint8_t*)malloc(T * NOPE);
|
||||
float* h_q_nope_scale = (float*)malloc(T * sizeof(float));
|
||||
for (int r = 0; r < T; r++) {
|
||||
float amax = 0.0f;
|
||||
for (int c = 0; c < NOPE; c++) amax = fmaxf(amax, fabsf(h_q[r * HD + c]));
|
||||
float s = amax / 448.0f;
|
||||
if (s < 1e-12f) s = 1.0f;
|
||||
h_q_nope_scale[r] = s;
|
||||
for (int c = 0; c < NOPE; c++) {
|
||||
float v = h_q[r * HD + c] / s;
|
||||
v = fmaxf(-448.0f, fminf(448.0f, v));
|
||||
__nv_fp8_e4m3 fp8 = __nv_fp8_e4m3(v);
|
||||
h_q_nope_fp8[r * NOPE + c] = fp8.__x;
|
||||
}
|
||||
}
|
||||
|
||||
// Q RoPE BF16
|
||||
bf16_t* h_q_rope_bf16 = (bf16_t*)malloc(T * ROPE * sizeof(bf16_t));
|
||||
for (int r = 0; r < T; r++)
|
||||
for (int c = 0; c < ROPE; c++)
|
||||
h_q_rope_bf16[r * ROPE + c] = f32_to_bf16_host(h_q[r * HD + NOPE + c]);
|
||||
|
||||
// K noPE FP8 quantization
|
||||
uint8_t* h_k_nope_fp8 = (uint8_t*)malloc(N * NOPE);
|
||||
float* h_k_nope_scale = (float*)malloc(N * sizeof(float));
|
||||
for (int r = 0; r < N; r++) {
|
||||
float amax = 0.0f;
|
||||
for (int c = 0; c < NOPE; c++) amax = fmaxf(amax, fabsf(h_k[r * HD + c]));
|
||||
float s = amax / 448.0f;
|
||||
if (s < 1e-12f) s = 1.0f;
|
||||
h_k_nope_scale[r] = s;
|
||||
for (int c = 0; c < NOPE; c++) {
|
||||
float v = h_k[r * HD + c] / s;
|
||||
v = fmaxf(-448.0f, fminf(448.0f, v));
|
||||
__nv_fp8_e4m3 fp8 = __nv_fp8_e4m3(v);
|
||||
h_k_nope_fp8[r * NOPE + c] = fp8.__x;
|
||||
}
|
||||
}
|
||||
|
||||
// K RoPE BF16
|
||||
bf16_t* h_k_rope_bf16 = (bf16_t*)malloc(N * ROPE * sizeof(bf16_t));
|
||||
for (int r = 0; r < N; r++)
|
||||
for (int c = 0; c < ROPE; c++)
|
||||
h_k_rope_bf16[r * ROPE + c] = f32_to_bf16_host(h_k[r * HD + NOPE + c]);
|
||||
|
||||
// Upload to GPU
|
||||
uint8_t *d_q_nope_fp8, *d_k_nope_fp8;
|
||||
float *d_q_nope_scale, *d_k_nope_scale;
|
||||
bf16_t *d_q_rope_bf16, *d_k_rope_bf16;
|
||||
|
||||
cudaMalloc(&d_q_nope_fp8, T * NOPE);
|
||||
cudaMalloc(&d_q_nope_scale, T * sizeof(float));
|
||||
cudaMalloc(&d_q_rope_bf16, T * ROPE * sizeof(bf16_t));
|
||||
cudaMalloc(&d_k_nope_fp8, N * NOPE);
|
||||
cudaMalloc(&d_k_nope_scale, N * sizeof(float));
|
||||
cudaMalloc(&d_k_rope_bf16, N * ROPE * sizeof(bf16_t));
|
||||
|
||||
cudaMemcpy(d_q_nope_fp8, h_q_nope_fp8, T * NOPE, cudaMemcpyHostToDevice);
|
||||
cudaMemcpy(d_q_nope_scale, h_q_nope_scale, T * sizeof(float), cudaMemcpyHostToDevice);
|
||||
cudaMemcpy(d_q_rope_bf16, h_q_rope_bf16, T * ROPE * sizeof(bf16_t), cudaMemcpyHostToDevice);
|
||||
cudaMemcpy(d_k_nope_fp8, h_k_nope_fp8, N * NOPE, cudaMemcpyHostToDevice);
|
||||
cudaMemcpy(d_k_nope_scale, h_k_nope_scale, N * sizeof(float), cudaMemcpyHostToDevice);
|
||||
cudaMemcpy(d_k_rope_bf16, h_k_rope_bf16, N * ROPE * sizeof(bf16_t), cudaMemcpyHostToDevice);
|
||||
|
||||
// Compute CPU reference QK
|
||||
printf("\n=== CPU Reference QK ===\n");
|
||||
float ref_qk[2][128] = {};
|
||||
for (int r = 0; r < T; r++) {
|
||||
for (int c = 0; c < N; c++) {
|
||||
float dot = 0.0f;
|
||||
// noPE: FP8 dequant dot product
|
||||
for (int d = 0; d < NOPE; d++) {
|
||||
float qv = fp8_to_f32(h_q_nope_fp8[r * NOPE + d]) * h_q_nope_scale[r];
|
||||
float kv = fp8_to_f32(h_k_nope_fp8[c * NOPE + d]) * h_k_nope_scale[c];
|
||||
dot += qv * kv;
|
||||
}
|
||||
// RoPE: BF16 dot product
|
||||
for (int d = 0; d < ROPE; d++) {
|
||||
float qv = bf16_to_f32_host(h_q_rope_bf16[r * ROPE + d]);
|
||||
float kv = bf16_to_f32_host(h_k_rope_bf16[c * ROPE + d]);
|
||||
dot += qv * kv;
|
||||
}
|
||||
ref_qk[r][c] = dot * scale;
|
||||
}
|
||||
}
|
||||
printf("CPU ref QK (row 0, first 8): ");
|
||||
for (int c = 0; c < 8; c++) printf("%.4f ", ref_qk[0][c]);
|
||||
printf("\n");
|
||||
printf("CPU ref QK (row 1, first 8): ");
|
||||
for (int c = 0; c < 8; c++) printf("%.4f ", ref_qk[1][c]);
|
||||
printf("\n");
|
||||
|
||||
// Compute CPU reference softmax
|
||||
printf("\n=== CPU Reference Softmax + Attention ===\n");
|
||||
float ref_softmax[2][128] = {};
|
||||
for (int r = 0; r < T; r++) {
|
||||
float mx = ref_qk[r][0];
|
||||
for (int c = 1; c < N; c++) mx = fmaxf(mx, ref_qk[r][c]);
|
||||
float sm = 0.0f;
|
||||
for (int c = 0; c < N; c++) {
|
||||
ref_softmax[r][c] = expf(ref_qk[r][c] - mx);
|
||||
sm += ref_softmax[r][c];
|
||||
}
|
||||
for (int c = 0; c < N; c++) ref_softmax[r][c] /= sm;
|
||||
}
|
||||
printf("CPU ref softmax (row 0, first 8): ");
|
||||
for (int c = 0; c < 8; c++) printf("%.6f ", ref_softmax[0][c]);
|
||||
printf("\n");
|
||||
|
||||
// Compute CPU reference attention output
|
||||
float ref_out[2][512] = {};
|
||||
for (int r = 0; r < T; r++) {
|
||||
for (int d = 0; d < HD; d++) {
|
||||
float val = 0.0f;
|
||||
for (int c = 0; c < N; c++) {
|
||||
float kv;
|
||||
if (d < NOPE) {
|
||||
kv = fp8_to_f32(h_k_nope_fp8[c * NOPE + d]) * h_k_nope_scale[c];
|
||||
} else {
|
||||
kv = bf16_to_f32_host(h_k_rope_bf16[c * ROPE + (d - NOPE)]);
|
||||
}
|
||||
val += ref_softmax[r][c] * kv;
|
||||
}
|
||||
ref_out[r][d] = val;
|
||||
}
|
||||
}
|
||||
printf("CPU ref output (row 0, first 8): ");
|
||||
for (int d = 0; d < 8; d++) printf("%.6f ", ref_out[0][d]);
|
||||
printf("\n");
|
||||
printf("CPU ref output (row 1, first 8): ");
|
||||
for (int d = 0; d < 8; d++) printf("%.6f ", ref_out[1][d]);
|
||||
printf("\n");
|
||||
|
||||
// Launch debug kernel
|
||||
printf("\n=== GPU Kernel Execution ===\n");
|
||||
int smem_size = 200 * 1024; // ~149KB needed, stay under 232KB limit
|
||||
cudaFuncSetAttribute(prefill_t2_debug_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size);
|
||||
prefill_t2_debug_kernel<<<dim3(1,1,1), 192, smem_size>>>(
|
||||
d_q_nope_fp8, d_q_nope_scale, d_q_rope_bf16,
|
||||
d_k_nope_fp8, d_k_nope_scale, d_k_rope_bf16,
|
||||
T, N, HD, NOPE, ROPE, scale);
|
||||
cudaDeviceSynchronize();
|
||||
|
||||
cudaError_t err = cudaGetLastError();
|
||||
if (err != cudaSuccess) {
|
||||
printf("Kernel launch FAILED: %s\n", cudaGetErrorString(err));
|
||||
} else {
|
||||
printf("Kernel completed successfully.\n");
|
||||
}
|
||||
|
||||
// Cleanup
|
||||
cudaFree(d_q_nope_fp8); cudaFree(d_q_nope_scale); cudaFree(d_q_rope_bf16);
|
||||
cudaFree(d_k_nope_fp8); cudaFree(d_k_nope_scale); cudaFree(d_k_rope_bf16);
|
||||
free(h_q); free(h_k);
|
||||
free(h_q_nope_fp8); free(h_q_nope_scale); free(h_q_rope_bf16);
|
||||
free(h_k_nope_fp8); free(h_k_nope_scale); free(h_k_rope_bf16);
|
||||
|
||||
printf("\n=== Done ===\n");
|
||||
return 0;
|
||||
}
|
||||
348
tests/unit/test_production_fmha_layer.py
Normal file
348
tests/unit/test_production_fmha_layer.py
Normal file
@@ -0,0 +1,348 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Production FMHA layer comparison test — real model weights, real pipeline.
|
||||
|
||||
Strategy:
|
||||
1. Run the full production pipeline (single_shot_inference.py forward_layer)
|
||||
for all prefill tokens through layers 0-4.
|
||||
2. On the LAST prefill token, for each layer, ALSO run the reference FMHA
|
||||
(dequantize KV to BF16, run PyTorch SDPA) on the SAME gathered KV
|
||||
that the production kernel saw.
|
||||
3. Compare raw FMHA output (before inverse RoPE, before output projection).
|
||||
|
||||
This isolates the FMHA kernel's accuracy from the rest of the pipeline.
|
||||
|
||||
Production values: HD=512, NOPE=448, ROPE=64, H=128, 61 layers, 8 GPUs.
|
||||
"""
|
||||
import os, sys, json, math, time
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
CHECKPOINT_DIR = os.environ.get(
|
||||
"CHECKPOINT_DIR", "/root/nvidia-meeting/DeepSeek-V4-Pro-NVFP4")
|
||||
NUM_GPUS = int(os.environ.get("NUM_GPUS", "8"))
|
||||
DEVICE = "cuda:0"
|
||||
|
||||
|
||||
def cosine(a, b):
|
||||
return F.cosine_similarity(a.flatten().float(), b.flatten().float(), dim=0).item()
|
||||
|
||||
|
||||
def main():
|
||||
torch.manual_seed(42)
|
||||
print("=" * 70)
|
||||
print("PRODUCTION FMHA LAYER COMPARISON TEST")
|
||||
print("=" * 70)
|
||||
|
||||
with open(os.path.join(CHECKPOINT_DIR, "config.json")) as f:
|
||||
cfg = json.load(f)
|
||||
n_layers = cfg["num_hidden_layers"]
|
||||
H = cfg["hidden_size"]
|
||||
hd = cfg["head_dim"]
|
||||
n_h = cfg["num_attention_heads"]
|
||||
rd = cfg.get("qk_rope_head_dim", 64)
|
||||
nope_dim = hd - rd
|
||||
cr = cfg.get("compress_ratios", [128] * n_layers)
|
||||
print(f"Model: {n_layers} layers, {n_h} heads, hd={hd}, rope_dim={rd}")
|
||||
|
||||
from single_shot_inference import (
|
||||
load_all_weights, make_nvfp4_linear, get_nvfp4_weight,
|
||||
rmsnorm, unweighted_rmsnorm, _apply_rope, build_rope_cache,
|
||||
KVCache, Compressor, Indexer, forward_layer, moe_forward,
|
||||
_load_moe_weights_stacked, _load_shared_expert_weights,
|
||||
_cache_layer_weights_no_experts,
|
||||
)
|
||||
from dsv4.layers.mhc import mHCLayer, mHCContext
|
||||
from dsv4.layers.router import Router
|
||||
from dsv4.layers.moe import Nvfp4MoE
|
||||
from dsv4.layers.shared_expert import Nvfp4SharedExpert
|
||||
from dsv4.layers.grouped_linear import Nvfp4GroupedLinear
|
||||
from dsv4.layers.linear import Nvfp4Linear
|
||||
from dsv4.ops.quantize import (
|
||||
rmsnorm_quantize_nvfp4, mhc_rmsnorm_quantize_nvfp4, dequantize_nvfp4,
|
||||
quantize_to_nvfp4,
|
||||
)
|
||||
|
||||
print("Loading weights...")
|
||||
all_w = load_all_weights(CHECKPOINT_DIR)
|
||||
|
||||
TEST_LAYERS = 5
|
||||
o_groups = cfg.get("o_groups", 16)
|
||||
o_rank = cfg.get("o_lora_rank", 1024)
|
||||
n_ih = cfg.get("index_n_heads", 64)
|
||||
ihd = cfg.get("index_head_dim", 128)
|
||||
itk = cfg.get("index_topk", 1024)
|
||||
|
||||
rope_caches = {g: build_rope_cache(65536, rd, f"cuda:{g}", 10000., "yarn", 16., 4096, 32, 1)
|
||||
for g in range(NUM_GPUS)}
|
||||
|
||||
# Build all production components (same as single_shot main())
|
||||
prod_lins, attn_mhcs, ffn_mhcs = {}, {}, {}
|
||||
attn_norms, ffn_norms = {}, {}
|
||||
compressors, indexers, kv_caches = {}, {}, {}
|
||||
routers, moe_runners, se_runners = {}, {}, {}
|
||||
|
||||
for li in range(TEST_LAYERS):
|
||||
gpu = li % NUM_GPUS
|
||||
dev = f"cuda:{gpu}"
|
||||
torch.cuda.set_device(gpu)
|
||||
pfx = f"model.layers.{li}.self_attn"
|
||||
mlp_pfx = f"model.layers.{li}.mlp"
|
||||
ratio = cr[li] if li < len(cr) else 128
|
||||
|
||||
# Attention linears
|
||||
pl = {}
|
||||
pl['q_a'] = make_nvfp4_linear(H, 1536, dev, all_w, pfx, 'q_a_proj')
|
||||
pl['q_b'] = make_nvfp4_linear(1536, H * hd, dev, all_w, pfx, 'q_b_proj')
|
||||
pl['kv'] = make_nvfp4_linear(H, hd, dev, all_w, pfx, 'kv_proj')
|
||||
hpg = n_h // o_groups
|
||||
wo_a = Nvfp4GroupedLinear(n_local_groups=o_groups, heads_per_group=hpg,
|
||||
head_dim=hd, o_lora_rank=o_rank, max_num_tokens=8192, device=dev)
|
||||
oa_w, oa_ws, oa_ws2, oa_isc = get_nvfp4_weight(all_w, pfx, 'o_a_proj')
|
||||
if oa_w is not None and oa_ws is not None:
|
||||
wo_a.load_nvfp4_weight(oa_w.to(dev), oa_ws.to(dev),
|
||||
oa_ws2.to(dev) if oa_ws2 is not None else None,
|
||||
oa_isc.to(dev) if oa_isc is not None else None)
|
||||
else:
|
||||
oa_bf = all_w.get(f"{pfx}.o_a_proj.weight")
|
||||
if oa_bf is not None:
|
||||
wo_a.set_bf16_weight(oa_bf.bfloat16().to(dev))
|
||||
pl['o_a'] = wo_a; wo_a._use_runtime_gsa = True
|
||||
pl['o_b'] = make_nvfp4_linear(o_groups * o_rank, H, dev, all_w, pfx, 'o_b_proj')
|
||||
prod_lins[li] = pl
|
||||
|
||||
# mHC
|
||||
for tag, blocks, fn_s, base_s, scale_s in [
|
||||
("attn", attn_mhcs, f"model.layers.{li}.attn_hc.fn",
|
||||
f"model.layers.{li}.attn_hc.base", f"model.layers.{li}.attn_hc.scale"),
|
||||
("ffn", ffn_mhcs, f"model.layers.{li}.ffn_hc.fn",
|
||||
f"model.layers.{li}.ffn_hc.base", f"model.layers.{li}.ffn_hc.scale"),
|
||||
]:
|
||||
fn, base, scale = all_w.get(fn_s), all_w.get(base_s), all_w.get(scale_s)
|
||||
if fn is not None and base is not None and scale is not None:
|
||||
m = mHCLayer(hidden_dim=H, n_hc=4, t_max_sinkhorn=20, device=dev)
|
||||
n = 4
|
||||
m.load_weights(
|
||||
W_pre=fn[0:n].to(dev, torch.float32), W_post=fn[n:2*n].to(dev, torch.float32),
|
||||
W_comb=fn[2*n:].to(dev, torch.float32),
|
||||
S_pre=base[0:n].reshape(1, n).to(dev, torch.float32),
|
||||
S_post=base[n:2*n].reshape(n, 1).to(dev, torch.float32),
|
||||
S_comb=base[2*n:].reshape(n, n).to(dev, torch.float32),
|
||||
alpha_pre=scale[0].item(), alpha_post=scale[1].item(), alpha_comb=scale[2].item())
|
||||
blocks[li] = m
|
||||
|
||||
an_k = f"model.layers.{li}.input_layernorm.weight"
|
||||
if an_k in all_w: attn_norms[li] = all_w[an_k].to(dev, torch.float32)
|
||||
fn_k = f"model.layers.{li}.post_attention_layernorm.weight"
|
||||
if fn_k in all_w: ffn_norms[li] = all_w[fn_k].to(dev, torch.float32)
|
||||
|
||||
max_comp = (8192 + ratio - 1) // ratio if ratio > 0 else 0
|
||||
kv_caches[li] = KVCache(hd, cfg.get("sliding_window", 128), max_comp=max_comp,
|
||||
device=dev, indexer_key_dim=ihd, compress_ratio=ratio, indexer_top_k=itk, rope_dim=rd)
|
||||
if ratio > 0: compressors[li] = Compressor(ratio, hd, H, dev)
|
||||
if ratio == 4: indexers[li] = Indexer(n_ih, ihd, itk, dev)
|
||||
|
||||
# Router
|
||||
is_hash = (li < cfg.get("num_hash_layers", 3)) and (f"{mlp_pfx}.gate.tid2eid" in all_w)
|
||||
router = Router(hidden_size=H, num_experts=cfg["n_routed_experts"],
|
||||
top_k=cfg.get("num_experts_per_tok", 6),
|
||||
routed_scaling_factor=cfg.get("routed_scaling_factor", 2.5),
|
||||
mode="hash" if is_hash else "dense",
|
||||
vocab_size=cfg.get("vocab_size", 128000) if is_hash else None, device=dev)
|
||||
if is_hash:
|
||||
router.load_weights(hash_lut=all_w[f"{mlp_pfx}.gate.tid2eid"].to(dev, torch.int32))
|
||||
else:
|
||||
eb = all_w.get(f"{mlp_pfx}.gate.e_score_correction_bias")
|
||||
gate_w, gate_ws, gate_ws2, gate_isc = get_nvfp4_weight(all_w, mlp_pfx, 'gate')
|
||||
E = cfg["n_routed_experts"]
|
||||
if gate_w is not None and gate_ws is not None:
|
||||
gate_lin = Nvfp4Linear(in_features=H, out_features=E, device=dev)
|
||||
gate_lin.fp4 = [gate_w.to(dev).view(torch.float4_e2m1fn_x2) if gate_w.dtype == torch.uint8 else gate_w.to(dev)]
|
||||
gate_lin.sf = [gate_ws.to(dev)]
|
||||
ws2_v = gate_ws2.float().item() if gate_ws2 is not None else 1.0
|
||||
isc_v = gate_isc.float().item() if gate_isc is not None else 1.0/(6.0*448.0)
|
||||
gate_lin.gs = [1.0]
|
||||
gate_lin.ws2 = [torch.tensor([ws2_v], device=dev, dtype=torch.float32)]
|
||||
gate_lin._activation_global_scale = isc_v
|
||||
gate_lin._use_runtime_gsa = True
|
||||
gate_lin.finalize_weights()
|
||||
router.load_nvfp4_gate(gate_lin)
|
||||
router.load_weights(e_bias=eb.to(dev, torch.float32))
|
||||
else:
|
||||
# BF16 gate weight — quantize to NVFP4
|
||||
gw = all_w.get(f"{mlp_pfx}.gate.weight")
|
||||
if gw is not None:
|
||||
g_bf16 = gw if gw.shape == (E, H) else gw.T.contiguous()
|
||||
g_bf16 = g_bf16.bfloat16().to(dev)
|
||||
g_fp4, g_sf, g_gs = quantize_to_nvfp4(g_bf16)
|
||||
gate_lin = Nvfp4Linear(in_features=H, out_features=E, device=dev)
|
||||
gate_lin.fp4 = [g_fp4]
|
||||
gate_lin.sf = [g_sf]
|
||||
gate_lin.gs = [g_gs]
|
||||
gate_lin.ws2 = [torch.tensor([g_gs], device=dev, dtype=torch.float32)]
|
||||
gate_lin._activation_global_scale = 1.0 / (6.0 * 448.0)
|
||||
gate_lin._use_runtime_gsa = True
|
||||
gate_lin.finalize_weights()
|
||||
router.load_nvfp4_gate(gate_lin)
|
||||
router.load_weights(e_bias=eb.to(dev, torch.float32))
|
||||
router.finalize_weights(); routers[li] = router
|
||||
|
||||
moe = Nvfp4MoE(num_experts=cfg["n_routed_experts"], hidden_size=H,
|
||||
intermediate_size=cfg.get("moe_intermediate_size", 3072),
|
||||
top_k=cfg.get("num_experts_per_tok", 6), device=dev)
|
||||
moe.set_swiglu_limit(cfg.get("swiglu_limit", 10.0)); moe.set_fused_swiglu(True)
|
||||
_load_moe_weights_stacked(all_w, li, mlp_pfx, dev, moe, cfg)
|
||||
moe._ensure_stacked(); moe._use_runtime_gsa = True; moe_runners[li] = moe
|
||||
|
||||
se = Nvfp4SharedExpert(hidden_size=H, intermediate_size=cfg.get("moe_intermediate_size", 3072),
|
||||
device=dev, swiglu_limit=cfg.get("swiglu_limit", 10.0))
|
||||
se.set_fused_swiglu(True)
|
||||
_load_shared_expert_weights(all_w, li, mlp_pfx, dev, se, cfg)
|
||||
se._ensure_initialized(); se._use_runtime_gsa = True; se_runners[li] = se
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
for li in range(TEST_LAYERS):
|
||||
pfx = f"model.layers.{li}.self_attn.compressor"
|
||||
dev = f"cuda:{li % NUM_GPUS}"
|
||||
if li in compressors: compressors[li].load(all_w, pfx, dev=dev)
|
||||
if li in indexers: indexers[li].load(all_w, f"{pfx}.indexer", dev=dev)
|
||||
print("Components built")
|
||||
|
||||
# Embedding + tokenizer
|
||||
from transformers import AutoTokenizer
|
||||
tokenizer = AutoTokenizer.from_pretrained(CHECKPOINT_DIR)
|
||||
bos = tokenizer.bos_token_id or 0
|
||||
USER_TOKEN, ASSISTANT_TOKEN, THINK_START = 128803, 128804, 128821
|
||||
input_ids = [bos, USER_TOKEN]
|
||||
input_ids += tokenizer.encode('\n\nThe capital of France is', add_special_tokens=False)
|
||||
input_ids.append(ASSISTANT_TOKEN); input_ids.append(THINK_START)
|
||||
print(f"Input: {len(input_ids)} tokens")
|
||||
|
||||
torch.cuda.set_device(0)
|
||||
embed_w = all_w.get("model.embed_tokens.weight")
|
||||
embed = torch.nn.Embedding.from_pretrained(embed_w.bfloat16().to(DEVICE))
|
||||
devs_list = [f"cuda:{g}" for g in range(NUM_GPUS)]
|
||||
layer_w = _cache_layer_weights_no_experts(all_w, TEST_LAYERS, devs_list)
|
||||
del all_w; import gc; gc.collect()
|
||||
for g in range(NUM_GPUS): torch.cuda.set_device(g); torch.cuda.empty_cache()
|
||||
torch.cuda.set_device(0)
|
||||
|
||||
# ================================================================
|
||||
# PHASE 1: Run full production pipeline to populate KV caches
|
||||
# ================================================================
|
||||
print(f"\nPhase 1: Populating KV caches...")
|
||||
for pi, tid_val in enumerate(input_ids):
|
||||
t1 = time.time()
|
||||
tid = torch.tensor([tid_val], dtype=torch.long, device=DEVICE)
|
||||
pos = torch.tensor([pi], dtype=torch.long, device=DEVICE)
|
||||
tid32 = torch.tensor([tid_val], dtype=torch.int32, device=DEVICE)
|
||||
|
||||
X = mHCLayer.init_state(embed(tid))
|
||||
for li in range(TEST_LAYERS):
|
||||
gpu = li % NUM_GPUS
|
||||
if X.device != torch.device(f"cuda:{gpu}"): X = X.to(f"cuda:{gpu}")
|
||||
torch.cuda.set_device(gpu)
|
||||
X = forward_layer(X, layer_w[li], li, cfg, *rope_caches[gpu],
|
||||
attn_mhcs.get(li), ffn_mhcs.get(li), attn_norms.get(li), ffn_norms.get(li),
|
||||
kv_caches[li], pos, tid32, compressors.get(li), indexers.get(li),
|
||||
moe_runners.get(li), se_runners.get(li), routers.get(li),
|
||||
prod_lin=prod_lins.get(li), _use_fused_rmsnorm_quantize=True)
|
||||
if pi % 5 == 0:
|
||||
print(f" Token {pi}/{len(input_ids)}: {time.time()-t1:.2f}s", flush=True)
|
||||
|
||||
# ================================================================
|
||||
# PHASE 2: For each layer, gather KV, run production FMHA, compare vs SDPA
|
||||
# ================================================================
|
||||
print(f"\nPhase 2: FMHA comparison per layer...")
|
||||
results = {}
|
||||
|
||||
for li in range(TEST_LAYERS):
|
||||
gpu = li % NUM_GPUS
|
||||
dev = f"cuda:{gpu}"
|
||||
torch.cuda.set_device(gpu)
|
||||
ratio = cr[li] if li < len(cr) else 128
|
||||
k_cache = kv_caches[li]
|
||||
|
||||
# Gather KV in mixed format (same as production path)
|
||||
if k_cache.n_comp > 0:
|
||||
if ratio > 4:
|
||||
kv_nope_fp8, kv_nope_scale, kv_rope_bf16 = k_cache.gather_mixed_all()
|
||||
else:
|
||||
kv_nope_fp8, kv_nope_scale, kv_rope_bf16 = k_cache.gather_mixed_swa_only()
|
||||
else:
|
||||
kv_nope_fp8, kv_nope_scale, kv_rope_bf16 = k_cache.gather_mixed_swa_only()
|
||||
|
||||
seq_len = kv_nope_scale.shape[0]
|
||||
if seq_len == 0:
|
||||
print(f" L{li}: SKIPPED (seq_len=0)")
|
||||
continue
|
||||
|
||||
# Generate a test Q (random, on this GPU)
|
||||
q_bf16 = torch.randn(1, n_h, 1, hd, dtype=torch.bfloat16, device=dev) * 0.5
|
||||
|
||||
# 1. Run production mixed FP8 FMHA
|
||||
from dsv4.kernels.attention.fmha_mixed_fp8_op import fmha_mixed_fp8_decode_raw
|
||||
scale_val = 1.0 / math.sqrt(hd)
|
||||
try:
|
||||
o_prod, lse_prod = fmha_mixed_fp8_decode_raw(
|
||||
q_bf16, kv_nope_fp8, kv_nope_scale, kv_rope_bf16, scale_val, rope_dim=rd)
|
||||
except Exception as e:
|
||||
print(f" L{li}: PROD FMHA FAILED: {e}")
|
||||
results[li] = {'cos': -1.0, 'error': str(e)}
|
||||
continue
|
||||
|
||||
# 2. Reference: dequantize KV, run SDPA
|
||||
nope_dequant = kv_nope_fp8.view(torch.float8_e4m3fn).float() * kv_nope_scale.unsqueeze(-1).float()
|
||||
kv_full = torch.cat([nope_dequant.bfloat16(), kv_rope_bf16], dim=-1) # (N, hd)
|
||||
k_4d = kv_full.unsqueeze(0).unsqueeze(0).expand(1, 1, -1, -1) # (1, 1, N, hd)
|
||||
v_4d = k_4d.clone()
|
||||
o_ref = F.scaled_dot_product_attention(q_bf16, k_4d, v_4d, scale=scale_val) # (1, H, 1, hd)
|
||||
|
||||
# 3. Compare
|
||||
cos_val = cosine(o_prod, o_ref)
|
||||
mag_prod = o_prod.float().abs().max().item()
|
||||
mag_ref = o_ref.float().abs().max().item()
|
||||
|
||||
# Per-head cosine
|
||||
o_prod_h = o_prod.float().squeeze(2) # (1, H, hd) → (H, hd) after squeeze
|
||||
o_ref_h = o_ref.float().squeeze(2)
|
||||
if o_prod_h.dim() == 3: o_prod_h = o_prod_h.squeeze(0)
|
||||
if o_ref_h.dim() == 3: o_ref_h = o_ref_h.squeeze(0)
|
||||
per_head_cos = F.cosine_similarity(o_prod_h, o_ref_h, dim=-1)
|
||||
min_head = per_head_cos.min().item()
|
||||
mean_head = per_head_cos.mean().item()
|
||||
|
||||
results[li] = {
|
||||
'cos': cos_val, 'mag_prod': mag_prod, 'mag_ref': mag_ref,
|
||||
'seq_len': seq_len, 'ratio': ratio,
|
||||
'min_head_cos': min_head, 'mean_head_cos': mean_head,
|
||||
}
|
||||
status = "PASS" if cos_val >= 0.999 else "FAIL"
|
||||
print(f" L{li}: {status} cos={cos_val:.6f} min_head={min_head:.6f} mean_head={mean_head:.6f} "
|
||||
f"|prod|={mag_prod:.4f} |ref|={mag_ref:.4f} seq_len={seq_len} ratio={ratio}")
|
||||
|
||||
if cos_val < 0.999:
|
||||
worst = per_head_cos.argsort()[:5]
|
||||
print(f" Worst heads: {worst.tolist()} cos={[f'{c:.4f}' for c in per_head_cos[worst].tolist()]}")
|
||||
|
||||
# Summary
|
||||
print("\n" + "=" * 70)
|
||||
print("SUMMARY")
|
||||
print("=" * 70)
|
||||
all_pass = True
|
||||
for li in sorted(results.keys()):
|
||||
r = results[li]
|
||||
c = r.get('cos', -1.0)
|
||||
status = "PASS" if c >= 0.999 else "FAIL"
|
||||
if c < 0.999: all_pass = False
|
||||
print(f" L{li}: {status} cos={c:.6f} seq={r.get('seq_len','?')} ratio={r.get('ratio','?')}")
|
||||
|
||||
print()
|
||||
if all_pass:
|
||||
print("ALL PASSED (cos >= 0.999)")
|
||||
else:
|
||||
print("SOME FAILED — see per-layer output above")
|
||||
return 0 if all_pass else 1
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(main())
|
||||
Reference in New Issue
Block a user