FusedMoE support for the Transformers backend (#22650)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
@@ -20,7 +20,7 @@ from vllm.config.multimodal import (MMCacheType, MMEncoderTPMode,
|
||||
MultiModalConfig)
|
||||
from vllm.config.pooler import PoolerConfig
|
||||
from vllm.config.scheduler import RunnerType
|
||||
from vllm.config.utils import assert_hashable, config
|
||||
from vllm.config.utils import assert_hashable, config, getattr_iter
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.transformers_utils.config import (
|
||||
@@ -667,6 +667,8 @@ class ModelConfig:
|
||||
def _get_transformers_backend_cls(self) -> str:
|
||||
"""Determine which Transformers backend class will be used if
|
||||
`model_impl` is set to `transformers` or `auto`."""
|
||||
prefix = "Transformers"
|
||||
prefix += "MoE" if self.get_num_experts() > 1 else ""
|
||||
# Check if the architecture we're wrapping has defaults
|
||||
runner = None
|
||||
convert = None
|
||||
@@ -685,15 +687,15 @@ class ModelConfig:
|
||||
# Resolve Transformers backend pooling classes
|
||||
if runner == "pooling":
|
||||
if convert == "embed":
|
||||
return "TransformersEmbeddingModel"
|
||||
return prefix + "EmbeddingModel"
|
||||
if convert == "classify":
|
||||
return "TransformersForSequenceClassification"
|
||||
return prefix + "ForSequenceClassification"
|
||||
# Resolve Transformers backend generate classes
|
||||
if self.hf_config != self.hf_text_config:
|
||||
# If 'hf_text_config' is the same as 'hf_config'. If not, it is
|
||||
# probably a composite config, i.e. multimodal
|
||||
return "TransformersForMultimodalLM"
|
||||
return "TransformersForCausalLM"
|
||||
return prefix + "ForMultimodalLM"
|
||||
return prefix + "ForCausalLM"
|
||||
|
||||
def using_transformers_backend(self) -> bool:
|
||||
"""Check if the model is using the Transformers backend class."""
|
||||
@@ -1025,17 +1027,7 @@ class ModelConfig:
|
||||
self.enforce_eager = True
|
||||
|
||||
def _verify_with_expert_parallelism(self) -> None:
|
||||
num_expert_names = [
|
||||
"moe_num_experts", # Dbrx
|
||||
"num_experts", # Jamba
|
||||
"n_routed_experts", # DeepSeek
|
||||
"num_local_experts", # Mixtral
|
||||
]
|
||||
num_experts = 0
|
||||
for name in num_expert_names:
|
||||
num_experts = getattr(self.hf_text_config, name, 0)
|
||||
if num_experts > 0:
|
||||
break
|
||||
num_experts = self.get_num_experts()
|
||||
if num_experts < 1:
|
||||
raise ValueError(
|
||||
"Number of experts in the model must be greater than 0 "
|
||||
@@ -1220,6 +1212,21 @@ class ModelConfig:
|
||||
num_heads = getattr(self.hf_text_config, "num_attention_heads", 0)
|
||||
return num_heads // parallel_config.tensor_parallel_size
|
||||
|
||||
def get_num_experts(self) -> int:
|
||||
"""Returns the number of experts in the model."""
|
||||
num_expert_names = [
|
||||
"num_experts", # Jamba
|
||||
"moe_num_experts", # Dbrx
|
||||
"n_routed_experts", # DeepSeek
|
||||
"num_local_experts", # Mixtral
|
||||
]
|
||||
num_experts = getattr_iter(self.hf_text_config, num_expert_names, 0)
|
||||
if isinstance(num_experts, list):
|
||||
# Ernie VL's remote code uses list[int]...
|
||||
# The values are always the same so we just take the first one.
|
||||
return num_experts[0]
|
||||
return num_experts
|
||||
|
||||
def get_layers_start_end_indices(
|
||||
self, parallel_config: ParallelConfig) -> tuple[int, int]:
|
||||
from vllm.distributed.utils import get_pp_indices
|
||||
|
||||
Reference in New Issue
Block a user