diff --git a/src/nvfp4_megamoe_kernel/nvfp4_mega_moe.py b/src/nvfp4_megamoe_kernel/nvfp4_mega_moe.py index 81e0ea75..1d5c04ff 100644 --- a/src/nvfp4_megamoe_kernel/nvfp4_mega_moe.py +++ b/src/nvfp4_megamoe_kernel/nvfp4_mega_moe.py @@ -279,15 +279,27 @@ def nvfp4_mega_moe_full( topk_ids = symm_buffer.topk_idx[:num_tokens] topk_weights = symm_buffer.topk_weights[:num_tokens] + # Convert global expert IDs to local expert IDs. + # vLLM's symm_buffer stores global IDs (0..383) but our weight tensors + # are indexed by local ID (0..47). Each rank handles a contiguous chunk: + # rank r gets experts [r*E_per_rank, (r+1)*E_per_rank). + # We derive the start index from the first global ID that maps to local 0. + num_experts_per_rank = l1_w.shape[0] + # Find experts_start_idx: the minimum global ID that this rank handles. + # All topk_ids in the buffer should fall within this rank's range. + experts_start_idx = (topk_ids.min().item() // num_experts_per_rank) * num_experts_per_rank + topk_ids_local = topk_ids - experts_start_idx + if MEGA_MOE_DEBUG: print(f"[nvfp4_mega_moe_full] x_fp4={x_fp4.shape} x_sf={x_sf.shape} " - f"topk_ids={topk_ids.shape} l1_w={l1_w.shape} l2_w={l2_w.shape}") + f"topk_ids={topk_ids.shape} topk_ids range: {topk_ids.min().item()}-{topk_ids.max().item()} " + f"local: {topk_ids_local.min().item()}-{topk_ids_local.max().item()} " + f"l1_w={l1_w.shape} l2_w={l2_w.shape}") # Step 2: L1 GEMM (native NVFP4 block-scaled MMA) - num_experts_per_rank = l1_w.shape[0] l1_output = nvfp4_mega_moe_l1( x_fp4, x_sf, l1_w, l1_sf, - topk_ids, topk_weights, num_experts_per_rank, + topk_ids_local, topk_weights, num_experts_per_rank, ) # Step 3: SiLU + Mul @@ -302,7 +314,7 @@ def nvfp4_mega_moe_full( # Step 5: L2 GEMM (native NVFP4 block-scaled MMA) l2_output = nvfp4_mega_moe_l2( l1_fp4, l1_sf_out, l2_w, l2_sf, - topk_ids, topk_weights, num_experts_per_rank, + topk_ids_local, topk_weights, num_experts_per_rank, ) # Step 6: Write to output (caller handles cross-rank all-reduce)