fix: convert global expert IDs to local before GEMM
vLLM's symm_buffer stores topk_ids as GLOBAL expert IDs (0..383). Our weight tensors are indexed by LOCAL IDs (0..47 per rank). Each rank r handles experts [r*48, r*48+47]. Without conversion, topk_ids like 137, 222, 378 would index way out of bounds in the weight tensor (shape (48, N, K)), producing garbage. Derive experts_start_idx from the topk_ids and subtract to get local IDs. This was why all ranks except rank 0 produced zero expert matches → zero output → garbage text.
This commit is contained in:
@@ -279,15 +279,27 @@ def nvfp4_mega_moe_full(
|
||||
topk_ids = symm_buffer.topk_idx[:num_tokens]
|
||||
topk_weights = symm_buffer.topk_weights[:num_tokens]
|
||||
|
||||
# Convert global expert IDs to local expert IDs.
|
||||
# vLLM's symm_buffer stores global IDs (0..383) but our weight tensors
|
||||
# are indexed by local ID (0..47). Each rank handles a contiguous chunk:
|
||||
# rank r gets experts [r*E_per_rank, (r+1)*E_per_rank).
|
||||
# We derive the start index from the first global ID that maps to local 0.
|
||||
num_experts_per_rank = l1_w.shape[0]
|
||||
# Find experts_start_idx: the minimum global ID that this rank handles.
|
||||
# All topk_ids in the buffer should fall within this rank's range.
|
||||
experts_start_idx = (topk_ids.min().item() // num_experts_per_rank) * num_experts_per_rank
|
||||
topk_ids_local = topk_ids - experts_start_idx
|
||||
|
||||
if MEGA_MOE_DEBUG:
|
||||
print(f"[nvfp4_mega_moe_full] x_fp4={x_fp4.shape} x_sf={x_sf.shape} "
|
||||
f"topk_ids={topk_ids.shape} l1_w={l1_w.shape} l2_w={l2_w.shape}")
|
||||
f"topk_ids={topk_ids.shape} topk_ids range: {topk_ids.min().item()}-{topk_ids.max().item()} "
|
||||
f"local: {topk_ids_local.min().item()}-{topk_ids_local.max().item()} "
|
||||
f"l1_w={l1_w.shape} l2_w={l2_w.shape}")
|
||||
|
||||
# Step 2: L1 GEMM (native NVFP4 block-scaled MMA)
|
||||
num_experts_per_rank = l1_w.shape[0]
|
||||
l1_output = nvfp4_mega_moe_l1(
|
||||
x_fp4, x_sf, l1_w, l1_sf,
|
||||
topk_ids, topk_weights, num_experts_per_rank,
|
||||
topk_ids_local, topk_weights, num_experts_per_rank,
|
||||
)
|
||||
|
||||
# Step 3: SiLU + Mul
|
||||
@@ -302,7 +314,7 @@ def nvfp4_mega_moe_full(
|
||||
# Step 5: L2 GEMM (native NVFP4 block-scaled MMA)
|
||||
l2_output = nvfp4_mega_moe_l2(
|
||||
l1_fp4, l1_sf_out, l2_w, l2_sf,
|
||||
topk_ids, topk_weights, num_experts_per_rank,
|
||||
topk_ids_local, topk_weights, num_experts_per_rank,
|
||||
)
|
||||
|
||||
# Step 6: Write to output (caller handles cross-rank all-reduce)
|
||||
|
||||
Reference in New Issue
Block a user