[Bugfix] Fix expert_ids padding values in moe_align_block_size kernel (#35161)
Signed-off-by: Xin Yang <xyangx@amazon.com>
This commit is contained in:
@@ -172,7 +172,7 @@ __device__ void _moe_align_block_size(
|
||||
}
|
||||
}
|
||||
|
||||
// Fill remaining expert_ids with 0
|
||||
// Fill remaining expert_ids with -1
|
||||
const size_t fill_start_idx =
|
||||
cumsum[cumsum_offset + num_experts] / block_size + threadIdx.x;
|
||||
for (size_t i = fill_start_idx; i < max_num_m_blocks; i += blockDim.x) {
|
||||
@@ -265,7 +265,7 @@ __device__ void _moe_align_block_size_small_batch_expert(
|
||||
}
|
||||
}
|
||||
|
||||
// Fill remaining expert_ids with 0
|
||||
// Fill remaining expert_ids with -1
|
||||
const size_t fill_start_idx = cumsum[num_experts] / block_size + tid;
|
||||
for (size_t i = fill_start_idx; i < max_num_m_blocks; i += stride) {
|
||||
expert_ids[expert_ids_offset + i] = inactive_expert_id;
|
||||
@@ -332,7 +332,7 @@ __global__ void moe_align_block_size_kernel(
|
||||
topk_ids, sorted_token_ids, expert_ids, total_tokens_post_pad, expert_map,
|
||||
num_experts, padded_num_experts, experts_per_warp, block_size, numel,
|
||||
cumsum, max_num_tokens_padded, CEILDIV(max_num_tokens_padded, block_size),
|
||||
0, 0, topk_num, nullptr, has_expert_map);
|
||||
0, -1, topk_num, nullptr, has_expert_map);
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
@@ -373,7 +373,7 @@ __global__ void moe_align_block_size_small_batch_expert_kernel(
|
||||
_moe_align_block_size_small_batch_expert<scalar_t, fill_threads>(
|
||||
topk_ids, sorted_token_ids, expert_ids, total_tokens_post_pad, expert_map,
|
||||
num_experts, block_size, numel, max_num_tokens_padded,
|
||||
CEILDIV(max_num_tokens_padded, block_size), 0, 0, topk_num, nullptr,
|
||||
CEILDIV(max_num_tokens_padded, block_size), -1, 0, topk_num, nullptr,
|
||||
has_expert_map);
|
||||
}
|
||||
|
||||
|
||||
@@ -12,7 +12,7 @@ from vllm.model_executor.layers.fused_moe.moe_align_block_size import (
|
||||
batched_moe_align_block_size,
|
||||
moe_align_block_size,
|
||||
)
|
||||
from vllm.utils.math_utils import round_up
|
||||
from vllm.utils.math_utils import cdiv, round_up
|
||||
from vllm.utils.torch_utils import set_random_seed
|
||||
|
||||
NUM_TOKENS = [1, 3, 256, 2256, 4096]
|
||||
@@ -142,7 +142,9 @@ def torch_moe_align_block_size(
|
||||
device=topk_ids.device,
|
||||
)
|
||||
max_num_blocks = (max_num_tokens_padded + block_size - 1) // block_size
|
||||
expert_ids = torch.zeros(max_num_blocks, dtype=torch.int32, device=topk_ids.device)
|
||||
expert_ids = torch.full(
|
||||
(max_num_blocks,), -1, dtype=torch.int32, device=topk_ids.device
|
||||
)
|
||||
|
||||
current_pos = 0
|
||||
current_block = 0
|
||||
@@ -234,9 +236,10 @@ def test_moe_align_block_size(
|
||||
assert len(valid_tokens) == total_tokens, (
|
||||
f"Should have exactly {total_tokens} valid tokens, got {len(valid_tokens)}"
|
||||
)
|
||||
assert (actual_expert_ids >= 0).all() and (actual_expert_ids < num_experts).all(), (
|
||||
"expert_ids should contain valid expert indices"
|
||||
)
|
||||
actual_num_blocks = cdiv(int(actual_num_tokens.item()), block_size)
|
||||
assert (actual_expert_ids[:actual_num_blocks] >= 0).all() and (
|
||||
actual_expert_ids[:actual_num_blocks] < num_experts
|
||||
).all(), "expert_ids should contain valid expert indices"
|
||||
|
||||
|
||||
@pytest.mark.parametrize("m", [16, 32, 2048])
|
||||
|
||||
Reference in New Issue
Block a user