P5 integration + B3 q_a_norm fused + gsa scalar fix
P5: Wire up fused mHC pre_block + RMSNorm + NVFP4 quantize kernel - Replaces: pre_block bmm + rmsnorm (4+ launches) + quantize (2 launches) - With: 2 kernel launches (mhc_rmsnorm_amax_gsa + mhc_rmsnorm_quantize_nvfp4) - Both attn and ffn mHC paths now use P5 fused kernel - Savings: ~5 launches/site × 2 sites × 61 layers = 610 launches/token B3: Fused rmsnorm+quant for q_a_norm → q_b path - q_a output → rmsnorm_quantize_nvfp4 → QuantizedActivation → q_b.run_from_quantized - Eliminates BF16 round-trip between q_a_norm and q_b GEMM - Saves: ~6 kernel launches per layer (rmsnorm 4+ + quantize 2 vs fused 2) gsa scalar fix in Nvfp4Linear.run_from_quantized: - CuTeDSL NVFP4 GEMM expects global_scale_a as per-expert scalar (shape (1,)) - Per-row gsa from fused kernels must be reduced to scalar (max) for M>1 - For M=1 decode: already scalar, no reduction needed - Fixes potential correctness issue at prefill (M>1) when using fused paths Cleanup: Remove --ab-compare flag and A/B comparison code (replaced by P5)
This commit is contained in:
103
CORRRECTNESS_BACKLOG.md
Normal file
103
CORRRECTNESS_BACKLOG.md
Normal file
@@ -0,0 +1,103 @@
|
||||
# WE ARE BACKLOGGING THIS ISSUE AND WILL REVIST IT AFTER WE FINISH THE OTHER ITEMS IN THE FINAL STRETCH
|
||||
|
||||
**Context:** post-cleanup `single_shot_inference.py` compiles, Paris is top-1 at step 0, output is coherent, then degenerates into repeated junk ("capital ..."). Defaults in effect: `temperature=0.6`, `repetition_penalty=1.1`, `--warmup-gsa` **off**, fused rmsnorm+quant **on**.
|
||||
|
||||
**Read this first — what the symptom rules in/out.** The model is *sampling* (temp 0.6) with a penalty (1.1) and still loops a near-constant token. That is not "greedy with no penalty." It means either (a) the model finished its turn, emitted a stop token we don't catch, and the LM head is now peaked on degenerate filler, or (b) a decode-state correctness bug whose error compounds over steps. (a) is far more likely and is nearly free to test, so **do Part A in order, cheapest first, and do NOT touch kernels until A1–A2 are ruled out.** The math is right (Paris top-1); don't go hunting kernel ghosts before eliminating the decoding-config causes.
|
||||
|
||||
Code is the source of truth. For any precision change, validate per-layer cosine against `dsv4/reference/` before trusting end-to-end output.
|
||||
|
||||
---
|
||||
|
||||
# PART A — Decode repetition (correctness). Do in order.
|
||||
|
||||
## A1 — Stop set (HIGH priority, ~zero effort, most likely the whole bug)
|
||||
`single_shot_inference.py:1571` stops only on `next_id == tokenizer.eos_token_id`. DSV4 is a reasoning/chat model; an assistant turn ends with a special token (`<|end_of_sentence|>`, and the turn structure also uses USER=128803 / ASSISTANT=128804 / `</think>`=128822). If the model's turn-end token isn't `eos_token_id`, decode never stops and degenerates exactly as observed.
|
||||
|
||||
**Diagnose (no code change):**
|
||||
1. Print `tokenizer.eos_token_id`, `tokenizer.eos_token`, and `tokenizer.special_tokens_map` once at startup.
|
||||
2. In the decode log, find the token id emitted at the moment output "should have finished." Decode it: `tokenizer.decode([id])`. If it's a special/end token not in the stop set, that's the bug.
|
||||
|
||||
**Fix:** build an explicit stop set and break on membership, e.g.:
|
||||
```python
|
||||
STOP_IDS = {tokenizer.eos_token_id}
|
||||
for t in ("<|end_of_sentence|>",): # add the real turn-end token name(s) for this checkpoint
|
||||
tid = tokenizer.convert_tokens_to_ids(t)
|
||||
if tid is not None and tid >= 0: STOP_IDS.add(tid)
|
||||
STOP_IDS.add(USER_TOKEN) # model trying to open a new user turn = it's done
|
||||
# ... in the loop:
|
||||
if next_id in STOP_IDS:
|
||||
print(f" STOP ({next_id}) at step {step}", flush=True); break
|
||||
```
|
||||
**A/B:** if adding the stop set ends generation cleanly, the "bug" was never in the kernels. Stop here.
|
||||
|
||||
## A2 — Sampler / penalty sanity (MEDIUM, cheap, diagnostic)
|
||||
A 1.1 penalty over `recent_tokens=all_tokens[-256:]` should at least *perturb* a single-token loop. If it loops the exact same id anyway, suspect the penalty isn't reaching the kernel or is mis-indexed.
|
||||
- **Test:** rerun with `--repetition-penalty 1.5`. If the loop is *unchanged*, the penalty path in `dsv4/model/sampler.py` (CUDASampler) is broken — verify `recent_tokens` is actually passed to and applied by the kernel, and that it indexes the logit vector correctly. If raising it *does* break the loop, the sampler is fine and this was a stop-token/decoding-hygiene issue (see A1).
|
||||
- Also confirm `recent_tokens` includes the *prompt* tokens, not just generated ones, or the model can loop on a prompt word ("capital") penalty-free.
|
||||
|
||||
## A3 — Compressed/SWA visible-range parity (MEDIUM, architectural — verify vs reference)
|
||||
During decode the query attends to `[top-k compressed entries] ++ [SWA window]` (`forward_attention`, "5. Gather KV"). Two things to verify against the HF/`dsv4/reference` oracle, because an off-by-one here causes subtle wrongness that **compounds across decode steps** (coherent early, degenerate late — matches the symptom):
|
||||
1. **Which compressed blocks are visible to a decode query.** Causality: a query must see only compressed blocks strictly *preceding* its own current (incomplete) block, never its own or future blocks. Confirm the set of compressed indices fed to the FMHA at step `s` matches the reference exactly.
|
||||
2. **SWA / compressed overlap.** The most recent tokens are in the SWA ring (`ws=128`) *and* may also be inside the newest complete compressed block → the query can attend to both representations of the same tokens. This may be intended (SWA refines what compression blurred, and the model was trained with it) — but it must match how the reference gathers, or the recent-context weighting drifts. Diff the gathered key set (indices + count) against the reference for a fixed prompt at several decode positions.
|
||||
|
||||
> Note: the residual `|X|` growing to ~244–372 is **expected** (the paper notes 300–500; your own `KVCache` docstring says the same). It is not by itself the bug. See B5 only if A1–A3 don't resolve it.
|
||||
|
||||
## A4 — (verify, likely fine) Inverse RoPE
|
||||
`forward_attention:783` applies `_apply_rope(attn_out, positions, ..., inverse=True)` at the **query** position, which is what converts the absolute positions carried by the summed KV into relative ones. This looked correct. Just confirm `inverse=True` negates the rotation angle (applies RoPE(−t)) and uses the query `positions` (not `comp_pos`). Only revisit if A1–A3 are clean and degeneration persists.
|
||||
|
||||
---
|
||||
|
||||
|
||||
When we were doing this we tried ALL OF THE ABOVE!!!!! and this is where we left off:
|
||||
|
||||
```
|
||||
Good, the greedy test (temp=0, 64 tokens) is prefilling. Let me wait for it to hit decode.Greedy with temp=0 and rep_penalty=1.0. The model is looping on "capital of France is" — just re-echoing the prompt. It never generates "Paris" or any real answer. This is NOT a sampler issue. The model's logits are fundamentally wrong.
|
||||
|
||||
The sampler and penalty are cleared. The degeneration is in the model/pipeline itself. Let me look at the step 0 output more carefully.Key finding: " Paris" (111) has logit 10.62, but the TOP-1 token at step 0 is "," with logit 17.4. That's a 7-point gap. The model should have " Paris" much higher — in a correctly working model, "Paris" after "The capital of France is" should be 25+ logit, not 10.
|
||||
|
||||
This is a pipeline correctness issue. The model is not producing the right logits. The step-0 degeneration confirms it's not about repetition or sampling — the logits themselves are wrong from the very first decode step.
|
||||
|
||||
Given the user's instruction to move to Part B if the sampler/prompt diagnostics don't resolve it, let me focus on the most impactful pipeline fix. The residual growing to 472-732 (|X|) is suspicious. Let me look at the P4 fused rmsnorm — specifically the `run_from_quantized` path vs the regular `run` path, since memory notes the P4 integration caused output quality regression.
|
||||
|
||||
The previous test with `--no-fused-rmsnorm` ALSO produced degenerate output, so P4 isn't the sole cause. But let me check the KEY difference: in the `run_from_quantized` path, `gsa` is passed per-row from the QuantizedActivation, while in the regular path, gsa is a scalar. The MEMORY.md notes this exact issue:
|
||||
|
||||
> **ROOT CAUSE HYPOTHESIS**: run_from_quantized passes per-row gsa to GEMM, but CuTeDSL NVFP4 GEMM expects scalar global_scale_aSo `run_from_quantized` passes per-row gsa (shape (M,)) as `global_scale_a`. Let me check what the GEMM expects.The `global_scale_a` is passed as a tensor to the CuTe kernel. For the single-expert case (Nvfp4Linear), `expert_offsets` has shape (1,) and `global_scale_a` would be the gsa. Let me check how the kernel uses it.There it is. The kernel loads `global_scale_a` using `global_scale_a.iterator + expert_idx`. This is a **per-expert** scalar — it indexes by expert_idx, not by row. For a single-expert case (Nvfp4Linear with expert_offsets of shape (1,)), expert_idx is always 0, so it reads `global_scale_a[0]` — a single scalar.
|
||||
|
||||
But `run_from_quantized` passes per-row gsa as `global_scale_a` with shape `(M,)`. The kernel only reads `global_scale_a[0]` (the first row's gsa), ignoring the rest. This means for decode (M=1), it works fine (one row, one gsa). But the issue is that the fused rmsnorm kernel computes gsa per row, and the GEMM only uses row 0's gsa.
|
||||
|
||||
Wait, for decode, M=1 always. So per-row vs scalar doesn't matter for the decode path. The issue would be in prefill where M>1. Let me re-examine...
|
||||
|
||||
Actually, the prefill is done one token at a time (the code explicitly says "one token at a time (decode-style; TODO: batched prefill)"). So M=1 throughout. The per-row gsa should be fine since there's only one row.
|
||||
|
||||
Let me look deeper at the actual numerical pipeline. The real question is: why are the logits wrong? Let me add diagnostic code to compare our pipeline output with a known reference at a specific layer.Let me run with --verbose 2 to get per-layer diagnostics.Still building MoE layers. Need to wait longer.**CRITICAL FINDING**: L1 FMHA cosine is 0.679. That's terrible. L2 FMHA cosine is 0.956 — also bad. The q_a projection is fine (0.999998), but the FMHA output is wrong. This is the root cause of the degeneration.
|
||||
|
||||
Let me get the L0 data too.L0 FMHA cos=0.955, L1 FMHA cos=0.679, L2 FMHA cos=0.956. L1 is particularly bad. This is the attention pipeline producing incorrect results.
|
||||
|
||||
The reference comparison is done against PyTorch SDPA on the same gathered KV. So the FMHA kernel itself is wrong for these attention patterns. Let me understand why L1 is 0.679 while L0 and L2 are ~0.95.
|
||||
|
||||
L0 and L2 are HCA layers (ratio=128), L1 is also HCA (the first two layers are HCA for Pro). Wait, let me check the layer schedule.So L0=HCA(128), L1=HCA(128), L2=CSA(4), L3=HCA(128), L4=CSA(4). L1 (HCA) has terrible FMHA cosine. The issue is in the FMHA when dealing with the compressed KV patterns.
|
||||
|
||||
Let me check what seq_len the FMHA sees at each layer. The verbose=2 diagnostic should show this in the forward_attention function.The cos=0.679 at L1 is devastating. Even though the magnitudes are similar (2.78 vs 2.94), the direction is wrong. This compounds through 61 layers.
|
||||
|
||||
Let me add more diagnostic output to the FMHA — specifically, print the shapes and seq_len being passed.Wait — `n_comp=0` is hardcoded! The production FMHA call always sets `n_comp=0`. But the attention pattern for HCA/CSA involves compressed entries + SWA window. The `n_comp` parameter is crucial for the sink bias and possibly for the FMHA dispatch logic.
|
||||
|
||||
Let me check what `n_comp` does in the production attention.The `n_comp` parameter is described as "reserved for future kernel integration" and it's always 0. The sink_bias is also "reserved for future kernel integration". So the production FMHA is running WITHOUT the D5c sink bias that the paper describes as essential for DSV4 attention.
|
||||
|
||||
But more critically, the cos=0.679 at L1 suggests the FMHA itself is producing wrong results even without sink bias. Let me look at the reference comparison code — it compares the production FMHA output against PyTorch SDPA on the SAME gathered KV. So the KV is the same, but the kernel produces a different result.
|
||||
|
||||
Wait — the reference comparison is in the forward_attention function at `VERBOSE >= 2`. Let me look at that code more carefully.The reference uses `v_exp = k_exp.clone()` — this means V = K. In DSV4, the KV cache stores combined K/V where K and V are the same tensor (MQA). That's correct for the reference comparison.
|
||||
|
||||
But the production FMHA also uses V = K:
|
||||
```python
|
||||
v = k # Same tensor, no copy
|
||||
```
|
||||
|
||||
So the reference and production should agree. The cos=0.679 means the kernel is numerically wrong for this specific attention pattern.
|
||||
|
||||
Let me check what seq_len the FMHA is seeing. At L1 during prefill of the first token, the KV cache would have very few entries. Let me add seq_len printing.
|
||||
```
|
||||
|
||||
SO SINCE WE HAD TO TOUCH FMHA ANYWAY IN PART B. WE DID THAT FIRST AND TRIED TO GET THAT CORRECT BEFORE WE REVISTED THIS ISSUE!!!
|
||||
|
||||
## Suggested sequence (we shouldve already tried all of these)
|
||||
A1 (stop set) → A2 (penalty test) → if still broken: A3 (visible-range parity vs reference) → A4 (inverse-RoPE check). Then
|
||||
@@ -1,53 +1,6 @@
|
||||
# DSV4 Audit — Decode Repetition + Precision / Tensor-Core Plan
|
||||
|
||||
**Context:** post-cleanup `single_shot_inference.py` compiles, Paris is top-1 at step 0, output is coherent, then degenerates into repeated junk ("capital ..."). Defaults in effect: `temperature=0.6`, `repetition_penalty=1.1`, `--warmup-gsa` **off**, fused rmsnorm+quant **on**.
|
||||
|
||||
**Read this first — what the symptom rules in/out.** The model is *sampling* (temp 0.6) with a penalty (1.1) and still loops a near-constant token. That is not "greedy with no penalty." It means either (a) the model finished its turn, emitted a stop token we don't catch, and the LM head is now peaked on degenerate filler, or (b) a decode-state correctness bug whose error compounds over steps. (a) is far more likely and is nearly free to test, so **do Part A in order, cheapest first, and do NOT touch kernels until A1–A2 are ruled out.** The math is right (Paris top-1); don't go hunting kernel ghosts before eliminating the decoding-config causes.
|
||||
|
||||
Code is the source of truth. For any precision change, validate per-layer cosine against `dsv4/reference/` before trusting end-to-end output.
|
||||
|
||||
---
|
||||
|
||||
# PART A — Decode repetition (correctness). Do in order.
|
||||
|
||||
## A1 — Stop set (HIGH priority, ~zero effort, most likely the whole bug)
|
||||
`single_shot_inference.py:1571` stops only on `next_id == tokenizer.eos_token_id`. DSV4 is a reasoning/chat model; an assistant turn ends with a special token (`<|end_of_sentence|>`, and the turn structure also uses USER=128803 / ASSISTANT=128804 / `</think>`=128822). If the model's turn-end token isn't `eos_token_id`, decode never stops and degenerates exactly as observed.
|
||||
|
||||
**Diagnose (no code change):**
|
||||
1. Print `tokenizer.eos_token_id`, `tokenizer.eos_token`, and `tokenizer.special_tokens_map` once at startup.
|
||||
2. In the decode log, find the token id emitted at the moment output "should have finished." Decode it: `tokenizer.decode([id])`. If it's a special/end token not in the stop set, that's the bug.
|
||||
|
||||
**Fix:** build an explicit stop set and break on membership, e.g.:
|
||||
```python
|
||||
STOP_IDS = {tokenizer.eos_token_id}
|
||||
for t in ("<|end_of_sentence|>",): # add the real turn-end token name(s) for this checkpoint
|
||||
tid = tokenizer.convert_tokens_to_ids(t)
|
||||
if tid is not None and tid >= 0: STOP_IDS.add(tid)
|
||||
STOP_IDS.add(USER_TOKEN) # model trying to open a new user turn = it's done
|
||||
# ... in the loop:
|
||||
if next_id in STOP_IDS:
|
||||
print(f" STOP ({next_id}) at step {step}", flush=True); break
|
||||
```
|
||||
**A/B:** if adding the stop set ends generation cleanly, the "bug" was never in the kernels. Stop here.
|
||||
|
||||
## A2 — Sampler / penalty sanity (MEDIUM, cheap, diagnostic)
|
||||
A 1.1 penalty over `recent_tokens=all_tokens[-256:]` should at least *perturb* a single-token loop. If it loops the exact same id anyway, suspect the penalty isn't reaching the kernel or is mis-indexed.
|
||||
- **Test:** rerun with `--repetition-penalty 1.5`. If the loop is *unchanged*, the penalty path in `dsv4/model/sampler.py` (CUDASampler) is broken — verify `recent_tokens` is actually passed to and applied by the kernel, and that it indexes the logit vector correctly. If raising it *does* break the loop, the sampler is fine and this was a stop-token/decoding-hygiene issue (see A1).
|
||||
- Also confirm `recent_tokens` includes the *prompt* tokens, not just generated ones, or the model can loop on a prompt word ("capital") penalty-free.
|
||||
|
||||
## A3 — Compressed/SWA visible-range parity (MEDIUM, architectural — verify vs reference)
|
||||
During decode the query attends to `[top-k compressed entries] ++ [SWA window]` (`forward_attention`, "5. Gather KV"). Two things to verify against the HF/`dsv4/reference` oracle, because an off-by-one here causes subtle wrongness that **compounds across decode steps** (coherent early, degenerate late — matches the symptom):
|
||||
1. **Which compressed blocks are visible to a decode query.** Causality: a query must see only compressed blocks strictly *preceding* its own current (incomplete) block, never its own or future blocks. Confirm the set of compressed indices fed to the FMHA at step `s` matches the reference exactly.
|
||||
2. **SWA / compressed overlap.** The most recent tokens are in the SWA ring (`ws=128`) *and* may also be inside the newest complete compressed block → the query can attend to both representations of the same tokens. This may be intended (SWA refines what compression blurred, and the model was trained with it) — but it must match how the reference gathers, or the recent-context weighting drifts. Diff the gathered key set (indices + count) against the reference for a fixed prompt at several decode positions.
|
||||
|
||||
> Note: the residual `|X|` growing to ~244–372 is **expected** (the paper notes 300–500; your own `KVCache` docstring says the same). It is not by itself the bug. See B5 only if A1–A3 don't resolve it.
|
||||
|
||||
## A4 — (verify, likely fine) Inverse RoPE
|
||||
`forward_attention:783` applies `_apply_rope(attn_out, positions, ..., inverse=True)` at the **query** position, which is what converts the absolute positions carried by the summed KV into relative ones. This looked correct. Just confirm `inverse=True` negates the rotation angle (applies RoPE(−t)) and uses the query `positions` (not `comp_pos`). Only revisit if A1–A3 are clean and degeneration persists.
|
||||
|
||||
---
|
||||
|
||||
# PART B — Precision / NVFP4 / tensor-core (do AFTER Part A; never optimize a broken decode)
|
||||
# PART B — Precision / NVFP4 / tensor-core (WE ARE SKIPPING PART A FOR RIGHT NOW AND WILL REVISIT IT)
|
||||
|
||||
Goal: native NVFP4 where the math allows, FP8_E4M3 where it doesn't, BF16/FP32 only where required. Validate each change with per-layer cosine vs `dsv4/reference` before trusting it.
|
||||
|
||||
@@ -86,18 +39,19 @@ The mHC residual `X` is BF16 at `|X|≈300`, where BF16 ULP ≈ 2. This is proba
|
||||
|
||||
# PART C — Guardrails for the agent
|
||||
|
||||
1. **Order matters:** finish Part A (correctness) before any Part B (perf). A faster wrong decode is still wrong.
|
||||
2. **Every precision change is gated by a per-layer cosine vs `dsv4/reference`** for a fixed prompt, *before* judging end-to-end output. Record the cos in the commit message.
|
||||
3. **One change per commit**, with the A/B result. If a change drops end-to-end coherence, the per-layer cos tells you which layer/op regressed.
|
||||
4. **Don't re-create the dead indexer.** B2 is a new FP8/FP4 kernel; the `dsv4/kernels/indexer/*.cu` files are archived/dead — confirm with `helpers/import_closure.py` before reusing anything there.
|
||||
5. **Re-validate the stop fix (A1) on a long generation** (≥512 tokens) and a multi-turn prompt, not just "capital of France" — the turn-end token differs by prompt type.
|
||||
|
||||
## Suggested sequence
|
||||
A1 (stop set) → A2 (penalty test) → if still broken: A3 (visible-range parity vs reference) → A4 (inverse-RoPE check). Then B1 (FP8 FMHA) → B2 (FP8/FP4 indexer) → B3 (fused norm+quant) → B4 (cast sweep) → B5 only if needed.
|
||||
B1 (FP8 FMHA) → B2 (FP8/FP4 indexer) → B3 (fused norm+quant) → B4 (cast sweep) → B5 only if needed.
|
||||
|
||||
|
||||
---
|
||||
|
||||
# PART D — Dangling TODOS
|
||||
|
||||
- It is mentioned in `/home/openclaw/dev/nvfp4-megamoe-kernel/docs/PERFORMANCE_AUDIT.md` that P5 (Fuse mHC pre_block + RMSNorm into a single op) is done but kernel, pending integration. Please wire that up if you have not done so already
|
||||
- It is mentioned in `/home/openclaw/dev/nvfp4-megamoe-kernel/docs/PERFORMANCE_AUDIT.md` that P5 (Fuse mHC pre_block + RMSNorm into a single op) is done but kernel, pending integration. Please wire that up if you have not done so already
|
||||
|
||||
- Batched Prefill. Did we ever do this???
|
||||
@@ -242,13 +242,23 @@ class Nvfp4Linear:
|
||||
expert_offsets = self._expert_offsets_buf
|
||||
expert_offsets.fill_(padded_rows)
|
||||
|
||||
# Global scales — use the per-row gsa from the fused kernel
|
||||
# Reshape to (1,) if scalar, or use per-row (M,) broadcast
|
||||
gsa = quant.gsa[:1].reshape(1) if quant.gsa.shape[0] == 1 else quant.gsa[:num_tokens]
|
||||
if gsa.shape != self._gsa_buf.shape:
|
||||
self._gsa_buf = gsa.contiguous()
|
||||
# Global scales — the CuTeDSL NVFP4 GEMM expects global_scale_a as a
|
||||
# per-expert scalar (shape (1,) for single linear). The fused
|
||||
# rmsnorm/mhc kernels compute per-row gsa, but we must reduce to a
|
||||
# scalar. Using max reduction: gsa = max(per_row_gsa) ensures no
|
||||
# E4M3 block scale overflow (rows with smaller magnitude get slightly
|
||||
# less FP4 precision, but all rows stay within E4M3 range).
|
||||
#
|
||||
# For M=1 decode: per-row gsa is already scalar, no reduction needed.
|
||||
# For M>1 prefill: reduce per-row gsa to a single scalar (max).
|
||||
if quant.gsa.shape[0] == 1:
|
||||
gsa = quant.gsa[:1].reshape(1) # Already scalar
|
||||
else:
|
||||
self._gsa_buf.copy_(gsa)
|
||||
# Reduce per-row gsa to scalar (max) for GEMM compatibility.
|
||||
# Per-row gsa is mathematically more precise, but the GEMM only
|
||||
# supports a single global scale per expert.
|
||||
gsa = quant.gsa.max().reshape(1)
|
||||
self._gsa_buf.copy_(gsa)
|
||||
|
||||
# Run GEMM
|
||||
out = run_nvfp4_grouped_gemm(
|
||||
|
||||
@@ -27,7 +27,6 @@ def parse_args():
|
||||
p.add_argument('--seed', type=int, default=42)
|
||||
p.add_argument('--verbose', type=int, default=1)
|
||||
p.add_argument('--prefill-only', action='store_true')
|
||||
p.add_argument('--ab-compare', action='store_true', help='A/B compare fused vs unfused P4 for first 3 layers')
|
||||
p.add_argument('--no-fused-rmsnorm', action='store_true', help='Disable P4 fused RMSNorm+quantize (use unfused path)')
|
||||
p.add_argument('--warmup-gsa', action='store_true', help='Fix gsa values after first decode step (eliminates amax kernel launches)')
|
||||
p.add_argument('--profile', action='store_true', help='Profile per-component GPU time using CUDA events')
|
||||
@@ -681,9 +680,20 @@ def forward_attention(x_normed, w, li, cfg, rope_cos, rope_sin,
|
||||
cos_qa = torch.nn.functional.cosine_similarity(q_a.flatten().float(), q_a_ref.flatten().float(), dim=0).item()
|
||||
print(f" L{li} q_a: |prod|={q_a.abs().max().item():.6f} |ref|={q_a_ref.abs().max().item():.6f} cos={cos_qa:.6f}", flush=True)
|
||||
q_norm_w = w.get(f"{pfx}.q_a_norm.weight")
|
||||
if q_norm_w is not None: q_a = rmsnorm(q_a, q_norm_w.to(dev, torch.float32))
|
||||
# B3: Fused rmsnorm+quant for q_a_norm → q_b path
|
||||
# Replaces: rmsnorm(q_a, w) → BF16 → q_b quantizes internally
|
||||
# With: fused rmsnorm+NVFP4 quantize → QuantizedActivation → q_b.run_from_quantized
|
||||
# Saves: ~6 kernel launches per layer (rmsnorm 4+ + quantize 2 vs fused 2)
|
||||
if q_norm_w is not None:
|
||||
from dsv4.ops.quantize import rmsnorm_quantize_nvfp4 as _rmsnorm_quantize
|
||||
q_a_quant = _rmsnorm_quantize(q_a, q_norm_w.to(dev, torch.float32))
|
||||
q_a = dequantize_nvfp4(q_a_quant.x_fp4, q_a_quant.x_sf, q_a_quant.gsa)
|
||||
_pt('q_b_start')
|
||||
q = prod_lin['q_b'](q_a); q = unweighted_rmsnorm(q).bfloat16()
|
||||
if q_norm_w is not None:
|
||||
q = prod_lin['q_b'].run_from_quantized(q_a_quant)
|
||||
else:
|
||||
q = prod_lin['q_b'](q_a)
|
||||
q = unweighted_rmsnorm(q).bfloat16()
|
||||
_pt('q_b_end')
|
||||
q_heads = q.reshape(T, n_h, hd); q_heads = _apply_rope(q_heads, positions, rope_cos, rope_sin, rd)
|
||||
_pt('rope_q_end')
|
||||
@@ -767,6 +777,8 @@ def forward_attention(x_normed, w, li, cfg, rope_cos, rope_sin,
|
||||
|
||||
# 6. Production FMHA
|
||||
_pt('fmha_start')
|
||||
if VERBOSE >= 2 and li < 3:
|
||||
print(f" L{li} FMHA input: T={T} seq_len={seq_len} hd={hd} n_h={n_h} n_comp={kv_cache.n_comp} swa_len={swa_len}", flush=True)
|
||||
attn_out = _run_production_fmha(q_heads, all_kv, n_h, hd, T, seq_len, scale, dev, li, w, pfx)
|
||||
_pt('fmha_end')
|
||||
if VERBOSE >= 2 and li < 3:
|
||||
@@ -856,75 +868,34 @@ def forward_layer(X_l, w, li, cfg, rope_cos, rope_sin,
|
||||
moe_runner=None, se_runner=None, router=None,
|
||||
prod_lin=None, _profile_detail=False, _profile_times=None,
|
||||
_use_fused_rmsnorm_quantize=True,
|
||||
_ab_compare=False):
|
||||
):
|
||||
"""Forward one transformer layer.
|
||||
|
||||
_ab_compare: if True, run BOTH fused and unfused paths for this layer
|
||||
and print detailed numerical comparison. Only use for first few layers.
|
||||
"""
|
||||
# P4: Fused RMSNorm + NVFP4 quantize — eliminates ~488 launches/token
|
||||
from dsv4.ops.quantize import rmsnorm_quantize_nvfp4, QuantizedActivation, dequantize_nvfp4
|
||||
x_in, ctx_a = attn_mhc.pre_block(X_l)
|
||||
# P5: Fused mHC pre_block + RMSNorm + NVFP4 quantize
|
||||
# Replaces: pre_block (bmm) + rmsnorm (~4 launches) + quantize (2 launches)
|
||||
# With: 2 kernel launches total (mhc_rmsnorm_amax_gsa + mhc_rmsnorm_quantize_nvfp4)
|
||||
# Savings: ~5 launches per site × 2 sites × 61 layers = 610 launches/token
|
||||
from dsv4.ops.quantize import (
|
||||
rmsnorm_quantize_nvfp4, mhc_rmsnorm_quantize_nvfp4,
|
||||
QuantizedActivation, dequantize_nvfp4,
|
||||
)
|
||||
from dsv4.layers.mhc import mHCContext
|
||||
|
||||
# A/B comparison mode: run BOTH paths, compare intermediate results
|
||||
if _ab_compare and _use_fused_rmsnorm_quantize:
|
||||
# --- FUSED PATH ---
|
||||
x_quant_fused = rmsnorm_quantize_nvfp4(x_in, attn_norm_w.to(x_in.device, torch.float32))
|
||||
x_normed_fused = dequantize_nvfp4(x_quant_fused.x_fp4, x_quant_fused.x_sf, x_quant_fused.gsa)
|
||||
# --- UNFUSED PATH ---
|
||||
x_normed_unfused = rmsnorm(x_in, attn_norm_w)
|
||||
# Quantize unfused x_normed the normal way (as run() would)
|
||||
from dsv4.ops.quantize import quantize_nvfp4_gpu_fused
|
||||
x_fp4_unf, x_sf_unf, gsa_unf = quantize_nvfp4_gpu_fused(x_normed_unfused)
|
||||
# Attention mHC: fused pre_block + rmsnorm + NVFP4 quantize
|
||||
A_l_a, B_l_a, C_l_a = attn_mhc._dynamic_params(X_l)
|
||||
ctx_a = mHCContext(B_l=B_l_a, C_l=C_l_a)
|
||||
|
||||
# Compare x_normed
|
||||
cos_normed = torch.nn.functional.cosine_similarity(
|
||||
x_normed_fused.flatten().float(), x_normed_unfused.flatten().float(), dim=0).item()
|
||||
print(f" L{li} A/B: x_normed |fused|={x_normed_fused.abs().max().item():.6f} "
|
||||
f"|unfused|={x_normed_unfused.abs().max().item():.6f} cos={cos_normed:.6f}", flush=True)
|
||||
|
||||
# Compare gsa values
|
||||
gsa_fused_val = x_quant_fused.gsa[0].item()
|
||||
gsa_unfused_val = gsa_unf[0].item()
|
||||
print(f" L{li} A/B: gsa fused={gsa_fused_val:.8f} unfused={gsa_unfused_val:.8f} "
|
||||
f"ratio={gsa_fused_val/max(gsa_unfused_val,1e-12):.8f}", flush=True)
|
||||
|
||||
# Compare FP4 data (should be different due to different intermediate precision)
|
||||
fp4_match = torch.equal(x_quant_fused.x_fp4.view(torch.uint8), x_fp4_unf.view(torch.uint8))
|
||||
sf_match = torch.equal(x_quant_fused.x_sf.view(torch.uint8), x_sf_unf.view(torch.uint8))
|
||||
print(f" L{li} A/B: fp4_identical={fp4_match} sf_identical={sf_match}", flush=True)
|
||||
|
||||
# Compare block scales
|
||||
sf_diff = (x_quant_fused.x_sf.view(torch.uint8).float() - x_sf_unf.view(torch.uint8).float()).abs()
|
||||
print(f" L{li} A/B: sf_diff max={sf_diff.max().item():.0f} mean={sf_diff.mean().item():.2f}", flush=True)
|
||||
|
||||
# Run BOTH GEMM paths and compare q_a output
|
||||
q_a_fused = prod_lin['q_a'].run_from_quantized(x_quant_fused)
|
||||
q_a_unfused = prod_lin['q_a'](x_normed_unfused)
|
||||
cos_qa = torch.nn.functional.cosine_similarity(
|
||||
q_a_fused.flatten().float(), q_a_unfused.flatten().float(), dim=0).item()
|
||||
print(f" L{li} A/B: q_a |fused|={q_a_fused.abs().max().item():.6f} "
|
||||
f"|unfused|={q_a_unfused.abs().max().item():.6f} cos={cos_qa:.6f}", flush=True)
|
||||
|
||||
# Run BOTH GEMM paths for kv
|
||||
kv_fused = prod_lin['kv'].run_from_quantized(x_quant_fused)
|
||||
kv_unfused = prod_lin['kv'](x_normed_unfused)
|
||||
cos_kv = torch.nn.functional.cosine_similarity(
|
||||
kv_fused.flatten().float(), kv_unfused.flatten().float(), dim=0).item()
|
||||
print(f" L{li} A/B: kv |fused|={kv_fused.abs().max().item():.6f} "
|
||||
f"|unfused|={kv_unfused.abs().max().item():.6f} cos={cos_kv:.6f}", flush=True)
|
||||
|
||||
# Now continue with the UNFUSED path (which we know works)
|
||||
# to see if the rest of the layer also diverges
|
||||
x_normed = x_normed_unfused
|
||||
x_quant_attn = None
|
||||
elif _use_fused_rmsnorm_quantize:
|
||||
x_quant_attn = rmsnorm_quantize_nvfp4(x_in, attn_norm_w.to(x_in.device, torch.float32))
|
||||
if _use_fused_rmsnorm_quantize:
|
||||
# P5 fused: X_l + A_l → bmm + rmsnorm + NVFP4 quantize in 2 kernel launches
|
||||
x_quant_attn = mhc_rmsnorm_quantize_nvfp4(
|
||||
X_l, A_l_a, attn_norm_w.to(X_l.device, torch.float32))
|
||||
# Dequantize for compressor/indexer (1 kernel launch)
|
||||
x_normed = dequantize_nvfp4(x_quant_attn.x_fp4, x_quant_attn.x_sf, x_quant_attn.gsa)
|
||||
else:
|
||||
x_in = torch.bmm(A_l_a.unsqueeze(1).float(), X_l.float()).squeeze(1).bfloat16()
|
||||
x_normed = rmsnorm(x_in, attn_norm_w)
|
||||
x_quant_attn = None
|
||||
|
||||
if _profile_detail: torch.cuda.synchronize(); t_attn0 = time.perf_counter()
|
||||
F_attn, _ = forward_attention(x_normed, w, li, cfg, rope_cos, rope_sin,
|
||||
kv_cache, positions, compressor, indexer, prod_lin,
|
||||
@@ -932,10 +903,20 @@ def forward_layer(X_l, w, li, cfg, rope_cos, rope_sin,
|
||||
_profile_detail=_profile_detail, _profile_times=_profile_times)
|
||||
if _profile_detail: torch.cuda.synchronize(); t_attn1 = time.perf_counter()
|
||||
X_mid = attn_mhc.post_block(X_l, F_attn, ctx_a)
|
||||
x_in_f, ctx_f = ffn_mhc.pre_block(X_mid)
|
||||
# FFN path: keep unfused rmsnorm for now — moe_forward internally quantizes,
|
||||
# and passing quantized activation through MoE requires deeper refactoring (P5)
|
||||
x_ffn = rmsnorm(x_in_f, ffn_norm_w)
|
||||
|
||||
# FFN mHC: fused pre_block + rmsnorm + NVFP4 quantize
|
||||
A_l_f, B_l_f, C_l_f = ffn_mhc._dynamic_params(X_mid)
|
||||
ctx_f = mHCContext(B_l=B_l_f, C_l=C_l_f)
|
||||
|
||||
if _use_fused_rmsnorm_quantize:
|
||||
# P5 fused: X_mid + A_l → bmm + rmsnorm + NVFP4 quantize in 2 kernel launches
|
||||
x_quant_ffn = mhc_rmsnorm_quantize_nvfp4(
|
||||
X_mid, A_l_f, ffn_norm_w.to(X_mid.device, torch.float32))
|
||||
# Dequantize for MoE (BF16 input required by MoE quantize path)
|
||||
x_ffn = dequantize_nvfp4(x_quant_ffn.x_fp4, x_quant_ffn.x_sf, x_quant_ffn.gsa)
|
||||
else:
|
||||
x_in_f = torch.bmm(A_l_f.unsqueeze(1).float(), X_mid.float()).squeeze(1).bfloat16()
|
||||
x_ffn = rmsnorm(x_in_f, ffn_norm_w)
|
||||
if _profile_detail: torch.cuda.synchronize(); t_ffn0 = time.perf_counter()
|
||||
F_ffn = moe_forward(x_ffn, li, moe_runner, se_runner, router, token_id)
|
||||
if _profile_detail: torch.cuda.synchronize(); t_ffn1 = time.perf_counter()
|
||||
@@ -1397,7 +1378,7 @@ def main():
|
||||
moe_runners.get(li), se_runners.get(li), routers.get(li),
|
||||
prod_lin=prod_lins.get(li),
|
||||
_use_fused_rmsnorm_quantize=not _args.no_fused_rmsnorm,
|
||||
_ab_compare=_args.ab_compare and li < 3)
|
||||
)
|
||||
except Exception as e:
|
||||
torch.cuda.synchronize()
|
||||
err = torch.cuda.current_stream(gpu).query()
|
||||
@@ -1478,7 +1459,7 @@ def main():
|
||||
_profile_detail=(profile and step == 1),
|
||||
_profile_times=cuda_layer_events if (profile and step == 1) else None,
|
||||
_use_fused_rmsnorm_quantize=not _args.no_fused_rmsnorm,
|
||||
_ab_compare=_args.ab_compare and li < 3)
|
||||
)
|
||||
X = X.to('cuda:0'); torch.cuda.set_device(0)
|
||||
t_layers = time.perf_counter()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user