diff --git a/csrc/moe/moe_align_sum_kernels.cu b/csrc/moe/moe_align_sum_kernels.cu index 5c9e47402..e3539ff40 100644 --- a/csrc/moe/moe_align_sum_kernels.cu +++ b/csrc/moe/moe_align_sum_kernels.cu @@ -172,7 +172,7 @@ __device__ void _moe_align_block_size( } } - // Fill remaining expert_ids with 0 + // Fill remaining expert_ids with -1 const size_t fill_start_idx = cumsum[cumsum_offset + num_experts] / block_size + threadIdx.x; for (size_t i = fill_start_idx; i < max_num_m_blocks; i += blockDim.x) { @@ -265,7 +265,7 @@ __device__ void _moe_align_block_size_small_batch_expert( } } - // Fill remaining expert_ids with 0 + // Fill remaining expert_ids with -1 const size_t fill_start_idx = cumsum[num_experts] / block_size + tid; for (size_t i = fill_start_idx; i < max_num_m_blocks; i += stride) { expert_ids[expert_ids_offset + i] = inactive_expert_id; @@ -332,7 +332,7 @@ __global__ void moe_align_block_size_kernel( topk_ids, sorted_token_ids, expert_ids, total_tokens_post_pad, expert_map, num_experts, padded_num_experts, experts_per_warp, block_size, numel, cumsum, max_num_tokens_padded, CEILDIV(max_num_tokens_padded, block_size), - 0, 0, topk_num, nullptr, has_expert_map); + 0, -1, topk_num, nullptr, has_expert_map); } template @@ -373,7 +373,7 @@ __global__ void moe_align_block_size_small_batch_expert_kernel( _moe_align_block_size_small_batch_expert( topk_ids, sorted_token_ids, expert_ids, total_tokens_post_pad, expert_map, num_experts, block_size, numel, max_num_tokens_padded, - CEILDIV(max_num_tokens_padded, block_size), 0, 0, topk_num, nullptr, + CEILDIV(max_num_tokens_padded, block_size), -1, 0, topk_num, nullptr, has_expert_map); } diff --git a/tests/kernels/moe/test_moe_align_block_size.py b/tests/kernels/moe/test_moe_align_block_size.py index 4165df37c..9096d0ab8 100644 --- a/tests/kernels/moe/test_moe_align_block_size.py +++ b/tests/kernels/moe/test_moe_align_block_size.py @@ -12,7 +12,7 @@ from vllm.model_executor.layers.fused_moe.moe_align_block_size import ( batched_moe_align_block_size, moe_align_block_size, ) -from vllm.utils.math_utils import round_up +from vllm.utils.math_utils import cdiv, round_up from vllm.utils.torch_utils import set_random_seed NUM_TOKENS = [1, 3, 256, 2256, 4096] @@ -142,7 +142,9 @@ def torch_moe_align_block_size( device=topk_ids.device, ) max_num_blocks = (max_num_tokens_padded + block_size - 1) // block_size - expert_ids = torch.zeros(max_num_blocks, dtype=torch.int32, device=topk_ids.device) + expert_ids = torch.full( + (max_num_blocks,), -1, dtype=torch.int32, device=topk_ids.device + ) current_pos = 0 current_block = 0 @@ -234,9 +236,10 @@ def test_moe_align_block_size( assert len(valid_tokens) == total_tokens, ( f"Should have exactly {total_tokens} valid tokens, got {len(valid_tokens)}" ) - assert (actual_expert_ids >= 0).all() and (actual_expert_ids < num_experts).all(), ( - "expert_ids should contain valid expert indices" - ) + actual_num_blocks = cdiv(int(actual_num_tokens.item()), block_size) + assert (actual_expert_ids[:actual_num_blocks] >= 0).all() and ( + actual_expert_ids[:actual_num_blocks] < num_experts + ).all(), "expert_ids should contain valid expert indices" @pytest.mark.parametrize("m", [16, 32, 2048])