Fix RoutingMethodType logic (#33919)
Signed-off-by: Dimitrios Bariamis <12195802+dbari@users.noreply.github.com> Signed-off-by: mgoin <mgoin64@gmail.com> Co-authored-by: Dimitrios Bariamis <12195802+dbari@users.noreply.github.com> Co-authored-by: mgoin <mgoin64@gmail.com>
This commit is contained in:
committed by
GitHub
parent
ae2e93f89b
commit
207c3a0c20
@@ -124,6 +124,23 @@ class RoutingMethodType(IntEnum):
|
||||
Unspecified = 8.0
|
||||
|
||||
|
||||
def get_routing_method_type(
|
||||
scoring_func: str, top_k: int, renormalize: bool
|
||||
) -> RoutingMethodType:
|
||||
if scoring_func == "sigmoid":
|
||||
if top_k == 1:
|
||||
return RoutingMethodType.Llama4
|
||||
else:
|
||||
return RoutingMethodType.DeepSeekV3
|
||||
elif scoring_func == "softmax":
|
||||
if renormalize:
|
||||
return RoutingMethodType.Renormalize
|
||||
else:
|
||||
return RoutingMethodType.Default
|
||||
else:
|
||||
return RoutingMethodType.Unspecified
|
||||
|
||||
|
||||
@dataclass
|
||||
class FusedMoEQuantDesc:
|
||||
"""
|
||||
|
||||
@@ -61,6 +61,8 @@ def _supports_routing_method(
|
||||
routing_method: RoutingMethodType,
|
||||
) -> bool:
|
||||
"""Monolithic kernels need to express router support."""
|
||||
# NOTE(dbari): TopK routing could also be enabled, but need to validate models
|
||||
# NOTE(dbari): Default is not implemented and should not be enabled until it is
|
||||
if (weight_key, activation_key) == (kFp8Static128BlockSym, kFp8Dynamic128Sym):
|
||||
# NOTE(rob): potentially allow others here. This is a conservative list.
|
||||
return routing_method in [
|
||||
@@ -72,10 +74,8 @@ def _supports_routing_method(
|
||||
# NOTE(dbari): as above, potentially allow others here.
|
||||
return routing_method in [
|
||||
RoutingMethodType.Llama4,
|
||||
# NOTE(mgoin): Disabled to investigate accuracy issues.
|
||||
# See https://github.com/vllm-project/vllm/issues/33532
|
||||
# RoutingMethodType.Renormalize,
|
||||
# RoutingMethodType.RenormalizeNaive,
|
||||
RoutingMethodType.Renormalize,
|
||||
RoutingMethodType.RenormalizeNaive,
|
||||
]
|
||||
else:
|
||||
raise ValueError("Unsupported quantization scheme.")
|
||||
@@ -302,6 +302,8 @@ def fi_trtllm_fp8_per_tensor_moe(
|
||||
per_act_token_quant=False,
|
||||
)
|
||||
|
||||
from flashinfer.fused_moe.core import ActivationType
|
||||
|
||||
from vllm.utils.flashinfer import flashinfer_trtllm_fp8_per_tensor_scale_moe
|
||||
|
||||
return flashinfer_trtllm_fp8_per_tensor_scale_moe(
|
||||
@@ -323,6 +325,9 @@ def fi_trtllm_fp8_per_tensor_moe(
|
||||
routed_scaling_factor=routed_scaling_factor,
|
||||
use_routing_scales_on_input=use_routing_scales_on_input,
|
||||
routing_method_type=routing_method_type,
|
||||
# TODO: Required for flashinfer==0.6.3, remove with update
|
||||
# https://github.com/flashinfer-ai/flashinfer/pull/2508
|
||||
activation_type=ActivationType.Swiglu,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -10,7 +10,10 @@ from vllm.distributed.eplb.eplb_state import EplbLayerState
|
||||
from vllm.model_executor.layers.batch_invariant import (
|
||||
vllm_is_batch_invariant,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.config import RoutingMethodType
|
||||
from vllm.model_executor.layers.fused_moe.config import (
|
||||
RoutingMethodType,
|
||||
get_routing_method_type,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.router.base_router import BaseRouter
|
||||
|
||||
|
||||
@@ -158,10 +161,10 @@ class FusedTopKBiasRouter(BaseRouter):
|
||||
|
||||
@property
|
||||
def routing_method_type(self) -> RoutingMethodType:
|
||||
return (
|
||||
RoutingMethodType.Renormalize
|
||||
if not self.renormalize
|
||||
else RoutingMethodType.RenormalizeNaive
|
||||
return get_routing_method_type(
|
||||
scoring_func=self.scoring_func,
|
||||
top_k=self.top_k,
|
||||
renormalize=self.renormalize,
|
||||
)
|
||||
|
||||
def _compute_routing(
|
||||
|
||||
@@ -7,7 +7,10 @@ import torch
|
||||
import vllm._custom_ops as ops
|
||||
from vllm._aiter_ops import rocm_aiter_ops
|
||||
from vllm.distributed.eplb.eplb_state import EplbLayerState
|
||||
from vllm.model_executor.layers.fused_moe.config import RoutingMethodType
|
||||
from vllm.model_executor.layers.fused_moe.config import (
|
||||
RoutingMethodType,
|
||||
get_routing_method_type,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.router.base_router import BaseRouter
|
||||
|
||||
|
||||
@@ -135,10 +138,10 @@ class FusedTopKRouter(BaseRouter):
|
||||
|
||||
@property
|
||||
def routing_method_type(self) -> RoutingMethodType:
|
||||
return (
|
||||
RoutingMethodType.Renormalize
|
||||
if not self.renormalize
|
||||
else RoutingMethodType.RenormalizeNaive
|
||||
return get_routing_method_type(
|
||||
scoring_func=self.scoring_func,
|
||||
top_k=self.top_k,
|
||||
renormalize=self.renormalize,
|
||||
)
|
||||
|
||||
def _compute_routing(
|
||||
|
||||
@@ -6,6 +6,7 @@ import torch
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.distributed.eplb.eplb_state import EplbLayerState
|
||||
from vllm.model_executor.layers.fused_moe.config import RoutingMethodType
|
||||
from vllm.model_executor.layers.fused_moe.router.custom_routing_router import (
|
||||
CustomRoutingRouter,
|
||||
)
|
||||
@@ -106,7 +107,7 @@ def create_fused_moe_router(
|
||||
"num_expert_group and topk_group must be provided when "
|
||||
"use_grouped_topk is True"
|
||||
)
|
||||
return GroupedTopKRouter(
|
||||
grouped_topk_router = GroupedTopKRouter(
|
||||
top_k=top_k,
|
||||
global_num_experts=global_num_experts,
|
||||
eplb_state=eplb_state,
|
||||
@@ -120,6 +121,18 @@ def create_fused_moe_router(
|
||||
enable_eplb=enable_eplb,
|
||||
indices_type_getter=indices_type_getter,
|
||||
)
|
||||
if (
|
||||
grouped_topk_router.routing_method_type != RoutingMethodType.Unspecified
|
||||
or num_expert_group > 1
|
||||
or topk_group > 1
|
||||
):
|
||||
return grouped_topk_router
|
||||
|
||||
# If routing_method for GroupedTopKRouter is Unspecified and there is only
|
||||
# one group, fallback to standard top-k routing
|
||||
use_grouped_topk = False
|
||||
num_expert_group = None
|
||||
topk_group = None
|
||||
|
||||
if custom_routing_function is not None:
|
||||
return CustomRoutingRouter(
|
||||
|
||||
Reference in New Issue
Block a user