Files
nvfp4-megamoe-kernel/DEBUG_LOG.md

111 lines
6.1 KiB
Markdown
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
# NVFP4 MegaMoE Debug Log
## Current State (May 16, 2026 — 01:15 UTC)
**Status:** GEMM verified correct, SF remap verified correct, B layout verified correct, L2 slot_token cleaned up. vLLM still produces garbage. The checkpoint `input_scale` red herring is documented below. The bug remains unidentified.
**What's verified correct (DO NOT re-investigate):**
1. SF remap: roundtrip verifier = 0 errors, forward mapping with `layout_sf(make_coord(mn, k_elem, 0))`
2. GEMM math: uniform FP4 + uniform SF → exact output (72.0 = 1.5² × K)
3. B matrix layout: byte transpose correct, column-dependent weight test passes
4. L2 GEMM input: slot-major, no gather needed (cleaned up dead `slot_token` param)
**What's still broken:**
- "The capital of France is" → garbage tokens (varies, e.g. `-W'MSG173`, `( z tractor`)
- All magnitudes reasonable, no NaN anywhere
## Red Herring: Checkpoint input_scale (May 16, 00:10 UTC)
Mike's code review suggested using the checkpoint's `input_scale` as the activation global scale instead of the dynamic `amax/(6*448)`. This was **wrong** and has been reverted (commit `79b9bec`).
**What happened:**
- Applied `stage_activation(hidden_states, input_global_scale=w13_input_scale)`
- `w13_input_scale = 2.86e-4` (from checkpoint, same for all experts in a layer)
- Dynamic `amax/(6*448)` = 3.70e-3 (13x larger)
- Using 2.86e-4 for normalization: `x / 2.86e-4` produces values in the thousands for typical hidden states (amax ~5)
- ALL block scales saturated to 448.0 (max float8_e4m3 value)
- Output: still garbage, but now with quantized-to-death activations
**Why it's wrong:**
The checkpoint `input_scale` is NOT the `input_global_scale = amax/(6*448)` normalization constant. They are different quantities:
- `input_global_scale` normalizes data to [0,1] before FP4 quantization
- `input_scale` is a calibration constant from the Quark quantization tool
The calibration amax for `input_scale = 2.86e-4` would be `2.86e-4 * 6 * 448 = 0.77`. Runtime hidden states have amax ~5-10. The `input_scale` was computed on a different data distribution (probably calibration data, not actual inference data).
**The correct use of `input_scale` is still unknown.** The Quark path computes `alpha = input_scale * weight_scale_2`, but this may assume BF16 activations (not FP4-quantized). Our CUTLASS kernel requires FP4 input, so we must quantize with the dynamic scale.
**Preserved for future use:** `_w13_input_scale` and `_w2_input_scale` are now saved in `finalize_weights` (not dropped) in case we need them for alpha computation later.
**Checkpoint input_scale values (layer 0, all experts identical):**
- `gate_proj.input_scale` = `up_proj.input_scale` = 2.862840e-04
- `down_proj.input_scale` = 3.069196e-02
- `weight_scale_2` = 4.650298e-05 (all projections)
- Scales vary by layer: layer 0 = 2.86e-4, layer 60 = 1.07e-2
## SF Remap — Final Correct Implementation (commit `6626b75`)
```cpp
int mn = tid / K_sf;
int k_sf = tid % K_sf;
int k_elem = k_sf * 16;
int dst_idx = layout_sf(cute::make_coord(mn, k_elem, 0));
dst[dst_idx] = src[mn * src_stride_mn + k_sf * src_stride_ksf];
```
Source strides: SFA=(K_sf, 1), SFB=(1, N)
Allocation: `cute::size(cute::filter_zeros(layout))`
## Previous Bugs Fixed
### BF16 reference comparison (May 15, 23:37 UTC)
The 0.2 cosine against the Python BF16 dequantization reference was a RED HERRING. The reference is wrong, not the GEMM. 8+ iterations of SF remap changes all produced the same 0.2 cosine because it was never about the remap. **A wrong reference is worse than no reference.**
### `cute::size` vs `cute::cosize` (commit `c384198`)
Iteration bound used `size` (logical) instead of `cosize` (physical). Fixed but insufficient alone.
### M/K coordinate extraction in `idx2crd` (commits `deb6b32` → `30b6c89`)
Original had M/K swapped. Mike's correction: `mn = f0 + 32*f1 + 128*f2`, `k_sf = f4 + 4*f5`.
### `if/else if` fallthrough (commit `6626b75`)
Dead `dst_idx=0` when no branch matched. Fix: branchless `layout_sf(make_coord(...))`.
### `col_major_src` ambiguity (commit `7285331`)
Boolean flag → explicit `src_stride_mn, src_stride_ksf`.
### Allocation size (commit `6626b75`)
`cosize``size(filter_zeros(layout))`.
### L2 slot_token cleanup (commit `bb5a1ba`)
`nvfp4_moe_l2` accepted `slot_token` but never passed it to GEMM. Removed dead parameter.
## Mike's Code Review Answers
1. **E2M1 packing:** Confirmed correct — element 2j in low nibble, 2j+1 in high nibble. Suggested hardware oracle with `__nv_cvt_bfloat16raw2_to_fp4x2`.
2. **A RowMajor:** Confirmed correct — no micro-tiling for A.
3. **B ColumnMajor:** Byte transpose confirmed correct by test. Mike flagged theoretical concern about nibble-level transpose but our test passed.
4. **Alpha/global scale:** Mike suggested `alpha = input_scale * weight_scale_2` (from checkpoint). We tried it — wrong for activation normalization. The correct use of `input_scale` in our pipeline is still TBD.
5. **Gate/up correction:** Mathematically valid. `up_half *= up_weight_gs / gate_weight_gs` is equivalent to per-column alpha.
## Architecture Notes
### NVFP4 MoE Pipeline
```
stage_activation(hidden_states) → x_fp4, x_sf, input_global_scale
L1 GEMM: (x_fp4, x_sf) @ (l1_w, l1_sf) with alpha=igs*l1_global_sf → gate_up
SiLU(gate) * up → activated
stage_activation(activated) → l1_fp4, l1_sf, l1_igs
L2 GEMM: (l1_fp4, l1_sf) @ (l2_w, l2_sf) with alpha=l1_igs*l2_global_sf → output
scatter with routing weights → y
```
### Per-element multiply order
`res += A_fp4 * SFA_fp8 * B_fp4 * SFB_fp8`
## Next Steps
1. **Compare against BF16 model** — run the same prompt on a known-good implementation to see if the attention layers are working and only MoE is broken
2. **Check the vLLM model integration** — how does the MoE output get mixed with the residual? Is `hc_post` correct?
3. **Understand the Quark input_scale contract** — maybe we need to NOT quantize activations to FP4 and instead use BF16 input
4. **Add per-layer token output logging** — see which layer the tokens go off the rails
5. **Check o_a_proj BF16 handling** — it's kept in BF16 in the checkpoint, is it being processed correctly?