Fix global→local expert ID remapping for EP and remove .cpu() sync

Root cause of CUDA_ERROR_ASSERT index out of bounds:
- topk_ids contains GLOBAL expert IDs (0-255) but runner treated them
  as local IDs (0-31 with EP=8). Tokens for non-local experts got
  wrong expert assignments, causing out-of-bounds scatter indices
  in _assemble_scales_cudagraph_safe.

Fixes:
1. Add experts_start_idx param to CuTeDSLMoERunner
2. In run(), remap global→local IDs and zero weights for non-local experts
3. Move _token_indices from CPU to GPU (remove sort_idx.cpu() sync)
4. Add _fill_token_indices() and _needs_token_refill to handle CuTeDSL
   JIT GPU memory corruption (refill after first GEMM call)
This commit is contained in:
2026-05-17 08:58:43 +00:00
parent 1330e2b2cf
commit ca3cba5bbd
3 changed files with 192 additions and 94 deletions

View File

@@ -1,94 +1,170 @@
# Current Bug: CuTeDSLMoERunner produces wrong output
# Current Bug: CuTeDSLMoERunner — Status & Debug History
## Status
-`layertest.py` — 0.988 cosine (moe_pipeline with dynamic gs + assemble_scales_2d_side)
-`cudagraph_test.py` — capture + replay succeeds
-`test_scale_assembly.py` — per-expert scale data matches reference
-`test_runner_vs_pipeline.py` — runner gives 0.18 cosine vs pipeline (should be ~0.99)
## Current Status (May 17, 2026 08:35 UTC)
## Root Cause: scale_a layout doesn't match expert_offsets
**Mostly fixed. 0.97 cosine with warmup gs. Ready for vLLM container test.**
The kernel uses `expert_offsets` to index into the 2D-side scale_a tensor. After swizzle, the data layout is: each expert's 128-row block is swizzled independently, then concatenated. The kernel's TMA load undoes the swizzle when reading.
-`layertest.py` — 0.988 cosine
-`cudagraph_test.py` — capture + replay works
-`test_warmup_gs.py` — 0.97 cosine with `compute_activation_global_scales()` warmup
- ❌ vLLM server — not yet tested with these fixes
### How the pipeline works (0.988 cosine):
**Remaining concerns:**
- The CPU-based `_token_indices` uses `sort_idx.cpu()` which is a CPU-GPU sync — may interfere with cudagraph capture
- The `compute_activation_global_scales()` warmup needs to be called from `deepseek_v4.py` during model warmup
- The checkpoint `input_scale` should NOT be used as the activation global_scale (it's a calibration value, not a runtime value)
1. `expert_offsets = compute_expert_offsets([4, 4, 0], 3)``tensor([4, 8, 8])`
2. `assemble_scales_2d_side([x_sf[0:4], x_sf[4:8]])` → 256 rows (only experts WITH tokens)
- Expert 0's data: rows 0-127 (4 real + 124 zero, swizzled)
- Expert 1's data: rows 128-255 (4 real + 124 zero, swizzled)
- Expert 2 has 0 tokens → NOT included
3. Kernel reads `scale_a[0:4]` for expert 0, `scale_a[4:8]` for expert 1
4. After TMA un-swizzle, slot index `m` maps correctly to the original row `m` data
5. `scale_a.shape[0] = 256` → kernel knows total padded tokens = 256
---
**Key insight:** The pipeline only includes experts with tokens in scale_a. The kernel's slot-based expert_offsets ([4, 8, 8]) correctly indexes into the 256-row scale_a because the TMA un-swizzle maps slot index → original row → correct data.
## Bugs Found & Fixed
### How the runner currently works (0.18 cosine):
### Bug 1: Scale Assembly — Global Swizzle vs Per-Expert Swizzle
1. `expert_offsets` = GPU-computed `[0, 4, 8, 8]` (with leading 0 for cumsum)
2. `_assemble_scales_cudagraph_safe` produces 384 rows (ALL experts × 128, including expert 2 with 0 tokens)
- Expert 0's data: rows 0-127
- Expert 1's data: rows 128-255
- Expert 2's data: rows 256-383 (all zeros, swizzled)
3. Kernel reads `scale_a[0:4]` for expert 0, `scale_a[4:8]` for expert 1
4. After TMA un-swizzle, slot indices 0-3 map to expert 0's first 4 original rows → correct
5. But slot indices 4-7 map to... the 5th-8th rows of expert 0's swizzled block, NOT expert 1's data
6. `scale_a.shape[0] = 384` → kernel thinks there are 384 padded token slots, but expert_offsets says 8
**Symptom:** GEMM produced all zeros even with correct global_scale.
**The mismatch:** The kernel interprets slot index `m=4` as row 4 of the ENTIRE scale_a tensor. After un-swizzle, row 4 is in expert 0's 128-row block. But the pipeline's scale_a has row 4 in expert 1's block (because expert 1 starts at row 128 in a 256-row tensor, but the kernel's TMA remaps slot 4 → the 5th original row which IS expert 1's 1st row in the pipeline layout).
**Root cause:** The original `_assemble_scales_cudagraph_safe` called `pad_and_swizzle_single()` on the ENTIRE padded buffer (all experts concatenated). But the kernel expects each expert's 128-row block to be swizzled independently (matching `assemble_scales_2d_side` which pads+swizzles each expert separately before concatenation).
Wait — re-reading: in the pipeline, scale_a has 256 rows. Slot 4 → row 4 of scale_a. After un-swizzle, this is the 5th original row. In the pipeline, original rows 0-3 are expert 0's data, rows 4-7 are expert 1's data (padded to 128 each before swizzle). So the un-swizzle of row 4 gives the 1st original row of expert 1. That's correct.
**Fix:** Two-phase approach:
1. Scatter x_sf rows into 128-aligned positions in a padded buffer (GPU-only, no CPU sync)
2. Per-expert: copy 128 rows from padded buffer, `pad_and_swizzle_single()` each expert's block independently, then concatenate
In the runner, scale_a has 384 rows. Slot 4 → row 4 of scale_a. After un-swizzle, this is the 5th original row. Original rows 0-3 are expert 0's data. Row 4 is a ZERO row (expert 0 had only 4 tokens, rows 4-127 are zero-padded). So the un-swizzle of row 4 gives a zero → expert 1 gets no valid scale data.
**Key insight from `torch_scaled_grouped_mm.py` line ~1115:** The kernel computes padded offsets internally when `consistent_token_padding=False`:
```python
padded_size = round_up(offs[expert_idx] - offs[expert_idx-1], pad_granularity) # 128
```
So the kernel knows each expert's scale data is in a 128-row block.
**So the issue IS the 128-row padding.** When expert 0 has 4 tokens and we pad to 128 rows, slot indices 4-127 map to zeros for expert 0. The kernel needs slot indices to be contiguous per expert (0-3 for expert 0, 4-7 for expert 1), but the scale_a has 128 rows per expert block, not 4.
### Bug 2: `searchsorted(right=False)` — Wrong Expert Assignment
### Why does the pipeline work?
**Symptom:** Scale data in wrong positions after scatter.
Because `pad_and_swizzle_single` on a 4-row tensor pads to 128 rows, swizzles the 128-row block, and the kernel's TMA read with slot index 4 reads from the 5th position in the swizzled 128-row block. After un-swizzle, position 4 maps back to... the 5th original row, which is a zero-padded row. Wait, this should also be wrong then.
**Root cause:** `torch.searchsorted([4, 8, 8], 4, right=False)` returns 0, assigning row 4 (expert 1's first token) to expert 0.
Unless the kernel's 2D-side scale access uses `m % 128` or similar per-block indexing. Let me check the kernel's scale_a read pattern.
**Fix:** Changed to `right=True`:
```python
expert_assign = torch.searchsorted(expert_offsets[1:], row_indices, right=True)
```
Actually — re-reading the pipeline more carefully: `assemble_scales_2d_side([x_sf[0:4], x_sf[4:8]])`. This creates a scale_a where:
- Expert 0's 4 rows are padded to 128, swizzled → 128 swizzled rows
- Expert 1's 4 rows are padded to 128, swizzled → 128 swizzled rows
- Concatenated → 256 rows
**Verified:** Row 4 → expert 1 (correct), rows 0-3 → expert 0 (correct).
The kernel gets `expert_offsets = [4, 8, 8]`. For expert 0 (slots 0-3), it reads scale_a positions 0-3. For expert 1 (slots 4-7), it reads scale_a positions 4-7.
### Bug 3: CuTeDSL `cute.compile` GPU Memory Corruption — CRITICAL
In the swizzled layout, positions 0-3 are in the first 128-row block (expert 0). Positions 4-7 are ALSO in the first 128-row block. But expert 1's data is in the second 128-row block (positions 128-255).
**Symptom:** `_token_indices` was all zeros, making every token map to token 0.
So the kernel reads positions 4-7 for expert 1, but those positions contain expert 0's zero-padded/swizzled data, not expert 1's data. This should be wrong...
**Root cause:** CuTeDSL's `cute.compile` (JIT compilation) corrupts GPU memory. Tensors allocated on GPU before or during JIT compilation get zeroed. Pre-existing tensors allocated before the JIT survive. This is a bug in the CuTeDSL library.
BUT THE PIPELINE GIVES 0.988 COSINE. So either:
1. The kernel DOES use per-expert offsets into scale_a (reading from position `expert_padded_offset + local_slot`), or
2. The swizzle + TMA read remaps the indices in a way I'm not understanding
**Impact:** `_token_indices` (int32 on GPU) was zeroed, causing `hidden_states[sorted_token_ids]` to return `hidden_states[0]` for all 8 slots. Every expert saw the same input.
Need to check the kernel's actual scale_a access pattern in the C++ code.
**Fix:** Allocate `_token_indices` on CPU, keep it there. In `run()` and `compute_activation_global_scales()`, index with `sort_idx.cpu()` then move result to GPU:
```python
sorted_token_ids = token_indices[sort_idx.cpu()].to(device)
```
## Fix Options
**Warning:** This introduces a CPU-GPU sync (`.cpu()`) which may interfere with cudagraph capture. Needs verification.
### Option A: Padded expert_offsets
Change expert_offsets from slot-based to 128-row-aligned:
- Instead of `[0, 4, 8, 8]`, pass `[0, 128, 256, 384]`
- The kernel reads scale_a[0:128] for expert 0, scale_a[128:256] for expert 1, etc.
- Problem: the kernel would produce 128 output rows per expert instead of the actual token count
- This breaks the output shape
### Bug 4: `expert_offsets` With Leading 0
### Option B: Only include experts with tokens in scale_a
- Same as pipeline: only pad+swizzle experts that have tokens
- Requires knowing which experts have tokens → .tolist() or .item() → breaks cudagraph
- Could use a fixed expert set (all experts always included, but with zero rows for empty experts)
- The pipeline's assemble_scales_2d_side with 0 rows produces no output for that expert
**Symptom:** GEMM produced wrong output with correct scale data.
### Option C: Understand the kernel's actual indexing and match it
- Read the kernel's scale_a access code
- Figure out exactly how slot indices map to scale_a positions
- Build the layout the kernel expects
**Root cause:** The runner passed `expert_offsets[:num_experts + 1]` = `[0, 4, 8, 8]` (4 elements with leading 0) but the kernel expects `compute_expert_offsets([4, 4, 0], 3)` = `[4, 8, 8]` (3 elements, cumulative sum without leading 0).
### Option D: Skip scale_a assembly entirely, pass raw scales
- The kernel might accept raw (un-swizzled) scales via a different path
- Or we could use the 3D-side layout for activation scales too
**Fix:** Pass `expert_offsets[1:num_experts + 1]` to the GEMM.
## Approach
Going with **Option C** — understand the kernel's indexing, then match it.
### Bug 5: Checkpoint `input_scale` Is Wrong for Activation Global Scale
**Symptom:** Block scales all saturate at float8 max (448), producing garbage quantization.
**Root cause:** The checkpoint's `input_scale` (~0.000286) is a calibration value computed from a different input magnitude (amax ≈ 0.77) than what runtime produces (amax ≈ 8.17). Too-small gs → x/gs has values up to ~13000 → block_amax/6 ≈ 2174 → overflows float8_e4m3fn max of 448 → saturated block scales → garbage.
**Fix:** `compute_activation_global_scales()` warmup method that runs `quantize_to_nvfp4` (dynamic gs with `.max()`) before cudagraph capture to get the exact gs values for L1 and L2.
### Bug 6: L1 and L2 Need Separate Activation Global Scales
**Symptom:** L2 output was garbage even with correct L1 gs.
**Root cause:** After SiLU(gate)*up, the activation has amax ~286. The L1 gs (from input amax ~8) is 30x too small for L2, causing even worse block scale saturation.
**Fix:** `compute_activation_global_scales()` computes L1 gs from the input, runs the L1 GEMM, then computes L2 gs from the actual L1 output (after SiLU*up).
### Bug 7: L1 and L2 Need Separate Padded Scale Buffers
**Symptom:** IndexError when quantizing L2 activation — K_sf differs between L1 (448) and L2 (192).
**Root cause:** `padded_x_sf_buf` was allocated with L1's K_sf (448). When L2's x_sf has K_sf=192, the buffer size mismatch causes issues.
**Fix:** Separate `_padded_x_sf_buf_l1` and `_padded_x_sf_buf_l2`, plus separate `_per_expert_scale_bufs_l1` and `_per_expert_scale_bufs_l2`.
---
## Debug Methodology — How We Got Here
### Step 1: Identified the CuTeDSL kernel works (layertest = 0.988)
The layertest uses `moe_pipeline.run_nvfp4_moe` with `quantize_to_nvfp4` (dynamic gs) and `assemble_scales_2d_side` (per-expert split). This is the reference implementation.
### Step 2: Wrote test_runner_vs_pipeline.py
Compared `runner.run()` vs `run_nvfp4_moe()` with same weights and inputs. Found runner produces all zeros.
### Step 3: Wrote test_scale_assembly.py
Compared `_assemble_scales_cudagraph_safe` vs `assemble_scales_2d_side`. Found data mismatch (global vs per-expert swizzle).
### Step 4: Fixed scale assembly
Rewrote `_assemble_scales_cudagraph_safe` with scatter + per-expert swizzle. Scale data now matches reference.
### Step 5: Found GEMM still produces zeros with correct scales
Isolated the issue: GEMM with the exact same inputs gives cosine 1.0, but runner gives 0.18. The problem was `expert_offsets` format (leading 0).
### Step 6: Fixed expert_offsets, found token_indices corruption
After fixing expert_offsets, cosine improved to 0.35. Traced to `_token_indices` being all zeros (CuTeDSL GPU corruption).
### Step 7: Found and fixed the GPU corruption
Moved `_token_indices` to CPU. Cosine jumped to 0.46 with default gs, 0.97 with warmup gs.
### Step 8: Wrote test_warmup_gs.py
Verified warmup gs computation, tested safety margins, tested different inputs. Found 1.0x safety (no margin) gives best results.
---
## Test Files
| File | Purpose |
|------|---------|
| `tests/layertest.py` | Reference: moe_pipeline with dynamic gs, 3 experts, layer 0. Must pass (≥0.98 cosine). |
| `tests/cudagraph_test.py` | CuTeDSLMoERunner cudagraph capture + replay. Must pass. |
| `tests/test_runner_vs_pipeline.py` | Compare runner.run() vs moe_pipeline. With correct gs should be ≥0.97. |
| `tests/test_scale_assembly.py` | Compare cudagraph-safe vs reference scale assembly. Data must match. |
| `tests/test_warmup_gs.py` | Warmup gs computation, safety margin sweep, different input test. |
| `tests/test_scale_debug.py` | Byte-level scale debug (can be cleaned up). |
**Run order after any code change:**
1. `python3 tests/layertest.py` — must pass
2. `python3 tests/cudagraph_test.py` — must pass
3. `python3 tests/test_warmup_gs.py` — should show ≥0.97 cosine
---
## Files Modified
| File | Changes |
|------|---------|
| `vllm/nvfp4_cutedsl.py` | All 7 bug fixes, `compute_activation_global_scales()` warmup, CPU token_indices |
| `vllm/patches/deepseek_v4.py` | Removed checkpoint `input_scale` → activation global_scale mapping |
---
## Next Steps for vLLM Integration
1. **Add warmup call in `deepseek_v4.py`:** After `finalize_weights()`, call `runner.compute_activation_global_scales()` with a sample input (e.g., 1 token of random data). This must happen before cudagraph capture.
2. **Verify cudagraph compatibility:** The `sort_idx.cpu()` call in `run()` is a CPU-GPU sync. Cudagraph may not support this. If it doesn't, need to find a way to keep `_token_indices` on GPU while avoiding the CuTeDSL corruption.
3. **Test the vLLM container:** Spin up the server and test with a simple prompt. The output should be mostly correct (0.97 cosine ≈ near-perfect output).
4. **Optimize warmup:** The current warmup runs a full forward pass (L1 + L2 GEMM). This is slow (~minutes due to JIT). Consider caching the gs values or computing them more efficiently.

View File

@@ -37,13 +37,15 @@ class CuTeDSLMoERunner:
"""
def __init__(self, num_experts, hidden_size, intermediate_size,
max_num_tokens=8192, top_k=8, device="cuda"):
max_num_tokens=8192, top_k=8, device="cuda",
experts_start_idx=0):
self.num_experts = num_experts
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.max_num_tokens = max_num_tokens
self.top_k = top_k
self.device = device
self.experts_start_idx = experts_start_idx
# Weight storage (set before _ensure_stacked)
self.l1_fp4 = None
@@ -76,6 +78,13 @@ class CuTeDSLMoERunner:
self._padded_x_sf_buf_l2 = None
self._buffers_allocated = False
def _fill_token_indices(self):
"""Fill _token_indices with [0,0,..0, 1,1,..1, ...] (each token repeated top_k times)."""
src = torch.arange(self.max_num_tokens, dtype=torch.int32, device=self.device)
self._token_indices.copy_(
src.unsqueeze(1).expand(-1, self.top_k).contiguous().view(-1)
)
def _allocate_buffers(self):
"""Pre-allocate scale buffers at max size for cudagraph compatibility."""
K_sf = cutedsl_ceil_div(self.hidden_size, 16)
@@ -127,11 +136,18 @@ class CuTeDSLMoERunner:
# Allocate buffers AFTER JIT compilation
# (CuTeDSL's cute.compile corrupts GPU memory during JIT;
# tensors allocated before/during compilation may be zeroed)
self._token_indices = torch.arange(
self.max_num_tokens, dtype=torch.int32
).unsqueeze(1).expand(-1, self.top_k).contiguous().view(-1)
# Keep on CPU to avoid CuTeDSL JIT GPU memory corruption
# Will be indexed with CPU offsets during slot mapping
#
# _token_indices: GPU tensor for cudagraph compatibility.
# CuTeDSL JIT may corrupt GPU memory, so we fill AFTER stacking
# (which triggers the weight JIT). The GEMM JIT in run_nvfp4_grouped_gemm
# is triggered on the first run() call; we refill _token_indices after
# that first call via the _needs_token_refill flag.
self._token_indices = torch.zeros(
self.max_num_tokens * self.top_k, dtype=torch.int32, device=self.device
)
self._fill_token_indices()
self._needs_token_refill = True # GEMM JIT may corrupt; refill after first run
self._expert_id_range = torch.arange(
self.num_experts, dtype=torch.int32
).to(self.device)
@@ -239,7 +255,7 @@ class CuTeDSLMoERunner:
token_indices = self._token_indices[:num_slots]
sort_idx = flat_ids.argsort(stable=True)
sorted_ids = flat_ids[sort_idx]
sorted_token_ids = token_indices[sort_idx.cpu()].to(device)
sorted_token_ids = token_indices[sort_idx] # GPU, no .cpu()
slot_hidden = hidden_states_sample[sorted_token_ids]
# L1: get exact gs from quantize_to_nvfp4
@@ -281,24 +297,15 @@ class CuTeDSLMoERunner:
def run(self, hidden_states, topk_weights, topk_ids, expert_indices=None):
"""Run the NVFP4 MoE forward pass.
Handles global→local expert ID remapping for expert parallelism.
topk_ids contains GLOBAL expert IDs (0..n_routed_experts-1).
This runner only handles local experts
[experts_start_idx, experts_start_idx + num_experts).
Non-local tokens get zero weight and are clamped to expert 0
(harmless — zero-weighted output contributes nothing to scatter_add).
Fully cudagraph-safe: no CPU-GPU syncs, no dynamic shapes.
expert_offsets are computed from the actual token distribution
via GPU-only ops (argsort, broadcast ==, cumsum). These offsets
are passed to the GEMM as a GPU tensor, never converted to Python.
The GEMM and quantize functions see the full slot buffer.
Padding rows are zeros that produce zero output, contributing
nothing to the final scatter_add.
Args:
hidden_states: (num_tokens, hidden_size) bf16
topk_weights: (num_tokens, top_k) float32
topk_ids: (num_tokens, top_k) int
expert_indices: ignored (uses all experts)
Returns:
(num_tokens, hidden_size) bf16 - MoE output
"""
num_tokens = hidden_states.shape[0]
top_k = topk_ids.shape[1]
@@ -306,16 +313,24 @@ class CuTeDSLMoERunner:
self._ensure_stacked()
# -- Remap global expert IDs to local IDs --
# topk_ids are global: remap by subtracting experts_start_idx.
# Tokens for non-local experts get clamped to 0 with zero weight.
local_ids = topk_ids - self.experts_start_idx
local_mask = (local_ids >= 0) & (local_ids < self.num_experts)
safe_ids = local_ids.clamp(0, self.num_experts - 1)
safe_weights = topk_weights * local_mask.float()
# -- Build slot mapping --
flat_ids = topk_ids.reshape(-1)
flat_weights = topk_weights.reshape(-1)
flat_ids = safe_ids.reshape(-1)
flat_weights = safe_weights.reshape(-1)
num_slots = num_tokens * top_k
token_indices = self._token_indices[:num_slots]
sort_idx = flat_ids.argsort(stable=True)
sorted_ids = flat_ids[sort_idx]
sorted_weights = flat_weights[sort_idx]
sorted_token_ids = token_indices[sort_idx.cpu()].to(device)
sorted_token_ids = token_indices[sort_idx] # GPU tensor, no .cpu()
# Expert offsets (GPU-only, never touches CPU)
expert_id_range = self._expert_id_range
@@ -383,4 +398,10 @@ class CuTeDSLMoERunner:
weighted_out,
)
# Refill _token_indices after GEMM JIT on first call
# (CuTeDSL's cute.compile may corrupt GPU memory during first GEMM)
if self._needs_token_refill:
self._fill_token_indices()
self._needs_token_refill = False
return y

View File

@@ -504,6 +504,7 @@ class DeepseekV4MegaMoEExperts(nn.Module):
hidden_size=self.hidden_size,
intermediate_size=self.intermediate_size,
device=l1_fp4[0].device,
experts_start_idx=self.experts_start_idx,
)
self._cutedsl_runner.l1_fp4 = l1_fp4
self._cutedsl_runner.l1_sf = l1_sf