[Kernel] Optimize grouped topk kernel (#34206)
Signed-off-by: Xin Yang <xyangx@amazon.com>
This commit is contained in:
@@ -8,6 +8,7 @@ Run `pytest tests/kernels/moe/test_grouped_topk.py`.
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
import vllm.model_executor.layers.batch_invariant as batch_invariant
|
||||
from vllm.config import (
|
||||
CompilationConfig,
|
||||
VllmConfig,
|
||||
@@ -27,11 +28,17 @@ from vllm.utils.torch_utils import set_random_seed
|
||||
)
|
||||
@pytest.mark.parametrize("n_token", [1, 33, 64])
|
||||
@pytest.mark.parametrize("n_hidden", [1024, 2048])
|
||||
@pytest.mark.parametrize("n_expert", [16])
|
||||
@pytest.mark.parametrize("topk", [2])
|
||||
@pytest.mark.parametrize(
|
||||
"n_expert,topk,num_expert_group,topk_group",
|
||||
[
|
||||
(16, 2, 8, 2),
|
||||
(128, 2, 8, 2),
|
||||
(256, 8, 8, 4),
|
||||
(384, 8, 1, 1),
|
||||
(512, 22, 1, 1),
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize("renormalize", [True, False])
|
||||
@pytest.mark.parametrize("num_expert_group", [8])
|
||||
@pytest.mark.parametrize("topk_group", [2])
|
||||
@pytest.mark.parametrize("scoring_func", ["softmax", "sigmoid"])
|
||||
@pytest.mark.parametrize("routed_scaling_factor", [1.0, 2.5])
|
||||
@pytest.mark.parametrize("input_dtype", [torch.bfloat16, torch.float32])
|
||||
@@ -42,9 +49,9 @@ def test_grouped_topk(
|
||||
n_hidden: int,
|
||||
n_expert: int,
|
||||
topk: int,
|
||||
renormalize: bool,
|
||||
num_expert_group: int,
|
||||
topk_group: int,
|
||||
renormalize: bool,
|
||||
scoring_func: str,
|
||||
routed_scaling_factor: float,
|
||||
input_dtype: torch.dtype,
|
||||
@@ -62,6 +69,7 @@ def test_grouped_topk(
|
||||
|
||||
with set_current_vllm_config(vllm_config), monkeypatch.context() as m:
|
||||
m.setenv("VLLM_USE_FUSED_MOE_GROUPED_TOPK", "0")
|
||||
m.setattr(batch_invariant, "VLLM_BATCH_INVARIANT", True)
|
||||
grouped_topk = GroupedTopk(
|
||||
topk=topk,
|
||||
renormalize=renormalize,
|
||||
@@ -89,8 +97,7 @@ def test_grouped_topk(
|
||||
e_score_correction_bias=e_score_correction_bias,
|
||||
)
|
||||
|
||||
if renormalize:
|
||||
torch.testing.assert_close(
|
||||
baseline_topk_weights, test_topk_weights, atol=2e-2, rtol=0
|
||||
)
|
||||
torch.testing.assert_close(
|
||||
baseline_topk_weights, test_topk_weights, atol=2e-2, rtol=0
|
||||
)
|
||||
torch.testing.assert_close(baseline_topk_ids, test_topk_ids, atol=0, rtol=0)
|
||||
|
||||
Reference in New Issue
Block a user