Files
nvfp4-megamoe-kernel/CORRRECTNESS_BACKLOG.md

12 KiB
Raw Blame History

WE ARE BACKLOGGING THIS ISSUE AND WILL REVIST IT AFTER WE FINISH THE OTHER ITEMS IN THE FINAL STRETCH

Context: post-cleanup single_shot_inference.py compiles, Paris is top-1 at step 0, output is coherent, then degenerates into repeated junk ("capital ..."). Defaults in effect: temperature=0.6, repetition_penalty=1.1, --warmup-gsa off, fused rmsnorm+quant on.

Read this first — what the symptom rules in/out. The model is sampling (temp 0.6) with a penalty (1.1) and still loops a near-constant token. That is not "greedy with no penalty." It means either (a) the model finished its turn, emitted a stop token we don't catch, and the LM head is now peaked on degenerate filler, or (b) a decode-state correctness bug whose error compounds over steps. (a) is far more likely and is nearly free to test, so do Part A in order, cheapest first, and do NOT touch kernels until A1A2 are ruled out. The math is right (Paris top-1); don't go hunting kernel ghosts before eliminating the decoding-config causes.

Code is the source of truth. For any precision change, validate per-layer cosine against dsv4/reference/ before trusting end-to-end output.


PART A — Decode repetition (correctness). Do in order.

A1 — Stop set (HIGH priority, ~zero effort, most likely the whole bug)

single_shot_inference.py:1571 stops only on next_id == tokenizer.eos_token_id. DSV4 is a reasoning/chat model; an assistant turn ends with a special token (<|end_of_sentence|>, and the turn structure also uses USER=128803 / ASSISTANT=128804 / </think>=128822). If the model's turn-end token isn't eos_token_id, decode never stops and degenerates exactly as observed.

Diagnose (no code change):

  1. Print tokenizer.eos_token_id, tokenizer.eos_token, and tokenizer.special_tokens_map once at startup.
  2. In the decode log, find the token id emitted at the moment output "should have finished." Decode it: tokenizer.decode([id]). If it's a special/end token not in the stop set, that's the bug.

Fix: build an explicit stop set and break on membership, e.g.:

STOP_IDS = {tokenizer.eos_token_id}
for t in ("<|end_of_sentence|>",):           # add the real turn-end token name(s) for this checkpoint
    tid = tokenizer.convert_tokens_to_ids(t)
    if tid is not None and tid >= 0: STOP_IDS.add(tid)
STOP_IDS.add(USER_TOKEN)   # model trying to open a new user turn = it's done
# ... in the loop:
if next_id in STOP_IDS:
    print(f"  STOP ({next_id}) at step {step}", flush=True); break

A/B: if adding the stop set ends generation cleanly, the "bug" was never in the kernels. Stop here.

A2 — Sampler / penalty sanity (MEDIUM, cheap, diagnostic)

A 1.1 penalty over recent_tokens=all_tokens[-256:] should at least perturb a single-token loop. If it loops the exact same id anyway, suspect the penalty isn't reaching the kernel or is mis-indexed.

  • Test: rerun with --repetition-penalty 1.5. If the loop is unchanged, the penalty path in dsv4/model/sampler.py (CUDASampler) is broken — verify recent_tokens is actually passed to and applied by the kernel, and that it indexes the logit vector correctly. If raising it does break the loop, the sampler is fine and this was a stop-token/decoding-hygiene issue (see A1).
  • Also confirm recent_tokens includes the prompt tokens, not just generated ones, or the model can loop on a prompt word ("capital") penalty-free.

A3 — Compressed/SWA visible-range parity (MEDIUM, architectural — verify vs reference)

During decode the query attends to [top-k compressed entries] ++ [SWA window] (forward_attention, "5. Gather KV"). Two things to verify against the HF/dsv4/reference oracle, because an off-by-one here causes subtle wrongness that compounds across decode steps (coherent early, degenerate late — matches the symptom):

  1. Which compressed blocks are visible to a decode query. Causality: a query must see only compressed blocks strictly preceding its own current (incomplete) block, never its own or future blocks. Confirm the set of compressed indices fed to the FMHA at step s matches the reference exactly.
  2. SWA / compressed overlap. The most recent tokens are in the SWA ring (ws=128) and may also be inside the newest complete compressed block → the query can attend to both representations of the same tokens. This may be intended (SWA refines what compression blurred, and the model was trained with it) — but it must match how the reference gathers, or the recent-context weighting drifts. Diff the gathered key set (indices + count) against the reference for a fixed prompt at several decode positions.

Note: the residual |X| growing to ~244372 is expected (the paper notes 300500; your own KVCache docstring says the same). It is not by itself the bug. See B5 only if A1A3 don't resolve it.

A4 — (verify, likely fine) Inverse RoPE

forward_attention:783 applies _apply_rope(attn_out, positions, ..., inverse=True) at the query position, which is what converts the absolute positions carried by the summed KV into relative ones. This looked correct. Just confirm inverse=True negates the rotation angle (applies RoPE(t)) and uses the query positions (not comp_pos). Only revisit if A1A3 are clean and degeneration persists.


When we were doing this we tried ALL OF THE ABOVE!!!!! and this is where we left off:

Good, the greedy test (temp=0, 64 tokens) is prefilling. Let me wait for it to hit decode.Greedy with temp=0 and rep_penalty=1.0. The model is looping on "capital of France is" — just re-echoing the prompt. It never generates "Paris" or any real answer. This is NOT a sampler issue. The model's logits are fundamentally wrong.

The sampler and penalty are cleared. The degeneration is in the model/pipeline itself. Let me look at the step 0 output more carefully.Key finding: " Paris" (111) has logit 10.62, but the TOP-1 token at step 0 is "," with logit 17.4. That's a 7-point gap. The model should have " Paris" much higher — in a correctly working model, "Paris" after "The capital of France is" should be 25+ logit, not 10.

This is a pipeline correctness issue. The model is not producing the right logits. The step-0 degeneration confirms it's not about repetition or sampling — the logits themselves are wrong from the very first decode step.

Given the user's instruction to move to Part B if the sampler/prompt diagnostics don't resolve it, let me focus on the most impactful pipeline fix. The residual growing to 472-732 (|X|) is suspicious. Let me look at the P4 fused rmsnorm — specifically the `run_from_quantized` path vs the regular `run` path, since memory notes the P4 integration caused output quality regression.

The previous test with `--no-fused-rmsnorm` ALSO produced degenerate output, so P4 isn't the sole cause. But let me check the KEY difference: in the `run_from_quantized` path, `gsa` is passed per-row from the QuantizedActivation, while in the regular path, gsa is a scalar. The MEMORY.md notes this exact issue:

> **ROOT CAUSE HYPOTHESIS**: run_from_quantized passes per-row gsa to GEMM, but CuTeDSL NVFP4 GEMM expects scalar global_scale_aSo `run_from_quantized` passes per-row gsa (shape (M,)) as `global_scale_a`. Let me check what the GEMM expects.The `global_scale_a` is passed as a tensor to the CuTe kernel. For the single-expert case (Nvfp4Linear), `expert_offsets` has shape (1,) and `global_scale_a` would be the gsa. Let me check how the kernel uses it.There it is. The kernel loads `global_scale_a` using `global_scale_a.iterator + expert_idx`. This is a **per-expert** scalar — it indexes by expert_idx, not by row. For a single-expert case (Nvfp4Linear with expert_offsets of shape (1,)), expert_idx is always 0, so it reads `global_scale_a[0]` — a single scalar.

But `run_from_quantized` passes per-row gsa as `global_scale_a` with shape `(M,)`. The kernel only reads `global_scale_a[0]` (the first row's gsa), ignoring the rest. This means for decode (M=1), it works fine (one row, one gsa). But the issue is that the fused rmsnorm kernel computes gsa per row, and the GEMM only uses row 0's gsa.

Wait, for decode, M=1 always. So per-row vs scalar doesn't matter for the decode path. The issue would be in prefill where M>1. Let me re-examine...

Actually, the prefill is done one token at a time (the code explicitly says "one token at a time (decode-style; TODO: batched prefill)"). So M=1 throughout. The per-row gsa should be fine since there's only one row.

Let me look deeper at the actual numerical pipeline. The real question is: why are the logits wrong? Let me add diagnostic code to compare our pipeline output with a known reference at a specific layer.Let me run with --verbose 2 to get per-layer diagnostics.Still building MoE layers. Need to wait longer.**CRITICAL FINDING**: L1 FMHA cosine is 0.679. That's terrible. L2 FMHA cosine is 0.956 — also bad. The q_a projection is fine (0.999998), but the FMHA output is wrong. This is the root cause of the degeneration.

Let me get the L0 data too.L0 FMHA cos=0.955, L1 FMHA cos=0.679, L2 FMHA cos=0.956. L1 is particularly bad. This is the attention pipeline producing incorrect results.

The reference comparison is done against PyTorch SDPA on the same gathered KV. So the FMHA kernel itself is wrong for these attention patterns. Let me understand why L1 is 0.679 while L0 and L2 are ~0.95.

L0 and L2 are HCA layers (ratio=128), L1 is also HCA (the first two layers are HCA for Pro). Wait, let me check the layer schedule.So L0=HCA(128), L1=HCA(128), L2=CSA(4), L3=HCA(128), L4=CSA(4). L1 (HCA) has terrible FMHA cosine. The issue is in the FMHA when dealing with the compressed KV patterns.

Let me check what seq_len the FMHA sees at each layer. The verbose=2 diagnostic should show this in the forward_attention function.The cos=0.679 at L1 is devastating. Even though the magnitudes are similar (2.78 vs 2.94), the direction is wrong. This compounds through 61 layers.

Let me add more diagnostic output to the FMHA — specifically, print the shapes and seq_len being passed.Wait — `n_comp=0` is hardcoded! The production FMHA call always sets `n_comp=0`. But the attention pattern for HCA/CSA involves compressed entries + SWA window. The `n_comp` parameter is crucial for the sink bias and possibly for the FMHA dispatch logic.

Let me check what `n_comp` does in the production attention.The `n_comp` parameter is described as "reserved for future kernel integration" and it's always 0. The sink_bias is also "reserved for future kernel integration". So the production FMHA is running WITHOUT the D5c sink bias that the paper describes as essential for DSV4 attention.

But more critically, the cos=0.679 at L1 suggests the FMHA itself is producing wrong results even without sink bias. Let me look at the reference comparison code — it compares the production FMHA output against PyTorch SDPA on the SAME gathered KV. So the KV is the same, but the kernel produces a different result.

Wait — the reference comparison is in the forward_attention function at `VERBOSE >= 2`. Let me look at that code more carefully.The reference uses `v_exp = k_exp.clone()` — this means V = K. In DSV4, the KV cache stores combined K/V where K and V are the same tensor (MQA). That's correct for the reference comparison.

But the production FMHA also uses V = K:
```python
v = k  # Same tensor, no copy

So the reference and production should agree. The cos=0.679 means the kernel is numerically wrong for this specific attention pattern.

Let me check what seq_len the FMHA is seeing. At L1 during prefill of the first token, the KV cache would have very few entries. Let me add seq_len printing.


SO SINCE WE HAD TO TOUCH FMHA ANYWAY IN PART B. WE DID THAT FIRST AND TRIED TO GET THAT CORRECT BEFORE WE REVISTED THIS ISSUE!!!