CustomOp: grouped topk (#29575)

Signed-off-by: Xinyu Chen <xinyu1.chen@intel.com>
This commit is contained in:
Xinyu Chen
2025-12-17 17:43:00 +08:00
committed by GitHub
parent a9e15c21ef
commit 3b1d440ede
4 changed files with 75 additions and 14 deletions

View File

@@ -9,8 +9,8 @@ import pytest
import torch
from vllm.model_executor.layers.fused_moe.fused_moe import (
GroupedTopk,
fused_grouped_topk,
grouped_topk,
)
from vllm.platforms import current_platform
@@ -50,15 +50,17 @@ def test_grouped_topk(
with monkeypatch.context() as m:
m.setenv("VLLM_USE_FUSED_MOE_GROUPED_TOPK", "0")
baseline_topk_weights, baseline_topk_ids = grouped_topk(
hidden_states=hidden_states,
gating_output=gating_output,
grouped_topk = GroupedTopk(
topk=topk,
renormalize=renormalize,
num_expert_group=num_expert_group,
topk_group=topk_group,
scoring_func=scoring_func,
routed_scaling_factor=routed_scaling_factor,
)
baseline_topk_weights, baseline_topk_ids = grouped_topk(
hidden_states=hidden_states,
gating_output=gating_output,
e_score_correction_bias=e_score_correction_bias,
)