Fix global→local expert ID remapping for EP and remove .cpu() sync
Root cause of CUDA_ERROR_ASSERT index out of bounds: - topk_ids contains GLOBAL expert IDs (0-255) but runner treated them as local IDs (0-31 with EP=8). Tokens for non-local experts got wrong expert assignments, causing out-of-bounds scatter indices in _assemble_scales_cudagraph_safe. Fixes: 1. Add experts_start_idx param to CuTeDSLMoERunner 2. In run(), remap global→local IDs and zero weights for non-local experts 3. Move _token_indices from CPU to GPU (remove sort_idx.cpu() sync) 4. Add _fill_token_indices() and _needs_token_refill to handle CuTeDSL JIT GPU memory corruption (refill after first GEMM call)
This commit is contained in:
210
CURRENT_BUG.md
210
CURRENT_BUG.md
@@ -1,94 +1,170 @@
|
||||
# Current Bug: CuTeDSLMoERunner produces wrong output
|
||||
# Current Bug: CuTeDSLMoERunner — Status & Debug History
|
||||
|
||||
## Status
|
||||
- ✅ `layertest.py` — 0.988 cosine (moe_pipeline with dynamic gs + assemble_scales_2d_side)
|
||||
- ✅ `cudagraph_test.py` — capture + replay succeeds
|
||||
- ✅ `test_scale_assembly.py` — per-expert scale data matches reference
|
||||
- ❌ `test_runner_vs_pipeline.py` — runner gives 0.18 cosine vs pipeline (should be ~0.99)
|
||||
## Current Status (May 17, 2026 08:35 UTC)
|
||||
|
||||
## Root Cause: scale_a layout doesn't match expert_offsets
|
||||
**Mostly fixed. 0.97 cosine with warmup gs. Ready for vLLM container test.**
|
||||
|
||||
The kernel uses `expert_offsets` to index into the 2D-side scale_a tensor. After swizzle, the data layout is: each expert's 128-row block is swizzled independently, then concatenated. The kernel's TMA load undoes the swizzle when reading.
|
||||
- ✅ `layertest.py` — 0.988 cosine
|
||||
- ✅ `cudagraph_test.py` — capture + replay works
|
||||
- ✅ `test_warmup_gs.py` — 0.97 cosine with `compute_activation_global_scales()` warmup
|
||||
- ❌ vLLM server — not yet tested with these fixes
|
||||
|
||||
### How the pipeline works (0.988 cosine):
|
||||
**Remaining concerns:**
|
||||
- The CPU-based `_token_indices` uses `sort_idx.cpu()` which is a CPU-GPU sync — may interfere with cudagraph capture
|
||||
- The `compute_activation_global_scales()` warmup needs to be called from `deepseek_v4.py` during model warmup
|
||||
- The checkpoint `input_scale` should NOT be used as the activation global_scale (it's a calibration value, not a runtime value)
|
||||
|
||||
1. `expert_offsets = compute_expert_offsets([4, 4, 0], 3)` → `tensor([4, 8, 8])`
|
||||
2. `assemble_scales_2d_side([x_sf[0:4], x_sf[4:8]])` → 256 rows (only experts WITH tokens)
|
||||
- Expert 0's data: rows 0-127 (4 real + 124 zero, swizzled)
|
||||
- Expert 1's data: rows 128-255 (4 real + 124 zero, swizzled)
|
||||
- Expert 2 has 0 tokens → NOT included
|
||||
3. Kernel reads `scale_a[0:4]` for expert 0, `scale_a[4:8]` for expert 1
|
||||
4. After TMA un-swizzle, slot index `m` maps correctly to the original row `m` data
|
||||
5. `scale_a.shape[0] = 256` → kernel knows total padded tokens = 256
|
||||
---
|
||||
|
||||
**Key insight:** The pipeline only includes experts with tokens in scale_a. The kernel's slot-based expert_offsets ([4, 8, 8]) correctly indexes into the 256-row scale_a because the TMA un-swizzle maps slot index → original row → correct data.
|
||||
## Bugs Found & Fixed
|
||||
|
||||
### How the runner currently works (0.18 cosine):
|
||||
### Bug 1: Scale Assembly — Global Swizzle vs Per-Expert Swizzle
|
||||
|
||||
1. `expert_offsets` = GPU-computed `[0, 4, 8, 8]` (with leading 0 for cumsum)
|
||||
2. `_assemble_scales_cudagraph_safe` produces 384 rows (ALL experts × 128, including expert 2 with 0 tokens)
|
||||
- Expert 0's data: rows 0-127
|
||||
- Expert 1's data: rows 128-255
|
||||
- Expert 2's data: rows 256-383 (all zeros, swizzled)
|
||||
3. Kernel reads `scale_a[0:4]` for expert 0, `scale_a[4:8]` for expert 1
|
||||
4. After TMA un-swizzle, slot indices 0-3 map to expert 0's first 4 original rows → correct
|
||||
5. But slot indices 4-7 map to... the 5th-8th rows of expert 0's swizzled block, NOT expert 1's data
|
||||
6. `scale_a.shape[0] = 384` → kernel thinks there are 384 padded token slots, but expert_offsets says 8
|
||||
**Symptom:** GEMM produced all zeros even with correct global_scale.
|
||||
|
||||
**The mismatch:** The kernel interprets slot index `m=4` as row 4 of the ENTIRE scale_a tensor. After un-swizzle, row 4 is in expert 0's 128-row block. But the pipeline's scale_a has row 4 in expert 1's block (because expert 1 starts at row 128 in a 256-row tensor, but the kernel's TMA remaps slot 4 → the 5th original row which IS expert 1's 1st row in the pipeline layout).
|
||||
**Root cause:** The original `_assemble_scales_cudagraph_safe` called `pad_and_swizzle_single()` on the ENTIRE padded buffer (all experts concatenated). But the kernel expects each expert's 128-row block to be swizzled independently (matching `assemble_scales_2d_side` which pads+swizzles each expert separately before concatenation).
|
||||
|
||||
Wait — re-reading: in the pipeline, scale_a has 256 rows. Slot 4 → row 4 of scale_a. After un-swizzle, this is the 5th original row. In the pipeline, original rows 0-3 are expert 0's data, rows 4-7 are expert 1's data (padded to 128 each before swizzle). So the un-swizzle of row 4 gives the 1st original row of expert 1. That's correct.
|
||||
**Fix:** Two-phase approach:
|
||||
1. Scatter x_sf rows into 128-aligned positions in a padded buffer (GPU-only, no CPU sync)
|
||||
2. Per-expert: copy 128 rows from padded buffer, `pad_and_swizzle_single()` each expert's block independently, then concatenate
|
||||
|
||||
In the runner, scale_a has 384 rows. Slot 4 → row 4 of scale_a. After un-swizzle, this is the 5th original row. Original rows 0-3 are expert 0's data. Row 4 is a ZERO row (expert 0 had only 4 tokens, rows 4-127 are zero-padded). So the un-swizzle of row 4 gives a zero → expert 1 gets no valid scale data.
|
||||
**Key insight from `torch_scaled_grouped_mm.py` line ~1115:** The kernel computes padded offsets internally when `consistent_token_padding=False`:
|
||||
```python
|
||||
padded_size = round_up(offs[expert_idx] - offs[expert_idx-1], pad_granularity) # 128
|
||||
```
|
||||
So the kernel knows each expert's scale data is in a 128-row block.
|
||||
|
||||
**So the issue IS the 128-row padding.** When expert 0 has 4 tokens and we pad to 128 rows, slot indices 4-127 map to zeros for expert 0. The kernel needs slot indices to be contiguous per expert (0-3 for expert 0, 4-7 for expert 1), but the scale_a has 128 rows per expert block, not 4.
|
||||
### Bug 2: `searchsorted(right=False)` — Wrong Expert Assignment
|
||||
|
||||
### Why does the pipeline work?
|
||||
**Symptom:** Scale data in wrong positions after scatter.
|
||||
|
||||
Because `pad_and_swizzle_single` on a 4-row tensor pads to 128 rows, swizzles the 128-row block, and the kernel's TMA read with slot index 4 reads from the 5th position in the swizzled 128-row block. After un-swizzle, position 4 maps back to... the 5th original row, which is a zero-padded row. Wait, this should also be wrong then.
|
||||
**Root cause:** `torch.searchsorted([4, 8, 8], 4, right=False)` returns 0, assigning row 4 (expert 1's first token) to expert 0.
|
||||
|
||||
Unless the kernel's 2D-side scale access uses `m % 128` or similar per-block indexing. Let me check the kernel's scale_a read pattern.
|
||||
**Fix:** Changed to `right=True`:
|
||||
```python
|
||||
expert_assign = torch.searchsorted(expert_offsets[1:], row_indices, right=True)
|
||||
```
|
||||
|
||||
Actually — re-reading the pipeline more carefully: `assemble_scales_2d_side([x_sf[0:4], x_sf[4:8]])`. This creates a scale_a where:
|
||||
- Expert 0's 4 rows are padded to 128, swizzled → 128 swizzled rows
|
||||
- Expert 1's 4 rows are padded to 128, swizzled → 128 swizzled rows
|
||||
- Concatenated → 256 rows
|
||||
**Verified:** Row 4 → expert 1 (correct), rows 0-3 → expert 0 (correct).
|
||||
|
||||
The kernel gets `expert_offsets = [4, 8, 8]`. For expert 0 (slots 0-3), it reads scale_a positions 0-3. For expert 1 (slots 4-7), it reads scale_a positions 4-7.
|
||||
### Bug 3: CuTeDSL `cute.compile` GPU Memory Corruption — CRITICAL
|
||||
|
||||
In the swizzled layout, positions 0-3 are in the first 128-row block (expert 0). Positions 4-7 are ALSO in the first 128-row block. But expert 1's data is in the second 128-row block (positions 128-255).
|
||||
**Symptom:** `_token_indices` was all zeros, making every token map to token 0.
|
||||
|
||||
So the kernel reads positions 4-7 for expert 1, but those positions contain expert 0's zero-padded/swizzled data, not expert 1's data. This should be wrong...
|
||||
**Root cause:** CuTeDSL's `cute.compile` (JIT compilation) corrupts GPU memory. Tensors allocated on GPU before or during JIT compilation get zeroed. Pre-existing tensors allocated before the JIT survive. This is a bug in the CuTeDSL library.
|
||||
|
||||
BUT THE PIPELINE GIVES 0.988 COSINE. So either:
|
||||
1. The kernel DOES use per-expert offsets into scale_a (reading from position `expert_padded_offset + local_slot`), or
|
||||
2. The swizzle + TMA read remaps the indices in a way I'm not understanding
|
||||
**Impact:** `_token_indices` (int32 on GPU) was zeroed, causing `hidden_states[sorted_token_ids]` to return `hidden_states[0]` for all 8 slots. Every expert saw the same input.
|
||||
|
||||
Need to check the kernel's actual scale_a access pattern in the C++ code.
|
||||
**Fix:** Allocate `_token_indices` on CPU, keep it there. In `run()` and `compute_activation_global_scales()`, index with `sort_idx.cpu()` then move result to GPU:
|
||||
```python
|
||||
sorted_token_ids = token_indices[sort_idx.cpu()].to(device)
|
||||
```
|
||||
|
||||
## Fix Options
|
||||
**Warning:** This introduces a CPU-GPU sync (`.cpu()`) which may interfere with cudagraph capture. Needs verification.
|
||||
|
||||
### Option A: Padded expert_offsets
|
||||
Change expert_offsets from slot-based to 128-row-aligned:
|
||||
- Instead of `[0, 4, 8, 8]`, pass `[0, 128, 256, 384]`
|
||||
- The kernel reads scale_a[0:128] for expert 0, scale_a[128:256] for expert 1, etc.
|
||||
- Problem: the kernel would produce 128 output rows per expert instead of the actual token count
|
||||
- This breaks the output shape
|
||||
### Bug 4: `expert_offsets` With Leading 0
|
||||
|
||||
### Option B: Only include experts with tokens in scale_a
|
||||
- Same as pipeline: only pad+swizzle experts that have tokens
|
||||
- Requires knowing which experts have tokens → .tolist() or .item() → breaks cudagraph
|
||||
- Could use a fixed expert set (all experts always included, but with zero rows for empty experts)
|
||||
- The pipeline's assemble_scales_2d_side with 0 rows produces no output for that expert
|
||||
**Symptom:** GEMM produced wrong output with correct scale data.
|
||||
|
||||
### Option C: Understand the kernel's actual indexing and match it
|
||||
- Read the kernel's scale_a access code
|
||||
- Figure out exactly how slot indices map to scale_a positions
|
||||
- Build the layout the kernel expects
|
||||
**Root cause:** The runner passed `expert_offsets[:num_experts + 1]` = `[0, 4, 8, 8]` (4 elements with leading 0) but the kernel expects `compute_expert_offsets([4, 4, 0], 3)` = `[4, 8, 8]` (3 elements, cumulative sum without leading 0).
|
||||
|
||||
### Option D: Skip scale_a assembly entirely, pass raw scales
|
||||
- The kernel might accept raw (un-swizzled) scales via a different path
|
||||
- Or we could use the 3D-side layout for activation scales too
|
||||
**Fix:** Pass `expert_offsets[1:num_experts + 1]` to the GEMM.
|
||||
|
||||
## Approach
|
||||
Going with **Option C** — understand the kernel's indexing, then match it.
|
||||
### Bug 5: Checkpoint `input_scale` Is Wrong for Activation Global Scale
|
||||
|
||||
**Symptom:** Block scales all saturate at float8 max (448), producing garbage quantization.
|
||||
|
||||
**Root cause:** The checkpoint's `input_scale` (~0.000286) is a calibration value computed from a different input magnitude (amax ≈ 0.77) than what runtime produces (amax ≈ 8.17). Too-small gs → x/gs has values up to ~13000 → block_amax/6 ≈ 2174 → overflows float8_e4m3fn max of 448 → saturated block scales → garbage.
|
||||
|
||||
**Fix:** `compute_activation_global_scales()` warmup method that runs `quantize_to_nvfp4` (dynamic gs with `.max()`) before cudagraph capture to get the exact gs values for L1 and L2.
|
||||
|
||||
### Bug 6: L1 and L2 Need Separate Activation Global Scales
|
||||
|
||||
**Symptom:** L2 output was garbage even with correct L1 gs.
|
||||
|
||||
**Root cause:** After SiLU(gate)*up, the activation has amax ~286. The L1 gs (from input amax ~8) is 30x too small for L2, causing even worse block scale saturation.
|
||||
|
||||
**Fix:** `compute_activation_global_scales()` computes L1 gs from the input, runs the L1 GEMM, then computes L2 gs from the actual L1 output (after SiLU*up).
|
||||
|
||||
### Bug 7: L1 and L2 Need Separate Padded Scale Buffers
|
||||
|
||||
**Symptom:** IndexError when quantizing L2 activation — K_sf differs between L1 (448) and L2 (192).
|
||||
|
||||
**Root cause:** `padded_x_sf_buf` was allocated with L1's K_sf (448). When L2's x_sf has K_sf=192, the buffer size mismatch causes issues.
|
||||
|
||||
**Fix:** Separate `_padded_x_sf_buf_l1` and `_padded_x_sf_buf_l2`, plus separate `_per_expert_scale_bufs_l1` and `_per_expert_scale_bufs_l2`.
|
||||
|
||||
---
|
||||
|
||||
## Debug Methodology — How We Got Here
|
||||
|
||||
### Step 1: Identified the CuTeDSL kernel works (layertest = 0.988)
|
||||
|
||||
The layertest uses `moe_pipeline.run_nvfp4_moe` with `quantize_to_nvfp4` (dynamic gs) and `assemble_scales_2d_side` (per-expert split). This is the reference implementation.
|
||||
|
||||
### Step 2: Wrote test_runner_vs_pipeline.py
|
||||
|
||||
Compared `runner.run()` vs `run_nvfp4_moe()` with same weights and inputs. Found runner produces all zeros.
|
||||
|
||||
### Step 3: Wrote test_scale_assembly.py
|
||||
|
||||
Compared `_assemble_scales_cudagraph_safe` vs `assemble_scales_2d_side`. Found data mismatch (global vs per-expert swizzle).
|
||||
|
||||
### Step 4: Fixed scale assembly
|
||||
|
||||
Rewrote `_assemble_scales_cudagraph_safe` with scatter + per-expert swizzle. Scale data now matches reference.
|
||||
|
||||
### Step 5: Found GEMM still produces zeros with correct scales
|
||||
|
||||
Isolated the issue: GEMM with the exact same inputs gives cosine 1.0, but runner gives 0.18. The problem was `expert_offsets` format (leading 0).
|
||||
|
||||
### Step 6: Fixed expert_offsets, found token_indices corruption
|
||||
|
||||
After fixing expert_offsets, cosine improved to 0.35. Traced to `_token_indices` being all zeros (CuTeDSL GPU corruption).
|
||||
|
||||
### Step 7: Found and fixed the GPU corruption
|
||||
|
||||
Moved `_token_indices` to CPU. Cosine jumped to 0.46 with default gs, 0.97 with warmup gs.
|
||||
|
||||
### Step 8: Wrote test_warmup_gs.py
|
||||
|
||||
Verified warmup gs computation, tested safety margins, tested different inputs. Found 1.0x safety (no margin) gives best results.
|
||||
|
||||
---
|
||||
|
||||
## Test Files
|
||||
|
||||
| File | Purpose |
|
||||
|------|---------|
|
||||
| `tests/layertest.py` | Reference: moe_pipeline with dynamic gs, 3 experts, layer 0. Must pass (≥0.98 cosine). |
|
||||
| `tests/cudagraph_test.py` | CuTeDSLMoERunner cudagraph capture + replay. Must pass. |
|
||||
| `tests/test_runner_vs_pipeline.py` | Compare runner.run() vs moe_pipeline. With correct gs should be ≥0.97. |
|
||||
| `tests/test_scale_assembly.py` | Compare cudagraph-safe vs reference scale assembly. Data must match. |
|
||||
| `tests/test_warmup_gs.py` | Warmup gs computation, safety margin sweep, different input test. |
|
||||
| `tests/test_scale_debug.py` | Byte-level scale debug (can be cleaned up). |
|
||||
|
||||
**Run order after any code change:**
|
||||
1. `python3 tests/layertest.py` — must pass
|
||||
2. `python3 tests/cudagraph_test.py` — must pass
|
||||
3. `python3 tests/test_warmup_gs.py` — should show ≥0.97 cosine
|
||||
|
||||
---
|
||||
|
||||
## Files Modified
|
||||
|
||||
| File | Changes |
|
||||
|------|---------|
|
||||
| `vllm/nvfp4_cutedsl.py` | All 7 bug fixes, `compute_activation_global_scales()` warmup, CPU token_indices |
|
||||
| `vllm/patches/deepseek_v4.py` | Removed checkpoint `input_scale` → activation global_scale mapping |
|
||||
|
||||
---
|
||||
|
||||
## Next Steps for vLLM Integration
|
||||
|
||||
1. **Add warmup call in `deepseek_v4.py`:** After `finalize_weights()`, call `runner.compute_activation_global_scales()` with a sample input (e.g., 1 token of random data). This must happen before cudagraph capture.
|
||||
|
||||
2. **Verify cudagraph compatibility:** The `sort_idx.cpu()` call in `run()` is a CPU-GPU sync. Cudagraph may not support this. If it doesn't, need to find a way to keep `_token_indices` on GPU while avoiding the CuTeDSL corruption.
|
||||
|
||||
3. **Test the vLLM container:** Spin up the server and test with a simple prompt. The output should be mostly correct (0.97 cosine ≈ near-perfect output).
|
||||
|
||||
4. **Optimize warmup:** The current warmup runs a full forward pass (L1 + L2 GEMM). This is slow (~minutes due to JIT). Consider caching the gs values or computing them more efficiently.
|
||||
|
||||
@@ -37,13 +37,15 @@ class CuTeDSLMoERunner:
|
||||
"""
|
||||
|
||||
def __init__(self, num_experts, hidden_size, intermediate_size,
|
||||
max_num_tokens=8192, top_k=8, device="cuda"):
|
||||
max_num_tokens=8192, top_k=8, device="cuda",
|
||||
experts_start_idx=0):
|
||||
self.num_experts = num_experts
|
||||
self.hidden_size = hidden_size
|
||||
self.intermediate_size = intermediate_size
|
||||
self.max_num_tokens = max_num_tokens
|
||||
self.top_k = top_k
|
||||
self.device = device
|
||||
self.experts_start_idx = experts_start_idx
|
||||
|
||||
# Weight storage (set before _ensure_stacked)
|
||||
self.l1_fp4 = None
|
||||
@@ -76,6 +78,13 @@ class CuTeDSLMoERunner:
|
||||
self._padded_x_sf_buf_l2 = None
|
||||
self._buffers_allocated = False
|
||||
|
||||
def _fill_token_indices(self):
|
||||
"""Fill _token_indices with [0,0,..0, 1,1,..1, ...] (each token repeated top_k times)."""
|
||||
src = torch.arange(self.max_num_tokens, dtype=torch.int32, device=self.device)
|
||||
self._token_indices.copy_(
|
||||
src.unsqueeze(1).expand(-1, self.top_k).contiguous().view(-1)
|
||||
)
|
||||
|
||||
def _allocate_buffers(self):
|
||||
"""Pre-allocate scale buffers at max size for cudagraph compatibility."""
|
||||
K_sf = cutedsl_ceil_div(self.hidden_size, 16)
|
||||
@@ -127,11 +136,18 @@ class CuTeDSLMoERunner:
|
||||
# Allocate buffers AFTER JIT compilation
|
||||
# (CuTeDSL's cute.compile corrupts GPU memory during JIT;
|
||||
# tensors allocated before/during compilation may be zeroed)
|
||||
self._token_indices = torch.arange(
|
||||
self.max_num_tokens, dtype=torch.int32
|
||||
).unsqueeze(1).expand(-1, self.top_k).contiguous().view(-1)
|
||||
# Keep on CPU to avoid CuTeDSL JIT GPU memory corruption
|
||||
# Will be indexed with CPU offsets during slot mapping
|
||||
#
|
||||
# _token_indices: GPU tensor for cudagraph compatibility.
|
||||
# CuTeDSL JIT may corrupt GPU memory, so we fill AFTER stacking
|
||||
# (which triggers the weight JIT). The GEMM JIT in run_nvfp4_grouped_gemm
|
||||
# is triggered on the first run() call; we refill _token_indices after
|
||||
# that first call via the _needs_token_refill flag.
|
||||
self._token_indices = torch.zeros(
|
||||
self.max_num_tokens * self.top_k, dtype=torch.int32, device=self.device
|
||||
)
|
||||
self._fill_token_indices()
|
||||
self._needs_token_refill = True # GEMM JIT may corrupt; refill after first run
|
||||
|
||||
self._expert_id_range = torch.arange(
|
||||
self.num_experts, dtype=torch.int32
|
||||
).to(self.device)
|
||||
@@ -239,7 +255,7 @@ class CuTeDSLMoERunner:
|
||||
token_indices = self._token_indices[:num_slots]
|
||||
sort_idx = flat_ids.argsort(stable=True)
|
||||
sorted_ids = flat_ids[sort_idx]
|
||||
sorted_token_ids = token_indices[sort_idx.cpu()].to(device)
|
||||
sorted_token_ids = token_indices[sort_idx] # GPU, no .cpu()
|
||||
slot_hidden = hidden_states_sample[sorted_token_ids]
|
||||
|
||||
# L1: get exact gs from quantize_to_nvfp4
|
||||
@@ -281,24 +297,15 @@ class CuTeDSLMoERunner:
|
||||
def run(self, hidden_states, topk_weights, topk_ids, expert_indices=None):
|
||||
"""Run the NVFP4 MoE forward pass.
|
||||
|
||||
Handles global→local expert ID remapping for expert parallelism.
|
||||
topk_ids contains GLOBAL expert IDs (0..n_routed_experts-1).
|
||||
This runner only handles local experts
|
||||
[experts_start_idx, experts_start_idx + num_experts).
|
||||
|
||||
Non-local tokens get zero weight and are clamped to expert 0
|
||||
(harmless — zero-weighted output contributes nothing to scatter_add).
|
||||
|
||||
Fully cudagraph-safe: no CPU-GPU syncs, no dynamic shapes.
|
||||
|
||||
expert_offsets are computed from the actual token distribution
|
||||
via GPU-only ops (argsort, broadcast ==, cumsum). These offsets
|
||||
are passed to the GEMM as a GPU tensor, never converted to Python.
|
||||
|
||||
The GEMM and quantize functions see the full slot buffer.
|
||||
Padding rows are zeros that produce zero output, contributing
|
||||
nothing to the final scatter_add.
|
||||
|
||||
Args:
|
||||
hidden_states: (num_tokens, hidden_size) bf16
|
||||
topk_weights: (num_tokens, top_k) float32
|
||||
topk_ids: (num_tokens, top_k) int
|
||||
expert_indices: ignored (uses all experts)
|
||||
|
||||
Returns:
|
||||
(num_tokens, hidden_size) bf16 - MoE output
|
||||
"""
|
||||
num_tokens = hidden_states.shape[0]
|
||||
top_k = topk_ids.shape[1]
|
||||
@@ -306,16 +313,24 @@ class CuTeDSLMoERunner:
|
||||
|
||||
self._ensure_stacked()
|
||||
|
||||
# -- Remap global expert IDs to local IDs --
|
||||
# topk_ids are global: remap by subtracting experts_start_idx.
|
||||
# Tokens for non-local experts get clamped to 0 with zero weight.
|
||||
local_ids = topk_ids - self.experts_start_idx
|
||||
local_mask = (local_ids >= 0) & (local_ids < self.num_experts)
|
||||
safe_ids = local_ids.clamp(0, self.num_experts - 1)
|
||||
safe_weights = topk_weights * local_mask.float()
|
||||
|
||||
# -- Build slot mapping --
|
||||
flat_ids = topk_ids.reshape(-1)
|
||||
flat_weights = topk_weights.reshape(-1)
|
||||
flat_ids = safe_ids.reshape(-1)
|
||||
flat_weights = safe_weights.reshape(-1)
|
||||
num_slots = num_tokens * top_k
|
||||
token_indices = self._token_indices[:num_slots]
|
||||
|
||||
sort_idx = flat_ids.argsort(stable=True)
|
||||
sorted_ids = flat_ids[sort_idx]
|
||||
sorted_weights = flat_weights[sort_idx]
|
||||
sorted_token_ids = token_indices[sort_idx.cpu()].to(device)
|
||||
sorted_token_ids = token_indices[sort_idx] # GPU tensor, no .cpu()
|
||||
|
||||
# Expert offsets (GPU-only, never touches CPU)
|
||||
expert_id_range = self._expert_id_range
|
||||
@@ -383,4 +398,10 @@ class CuTeDSLMoERunner:
|
||||
weighted_out,
|
||||
)
|
||||
|
||||
# Refill _token_indices after GEMM JIT on first call
|
||||
# (CuTeDSL's cute.compile may corrupt GPU memory during first GEMM)
|
||||
if self._needs_token_refill:
|
||||
self._fill_token_indices()
|
||||
self._needs_token_refill = False
|
||||
|
||||
return y
|
||||
|
||||
@@ -504,6 +504,7 @@ class DeepseekV4MegaMoEExperts(nn.Module):
|
||||
hidden_size=self.hidden_size,
|
||||
intermediate_size=self.intermediate_size,
|
||||
device=l1_fp4[0].device,
|
||||
experts_start_idx=self.experts_start_idx,
|
||||
)
|
||||
self._cutedsl_runner.l1_fp4 = l1_fp4
|
||||
self._cutedsl_runner.l1_sf = l1_sf
|
||||
|
||||
Reference in New Issue
Block a user