[Kernel] Optimize grouped topk kernel (#34206)

Signed-off-by: Xin Yang <xyangx@amazon.com>
This commit is contained in:
Xin Yang
2026-02-20 01:34:45 -08:00
committed by GitHub
parent 8de7c636cc
commit b1c4f0b265
3 changed files with 642 additions and 99 deletions

View File

@@ -8,6 +8,7 @@ Run `pytest tests/kernels/moe/test_grouped_topk.py`.
import pytest
import torch
import vllm.model_executor.layers.batch_invariant as batch_invariant
from vllm.config import (
CompilationConfig,
VllmConfig,
@@ -27,11 +28,17 @@ from vllm.utils.torch_utils import set_random_seed
)
@pytest.mark.parametrize("n_token", [1, 33, 64])
@pytest.mark.parametrize("n_hidden", [1024, 2048])
@pytest.mark.parametrize("n_expert", [16])
@pytest.mark.parametrize("topk", [2])
@pytest.mark.parametrize(
"n_expert,topk,num_expert_group,topk_group",
[
(16, 2, 8, 2),
(128, 2, 8, 2),
(256, 8, 8, 4),
(384, 8, 1, 1),
(512, 22, 1, 1),
],
)
@pytest.mark.parametrize("renormalize", [True, False])
@pytest.mark.parametrize("num_expert_group", [8])
@pytest.mark.parametrize("topk_group", [2])
@pytest.mark.parametrize("scoring_func", ["softmax", "sigmoid"])
@pytest.mark.parametrize("routed_scaling_factor", [1.0, 2.5])
@pytest.mark.parametrize("input_dtype", [torch.bfloat16, torch.float32])
@@ -42,9 +49,9 @@ def test_grouped_topk(
n_hidden: int,
n_expert: int,
topk: int,
renormalize: bool,
num_expert_group: int,
topk_group: int,
renormalize: bool,
scoring_func: str,
routed_scaling_factor: float,
input_dtype: torch.dtype,
@@ -62,6 +69,7 @@ def test_grouped_topk(
with set_current_vllm_config(vllm_config), monkeypatch.context() as m:
m.setenv("VLLM_USE_FUSED_MOE_GROUPED_TOPK", "0")
m.setattr(batch_invariant, "VLLM_BATCH_INVARIANT", True)
grouped_topk = GroupedTopk(
topk=topk,
renormalize=renormalize,
@@ -89,8 +97,7 @@ def test_grouped_topk(
e_score_correction_bias=e_score_correction_bias,
)
if renormalize:
torch.testing.assert_close(
baseline_topk_weights, test_topk_weights, atol=2e-2, rtol=0
)
torch.testing.assert_close(
baseline_topk_weights, test_topk_weights, atol=2e-2, rtol=0
)
torch.testing.assert_close(baseline_topk_ids, test_topk_ids, atol=0, rtol=0)