diff --git a/benchmarks/kernels/benchmark_moe_permute_unpermute.py b/benchmarks/kernels/benchmark_moe_permute_unpermute.py index 8c90dd725..d9a1d3303 100644 --- a/benchmarks/kernels/benchmark_moe_permute_unpermute.py +++ b/benchmarks/kernels/benchmark_moe_permute_unpermute.py @@ -44,10 +44,8 @@ def benchmark_permute( hidden_states = torch.randn(num_tokens, hidden_size, dtype=dtype) # output_hidden_states = torch.empty_like(hidden_states) if use_fp8_w8a8: - align_block_size = 128 # deepgemm needs 128 m aligned block qhidden_states, scale = _fp8_quantize(hidden_states, None, None) else: - align_block_size = None qhidden_states = hidden_states gating_output = torch.randn(num_iters, num_tokens, num_experts, dtype=torch.float32) @@ -67,7 +65,6 @@ def benchmark_permute( topk_ids=topk_ids, n_expert=num_experts, expert_map=None, - align_block_size=align_block_size, ) # JIT compilation & warmup @@ -117,10 +114,8 @@ def benchmark_unpermute( # init_dtype = torch.float16 if use_fp8_w8a8 else dtype hidden_states = torch.randn(num_tokens, hidden_size, dtype=dtype) if use_fp8_w8a8: - align_block_size = 128 # deepgemm needs 128 m aligned block qhidden_states, scale = _fp8_quantize(hidden_states, None, None) else: - align_block_size = None qhidden_states = hidden_states input_gating = torch.randn(num_tokens, num_experts, dtype=torch.float32) @@ -142,7 +137,6 @@ def benchmark_unpermute( topk_ids=topk_ids, n_expert=num_experts, expert_map=None, - align_block_size=align_block_size, ) # convert to fp16/bf16 as gemm output return ( diff --git a/csrc/moe/moe_permute_unpermute_op.cu b/csrc/moe/moe_permute_unpermute_op.cu index 3de64eda6..eec8f9854 100644 --- a/csrc/moe/moe_permute_unpermute_op.cu +++ b/csrc/moe/moe_permute_unpermute_op.cu @@ -14,12 +14,10 @@ void moe_permute( const torch::Tensor& token_expert_indices, // [n_token, topk] const std::optional& expert_map, // [n_expert] int64_t n_expert, int64_t n_local_expert, int64_t topk, - const std::optional& align_block_size, torch::Tensor& permuted_input, // [permuted_size, hidden] torch::Tensor& expert_first_token_offset, // [n_local_expert + 1] torch::Tensor& inv_permuted_idx, // [n_token, topk] - torch::Tensor& permuted_idx, // [permute_size] - torch::Tensor& m_indices) { // [align_expand_m] + torch::Tensor& permuted_idx) { // [permute_size] TORCH_CHECK(expert_first_token_offset.scalar_type() == at::ScalarType::Long, "expert_first_token_offset must be int64"); TORCH_CHECK(topk_ids.scalar_type() == at::ScalarType::Int, @@ -34,8 +32,6 @@ void moe_permute( "token_expert_indices shape must be same as inv_permuted_idx"); auto n_token = input.sizes()[0]; auto n_hidden = input.sizes()[1]; - auto align_block_size_value = - align_block_size.has_value() ? align_block_size.value() : -1; auto stream = at::cuda::getCurrentCUDAStream().stream(); const long sorter_size = CubKeyValueSorter::getWorkspaceSize(n_token * topk, n_expert); @@ -73,42 +69,15 @@ void moe_permute( get_ptr(expert_first_token_offset), n_token, n_expert, n_local_expert, topk, sorter, get_ptr(sort_workspace), stream); - // DeepGEMM: use getMIndices kernel to compute - // 1) align_expert_first_token_offset (aligned prefix offsets) - // 2) m_indices (expert id for each aligned row) - // eg. expert0: 3, expert1: 5, expert2: 2 tokens respectively - // expert_first_token_offset = [0, 3, 8, 10], align_block_size = 4 - // expert0: 3->4, expert1: 5->8, expert2: 2->4 - // align_expert_first_token_offset = [0, 4, 12, 16] - // so m_indices = [0,0,0,0, 1,1,1,1,1,1,1,1, 2,2,2,2] - torch::Tensor align_expert_first_token_offset; - const int64_t* aligned_expert_first_token_offset_ptr = nullptr; - if (align_block_size.has_value()) { - align_expert_first_token_offset = - torch::zeros_like(expert_first_token_offset); - getMIndices(get_ptr(expert_first_token_offset), - get_ptr(align_expert_first_token_offset), - get_ptr(m_indices), n_local_expert, align_block_size_value, - stream); - aligned_expert_first_token_offset_ptr = - get_ptr(align_expert_first_token_offset); - } - // dispatch expandInputRowsKernelLauncher MOE_DISPATCH(input.scalar_type(), [&] { expandInputRowsKernelLauncher( get_ptr(input), get_ptr(permuted_input), get_ptr(permuted_experts_id), get_ptr(sorted_row_idx), get_ptr(inv_permuted_idx), get_ptr(permuted_idx), - get_ptr(expert_first_token_offset), - aligned_expert_first_token_offset_ptr, n_token, valid_num_ptr, n_hidden, - topk, n_local_expert, align_block_size_value, stream); + get_ptr(expert_first_token_offset), n_token, valid_num_ptr, + n_hidden, topk, n_local_expert, stream); }); - - // this is only required for DeepGemm and not required for CUTLASS group gemm - if (align_block_size.has_value()) { - expert_first_token_offset.copy_(align_expert_first_token_offset); - } } void moe_unpermute( @@ -201,16 +170,13 @@ void shuffle_rows(const torch::Tensor& input_tensor, #else -void moe_permute(const torch::Tensor& input, const torch::Tensor& topk_weights, - torch::Tensor& topk_ids, +void moe_permute(const torch::Tensor& input, const torch::Tensor& topk_ids, const torch::Tensor& token_expert_indices, const std::optional& expert_map, int64_t n_expert, int64_t n_local_expert, int64_t topk, - const std::optional& align_block_size, torch::Tensor& permuted_input, torch::Tensor& expert_first_token_offset, - torch::Tensor& src_row_id2dst_row_id_map, - torch::Tensor& m_indices) { + torch::Tensor& inv_permuted_idx, torch::Tensor& permuted_idx) { TORCH_CHECK(false, "moe_permute is not supported on CUDA < 12.0"); } diff --git a/csrc/moe/permute_unpermute_kernels/moe_permute_unpermute_kernel.cu b/csrc/moe/permute_unpermute_kernels/moe_permute_unpermute_kernel.cu index 9499b297f..2cc200321 100644 --- a/csrc/moe/permute_unpermute_kernels/moe_permute_unpermute_kernel.cu +++ b/csrc/moe/permute_unpermute_kernels/moe_permute_unpermute_kernel.cu @@ -168,64 +168,4 @@ void preprocessTopkIdLauncher(int* topk_id_ptr, int size, topk_id_ptr, size, expert_map_ptr, num_experts); } -template -__global__ void getMIndicesKernel(int64_t* expert_first_token_offset, - int64_t* align_expert_first_token_offset, - int* m_indices, const int num_local_expert, - const int align_block_size) { - int eidx = blockIdx.x; - int tidx = threadIdx.x; - extern __shared__ int64_t smem_expert_first_token_offset[]; - for (int i = tidx; i <= num_local_expert; i += blockDim.x) { - smem_expert_first_token_offset[i] = __ldg(expert_first_token_offset + i); - } - __syncthreads(); - auto last_token_offset = smem_expert_first_token_offset[eidx + 1]; - auto first_token_offset = smem_expert_first_token_offset[eidx]; - int n_token_in_expert = last_token_offset - first_token_offset; - - if constexpr (ALIGN_BLOCK_SIZE) { - n_token_in_expert = (n_token_in_expert + align_block_size - 1) / - align_block_size * align_block_size; - // round up to ALIGN_BLOCK_SIZE - int64_t accumulate_align_offset = 0; - for (int i = 1; i <= eidx + 1; i++) { - int n_token = smem_expert_first_token_offset[i] - - smem_expert_first_token_offset[i - 1]; - accumulate_align_offset = - accumulate_align_offset + (n_token + align_block_size - 1) / - align_block_size * align_block_size; - if (i == eidx) { - first_token_offset = accumulate_align_offset; - } - // last block store align_expert_first_token_offset - if (eidx == num_local_expert - 1 && threadIdx.x == 0) { - align_expert_first_token_offset[i] = accumulate_align_offset; - } - } - } - for (int idx = tidx; idx < n_token_in_expert; idx += blockDim.x) { - // update m_indice with expert id - m_indices[first_token_offset + idx] = eidx; - } -} - -void getMIndices(int64_t* expert_first_token_offset, - int64_t* align_expert_first_token_offset, int* m_indices, - int num_local_expert, const int align_block_size, - cudaStream_t stream) { - int block = 256; - int grid = num_local_expert; - int smem_size = sizeof(int64_t) * (num_local_expert + 1); - if (align_block_size == -1) { - getMIndicesKernel<<>>( - expert_first_token_offset, align_expert_first_token_offset, m_indices, - num_local_expert, align_block_size); - } else { - getMIndicesKernel<<>>( - expert_first_token_offset, align_expert_first_token_offset, m_indices, - num_local_expert, align_block_size); - } -} - #endif diff --git a/csrc/moe/permute_unpermute_kernels/moe_permute_unpermute_kernel.h b/csrc/moe/permute_unpermute_kernels/moe_permute_unpermute_kernel.h index 09491ab98..840b47546 100644 --- a/csrc/moe/permute_unpermute_kernels/moe_permute_unpermute_kernel.h +++ b/csrc/moe/permute_unpermute_kernels/moe_permute_unpermute_kernel.h @@ -60,10 +60,9 @@ void expandInputRowsKernelLauncher( T const* unpermuted_input, T* permuted_output, int* sorted_experts, int const* expanded_dest_row_to_expanded_source_row, int* expanded_source_row_to_expanded_dest_row, int* permuted_idx, - int64_t const* expert_first_token_offset, - int64_t const* aligned_expert_first_token_offset, int64_t const num_rows, + int64_t const* expert_first_token_offset, int64_t const num_rows, int64_t const* num_valid_tokens_ptr, int64_t const cols, int const k, - int num_local_experts, const int& align_block_size, cudaStream_t stream); + int num_local_experts, cudaStream_t stream); template void finalizeMoeRoutingKernelLauncher( @@ -76,9 +75,4 @@ void preprocessTopkIdLauncher(int* topk_id_ptr, int size, const int* expert_map_ptr, int num_experts, cudaStream_t stream); -void getMIndices(int64_t* expert_first_token_offset, - int64_t* align_expert_first_token_offset, int* m_indices, - int num_local_expert, const int align_block_size, - cudaStream_t stream); - #include "moe_permute_unpermute_kernel.inl" diff --git a/csrc/moe/permute_unpermute_kernels/moe_permute_unpermute_kernel.inl b/csrc/moe/permute_unpermute_kernels/moe_permute_unpermute_kernel.inl index 68f3cc9fa..bcb2f9ca5 100644 --- a/csrc/moe/permute_unpermute_kernels/moe_permute_unpermute_kernel.inl +++ b/csrc/moe/permute_unpermute_kernels/moe_permute_unpermute_kernel.inl @@ -1,14 +1,13 @@ #pragma once -template +template __global__ void expandInputRowsKernel( T const* unpermuted_input, T* permuted_output, int* sorted_experts, int const* expanded_dest_row_to_expanded_source_row, int* expanded_source_row_to_expanded_dest_row, int* permuted_idx, - int64_t const* expert_first_token_offset, - int64_t const* aligned_expert_first_token_offset, int64_t const num_rows, + int64_t const* expert_first_token_offset, int64_t const num_rows, int64_t const* num_dest_rows, int64_t const cols, int64_t k, - int num_local_experts, int align_block_size) { + int num_local_experts) { // Reverse permutation map. // I do this so that later, we can use the source -> dest map to do the k-way // reduction and unpermuting. I need the reverse map for that reduction to @@ -19,24 +18,6 @@ __global__ void expandInputRowsKernel( expanded_dest_row_to_expanded_source_row[expanded_dest_row]; int expert_id = sorted_experts[expanded_dest_row]; - if constexpr (ALIGN_BLOCK_SIZE) { - // convert (unaligned) expanded_dest_row -> aligned expanded_dest_row. - // aligned_expert_first_token_offset[e] provides the aligned prefix start - // for expert e. For non-local experts we map to the end (total aligned M). - int64_t aligned_base = 0; - int64_t token_offset_in_expert = 0; - if (expert_id >= num_local_experts) { - aligned_base = - __ldg(aligned_expert_first_token_offset + num_local_experts); - token_offset_in_expert = 0; - } else { - aligned_base = __ldg(aligned_expert_first_token_offset + expert_id); - token_offset_in_expert = - expanded_dest_row - __ldg(expert_first_token_offset + expert_id); - } - expanded_dest_row = aligned_base + token_offset_in_expert; - } - if (threadIdx.x == 0) { assert(expanded_dest_row <= INT32_MAX); expanded_source_row_to_expanded_dest_row[expanded_source_row] = @@ -76,29 +57,25 @@ void expandInputRowsKernelLauncher( T const* unpermuted_input, T* permuted_output, int* sorted_experts, int const* expanded_dest_row_to_expanded_source_row, int* expanded_source_row_to_expanded_dest_row, int* permuted_idx, - int64_t const* expert_first_token_offset, - int64_t const* aligned_expert_first_token_offset, int64_t const num_rows, + int64_t const* expert_first_token_offset, int64_t const num_rows, int64_t const* num_valid_tokens_ptr, int64_t const cols, int const k, - int num_local_experts, const int& align_block_size, cudaStream_t stream) { + int num_local_experts, cudaStream_t stream) { int64_t const blocks = num_rows * k; int64_t const threads = 256; - using FuncPtr = decltype(&expandInputRowsKernel); - FuncPtr func_map[2][2] = { - {&expandInputRowsKernel, - &expandInputRowsKernel}, - {&expandInputRowsKernel, - &expandInputRowsKernel}, + using FuncPtr = decltype(&expandInputRowsKernel); + FuncPtr func_map[2] = { + &expandInputRowsKernel, + &expandInputRowsKernel, }; bool is_check_skip = num_valid_tokens_ptr != nullptr; - bool is_align_block_size = align_block_size != -1; - auto func = func_map[is_check_skip][is_align_block_size]; + auto func = func_map[is_check_skip]; func<<>>( unpermuted_input, permuted_output, sorted_experts, expanded_dest_row_to_expanded_source_row, expanded_source_row_to_expanded_dest_row, permuted_idx, - expert_first_token_offset, aligned_expert_first_token_offset, num_rows, - num_valid_tokens_ptr, cols, k, num_local_experts, align_block_size); + expert_first_token_offset, num_rows, num_valid_tokens_ptr, cols, k, + num_local_experts); } template diff --git a/csrc/moe/torch_bindings.cpp b/csrc/moe/torch_bindings.cpp index f8cfe058f..fd9b8945e 100644 --- a/csrc/moe/torch_bindings.cpp +++ b/csrc/moe/torch_bindings.cpp @@ -99,9 +99,9 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) { "moe_permute(Tensor input, Tensor topk_ids," "Tensor token_expert_indices, Tensor? expert_map, int n_expert," "int n_local_expert," - "int topk, int? align_block_size,Tensor! permuted_input, Tensor! " + "int topk, Tensor! permuted_input, Tensor! " "expert_first_token_offset, Tensor! inv_permuted_idx, Tensor! " - "permuted_idx, Tensor! m_indices)->()"); + "permuted_idx)->()"); m.def( "moe_unpermute(Tensor permuted_hidden_states, Tensor topk_weights," diff --git a/tests/kernels/moe/test_moe_permute_unpermute.py b/tests/kernels/moe/test_moe_permute_unpermute.py index 0b3c435aa..92126171a 100644 --- a/tests/kernels/moe/test_moe_permute_unpermute.py +++ b/tests/kernels/moe/test_moe_permute_unpermute.py @@ -40,10 +40,8 @@ def torch_permute( n_local_expert: int, start_expert: int, expert_map: torch.Tensor | None = None, - align_block_size: int | None = None, - fill_invalid_expert: int = -1, ) -> list[torch.Tensor]: - n_token, n_hidden = hidden_states.shape[0], hidden_states.shape[1] + n_token = hidden_states.shape[0] if expert_map is not None: is_local_expert = expert_map[topk_ids] != -1 not_local_expert = expert_map[topk_ids] == -1 @@ -70,107 +68,19 @@ def torch_permute( _, src2dst_idx = torch.sort(dst_row_id2src_row_id_map) valid_row_idx = [] - if align_block_size is None: - permuted_hidden_states = hidden_states[dst_row_id2src_row_id_map // topk, ...] - permuted_row_size = permuted_hidden_states.shape[0] - m_indices = torch.empty( - permuted_row_size, device="cuda", dtype=torch.int32 - ).fill_(fill_invalid_expert) - for i in range(1, n_local_expert + 1): - first_token_offset = expert_first_token_offset[i - 1] - last_token_offset = expert_first_token_offset[i] - m_indices[first_token_offset:last_token_offset] = i - 1 - src_row_id2dst_row_id_map = torch.arange( - 0, n_token * topk, device="cuda", dtype=torch.int32 - )[src2dst_idx].reshape((n_token, topk)) - valid_row_idx += [i for i in range(expert_first_token_offset[-1])] - dst_row_id2src_row_id_map[expert_first_token_offset[-1] :] = n_token * topk - return [ - permuted_hidden_states, - expert_first_token_offset, - src_row_id2dst_row_id_map, - dst_row_id2src_row_id_map, - m_indices, - valid_row_idx, - ] - else: - permuted_row_size = ( - (topk * n_token + n_expert * (align_block_size - 1) + align_block_size - 1) - // align_block_size - * align_block_size - ) - permuted_idx = torch.full( - (permuted_row_size,), - n_token * topk, - dtype=torch.int32, - device=hidden_states.device, - ) - permuted_hidden_states = torch.empty( - (permuted_row_size, n_hidden), device="cuda", dtype=hidden_states.dtype - ) - align_src_row_id2dst_row_id = torch.empty( - n_token * topk, device="cuda", dtype=torch.int32 - ) - align_expert_first_token_offset = torch.zeros_like(expert_first_token_offset) - m_indices = torch.empty( - permuted_row_size, device="cuda", dtype=torch.int32 - ).fill_(fill_invalid_expert) - # get align_permuted_hidden_states, - # valid row_idx and align_expert_first_token_offset - for i in range(1, n_local_expert + 1): - first_token_offset = expert_first_token_offset[i - 1] - last_token_offset = expert_first_token_offset[i] - n_token_in_expert = last_token_offset - first_token_offset - align_expert_first_token_offset[i] = ( - align_expert_first_token_offset[i - 1] - + (n_token_in_expert + align_block_size - 1) - // align_block_size - * align_block_size - ) - align_first_token_offset = align_expert_first_token_offset[i - 1] - align_last_token_offset = align_expert_first_token_offset[i] - dst_row_id2src_row_id_in_expert = dst_row_id2src_row_id_map[ - first_token_offset : first_token_offset + n_token_in_expert - ] - # store token in current expert with align_first_token_offset - permuted_hidden_states[ - align_first_token_offset : align_first_token_offset + n_token_in_expert, - ..., - ] = hidden_states[dst_row_id2src_row_id_in_expert // topk, ...] - permuted_idx[ - align_first_token_offset : align_first_token_offset + n_token_in_expert - ] = dst_row_id2src_row_id_in_expert - # set current expert m_indices - m_indices[align_first_token_offset:align_last_token_offset] = i - 1 - valid_row_idx += [ - i - for i in range( - align_first_token_offset, - align_first_token_offset + n_token_in_expert, - ) - ] - # get align_src_row_id2dst_row_id - for i in range(n_token * topk): - eid = sorted_topk_ids[i] - if eid >= n_local_expert: - # check token not in local expert - align_src_row_id2dst_row_id[i] = align_expert_first_token_offset[-1] - continue - first_token_offset = expert_first_token_offset[eid] - align_first_token_offset = align_expert_first_token_offset[eid] - token_offset = i - first_token_offset - align_src_row_id2dst_row_id[i] = align_first_token_offset + token_offset - align_src_row_id2dst_row_id = align_src_row_id2dst_row_id[src2dst_idx].reshape( - (n_token, topk) - ) - return [ - permuted_hidden_states, - align_expert_first_token_offset, - align_src_row_id2dst_row_id, - permuted_idx, - m_indices, - valid_row_idx, - ] + permuted_hidden_states = hidden_states[dst_row_id2src_row_id_map // topk, ...] + src_row_id2dst_row_id_map = torch.arange( + 0, n_token * topk, device="cuda", dtype=torch.int32 + )[src2dst_idx].reshape((n_token, topk)) + valid_row_idx += [i for i in range(expert_first_token_offset[-1])] + dst_row_id2src_row_id_map[expert_first_token_offset[-1] :] = n_token * topk + return [ + permuted_hidden_states, + expert_first_token_offset, + src_row_id2dst_row_id_map, + dst_row_id2src_row_id_map, + valid_row_idx, + ] def torch_unpermute( @@ -207,7 +117,6 @@ def torch_unpermute( @pytest.mark.parametrize("topk", TOP_KS) @pytest.mark.parametrize("dtype", [torch.bfloat16]) @pytest.mark.parametrize("ep_size", EP_SIZE) -@pytest.mark.parametrize("align_block_size", [None, 128]) def test_moe_permute_unpermute( n_token: int, n_hidden: int, @@ -215,11 +124,9 @@ def test_moe_permute_unpermute( n_expert: int, ep_size: int, dtype: torch.dtype, - align_block_size: int | None, ): if not moe_permute_unpermute_supported(): pytest.skip("moe_permute_unpermute is not supported on this platform.") - fill_invalid_expert = 0 ep_rank = np.random.randint(0, ep_size) expert_map = None n_local_expert = n_expert @@ -238,7 +145,6 @@ def test_moe_permute_unpermute( gold_expert_first_token_offset, gold_inv_permuted_idx, gold_permuted_idx, - gold_m_indices, valid_row_idx, ) = torch_permute( hidden_states, @@ -249,8 +155,6 @@ def test_moe_permute_unpermute( n_local_expert, start_expert, expert_map=expert_map, - align_block_size=align_block_size, - fill_invalid_expert=fill_invalid_expert, ) ( @@ -258,7 +162,7 @@ def test_moe_permute_unpermute( _, expert_first_token_offset, inv_permuted_idx, - m_indices, + _, ) = moe_permute( hidden_states=hidden_states, a1q_scale=None, @@ -266,8 +170,6 @@ def test_moe_permute_unpermute( n_expert=n_expert, n_local_expert=n_local_expert, expert_map=expert_map, - align_block_size=align_block_size, - fill_invalid_expert=fill_invalid_expert, ) # check expert_first_token_offset @@ -278,11 +180,6 @@ def test_moe_permute_unpermute( torch.testing.assert_close( gold_inv_permuted_idx.flatten(), inv_permuted_idx, atol=0, rtol=0 ) - # check mindice - # current kernel usage assumes deepgemm requires align_block_size - # when it's not provided then we don't compute m_indices (for cutlass) - if align_block_size is not None: - torch.testing.assert_close(gold_m_indices, m_indices, atol=0, rtol=0) # check permuted_hidden_states, only valid token torch.testing.assert_close( diff --git a/vllm/model_executor/layers/fused_moe/moe_permute_unpermute.py b/vllm/model_executor/layers/fused_moe/moe_permute_unpermute.py index 0c8cbd04b..de2a39295 100644 --- a/vllm/model_executor/layers/fused_moe/moe_permute_unpermute.py +++ b/vllm/model_executor/layers/fused_moe/moe_permute_unpermute.py @@ -11,8 +11,6 @@ def moe_permute( n_expert: int, n_local_expert: int = -1, expert_map: torch.Tensor | None = None, - align_block_size: int | None = None, - fill_invalid_expert: int = -1, permuted_hidden_states: torch.Tensor | None = None, ) -> tuple[torch.Tensor, torch.Tensor | None, torch.Tensor, torch.Tensor, torch.Tensor]: """ @@ -27,9 +25,6 @@ def moe_permute( - expert_map (Optional[torch.Tensor]): A tensor mapping expert indices from the global expert space to the local expert space of the expert parallel shard. - - align_block_size (Optional[int]): align group gemm block size for deepgemm - - fill_invalid_expert(int): fill expert id in m_indices for invalid expert - to workaround DeepGemm unsupported -1 in m_indices - permuted_hidden_states (Optional[torch.Tensor]): Optional output tensor. If None, the output tensor will be created in this function. Returns: @@ -37,12 +32,9 @@ def moe_permute( - a1q_scale (Optional[torch.Tensor]): permuted quant scale for hidden_states if original scale not per-tensor scaling - expert_first_token_offset (torch.Tensor): offset of the first token - of each expert for standard grouped gemm. if enable 'align_block_size' - expert_first_token_offset will align up to 'align_block_size'. + of each expert for standard grouped gemm. - inv_permuted_idx (torch.Tensor): idx map for moe_unpermute. - permuted_idx (torch.Tensor): idx map from hidden to permuted_hidden. - - m_indices: m_indices for grouped gemm in deepgemm,`m_indices[i]` records - the group which the j-th row of the LHS belong to.` """ n_token, n_hidden = hidden_states.size() topk = topk_ids.size(1) @@ -50,17 +42,6 @@ def moe_permute( "permue kernel need hidden dim align to 16B" ) permuted_row_size = n_token * topk - if align_block_size is not None: - permuted_row_size = ( - ( - permuted_row_size - + n_expert * (align_block_size - 1) - + align_block_size - - 1 - ) - // align_block_size - * align_block_size - ) if n_local_expert == -1: n_local_expert = n_expert if permuted_hidden_states is None: @@ -78,12 +59,6 @@ def moe_permute( 0, n_token * topk, dtype=torch.int32, device=hidden_states.device ).reshape((n_token, topk)) - m_indices = torch.full( - (permuted_row_size,), - fill_invalid_expert, - dtype=torch.int32, - device=hidden_states.device, - ) expert_first_token_offset = torch.empty( n_local_expert + 1, dtype=torch.int64, device=hidden_states.device ) @@ -105,12 +80,10 @@ def moe_permute( n_expert, n_local_expert, topk, - align_block_size, permuted_hidden_states, expert_first_token_offset, inv_permuted_idx, permuted_idx, - m_indices, ) if a1q_scale is not None and a1q_scale.dim() > 1: @@ -120,7 +93,7 @@ def moe_permute( a1q_scale, expert_first_token_offset, inv_permuted_idx.flatten(), - m_indices, + permuted_idx, )