Fix RoutingMethodType logic (#33919)
Signed-off-by: Dimitrios Bariamis <12195802+dbari@users.noreply.github.com> Signed-off-by: mgoin <mgoin64@gmail.com> Co-authored-by: Dimitrios Bariamis <12195802+dbari@users.noreply.github.com> Co-authored-by: mgoin <mgoin64@gmail.com>
This commit is contained in:
committed by
GitHub
parent
ae2e93f89b
commit
207c3a0c20
@@ -582,7 +582,7 @@ RUN --mount=type=cache,target=/root/.cache/uv \
|
||||
# This is ~1.1GB and only changes when FlashInfer version bumps
|
||||
# https://docs.flashinfer.ai/installation.html
|
||||
# From versions.json: .flashinfer.version
|
||||
ARG FLASHINFER_VERSION=0.6.2
|
||||
ARG FLASHINFER_VERSION=0.6.3
|
||||
RUN --mount=type=cache,target=/root/.cache/uv \
|
||||
uv pip install --system flashinfer-cubin==${FLASHINFER_VERSION} \
|
||||
&& uv pip install --system flashinfer-jit-cache==${FLASHINFER_VERSION} \
|
||||
|
||||
@@ -217,13 +217,13 @@ RUN pip install setuptools==75.6.0 packaging==23.2 ninja==1.11.1.3 build==1.2.2.
|
||||
|
||||
|
||||
# build flashinfer for torch nightly from source around 10 mins
|
||||
# release version: v0.6.2
|
||||
# release version: v0.6.3
|
||||
# todo(elainewy): cache flashinfer build result for faster build
|
||||
ENV CCACHE_DIR=/root/.cache/ccache
|
||||
RUN --mount=type=cache,target=/root/.cache/ccache \
|
||||
--mount=type=cache,target=/root/.cache/uv \
|
||||
echo "git clone flashinfer..." \
|
||||
&& git clone --depth 1 --branch v0.6.2 --recursive https://github.com/flashinfer-ai/flashinfer.git \
|
||||
&& git clone --depth 1 --branch v0.6.3 --recursive https://github.com/flashinfer-ai/flashinfer.git \
|
||||
&& cd flashinfer \
|
||||
&& git submodule update --init --recursive \
|
||||
&& echo "finish git clone flashinfer..." \
|
||||
|
||||
@@ -68,7 +68,7 @@
|
||||
"default": "true"
|
||||
},
|
||||
"FLASHINFER_VERSION": {
|
||||
"default": "0.6.2"
|
||||
"default": "0.6.3"
|
||||
},
|
||||
"GDRCOPY_CUDA_VERSION": {
|
||||
"default": "12.8"
|
||||
|
||||
@@ -10,4 +10,4 @@ torchaudio==2.9.1
|
||||
# These must be updated alongside torch
|
||||
torchvision==0.24.1 # Required for phi3v processor. See https://github.com/pytorch/vision?tab=readme-ov-file#installation for corresponding version
|
||||
# FlashInfer should be updated together with the Dockerfile
|
||||
flashinfer-python==0.6.2
|
||||
flashinfer-python==0.6.3
|
||||
|
||||
@@ -124,6 +124,23 @@ class RoutingMethodType(IntEnum):
|
||||
Unspecified = 8.0
|
||||
|
||||
|
||||
def get_routing_method_type(
|
||||
scoring_func: str, top_k: int, renormalize: bool
|
||||
) -> RoutingMethodType:
|
||||
if scoring_func == "sigmoid":
|
||||
if top_k == 1:
|
||||
return RoutingMethodType.Llama4
|
||||
else:
|
||||
return RoutingMethodType.DeepSeekV3
|
||||
elif scoring_func == "softmax":
|
||||
if renormalize:
|
||||
return RoutingMethodType.Renormalize
|
||||
else:
|
||||
return RoutingMethodType.Default
|
||||
else:
|
||||
return RoutingMethodType.Unspecified
|
||||
|
||||
|
||||
@dataclass
|
||||
class FusedMoEQuantDesc:
|
||||
"""
|
||||
|
||||
@@ -61,6 +61,8 @@ def _supports_routing_method(
|
||||
routing_method: RoutingMethodType,
|
||||
) -> bool:
|
||||
"""Monolithic kernels need to express router support."""
|
||||
# NOTE(dbari): TopK routing could also be enabled, but need to validate models
|
||||
# NOTE(dbari): Default is not implemented and should not be enabled until it is
|
||||
if (weight_key, activation_key) == (kFp8Static128BlockSym, kFp8Dynamic128Sym):
|
||||
# NOTE(rob): potentially allow others here. This is a conservative list.
|
||||
return routing_method in [
|
||||
@@ -72,10 +74,8 @@ def _supports_routing_method(
|
||||
# NOTE(dbari): as above, potentially allow others here.
|
||||
return routing_method in [
|
||||
RoutingMethodType.Llama4,
|
||||
# NOTE(mgoin): Disabled to investigate accuracy issues.
|
||||
# See https://github.com/vllm-project/vllm/issues/33532
|
||||
# RoutingMethodType.Renormalize,
|
||||
# RoutingMethodType.RenormalizeNaive,
|
||||
RoutingMethodType.Renormalize,
|
||||
RoutingMethodType.RenormalizeNaive,
|
||||
]
|
||||
else:
|
||||
raise ValueError("Unsupported quantization scheme.")
|
||||
@@ -302,6 +302,8 @@ def fi_trtllm_fp8_per_tensor_moe(
|
||||
per_act_token_quant=False,
|
||||
)
|
||||
|
||||
from flashinfer.fused_moe.core import ActivationType
|
||||
|
||||
from vllm.utils.flashinfer import flashinfer_trtllm_fp8_per_tensor_scale_moe
|
||||
|
||||
return flashinfer_trtllm_fp8_per_tensor_scale_moe(
|
||||
@@ -323,6 +325,9 @@ def fi_trtllm_fp8_per_tensor_moe(
|
||||
routed_scaling_factor=routed_scaling_factor,
|
||||
use_routing_scales_on_input=use_routing_scales_on_input,
|
||||
routing_method_type=routing_method_type,
|
||||
# TODO: Required for flashinfer==0.6.3, remove with update
|
||||
# https://github.com/flashinfer-ai/flashinfer/pull/2508
|
||||
activation_type=ActivationType.Swiglu,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -10,7 +10,10 @@ from vllm.distributed.eplb.eplb_state import EplbLayerState
|
||||
from vllm.model_executor.layers.batch_invariant import (
|
||||
vllm_is_batch_invariant,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.config import RoutingMethodType
|
||||
from vllm.model_executor.layers.fused_moe.config import (
|
||||
RoutingMethodType,
|
||||
get_routing_method_type,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.router.base_router import BaseRouter
|
||||
|
||||
|
||||
@@ -158,10 +161,10 @@ class FusedTopKBiasRouter(BaseRouter):
|
||||
|
||||
@property
|
||||
def routing_method_type(self) -> RoutingMethodType:
|
||||
return (
|
||||
RoutingMethodType.Renormalize
|
||||
if not self.renormalize
|
||||
else RoutingMethodType.RenormalizeNaive
|
||||
return get_routing_method_type(
|
||||
scoring_func=self.scoring_func,
|
||||
top_k=self.top_k,
|
||||
renormalize=self.renormalize,
|
||||
)
|
||||
|
||||
def _compute_routing(
|
||||
|
||||
@@ -7,7 +7,10 @@ import torch
|
||||
import vllm._custom_ops as ops
|
||||
from vllm._aiter_ops import rocm_aiter_ops
|
||||
from vllm.distributed.eplb.eplb_state import EplbLayerState
|
||||
from vllm.model_executor.layers.fused_moe.config import RoutingMethodType
|
||||
from vllm.model_executor.layers.fused_moe.config import (
|
||||
RoutingMethodType,
|
||||
get_routing_method_type,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.router.base_router import BaseRouter
|
||||
|
||||
|
||||
@@ -135,10 +138,10 @@ class FusedTopKRouter(BaseRouter):
|
||||
|
||||
@property
|
||||
def routing_method_type(self) -> RoutingMethodType:
|
||||
return (
|
||||
RoutingMethodType.Renormalize
|
||||
if not self.renormalize
|
||||
else RoutingMethodType.RenormalizeNaive
|
||||
return get_routing_method_type(
|
||||
scoring_func=self.scoring_func,
|
||||
top_k=self.top_k,
|
||||
renormalize=self.renormalize,
|
||||
)
|
||||
|
||||
def _compute_routing(
|
||||
|
||||
@@ -6,6 +6,7 @@ import torch
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.distributed.eplb.eplb_state import EplbLayerState
|
||||
from vllm.model_executor.layers.fused_moe.config import RoutingMethodType
|
||||
from vllm.model_executor.layers.fused_moe.router.custom_routing_router import (
|
||||
CustomRoutingRouter,
|
||||
)
|
||||
@@ -106,7 +107,7 @@ def create_fused_moe_router(
|
||||
"num_expert_group and topk_group must be provided when "
|
||||
"use_grouped_topk is True"
|
||||
)
|
||||
return GroupedTopKRouter(
|
||||
grouped_topk_router = GroupedTopKRouter(
|
||||
top_k=top_k,
|
||||
global_num_experts=global_num_experts,
|
||||
eplb_state=eplb_state,
|
||||
@@ -120,6 +121,18 @@ def create_fused_moe_router(
|
||||
enable_eplb=enable_eplb,
|
||||
indices_type_getter=indices_type_getter,
|
||||
)
|
||||
if (
|
||||
grouped_topk_router.routing_method_type != RoutingMethodType.Unspecified
|
||||
or num_expert_group > 1
|
||||
or topk_group > 1
|
||||
):
|
||||
return grouped_topk_router
|
||||
|
||||
# If routing_method for GroupedTopKRouter is Unspecified and there is only
|
||||
# one group, fallback to standard top-k routing
|
||||
use_grouped_topk = False
|
||||
num_expert_group = None
|
||||
topk_group = None
|
||||
|
||||
if custom_routing_function is not None:
|
||||
return CustomRoutingRouter(
|
||||
|
||||
Reference in New Issue
Block a user