FusedMoE support for the Transformers backend (#22650)

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
Harry Mellor
2025-10-03 07:12:15 +01:00
committed by GitHub
parent 39b643dc1a
commit 10d765482d
10 changed files with 485 additions and 91 deletions

View File

@@ -20,7 +20,7 @@ from vllm.config.multimodal import (MMCacheType, MMEncoderTPMode,
MultiModalConfig)
from vllm.config.pooler import PoolerConfig
from vllm.config.scheduler import RunnerType
from vllm.config.utils import assert_hashable, config
from vllm.config.utils import assert_hashable, config, getattr_iter
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.transformers_utils.config import (
@@ -667,6 +667,8 @@ class ModelConfig:
def _get_transformers_backend_cls(self) -> str:
"""Determine which Transformers backend class will be used if
`model_impl` is set to `transformers` or `auto`."""
prefix = "Transformers"
prefix += "MoE" if self.get_num_experts() > 1 else ""
# Check if the architecture we're wrapping has defaults
runner = None
convert = None
@@ -685,15 +687,15 @@ class ModelConfig:
# Resolve Transformers backend pooling classes
if runner == "pooling":
if convert == "embed":
return "TransformersEmbeddingModel"
return prefix + "EmbeddingModel"
if convert == "classify":
return "TransformersForSequenceClassification"
return prefix + "ForSequenceClassification"
# Resolve Transformers backend generate classes
if self.hf_config != self.hf_text_config:
# If 'hf_text_config' is the same as 'hf_config'. If not, it is
# probably a composite config, i.e. multimodal
return "TransformersForMultimodalLM"
return "TransformersForCausalLM"
return prefix + "ForMultimodalLM"
return prefix + "ForCausalLM"
def using_transformers_backend(self) -> bool:
"""Check if the model is using the Transformers backend class."""
@@ -1025,17 +1027,7 @@ class ModelConfig:
self.enforce_eager = True
def _verify_with_expert_parallelism(self) -> None:
num_expert_names = [
"moe_num_experts", # Dbrx
"num_experts", # Jamba
"n_routed_experts", # DeepSeek
"num_local_experts", # Mixtral
]
num_experts = 0
for name in num_expert_names:
num_experts = getattr(self.hf_text_config, name, 0)
if num_experts > 0:
break
num_experts = self.get_num_experts()
if num_experts < 1:
raise ValueError(
"Number of experts in the model must be greater than 0 "
@@ -1220,6 +1212,21 @@ class ModelConfig:
num_heads = getattr(self.hf_text_config, "num_attention_heads", 0)
return num_heads // parallel_config.tensor_parallel_size
def get_num_experts(self) -> int:
"""Returns the number of experts in the model."""
num_expert_names = [
"num_experts", # Jamba
"moe_num_experts", # Dbrx
"n_routed_experts", # DeepSeek
"num_local_experts", # Mixtral
]
num_experts = getattr_iter(self.hf_text_config, num_expert_names, 0)
if isinstance(num_experts, list):
# Ernie VL's remote code uses list[int]...
# The values are always the same so we just take the first one.
return num_experts[0]
return num_experts
def get_layers_start_end_indices(
self, parallel_config: ParallelConfig) -> tuple[int, int]:
from vllm.distributed.utils import get_pp_indices