diff --git a/tests/lora/test_moe_lora_align_sum.py b/tests/lora/test_moe_lora_align_sum.py index 3a17f3eba..bb46b4d86 100644 --- a/tests/lora/test_moe_lora_align_sum.py +++ b/tests/lora/test_moe_lora_align_sum.py @@ -47,6 +47,8 @@ def test_moe_lora_align_block_size( # compute paddings max_num_tokens_padded = topk_ids.numel() + num_experts * (block_size - 1) max_num_tokens_padded = round_up(max_num_tokens_padded, block_size) + if topk_ids.numel() < num_experts: + max_num_tokens_padded = topk_ids.numel() * block_size max_num_m_blocks = CEILDIV(max_num_tokens_padded, block_size) # init output tensors diff --git a/vllm/lora/punica_wrapper/punica_gpu.py b/vllm/lora/punica_wrapper/punica_gpu.py index b75d297ba..5f2604892 100644 --- a/vllm/lora/punica_wrapper/punica_gpu.py +++ b/vllm/lora/punica_wrapper/punica_gpu.py @@ -351,6 +351,8 @@ class PunicaWrapperGPU(PunicaWrapperBase): max_num_tokens_padded = topk_ids.numel() + num_experts * (block_size - 1) if pad_sorted_ids: max_num_tokens_padded = round_up(max_num_tokens_padded, block_size) + if topk_ids.numel() < num_experts: + max_num_tokens_padded = topk_ids.numel() * block_size sorted_ids = torch.empty( (max_loras * max_num_tokens_padded,), dtype=torch.int32,