[Kernel] Support bias type in grouped_topk kernel (#31781)
Signed-off-by: Xin Yang <xyangx@amazon.com> Co-authored-by: Michael Goin <mgoin64@gmail.com>
This commit is contained in:
@@ -34,7 +34,8 @@ from vllm.utils.torch_utils import set_random_seed
|
||||
@pytest.mark.parametrize("topk_group", [2])
|
||||
@pytest.mark.parametrize("scoring_func", ["softmax", "sigmoid"])
|
||||
@pytest.mark.parametrize("routed_scaling_factor", [1.0, 2.5])
|
||||
@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float32])
|
||||
@pytest.mark.parametrize("input_dtype", [torch.bfloat16, torch.float32])
|
||||
@pytest.mark.parametrize("bias_dtype", [torch.float32])
|
||||
def test_grouped_topk(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
n_token: int,
|
||||
@@ -46,7 +47,8 @@ def test_grouped_topk(
|
||||
topk_group: int,
|
||||
scoring_func: str,
|
||||
routed_scaling_factor: float,
|
||||
dtype: torch.dtype,
|
||||
input_dtype: torch.dtype,
|
||||
bias_dtype: torch.dtype,
|
||||
):
|
||||
vllm_config = VllmConfig(
|
||||
compilation_config=CompilationConfig(custom_ops=["all", "+grouped_topk"])
|
||||
@@ -54,11 +56,9 @@ def test_grouped_topk(
|
||||
get_cached_compilation_config.cache_clear()
|
||||
|
||||
set_random_seed(0)
|
||||
hidden_states = torch.randn((n_token, n_hidden), dtype=dtype, device="cuda")
|
||||
gating_output = torch.randn((n_token, n_expert), dtype=dtype, device="cuda")
|
||||
e_score_correction_bias = torch.randn(
|
||||
(n_expert,), dtype=torch.float32, device="cuda"
|
||||
)
|
||||
hidden_states = torch.randn((n_token, n_hidden), dtype=input_dtype, device="cuda")
|
||||
gating_output = torch.randn((n_token, n_expert), dtype=input_dtype, device="cuda")
|
||||
e_score_correction_bias = torch.randn((n_expert,), dtype=bias_dtype, device="cuda")
|
||||
|
||||
with set_current_vllm_config(vllm_config), monkeypatch.context() as m:
|
||||
m.setenv("VLLM_USE_FUSED_MOE_GROUPED_TOPK", "0")
|
||||
|
||||
Reference in New Issue
Block a user