CustomOp: test forward dispatch for grouped_topk (#31530)
Signed-off-by: Xinyu Chen <xinyu1.chen@intel.com>
This commit is contained in:
@@ -8,6 +8,12 @@ Run `pytest tests/kernels/moe/test_grouped_topk.py`.
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm.config import (
|
||||
CompilationConfig,
|
||||
VllmConfig,
|
||||
get_cached_compilation_config,
|
||||
set_current_vllm_config,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe import (
|
||||
GroupedTopk,
|
||||
fused_grouped_topk,
|
||||
@@ -41,6 +47,11 @@ def test_grouped_topk(
|
||||
routed_scaling_factor: float,
|
||||
dtype: torch.dtype,
|
||||
):
|
||||
vllm_config = VllmConfig(
|
||||
compilation_config=CompilationConfig(custom_ops=["all", "+grouped_topk"])
|
||||
)
|
||||
get_cached_compilation_config.cache_clear()
|
||||
|
||||
current_platform.seed_everything(0)
|
||||
hidden_states = torch.randn((n_token, n_hidden), dtype=dtype, device="cuda")
|
||||
gating_output = torch.randn((n_token, n_expert), dtype=dtype, device="cuda")
|
||||
@@ -48,7 +59,7 @@ def test_grouped_topk(
|
||||
(n_expert,), dtype=torch.float32, device="cuda"
|
||||
)
|
||||
|
||||
with monkeypatch.context() as m:
|
||||
with set_current_vllm_config(vllm_config), monkeypatch.context() as m:
|
||||
m.setenv("VLLM_USE_FUSED_MOE_GROUPED_TOPK", "0")
|
||||
grouped_topk = GroupedTopk(
|
||||
topk=topk,
|
||||
@@ -58,6 +69,7 @@ def test_grouped_topk(
|
||||
scoring_func=scoring_func,
|
||||
routed_scaling_factor=routed_scaling_factor,
|
||||
)
|
||||
assert grouped_topk._forward_method.__name__ == "forward_cuda"
|
||||
baseline_topk_weights, baseline_topk_ids = grouped_topk(
|
||||
hidden_states=hidden_states,
|
||||
gating_output=gating_output,
|
||||
|
||||
Reference in New Issue
Block a user