[Kernel] Support bias type in grouped_topk kernel (#31781)

Signed-off-by: Xin Yang <xyangx@amazon.com>
Co-authored-by: Michael Goin <mgoin64@gmail.com>
This commit is contained in:
Xin Yang
2026-01-07 12:16:32 -08:00
committed by GitHub
parent c907d22158
commit 0ada960a20
3 changed files with 104 additions and 72 deletions

View File

@@ -34,7 +34,8 @@ from vllm.utils.torch_utils import set_random_seed
@pytest.mark.parametrize("topk_group", [2])
@pytest.mark.parametrize("scoring_func", ["softmax", "sigmoid"])
@pytest.mark.parametrize("routed_scaling_factor", [1.0, 2.5])
@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float32])
@pytest.mark.parametrize("input_dtype", [torch.bfloat16, torch.float32])
@pytest.mark.parametrize("bias_dtype", [torch.float32])
def test_grouped_topk(
monkeypatch: pytest.MonkeyPatch,
n_token: int,
@@ -46,7 +47,8 @@ def test_grouped_topk(
topk_group: int,
scoring_func: str,
routed_scaling_factor: float,
dtype: torch.dtype,
input_dtype: torch.dtype,
bias_dtype: torch.dtype,
):
vllm_config = VllmConfig(
compilation_config=CompilationConfig(custom_ops=["all", "+grouped_topk"])
@@ -54,11 +56,9 @@ def test_grouped_topk(
get_cached_compilation_config.cache_clear()
set_random_seed(0)
hidden_states = torch.randn((n_token, n_hidden), dtype=dtype, device="cuda")
gating_output = torch.randn((n_token, n_expert), dtype=dtype, device="cuda")
e_score_correction_bias = torch.randn(
(n_expert,), dtype=torch.float32, device="cuda"
)
hidden_states = torch.randn((n_token, n_hidden), dtype=input_dtype, device="cuda")
gating_output = torch.randn((n_token, n_expert), dtype=input_dtype, device="cuda")
e_score_correction_bias = torch.randn((n_expert,), dtype=bias_dtype, device="cuda")
with set_current_vllm_config(vllm_config), monkeypatch.context() as m:
m.setenv("VLLM_USE_FUSED_MOE_GROUPED_TOPK", "0")