From 8e4ca4d14e34305c0ab5640a9143871b25c91811 Mon Sep 17 00:00:00 2001 From: gnovack Date: Wed, 22 Oct 2025 05:23:57 -0700 Subject: [PATCH] Bugfix - pass 'max_num_tokens_padded' into 'moe_lora_align_block_size' (#27311) Signed-off-by: gnovack Co-authored-by: Jee Jee Li --- csrc/moe/moe_lora_align_sum_kernels.cu | 8 ++------ csrc/moe/moe_ops.h | 3 ++- csrc/moe/torch_bindings.cpp | 2 ++ tests/lora/test_fused_moe_lora_kernel.py | 2 ++ tests/lora/test_gptoss.py | 2 +- tests/lora/test_moe_lora_align_sum.py | 2 ++ vllm/_custom_ops.py | 4 ++++ vllm/lora/punica_wrapper/punica_gpu.py | 2 ++ 8 files changed, 17 insertions(+), 8 deletions(-) diff --git a/csrc/moe/moe_lora_align_sum_kernels.cu b/csrc/moe/moe_lora_align_sum_kernels.cu index 1d25844bd..e76d1c366 100644 --- a/csrc/moe/moe_lora_align_sum_kernels.cu +++ b/csrc/moe/moe_lora_align_sum_kernels.cu @@ -124,18 +124,14 @@ __global__ void moe_lora_align_sum_kernel( void moe_lora_align_block_size(torch::Tensor topk_ids, torch::Tensor token_lora_mapping, int64_t num_experts, int64_t block_size, - int64_t max_loras, + int64_t max_loras, int64_t max_num_tokens_padded, + int64_t max_num_m_blocks, torch::Tensor sorted_token_ids, torch::Tensor expert_ids, torch::Tensor num_tokens_post_pad) { const int topk_num = topk_ids.size(1); - int max_num_tokens_padded = topk_ids.numel() + num_experts * (block_size - 1); - TORCH_CHECK(block_size > 0, "block_size should be greater than 0. "); - max_num_tokens_padded = round_to_next_multiple_of( - max_num_tokens_padded, static_cast(block_size)); - int max_num_m_blocks = div_ceil(max_num_tokens_padded, block_size); int device_max_shared_mem; auto dev = topk_ids.get_device(); diff --git a/csrc/moe/moe_ops.h b/csrc/moe/moe_ops.h index 45dd1824d..e4bf0aa99 100644 --- a/csrc/moe/moe_ops.h +++ b/csrc/moe/moe_ops.h @@ -23,7 +23,8 @@ void batched_moe_align_block_size(int64_t max_tokens_per_batch, void moe_lora_align_block_size(torch::Tensor topk_ids, torch::Tensor token_lora_mapping, int64_t num_experts, int64_t block_size, - int64_t max_loras, + int64_t max_loras, int64_t max_num_tokens_padded, + int64_t max_num_m_blocks, torch::Tensor sorted_token_ids, torch::Tensor expert_ids, torch::Tensor num_tokens_post_pad); diff --git a/csrc/moe/torch_bindings.cpp b/csrc/moe/torch_bindings.cpp index f110683af..c08a54390 100644 --- a/csrc/moe/torch_bindings.cpp +++ b/csrc/moe/torch_bindings.cpp @@ -40,6 +40,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) { " Tensor token_lora_mapping," " int num_experts," " int block_size, int max_loras, " + " int max_num_tokens_padded, " + " int max_num_m_blocks, " " Tensor !sorted_token_ids," " Tensor !experts_ids," " Tensor !num_tokens_post_pad) -> () "); diff --git a/tests/lora/test_fused_moe_lora_kernel.py b/tests/lora/test_fused_moe_lora_kernel.py index 052e52c7b..f9a66d4d0 100644 --- a/tests/lora/test_fused_moe_lora_kernel.py +++ b/tests/lora/test_fused_moe_lora_kernel.py @@ -142,6 +142,8 @@ def use_fused_moe_lora_kernel( num_experts, block_size, max_loras, + max_num_tokens_padded, + max_num_m_blocks, sorted_token_ids, expert_ids, num_tokens_post_padded, diff --git a/tests/lora/test_gptoss.py b/tests/lora/test_gptoss.py index cdd0304af..f5c9a5cf2 100644 --- a/tests/lora/test_gptoss.py +++ b/tests/lora/test_gptoss.py @@ -36,7 +36,7 @@ def test_gptoss20b_lora(gptoss20b_lora_files): llm = vllm.LLM( MODEL_PATH, enable_lora=True, - max_loras=1, + max_loras=4, trust_remote_code=True, ) diff --git a/tests/lora/test_moe_lora_align_sum.py b/tests/lora/test_moe_lora_align_sum.py index e65dd40bd..6cd1281c3 100644 --- a/tests/lora/test_moe_lora_align_sum.py +++ b/tests/lora/test_moe_lora_align_sum.py @@ -68,6 +68,8 @@ def test_moe_lora_align_block_size( num_experts, block_size, max_loras, + max_num_tokens_padded, + max_num_m_blocks, sorted_token_ids, expert_ids, num_tokens_post_pad, diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index 7efd6aa44..eccb9a1ef 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -1801,6 +1801,8 @@ def moe_lora_align_block_size( num_experts: int, block_size: int, max_loras: int, + max_num_tokens_padded: int, + max_num_m_blocks: int, sorted_token_ids: torch.Tensor, experts_ids: torch.Tensor, num_tokens_post_pad: torch.Tensor, @@ -1811,6 +1813,8 @@ def moe_lora_align_block_size( num_experts, block_size, max_loras, + max_num_tokens_padded, + max_num_m_blocks, sorted_token_ids, experts_ids, num_tokens_post_pad, diff --git a/vllm/lora/punica_wrapper/punica_gpu.py b/vllm/lora/punica_wrapper/punica_gpu.py index daf89cd97..c2c26a01e 100644 --- a/vllm/lora/punica_wrapper/punica_gpu.py +++ b/vllm/lora/punica_wrapper/punica_gpu.py @@ -341,6 +341,8 @@ class PunicaWrapperGPU(PunicaWrapperBase): num_experts, block_size, max_loras, + max_num_tokens_padded, + max_num_m_blocks, sorted_ids, expert_ids, num_tokens_post_pad,