# SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """Tests for the MOE permute/unpermute kernel Run `pytest tests/kernels/test_moe_permute_unpermute.py`. """ import numpy as np import pytest import torch from vllm.model_executor.layers.fused_moe import fused_topk from vllm.model_executor.layers.fused_moe.layer import determine_expert_map from vllm.model_executor.layers.fused_moe.moe_permute_unpermute import ( moe_permute, moe_permute_unpermute_supported, moe_unpermute, ) from vllm.platforms import current_platform from vllm.utils.torch_utils import set_random_seed NUM_EXPERTS = [16, 64, 256] TOP_KS = [2, 6, 8] EP_SIZE = [1, 4, 16] set_random_seed(0) if current_platform.is_rocm(): pytest.skip( "moe_permute_unpermute_supported is not defined for ROCm", allow_module_level=True, ) def torch_permute( hidden_states: torch.Tensor, topk_ids: torch.Tensor, # token_expert_indices: torch.Tensor, topk: int, n_expert: int, n_local_expert: int, start_expert: int, expert_map: torch.Tensor | None = None, ) -> list[torch.Tensor]: n_token = hidden_states.shape[0] if expert_map is not None: is_local_expert = expert_map[topk_ids] != -1 not_local_expert = expert_map[topk_ids] == -1 topk_ids = is_local_expert * (topk_ids - start_expert) + not_local_expert * ( topk_ids + n_expert ) token_expert_indices = torch.arange( 0, n_token * topk, dtype=torch.int32, device=hidden_states.device ).reshape((n_token, topk)) sorted_topk_ids, sorted_indices = torch.sort(topk_ids.flatten(), stable=True) dst_row_id2src_row_id_map = token_expert_indices.flatten()[sorted_indices] expert_first_token_offset = torch.zeros( n_local_expert + 1, dtype=torch.int64, device="cuda" ) idx = 0 for i in range(0, n_local_expert): cnt = 0 while idx < sorted_topk_ids.numel() and sorted_topk_ids[idx] == i: cnt += 1 idx += 1 expert_first_token_offset[i + 1] = expert_first_token_offset[i] + cnt _, src2dst_idx = torch.sort(dst_row_id2src_row_id_map) valid_row_idx = [] permuted_hidden_states = hidden_states[dst_row_id2src_row_id_map // topk, ...] src_row_id2dst_row_id_map = torch.arange( 0, n_token * topk, device="cuda", dtype=torch.int32 )[src2dst_idx].reshape((n_token, topk)) valid_row_idx += [i for i in range(expert_first_token_offset[-1])] dst_row_id2src_row_id_map[expert_first_token_offset[-1] :] = n_token * topk return [ permuted_hidden_states, expert_first_token_offset, src_row_id2dst_row_id_map, dst_row_id2src_row_id_map, valid_row_idx, ] def torch_unpermute( permuted_hidden_states: torch.Tensor, topk_weights: torch.Tensor, topk_ids: torch.Tensor, token_expert_indices: torch.Tensor, src_row_id2dst_row_id_map: torch.Tensor, valid_row_idx: torch.Tensor, topk: int, n_expert: int, ) -> torch.Tensor: # ignore invalid row n_hidden = permuted_hidden_states.shape[1] mask = torch.zeros(permuted_hidden_states.shape[0], dtype=bool, device="cuda") mask[valid_row_idx] = True permuted_hidden_states[~mask] = 0 permuted_hidden_states = permuted_hidden_states[ src_row_id2dst_row_id_map.flatten(), ... ] permuted_hidden_states = permuted_hidden_states.view(-1, topk, n_hidden) output = ( (permuted_hidden_states * topk_weights.unsqueeze(2)) .sum(1) .to(permuted_hidden_states.dtype) ) return output @pytest.mark.parametrize("n_token", [1, 33, 1024, 5000]) @pytest.mark.parametrize("n_hidden", [2048, 7168]) @pytest.mark.parametrize("n_expert", NUM_EXPERTS) @pytest.mark.parametrize("topk", TOP_KS) @pytest.mark.parametrize("dtype", [torch.bfloat16]) @pytest.mark.parametrize("ep_size", EP_SIZE) def test_moe_permute_unpermute( n_token: int, n_hidden: int, topk: int, n_expert: int, ep_size: int, dtype: torch.dtype, ): if not moe_permute_unpermute_supported(): pytest.skip("moe_permute_unpermute is not supported on this platform.") ep_rank = np.random.randint(0, ep_size) expert_map = None n_local_expert = n_expert if ep_size != 1: n_local_expert, expert_map, _ = determine_expert_map(ep_size, ep_rank, n_expert) expert_map = expert_map.cuda() start_expert = n_local_expert * ep_rank set_random_seed(0) hidden_states = torch.randn((n_token, n_hidden), device="cuda").to(dtype) gating_output = torch.randn((n_token, n_expert), device="cuda").to(dtype) topk_weights, topk_ids, token_expert_indices = fused_topk( hidden_states, gating_output, topk, False ) ( gold_permuted_hidden_states, gold_expert_first_token_offset, gold_inv_permuted_idx, gold_permuted_idx, valid_row_idx, ) = torch_permute( hidden_states, topk_ids, # token_expert_indices, topk, n_expert, n_local_expert, start_expert, expert_map=expert_map, ) ( permuted_hidden_states, _, expert_first_token_offset, inv_permuted_idx, _, ) = moe_permute( hidden_states=hidden_states, a1q_scale=None, topk_ids=topk_ids, n_expert=n_expert, n_local_expert=n_local_expert, expert_map=expert_map, ) # check expert_first_token_offset torch.testing.assert_close( gold_expert_first_token_offset, expert_first_token_offset, atol=0, rtol=0 ) # check src_row_id2dst_row_id_map torch.testing.assert_close( gold_inv_permuted_idx.flatten(), inv_permuted_idx, atol=0, rtol=0 ) # check permuted_hidden_states, only valid token torch.testing.assert_close( gold_permuted_hidden_states[valid_row_idx], permuted_hidden_states[valid_row_idx], atol=0, rtol=0, ) # add a random tensor to simulate group gemm result0 = 0.5 * permuted_hidden_states + torch.randn_like(permuted_hidden_states) result4 = torch.empty_like(hidden_states) moe_unpermute( result4, result0, topk_weights, inv_permuted_idx, expert_first_token_offset ) gold4 = torch_unpermute( result0, topk_weights, topk_ids, token_expert_indices, inv_permuted_idx, valid_row_idx, topk, n_local_expert, ) # check unpermuted hidden torch.testing.assert_close(result4, gold4, atol=2e-2, rtol=0)