[Bugfix] Fix expert_ids padding values in moe_align_block_size kernel (#35161)
Signed-off-by: Xin Yang <xyangx@amazon.com>
This commit is contained in:
@@ -12,7 +12,7 @@ from vllm.model_executor.layers.fused_moe.moe_align_block_size import (
|
||||
batched_moe_align_block_size,
|
||||
moe_align_block_size,
|
||||
)
|
||||
from vllm.utils.math_utils import round_up
|
||||
from vllm.utils.math_utils import cdiv, round_up
|
||||
from vllm.utils.torch_utils import set_random_seed
|
||||
|
||||
NUM_TOKENS = [1, 3, 256, 2256, 4096]
|
||||
@@ -142,7 +142,9 @@ def torch_moe_align_block_size(
|
||||
device=topk_ids.device,
|
||||
)
|
||||
max_num_blocks = (max_num_tokens_padded + block_size - 1) // block_size
|
||||
expert_ids = torch.zeros(max_num_blocks, dtype=torch.int32, device=topk_ids.device)
|
||||
expert_ids = torch.full(
|
||||
(max_num_blocks,), -1, dtype=torch.int32, device=topk_ids.device
|
||||
)
|
||||
|
||||
current_pos = 0
|
||||
current_block = 0
|
||||
@@ -234,9 +236,10 @@ def test_moe_align_block_size(
|
||||
assert len(valid_tokens) == total_tokens, (
|
||||
f"Should have exactly {total_tokens} valid tokens, got {len(valid_tokens)}"
|
||||
)
|
||||
assert (actual_expert_ids >= 0).all() and (actual_expert_ids < num_experts).all(), (
|
||||
"expert_ids should contain valid expert indices"
|
||||
)
|
||||
actual_num_blocks = cdiv(int(actual_num_tokens.item()), block_size)
|
||||
assert (actual_expert_ids[:actual_num_blocks] >= 0).all() and (
|
||||
actual_expert_ids[:actual_num_blocks] < num_experts
|
||||
).all(), "expert_ids should contain valid expert indices"
|
||||
|
||||
|
||||
@pytest.mark.parametrize("m", [16, 32, 2048])
|
||||
|
||||
Reference in New Issue
Block a user