diff --git a/tests/kernels/moe/test_moe.py b/tests/kernels/moe/test_moe.py index 2ef67b414..b58c42b7d 100644 --- a/tests/kernels/moe/test_moe.py +++ b/tests/kernels/moe/test_moe.py @@ -64,8 +64,10 @@ from vllm.utils.torch_utils import set_random_seed from vllm.v1.worker.workspace import init_workspace_manager NUM_EXPERTS = [8, 64, 192] +NUM_EXPERTS_LARGE = [128, 256] EP_SIZE = [1, 4] TOP_KS = [2, 6] +TOP_KS_SMALL = [1, 2] MOE_MARLIN_QUANT_TEST_CONFIGS = [ # AWQ-INT4 @@ -133,6 +135,13 @@ FUSED_MOE_MNK_FACTORS = [ (40000, 1024, 1024), ] +FUSED_MOE_MNK_FACTORS_SMALL_M = [ + (1, 128, 128), + (1, 2048, 128), + (2, 2048, 128), + (2, 2048, 511), +] + FUSED_MOE_WN16_MNK_FACTORS = [ (1, 128, 128), (1, 1024, 1024), @@ -330,6 +339,111 @@ def test_fused_moe( ) +@pytest.mark.parametrize("m,n,k", FUSED_MOE_MNK_FACTORS_SMALL_M) +@pytest.mark.parametrize("e", NUM_EXPERTS_LARGE) +@pytest.mark.parametrize("topk", TOP_KS_SMALL) +@pytest.mark.parametrize("dtype", [torch.bfloat16]) +@pytest.mark.parametrize("padding", [True, False]) +@pytest.mark.parametrize("chunk_size", [8192]) +def test_naive_block_assignment_moe( + m: int, + n: int, + k: int, + e: int, + topk: int, + dtype: torch.dtype, + padding: bool, + chunk_size: int, + monkeypatch, + workspace_init, +): + current_platform.seed_everything(7) + + monkeypatch.setenv("VLLM_FUSED_MOE_CHUNK_SIZE", str(chunk_size)) + + # + # Setup test data + # + + # + # Setup test data + # + + a = torch.randn((m, k), device="cuda", dtype=dtype) / 10 + w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10 + w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10 + + score = torch.randn((m, e), device="cuda", dtype=dtype) + + e_map = None + + # + # Setup test functions + # + quant_config = FUSED_MOE_UNQUANTIZED_CONFIG + + m_fused_moe_fn = modular_triton_fused_moe(quant_config) + + def m_fused_moe( + a: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + score: torch.Tensor, + topk: int, + global_num_experts: int = -1, + expert_map: torch.Tensor | None = None, + ) -> torch.Tensor: + topk_weights, topk_ids, _ = fused_topk(a, score, topk, False) + return m_fused_moe_fn( + a, + w1, + w2, + topk_weights, + topk_ids, + global_num_experts=global_num_experts, + expert_map=expert_map, + ) + + fused_moe_fn = functools.partial(fused_moe, renormalize=False) + + # + # Run tests + # + runner = functools.partial( + run_moe_test, + a=a, + w1=w1, + w2=w2, + score=score, + topk=topk, + global_num_experts=e, + expert_map=e_map, + padding=padding, + ) + + # Note: for now use_compile will error out if the problem size is + # large enough to trigger chunking. I'm leaving the flag and + # setup code in case we are able to revisit this later. + use_compile = False + + use_cudagraph = n >= 1024 and k >= 1024 and current_platform.is_cuda_alike() + + with set_current_vllm_config(vllm_config): + baseline_output = runner(torch_moe, iterative_moe) + runner( + baseline_output, + fused_moe_fn, + use_compile=use_compile, + use_cudagraph=use_cudagraph, + ) + runner( + baseline_output, + m_fused_moe, + use_compile=use_compile, + use_cudagraph=use_cudagraph, + ) + + @pytest.mark.parametrize("m,n,k", FUSED_MOE_WN16_MNK_FACTORS) @pytest.mark.parametrize("e", NUM_EXPERTS) @pytest.mark.parametrize("topk", TOP_KS) diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index 577d9353c..735c4aa8f 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -351,6 +351,7 @@ def fused_moe_kernel( # Block size for block-wise quantization group_n: tl.constexpr, group_k: tl.constexpr, + naive_block_assignment: tl.constexpr, # Meta-parameters BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, @@ -386,6 +387,9 @@ def fused_moe_kernel( - expert_ids: A tensor containing the indices of the expert for each block. It determines which expert matrix from B should be used for each block in A. + - naive_block_assignment: A boolean flag indicating whether to use naive + token wise block assignment. If True, each block corresponds to a + single token. This kernel performs the multiplication of a token by its corresponding expert matrix as determined by `expert_ids`. The sorting of `sorted_token_ids` by expert index and padding ensures divisibility by @@ -411,11 +415,20 @@ def fused_moe_kernel( # and accumulate # `a_ptrs` is a block of [BLOCK_SIZE_M, BLOCK_SIZE_K] pointers # `b_ptrs` is a block of [BLOCK_SIZE_K, BLOCK_SIZE_N] pointers + offs = tl.arange(0, BLOCK_SIZE_M).to(tl.int64) num_tokens_post_padded = tl.load(num_tokens_post_padded_ptr) if pid_m * BLOCK_SIZE_M >= num_tokens_post_padded: return - offs_token_id = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M).to(tl.int64) - offs_token = tl.load(sorted_token_ids_ptr + offs_token_id) + if not naive_block_assignment: + offs_token_id = pid_m * BLOCK_SIZE_M + offs + offs_token = tl.load(sorted_token_ids_ptr + offs_token_id) + else: + offs_token = tl.where( + offs == 0, + pid_m, # first element = pid_m + num_valid_tokens, # remaining elements = constant + ) + token_mask = offs_token < num_valid_tokens off_experts = tl.load(expert_ids_ptr + pid_m).to(tl.int64) @@ -557,7 +570,7 @@ def invoke_fused_moe_wna16_cuda_kernel( B_scale: torch.Tensor | None, B_zp: torch.Tensor | None, topk_weights: torch.Tensor | None, - sorted_token_ids: torch.Tensor, + sorted_token_ids: torch.Tensor | None, expert_ids: torch.Tensor, num_tokens_post_padded: torch.Tensor, mul_routed_weight: bool, @@ -705,7 +718,7 @@ def invoke_fused_moe_triton_kernel( A_scale: torch.Tensor | None, B_scale: torch.Tensor | None, topk_weights: torch.Tensor | None, - sorted_token_ids: torch.Tensor, + sorted_token_ids: torch.Tensor | None, expert_ids: torch.Tensor, num_tokens_post_padded: torch.Tensor, mul_routed_weight: bool, @@ -722,7 +735,7 @@ def invoke_fused_moe_triton_kernel( ): assert topk_weights is not None or not mul_routed_weight assert topk_weights is None or topk_weights.stride(1) == 1 - assert sorted_token_ids.stride(0) == 1 + assert sorted_token_ids is None or sorted_token_ids.stride(0) == 1 if use_fp8_w8a8 or use_int8_w8a8: assert B_scale is not None @@ -741,14 +754,18 @@ def invoke_fused_moe_triton_kernel( M = A.size(0) num_tokens = M * top_k - - EM = sorted_token_ids.size(0) - if A.size(0) < config["BLOCK_SIZE_M"]: - # optimize for small batch_size. - # We assume that top_ids of each token is unique, - # so num_valid_experts <= batch_size <= BLOCK_SIZE_M, - # and we can skip some invalid blocks. - EM = min(sorted_token_ids.size(0), A.size(0) * top_k * config["BLOCK_SIZE_M"]) + if sorted_token_ids is not None: + EM = sorted_token_ids.size(0) + if A.size(0) < config["BLOCK_SIZE_M"]: + # optimize for small batch_size. + # We assume that top_ids of each token is unique, + # so num_valid_experts <= batch_size <= BLOCK_SIZE_M, + # and we can skip some invalid blocks. + EM = min( + sorted_token_ids.size(0), A.size(0) * top_k * config["BLOCK_SIZE_M"] + ) + else: + EM = num_tokens * config["BLOCK_SIZE_M"] grid = lambda META: ( triton.cdiv(EM, META["BLOCK_SIZE_M"]) * triton.cdiv(B.size(1), META["BLOCK_SIZE_N"]), @@ -798,6 +815,7 @@ def invoke_fused_moe_triton_kernel( use_int8_w8a8=use_int8_w8a8, use_int8_w8a16=use_int8_w8a16, per_channel_quant=per_channel_quant, + naive_block_assignment=(sorted_token_ids is None), HAS_BIAS=HAS_BIAS, BLOCK_SIZE_K=BLOCK_SIZE_K, **config, @@ -812,7 +830,7 @@ def dispatch_fused_moe_kernel( B_scale: torch.Tensor | None, B_zp: torch.Tensor | None, topk_weights: torch.Tensor | None, - sorted_token_ids: torch.Tensor, + sorted_token_ids: torch.Tensor | None, expert_ids: torch.Tensor, num_tokens_post_padded: torch.Tensor, mul_routed_weight: bool, @@ -829,7 +847,7 @@ def dispatch_fused_moe_kernel( ) -> None: assert topk_weights is not None or not mul_routed_weight assert topk_weights is None or topk_weights.stride(1) == 1 - assert sorted_token_ids.stride(0) == 1 + assert sorted_token_ids is None or sorted_token_ids.stride(0) == 1 M = A.size(0) num_tokens = M * top_k @@ -2165,14 +2183,37 @@ def fused_experts_impl( block_shape=block_shape, ) - sorted_token_ids, expert_ids, num_tokens_post_padded = moe_align_block_size( - curr_topk_ids, - config["BLOCK_SIZE_M"], - global_num_experts, - expert_map, - ignore_invalid_experts=True, + # SPARSITY_FACTOR is a heuristic margin ensuring tokens_in_chunk * top_k + # activates only a small fraction of total experts + SPARSITY_FACTOR = 4 + # block quantized code path is not implemented yet. + naive_block_assignment = ( + expert_map is None + and tokens_in_chunk * top_k_num * SPARSITY_FACTOR <= global_num_experts + and not ( + (use_int8_w8a16 or use_int4_w4a16) + and block_shape is not None + and block_shape[1] > 0 + ) ) + if not naive_block_assignment: + sorted_token_ids, expert_ids, num_tokens_post_padded = moe_align_block_size( + curr_topk_ids, + config["BLOCK_SIZE_M"], + global_num_experts, + expert_map, + ignore_invalid_experts=True, + ) + else: + max_num_tokens_padded = topk_ids.numel() * config["BLOCK_SIZE_M"] + expert_ids = curr_topk_ids.view(-1) + num_tokens_post_padded = torch.empty( + (1), dtype=torch.int32, device=topk_ids.device + ) + num_tokens_post_padded.fill_(max_num_tokens_padded) + sorted_token_ids = None + dispatch_fused_moe_kernel( qcurr_hidden_states, w1,