From 2687d1fc53f6995ce537cb55152d9b943d5f08e4 Mon Sep 17 00:00:00 2001 From: biondizzle Date: Thu, 14 May 2026 17:43:58 +0000 Subject: [PATCH] fix: convert global expert IDs to local before GEMM MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit vLLM's symm_buffer stores topk_ids as GLOBAL expert IDs (0..383). Our weight tensors are indexed by LOCAL IDs (0..47 per rank). Each rank r handles experts [r*48, r*48+47]. Without conversion, topk_ids like 137, 222, 378 would index way out of bounds in the weight tensor (shape (48, N, K)), producing garbage. Derive experts_start_idx from the topk_ids and subtract to get local IDs. This was why all ranks except rank 0 produced zero expert matches → zero output → garbage text. --- src/nvfp4_megamoe_kernel/nvfp4_mega_moe.py | 20 ++++++++++++++++---- 1 file changed, 16 insertions(+), 4 deletions(-) diff --git a/src/nvfp4_megamoe_kernel/nvfp4_mega_moe.py b/src/nvfp4_megamoe_kernel/nvfp4_mega_moe.py index 81e0ea75..1d5c04ff 100644 --- a/src/nvfp4_megamoe_kernel/nvfp4_mega_moe.py +++ b/src/nvfp4_megamoe_kernel/nvfp4_mega_moe.py @@ -279,15 +279,27 @@ def nvfp4_mega_moe_full( topk_ids = symm_buffer.topk_idx[:num_tokens] topk_weights = symm_buffer.topk_weights[:num_tokens] + # Convert global expert IDs to local expert IDs. + # vLLM's symm_buffer stores global IDs (0..383) but our weight tensors + # are indexed by local ID (0..47). Each rank handles a contiguous chunk: + # rank r gets experts [r*E_per_rank, (r+1)*E_per_rank). + # We derive the start index from the first global ID that maps to local 0. + num_experts_per_rank = l1_w.shape[0] + # Find experts_start_idx: the minimum global ID that this rank handles. + # All topk_ids in the buffer should fall within this rank's range. + experts_start_idx = (topk_ids.min().item() // num_experts_per_rank) * num_experts_per_rank + topk_ids_local = topk_ids - experts_start_idx + if MEGA_MOE_DEBUG: print(f"[nvfp4_mega_moe_full] x_fp4={x_fp4.shape} x_sf={x_sf.shape} " - f"topk_ids={topk_ids.shape} l1_w={l1_w.shape} l2_w={l2_w.shape}") + f"topk_ids={topk_ids.shape} topk_ids range: {topk_ids.min().item()}-{topk_ids.max().item()} " + f"local: {topk_ids_local.min().item()}-{topk_ids_local.max().item()} " + f"l1_w={l1_w.shape} l2_w={l2_w.shape}") # Step 2: L1 GEMM (native NVFP4 block-scaled MMA) - num_experts_per_rank = l1_w.shape[0] l1_output = nvfp4_mega_moe_l1( x_fp4, x_sf, l1_w, l1_sf, - topk_ids, topk_weights, num_experts_per_rank, + topk_ids_local, topk_weights, num_experts_per_rank, ) # Step 3: SiLU + Mul @@ -302,7 +314,7 @@ def nvfp4_mega_moe_full( # Step 5: L2 GEMM (native NVFP4 block-scaled MMA) l2_output = nvfp4_mega_moe_l2( l1_fp4, l1_sf_out, l2_w, l2_sf, - topk_ids, topk_weights, num_experts_per_rank, + topk_ids_local, topk_weights, num_experts_per_rank, ) # Step 6: Write to output (caller handles cross-rank all-reduce)