fix: convert global expert IDs to local before GEMM

vLLM's symm_buffer stores topk_ids as GLOBAL expert IDs (0..383).
Our weight tensors are indexed by LOCAL IDs (0..47 per rank).
Each rank r handles experts [r*48, r*48+47]. Without conversion,
topk_ids like 137, 222, 378 would index way out of bounds in the
weight tensor (shape (48, N, K)), producing garbage.

Derive experts_start_idx from the topk_ids and subtract to get
local IDs. This was why all ranks except rank 0 produced zero
expert matches → zero output → garbage text.
This commit is contained in:
2026-05-14 17:43:58 +00:00
parent 128ff84358
commit 2687d1fc53

View File

@@ -279,15 +279,27 @@ def nvfp4_mega_moe_full(
topk_ids = symm_buffer.topk_idx[:num_tokens]
topk_weights = symm_buffer.topk_weights[:num_tokens]
# Convert global expert IDs to local expert IDs.
# vLLM's symm_buffer stores global IDs (0..383) but our weight tensors
# are indexed by local ID (0..47). Each rank handles a contiguous chunk:
# rank r gets experts [r*E_per_rank, (r+1)*E_per_rank).
# We derive the start index from the first global ID that maps to local 0.
num_experts_per_rank = l1_w.shape[0]
# Find experts_start_idx: the minimum global ID that this rank handles.
# All topk_ids in the buffer should fall within this rank's range.
experts_start_idx = (topk_ids.min().item() // num_experts_per_rank) * num_experts_per_rank
topk_ids_local = topk_ids - experts_start_idx
if MEGA_MOE_DEBUG:
print(f"[nvfp4_mega_moe_full] x_fp4={x_fp4.shape} x_sf={x_sf.shape} "
f"topk_ids={topk_ids.shape} l1_w={l1_w.shape} l2_w={l2_w.shape}")
f"topk_ids={topk_ids.shape} topk_ids range: {topk_ids.min().item()}-{topk_ids.max().item()} "
f"local: {topk_ids_local.min().item()}-{topk_ids_local.max().item()} "
f"l1_w={l1_w.shape} l2_w={l2_w.shape}")
# Step 2: L1 GEMM (native NVFP4 block-scaled MMA)
num_experts_per_rank = l1_w.shape[0]
l1_output = nvfp4_mega_moe_l1(
x_fp4, x_sf, l1_w, l1_sf,
topk_ids, topk_weights, num_experts_per_rank,
topk_ids_local, topk_weights, num_experts_per_rank,
)
# Step 3: SiLU + Mul
@@ -302,7 +314,7 @@ def nvfp4_mega_moe_full(
# Step 5: L2 GEMM (native NVFP4 block-scaled MMA)
l2_output = nvfp4_mega_moe_l2(
l1_fp4, l1_sf_out, l2_w, l2_sf,
topk_ids, topk_weights, num_experts_per_rank,
topk_ids_local, topk_weights, num_experts_per_rank,
)
# Step 6: Write to output (caller handles cross-rank all-reduce)