[Kernel] Add topk_sigmoid kernel (#31246)
Signed-off-by: Xin Yang <xyangx@amazon.com>
This commit is contained in:
@@ -18,7 +18,9 @@ from vllm.model_executor.layers.activation import (
|
||||
SiluAndMul,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.router.fused_topk_router import (
|
||||
dispatch_topk_func,
|
||||
dispatch_topk_sigmoid_func,
|
||||
dispatch_topk_softmax_func,
|
||||
vllm_topk_sigmoid,
|
||||
vllm_topk_softmax,
|
||||
)
|
||||
from vllm.model_executor.layers.layernorm import (
|
||||
@@ -133,8 +135,8 @@ def test_enabled_ops_invalid(env: str):
|
||||
@pytest.mark.parametrize(
|
||||
"use_rocm_aiter", [True, False] if current_platform.is_rocm() else [False]
|
||||
)
|
||||
def test_topk_dispatch(use_rocm_aiter: bool):
|
||||
topk_func = dispatch_topk_func(use_rocm_aiter)
|
||||
def test_topk_softmax_dispatch(use_rocm_aiter: bool):
|
||||
topk_func = dispatch_topk_softmax_func(use_rocm_aiter)
|
||||
|
||||
if current_platform.is_rocm() and use_rocm_aiter:
|
||||
assert topk_func == rocm_aiter_ops.topk_softmax
|
||||
@@ -142,6 +144,18 @@ def test_topk_dispatch(use_rocm_aiter: bool):
|
||||
assert topk_func == vllm_topk_softmax
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"use_rocm_aiter", [True, False] if current_platform.is_rocm() else [False]
|
||||
)
|
||||
def test_topk_sigmoid_dispatch(use_rocm_aiter: bool):
|
||||
topk_func = dispatch_topk_sigmoid_func(use_rocm_aiter)
|
||||
|
||||
if current_platform.is_rocm() and use_rocm_aiter:
|
||||
assert topk_func == rocm_aiter_ops.topk_sigmoid
|
||||
else:
|
||||
assert topk_func == vllm_topk_sigmoid
|
||||
|
||||
|
||||
@pytest.mark.parametrize("add_residual", [True, False])
|
||||
@pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16])
|
||||
@pytest.mark.parametrize("use_rocm_aiter", [True, False])
|
||||
|
||||
Reference in New Issue
Block a user