[MoE Refactor] Make SharedExperts class for use with DefaultMoERunner (#35153)
Signed-off-by: Bill Nell <bnell@redhat.com>
This commit is contained in:
@@ -603,7 +603,6 @@ def make_shared_experts(
|
|||||||
def modular_triton_fused_moe(
|
def modular_triton_fused_moe(
|
||||||
moe_config: FusedMoEConfig,
|
moe_config: FusedMoEConfig,
|
||||||
quant_config: FusedMoEQuantConfig,
|
quant_config: FusedMoEQuantConfig,
|
||||||
shared_experts: torch.nn.Module | None = None,
|
|
||||||
) -> FusedMoEKernel:
|
) -> FusedMoEKernel:
|
||||||
return FusedMoEKernel(
|
return FusedMoEKernel(
|
||||||
maybe_make_prepare_finalize(
|
maybe_make_prepare_finalize(
|
||||||
@@ -613,6 +612,5 @@ def modular_triton_fused_moe(
|
|||||||
use_monolithic=False,
|
use_monolithic=False,
|
||||||
),
|
),
|
||||||
TritonExperts(moe_config, quant_config),
|
TritonExperts(moe_config, quant_config),
|
||||||
shared_experts,
|
|
||||||
inplace=False,
|
inplace=False,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -103,6 +103,7 @@ pushd "$WORKSPACE"
|
|||||||
echo "Downloading NVSHMEM ${NVSHMEM_VER} for ${NVSHMEM_SUBDIR} ..."
|
echo "Downloading NVSHMEM ${NVSHMEM_VER} for ${NVSHMEM_SUBDIR} ..."
|
||||||
curl -fSL "${NVSHMEM_URL}" -o "${NVSHMEM_FILE}"
|
curl -fSL "${NVSHMEM_URL}" -o "${NVSHMEM_FILE}"
|
||||||
tar -xf "${NVSHMEM_FILE}"
|
tar -xf "${NVSHMEM_FILE}"
|
||||||
|
rm -rf nvshmem
|
||||||
mv "${NVSHMEM_FILE%.tar.xz}" nvshmem
|
mv "${NVSHMEM_FILE%.tar.xz}" nvshmem
|
||||||
rm -f "${NVSHMEM_FILE}"
|
rm -f "${NVSHMEM_FILE}"
|
||||||
rm -rf nvshmem/lib/bin nvshmem/lib/share
|
rm -rf nvshmem/lib/bin nvshmem/lib/share
|
||||||
|
|||||||
@@ -410,8 +410,7 @@ class ElasticEPScalingExecutor:
|
|||||||
# for the new EP size by resetting quant_method to base
|
# for the new EP size by resetting quant_method to base
|
||||||
for module in moe_modules:
|
for module in moe_modules:
|
||||||
if hasattr(module.quant_method, "old_quant_method"):
|
if hasattr(module.quant_method, "old_quant_method"):
|
||||||
module.quant_method = module.quant_method.old_quant_method
|
module._replace_quant_method(module.quant_method.old_quant_method)
|
||||||
module.runner = module._init_runner()
|
|
||||||
prepare_communication_buffer_for_model(self.worker.model_runner.model)
|
prepare_communication_buffer_for_model(self.worker.model_runner.model)
|
||||||
|
|
||||||
eplb_model_state.communicator = create_eplb_communicator(
|
eplb_model_state.communicator = create_eplb_communicator(
|
||||||
|
|||||||
@@ -595,10 +595,6 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA):
|
|||||||
def maybe_all_reduce_tensor_model_parallel(self, *args, **kwargs):
|
def maybe_all_reduce_tensor_model_parallel(self, *args, **kwargs):
|
||||||
return self.base_layer.maybe_all_reduce_tensor_model_parallel(*args, **kwargs)
|
return self.base_layer.maybe_all_reduce_tensor_model_parallel(*args, **kwargs)
|
||||||
|
|
||||||
@property
|
|
||||||
def _shared_experts(self):
|
|
||||||
return self.base_layer._shared_experts
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def quant_method(self):
|
def quant_method(self):
|
||||||
return self.base_layer.quant_method
|
return self.base_layer.quant_method
|
||||||
|
|||||||
@@ -937,6 +937,15 @@ class FusedMoEParallelConfig:
|
|||||||
all2all_backend: str # all2all backend for MoE communication
|
all2all_backend: str # all2all backend for MoE communication
|
||||||
enable_eplb: bool # whether to enable expert load balancing
|
enable_eplb: bool # whether to enable expert load balancing
|
||||||
|
|
||||||
|
@property
|
||||||
|
def use_dp_chunking(self) -> bool:
|
||||||
|
return (
|
||||||
|
self.use_deepep_ll_kernels
|
||||||
|
or self.use_mori_kernels
|
||||||
|
or self.use_fi_nvl_two_sided_kernels
|
||||||
|
or self.use_nixl_ep_kernels
|
||||||
|
) and envs.VLLM_ENABLE_MOE_DP_CHUNK
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def is_sequence_parallel(self) -> bool:
|
def is_sequence_parallel(self) -> bool:
|
||||||
return self.sp_size > 1
|
return self.sp_size > 1
|
||||||
|
|||||||
@@ -1194,6 +1194,8 @@ def cutlass_moe_w4a8_fp8(
|
|||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
group_size=group_size,
|
group_size=group_size,
|
||||||
),
|
),
|
||||||
|
shared_experts=None,
|
||||||
|
inplace=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
return fn.apply(
|
return fn.apply(
|
||||||
|
|||||||
@@ -53,6 +53,7 @@ class TrtLlmFp8ExpertsBase:
|
|||||||
self.local_num_experts = moe_config.num_local_experts
|
self.local_num_experts = moe_config.num_local_experts
|
||||||
self.ep_rank = moe_config.moe_parallel_config.ep_rank
|
self.ep_rank = moe_config.moe_parallel_config.ep_rank
|
||||||
|
|
||||||
|
self.moe_config = moe_config
|
||||||
self.quant_config = quant_config
|
self.quant_config = quant_config
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
|||||||
@@ -40,9 +40,7 @@ class FusedMoEMethodBase(QuantizeMethodBase):
|
|||||||
def mk_owns_shared_expert(self) -> bool:
|
def mk_owns_shared_expert(self) -> bool:
|
||||||
# NOTE(rob): temporary attribute to indicate support for
|
# NOTE(rob): temporary attribute to indicate support for
|
||||||
# completed migration to the new internal MK interface.
|
# completed migration to the new internal MK interface.
|
||||||
return (
|
return self.moe_kernel is not None and self.moe_kernel.owns_shared_experts
|
||||||
self.moe_kernel is not None and self.moe_kernel.shared_experts is not None
|
|
||||||
)
|
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def create_weights(
|
def create_weights(
|
||||||
@@ -163,7 +161,7 @@ class FusedMoEMethodBase(QuantizeMethodBase):
|
|||||||
topk_weights: torch.Tensor,
|
topk_weights: torch.Tensor,
|
||||||
topk_ids: torch.Tensor,
|
topk_ids: torch.Tensor,
|
||||||
shared_experts_input: torch.Tensor | None,
|
shared_experts_input: torch.Tensor | None,
|
||||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
) -> torch.Tensor:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def apply_monolithic(
|
def apply_monolithic(
|
||||||
@@ -171,5 +169,5 @@ class FusedMoEMethodBase(QuantizeMethodBase):
|
|||||||
layer: "FusedMoE", # type: ignore[name-defined] # noqa: F821
|
layer: "FusedMoE", # type: ignore[name-defined] # noqa: F821
|
||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
router_logits: torch.Tensor,
|
router_logits: torch.Tensor,
|
||||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
) -> torch.Tensor:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|||||||
@@ -16,6 +16,9 @@ from vllm.model_executor.layers.fused_moe.modular_kernel import (
|
|||||||
FusedMoEKernel,
|
FusedMoEKernel,
|
||||||
FusedMoEPrepareAndFinalizeModular,
|
FusedMoEPrepareAndFinalizeModular,
|
||||||
)
|
)
|
||||||
|
from vllm.model_executor.layers.fused_moe.runner.shared_experts import (
|
||||||
|
SharedExperts,
|
||||||
|
)
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
@@ -44,7 +47,7 @@ class FusedMoEModularMethod(FusedMoEMethodBase, CustomOp):
|
|||||||
moe_layer: torch.nn.Module,
|
moe_layer: torch.nn.Module,
|
||||||
old_quant_method: FusedMoEMethodBase,
|
old_quant_method: FusedMoEMethodBase,
|
||||||
prepare_finalize: FusedMoEPrepareAndFinalizeModular,
|
prepare_finalize: FusedMoEPrepareAndFinalizeModular,
|
||||||
shared_experts: torch.nn.Module | None,
|
shared_experts: SharedExperts | None,
|
||||||
inplace: bool = False,
|
inplace: bool = False,
|
||||||
) -> "FusedMoEModularMethod":
|
) -> "FusedMoEModularMethod":
|
||||||
return FusedMoEModularMethod(
|
return FusedMoEModularMethod(
|
||||||
@@ -52,8 +55,7 @@ class FusedMoEModularMethod(FusedMoEMethodBase, CustomOp):
|
|||||||
FusedMoEKernel(
|
FusedMoEKernel(
|
||||||
prepare_finalize,
|
prepare_finalize,
|
||||||
old_quant_method.select_gemm_impl(prepare_finalize, moe_layer),
|
old_quant_method.select_gemm_impl(prepare_finalize, moe_layer),
|
||||||
shared_experts,
|
shared_experts=shared_experts,
|
||||||
moe_parallel_config=moe_layer.moe_parallel_config,
|
|
||||||
inplace=inplace,
|
inplace=inplace,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
@@ -89,7 +91,7 @@ class FusedMoEModularMethod(FusedMoEMethodBase, CustomOp):
|
|||||||
topk_weights: torch.Tensor,
|
topk_weights: torch.Tensor,
|
||||||
topk_ids: torch.Tensor,
|
topk_ids: torch.Tensor,
|
||||||
shared_experts_input: torch.Tensor | None,
|
shared_experts_input: torch.Tensor | None,
|
||||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
) -> torch.Tensor:
|
||||||
assert self.moe_kernel is not None
|
assert self.moe_kernel is not None
|
||||||
return self.moe_kernel.apply(
|
return self.moe_kernel.apply(
|
||||||
hidden_states=x,
|
hidden_states=x,
|
||||||
|
|||||||
@@ -42,6 +42,9 @@ from vllm.model_executor.layers.fused_moe.router.router_factory import (
|
|||||||
from vllm.model_executor.layers.fused_moe.runner.default_moe_runner import (
|
from vllm.model_executor.layers.fused_moe.runner.default_moe_runner import (
|
||||||
DefaultMoERunner,
|
DefaultMoERunner,
|
||||||
)
|
)
|
||||||
|
from vllm.model_executor.layers.fused_moe.runner.shared_experts import (
|
||||||
|
SharedExperts,
|
||||||
|
)
|
||||||
from vllm.model_executor.layers.fused_moe.unquantized_fused_moe_method import (
|
from vllm.model_executor.layers.fused_moe.unquantized_fused_moe_method import (
|
||||||
UnquantizedFusedMoEMethod,
|
UnquantizedFusedMoEMethod,
|
||||||
)
|
)
|
||||||
@@ -275,8 +278,6 @@ class FusedMoE(CustomOp):
|
|||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self._gate = gate
|
|
||||||
self._shared_experts = shared_experts
|
|
||||||
self._routed_input_transform = routed_input_transform
|
self._routed_input_transform = routed_input_transform
|
||||||
|
|
||||||
if params_dtype is None:
|
if params_dtype is None:
|
||||||
@@ -486,7 +487,7 @@ class FusedMoE(CustomOp):
|
|||||||
device=vllm_config.device_config.device,
|
device=vllm_config.device_config.device,
|
||||||
routing_method=self.routing_method_type,
|
routing_method=self.routing_method_type,
|
||||||
# TODO: in_dtype == out_dtype?
|
# TODO: in_dtype == out_dtype?
|
||||||
disable_inplace=disable_inplace() or self._shared_experts is not None,
|
disable_inplace=disable_inplace() or shared_experts is not None,
|
||||||
)
|
)
|
||||||
if self.moe_config.use_mori_kernels:
|
if self.moe_config.use_mori_kernels:
|
||||||
assert self.rocm_aiter_fmoe_enabled, (
|
assert self.rocm_aiter_fmoe_enabled, (
|
||||||
@@ -564,34 +565,20 @@ class FusedMoE(CustomOp):
|
|||||||
moe_quant_params["intermediate_size_full"] = intermediate_size
|
moe_quant_params["intermediate_size_full"] = intermediate_size
|
||||||
|
|
||||||
self.quant_method.create_weights(layer=self, **moe_quant_params)
|
self.quant_method.create_weights(layer=self, **moe_quant_params)
|
||||||
|
|
||||||
|
# TODO(bnell): this is un-needed and removed in a follow up PR.
|
||||||
self.base_quant_method = self.quant_method
|
self.base_quant_method = self.quant_method
|
||||||
|
|
||||||
# Disable shared expert overlap if:
|
|
||||||
# - we are using eplb with non-default backend, because of correctness issues
|
|
||||||
# - we are using flashinfer with DP, since there nothing to gain
|
|
||||||
# - we are using marlin kernels
|
|
||||||
backend = self.moe_parallel_config.all2all_backend
|
|
||||||
self.use_overlapped = (
|
|
||||||
not (
|
|
||||||
(self.enable_eplb and backend != "allgather_reducescatter")
|
|
||||||
or self.moe_parallel_config.use_fi_nvl_two_sided_kernels
|
|
||||||
)
|
|
||||||
and self._shared_experts is not None
|
|
||||||
)
|
|
||||||
|
|
||||||
self.runner = self._init_runner()
|
|
||||||
|
|
||||||
def _init_runner(self):
|
|
||||||
# Storing the runner in the FusedMoE is an intermediate state, eventually
|
# Storing the runner in the FusedMoE is an intermediate state, eventually
|
||||||
# the runner will own the FusedMoE layer and provide the execution interface
|
# the runner will own the FusedMoE layer and provide the execution interface
|
||||||
# for MoE ops.
|
# for MoE ops.
|
||||||
return DefaultMoERunner(
|
self.runner = DefaultMoERunner(
|
||||||
layer=self,
|
layer=self,
|
||||||
moe_config=self.moe_config,
|
moe_config=self.moe_config,
|
||||||
router=self.router,
|
router=self.router,
|
||||||
routed_input_transform=self._routed_input_transform,
|
routed_input_transform=self._routed_input_transform,
|
||||||
gate=self.gate,
|
gate=gate,
|
||||||
shared_experts=self.shared_experts,
|
shared_experts=shared_experts,
|
||||||
quant_method=self.quant_method,
|
quant_method=self.quant_method,
|
||||||
reduce_results=self.reduce_results,
|
reduce_results=self.reduce_results,
|
||||||
enable_dbo=self.vllm_config.parallel_config.enable_dbo,
|
enable_dbo=self.vllm_config.parallel_config.enable_dbo,
|
||||||
@@ -602,10 +589,7 @@ class FusedMoE(CustomOp):
|
|||||||
# intrusive way to do this.
|
# intrusive way to do this.
|
||||||
def _replace_quant_method(self, mk: FusedMoEMethodBase):
|
def _replace_quant_method(self, mk: FusedMoEMethodBase):
|
||||||
self.quant_method = mk
|
self.quant_method = mk
|
||||||
# We need to force reconstruction of runner because we're swapping out
|
self.runner._replace_quant_method(mk)
|
||||||
# the quant_method with a FusedMoEModularMethod. This logic can go
|
|
||||||
# away once the FusedMoEModularMethod is eliminated.
|
|
||||||
self.runner = self._init_runner()
|
|
||||||
|
|
||||||
# Note: maybe_init_modular_kernel should only be called by
|
# Note: maybe_init_modular_kernel should only be called by
|
||||||
# prepare_communication_buffer_for_model.
|
# prepare_communication_buffer_for_model.
|
||||||
@@ -639,8 +623,8 @@ class FusedMoE(CustomOp):
|
|||||||
)
|
)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def shared_experts(self) -> torch.nn.Module | None:
|
def shared_experts(self) -> SharedExperts | None:
|
||||||
return self._shared_experts if self.use_overlapped else None
|
return self.runner.shared_experts
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def layer_id(self):
|
def layer_id(self):
|
||||||
@@ -649,10 +633,6 @@ class FusedMoE(CustomOp):
|
|||||||
|
|
||||||
return extract_layer_index(self.layer_name)
|
return extract_layer_index(self.layer_name)
|
||||||
|
|
||||||
@property
|
|
||||||
def gate(self) -> torch.nn.Module | None:
|
|
||||||
return self._gate if self.use_overlapped else None
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def tp_size(self):
|
def tp_size(self):
|
||||||
return self.moe_parallel_config.tp_size
|
return self.moe_parallel_config.tp_size
|
||||||
@@ -676,7 +656,7 @@ class FusedMoE(CustomOp):
|
|||||||
@property
|
@property
|
||||||
def is_internal_router(self) -> bool:
|
def is_internal_router(self) -> bool:
|
||||||
# By default, router/gate is called before FusedMoE forward pass
|
# By default, router/gate is called before FusedMoE forward pass
|
||||||
return self.gate is not None
|
return self.runner.is_internal_router()
|
||||||
|
|
||||||
def _maybe_init_expert_routing_tables(
|
def _maybe_init_expert_routing_tables(
|
||||||
self,
|
self,
|
||||||
@@ -1467,7 +1447,12 @@ class FusedMoE(CustomOp):
|
|||||||
assert all(
|
assert all(
|
||||||
weight.is_contiguous()
|
weight.is_contiguous()
|
||||||
for name, weight in weights
|
for name, weight in weights
|
||||||
if not (name.startswith("_shared_experts.") or name.startswith("_gate."))
|
if not (
|
||||||
|
name.startswith("_shared_experts.")
|
||||||
|
or name.startswith("_gate.")
|
||||||
|
or name.startswith("_routed_input_transform.")
|
||||||
|
or name.startswith("_routed_output_transform.")
|
||||||
|
)
|
||||||
and name not in NON_EXPERT_WEIGHTS
|
and name not in NON_EXPERT_WEIGHTS
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -1477,8 +1462,11 @@ class FusedMoE(CustomOp):
|
|||||||
if name not in NON_EXPERT_WEIGHTS
|
if name not in NON_EXPERT_WEIGHTS
|
||||||
and weight.shape != torch.Size([])
|
and weight.shape != torch.Size([])
|
||||||
and not name.startswith("_shared_experts.")
|
and not name.startswith("_shared_experts.")
|
||||||
# exclude parameters from non-expert submodules (e.g. gate/shared)
|
# exclude parameters from non-expert submodules,
|
||||||
|
# e.g. gate/shared/transforms.
|
||||||
and not name.startswith("_gate.")
|
and not name.startswith("_gate.")
|
||||||
|
and not name.startswith("_routed_input_transform.")
|
||||||
|
and not name.startswith("_routed_output_transform.")
|
||||||
]
|
]
|
||||||
|
|
||||||
def set_eplb_state(
|
def set_eplb_state(
|
||||||
|
|||||||
@@ -21,6 +21,10 @@ from vllm.model_executor.layers.fused_moe.config import (
|
|||||||
FusedMoEQuantConfig,
|
FusedMoEQuantConfig,
|
||||||
RoutingMethodType,
|
RoutingMethodType,
|
||||||
)
|
)
|
||||||
|
from vllm.model_executor.layers.fused_moe.runner.shared_experts import (
|
||||||
|
SharedExperts,
|
||||||
|
SharedExpertsOrder,
|
||||||
|
)
|
||||||
from vllm.model_executor.layers.fused_moe.utils import (
|
from vllm.model_executor.layers.fused_moe.utils import (
|
||||||
_resize_cache,
|
_resize_cache,
|
||||||
disable_inplace,
|
disable_inplace,
|
||||||
@@ -235,6 +239,13 @@ class FusedMoEPrepareAndFinalize(ABC):
|
|||||||
"""
|
"""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def supports_async(self) -> bool:
|
||||||
|
"""
|
||||||
|
Indicates whether or not this class implements prepare_async and
|
||||||
|
finalize_async.
|
||||||
|
"""
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
# TODO: pass FusedMoEParallelConfig in as ctor parameter?
|
# TODO: pass FusedMoEParallelConfig in as ctor parameter?
|
||||||
class FusedMoEPrepareAndFinalizeModular(FusedMoEPrepareAndFinalize):
|
class FusedMoEPrepareAndFinalizeModular(FusedMoEPrepareAndFinalize):
|
||||||
@@ -281,13 +292,6 @@ class FusedMoEPrepareAndFinalizeModular(FusedMoEPrepareAndFinalize):
|
|||||||
"""
|
"""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def supports_async(self) -> bool:
|
|
||||||
"""
|
|
||||||
Indicates whether or not this class implements prepare_async and
|
|
||||||
finalize_async.
|
|
||||||
"""
|
|
||||||
return False
|
|
||||||
|
|
||||||
def prepare_async(
|
def prepare_async(
|
||||||
self,
|
self,
|
||||||
a1: torch.Tensor,
|
a1: torch.Tensor,
|
||||||
@@ -1003,15 +1007,20 @@ class FusedMoEKernelModularImpl:
|
|||||||
self,
|
self,
|
||||||
prepare_finalize: FusedMoEPrepareAndFinalizeModular,
|
prepare_finalize: FusedMoEPrepareAndFinalizeModular,
|
||||||
fused_experts: FusedMoEExpertsModular,
|
fused_experts: FusedMoEExpertsModular,
|
||||||
shared_experts: torch.nn.Module | None = None,
|
shared_experts: SharedExperts | None,
|
||||||
moe_parallel_config: FusedMoEParallelConfig | None = None,
|
|
||||||
inplace: bool = False,
|
inplace: bool = False,
|
||||||
):
|
):
|
||||||
self.prepare_finalize = prepare_finalize
|
self.prepare_finalize = prepare_finalize
|
||||||
self.fused_experts = fused_experts
|
self.fused_experts = fused_experts
|
||||||
self.shared_experts = shared_experts
|
# Only accept shared experts if they can be run w/async.
|
||||||
self.moe_parallel_config = moe_parallel_config
|
# The MoERunner/SharedExperts class will coordinate with the MK to ensure
|
||||||
|
# that the SharedExperts are executed only once.
|
||||||
|
self.shared_experts = (
|
||||||
|
shared_experts if prepare_finalize.supports_async() else None
|
||||||
|
)
|
||||||
self.inplace = inplace
|
self.inplace = inplace
|
||||||
|
moe_parallel_config = fused_experts.moe_config.moe_parallel_config
|
||||||
|
self.moe_parallel_config = moe_parallel_config
|
||||||
self.is_dp_ep = (
|
self.is_dp_ep = (
|
||||||
moe_parallel_config is not None
|
moe_parallel_config is not None
|
||||||
and moe_parallel_config.dp_size > 1
|
and moe_parallel_config.dp_size > 1
|
||||||
@@ -1081,6 +1090,17 @@ class FusedMoEKernelModularImpl:
|
|||||||
|
|
||||||
return workspace13, workspace2, fused_out
|
return workspace13, workspace2, fused_out
|
||||||
|
|
||||||
|
def _maybe_apply_shared_experts(
|
||||||
|
self,
|
||||||
|
shared_experts_input: torch.Tensor | None,
|
||||||
|
):
|
||||||
|
if self.shared_experts is not None:
|
||||||
|
assert shared_experts_input is not None
|
||||||
|
self.shared_experts.apply(
|
||||||
|
shared_experts_input,
|
||||||
|
SharedExpertsOrder.MK_INTERNAL_OVERLAPPED,
|
||||||
|
)
|
||||||
|
|
||||||
def _prepare(
|
def _prepare(
|
||||||
self,
|
self,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
@@ -1253,15 +1273,6 @@ class FusedMoEKernelModularImpl:
|
|||||||
shared_experts_input is the original hidden_states (full
|
shared_experts_input is the original hidden_states (full
|
||||||
dimension) needed by the shared expert MLP.
|
dimension) needed by the shared expert MLP.
|
||||||
"""
|
"""
|
||||||
shared_output: torch.Tensor | None = None
|
|
||||||
|
|
||||||
# For latent MoE: shared experts need the original hidden_states
|
|
||||||
# (full hidden_size), not the latent-projected version used by
|
|
||||||
# routed experts.
|
|
||||||
se_hidden_states = (
|
|
||||||
shared_experts_input if shared_experts_input is not None else hidden_states
|
|
||||||
)
|
|
||||||
|
|
||||||
if not self.prepare_finalize.supports_async():
|
if not self.prepare_finalize.supports_async():
|
||||||
assert not dbo_enabled()
|
assert not dbo_enabled()
|
||||||
|
|
||||||
@@ -1273,8 +1284,6 @@ class FusedMoEKernelModularImpl:
|
|||||||
apply_router_weight_on_input,
|
apply_router_weight_on_input,
|
||||||
self.fused_experts.finalize_weight_and_reduce_impl(),
|
self.fused_experts.finalize_weight_and_reduce_impl(),
|
||||||
)
|
)
|
||||||
if self.shared_experts is not None:
|
|
||||||
shared_output = self.shared_experts(se_hidden_states)
|
|
||||||
else:
|
else:
|
||||||
finalize_ret = self.prepare_finalize.finalize_async(
|
finalize_ret = self.prepare_finalize.finalize_async(
|
||||||
output,
|
output,
|
||||||
@@ -1284,8 +1293,7 @@ class FusedMoEKernelModularImpl:
|
|||||||
apply_router_weight_on_input,
|
apply_router_weight_on_input,
|
||||||
self.fused_experts.finalize_weight_and_reduce_impl(),
|
self.fused_experts.finalize_weight_and_reduce_impl(),
|
||||||
)
|
)
|
||||||
if self.shared_experts is not None:
|
self._maybe_apply_shared_experts(shared_experts_input)
|
||||||
shared_output = self.shared_experts(se_hidden_states)
|
|
||||||
|
|
||||||
# TODO(lucas): refactor this in the alternative schedules followup
|
# TODO(lucas): refactor this in the alternative schedules followup
|
||||||
# currently unpack if we have hook + receiver pair or just
|
# currently unpack if we have hook + receiver pair or just
|
||||||
@@ -1308,11 +1316,7 @@ class FusedMoEKernelModularImpl:
|
|||||||
|
|
||||||
receiver()
|
receiver()
|
||||||
|
|
||||||
if self.shared_experts is None:
|
return output
|
||||||
return output
|
|
||||||
else:
|
|
||||||
assert shared_output is not None
|
|
||||||
return shared_output, output
|
|
||||||
|
|
||||||
def apply(
|
def apply(
|
||||||
self,
|
self,
|
||||||
@@ -1326,7 +1330,7 @@ class FusedMoEKernelModularImpl:
|
|||||||
expert_map: torch.Tensor | None = None,
|
expert_map: torch.Tensor | None = None,
|
||||||
apply_router_weight_on_input: bool = False,
|
apply_router_weight_on_input: bool = False,
|
||||||
shared_experts_input: torch.Tensor | None = None,
|
shared_experts_input: torch.Tensor | None = None,
|
||||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
This function computes a Mixture of Experts (MoE) layer using two sets
|
This function computes a Mixture of Experts (MoE) layer using two sets
|
||||||
of weights, w1 and w2, and top-k gating mechanism.
|
of weights, w1 and w2, and top-k gating mechanism.
|
||||||
@@ -1469,12 +1473,10 @@ class FusedMoEKernel:
|
|||||||
self,
|
self,
|
||||||
prepare_finalize: FusedMoEPrepareAndFinalize,
|
prepare_finalize: FusedMoEPrepareAndFinalize,
|
||||||
fused_experts: FusedMoEExperts,
|
fused_experts: FusedMoEExperts,
|
||||||
shared_experts: torch.nn.Module | None = None,
|
shared_experts: SharedExperts | None = None,
|
||||||
moe_parallel_config: FusedMoEParallelConfig | None = None,
|
|
||||||
inplace: bool = False,
|
inplace: bool = False,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.shared_experts = shared_experts # NOTE: check if we can remove
|
|
||||||
|
|
||||||
# Initialize the implementation (monolithic or modular).
|
# Initialize the implementation (monolithic or modular).
|
||||||
self.impl: FusedMoEKernelModularImpl | FusedMoEKernelMonolithicImpl
|
self.impl: FusedMoEKernelModularImpl | FusedMoEKernelMonolithicImpl
|
||||||
@@ -1485,14 +1487,12 @@ class FusedMoEKernel:
|
|||||||
prepare_finalize,
|
prepare_finalize,
|
||||||
fused_experts,
|
fused_experts,
|
||||||
shared_experts,
|
shared_experts,
|
||||||
moe_parallel_config,
|
|
||||||
inplace,
|
inplace,
|
||||||
)
|
)
|
||||||
|
|
||||||
elif isinstance(
|
elif isinstance(
|
||||||
prepare_finalize, FusedMoEPrepareAndFinalizeMonolithic
|
prepare_finalize, FusedMoEPrepareAndFinalizeMonolithic
|
||||||
) and isinstance(fused_experts, FusedMoEExpertsMonolithic):
|
) and isinstance(fused_experts, FusedMoEExpertsMonolithic):
|
||||||
assert shared_experts is None
|
|
||||||
assert not inplace
|
assert not inplace
|
||||||
self.impl = FusedMoEKernelMonolithicImpl(
|
self.impl = FusedMoEKernelMonolithicImpl(
|
||||||
prepare_finalize,
|
prepare_finalize,
|
||||||
@@ -1508,6 +1508,13 @@ class FusedMoEKernel:
|
|||||||
|
|
||||||
self._post_init_setup()
|
self._post_init_setup()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def owns_shared_experts(self) -> bool:
|
||||||
|
if isinstance(self.impl, FusedMoEKernelModularImpl):
|
||||||
|
return self.impl.shared_experts is not None
|
||||||
|
else:
|
||||||
|
return False
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def is_monolithic(self) -> bool:
|
def is_monolithic(self) -> bool:
|
||||||
return isinstance(self.impl, FusedMoEKernelMonolithicImpl)
|
return isinstance(self.impl, FusedMoEKernelMonolithicImpl)
|
||||||
|
|||||||
@@ -18,6 +18,9 @@ from vllm.model_executor.layers.fused_moe.config import (
|
|||||||
fp8_w8a8_moe_quant_config,
|
fp8_w8a8_moe_quant_config,
|
||||||
fp8_w8a16_moe_quant_config,
|
fp8_w8a16_moe_quant_config,
|
||||||
)
|
)
|
||||||
|
from vllm.model_executor.layers.fused_moe.runner.shared_experts import (
|
||||||
|
SharedExperts,
|
||||||
|
)
|
||||||
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
|
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
|
||||||
FlashinferMoeBackend,
|
FlashinferMoeBackend,
|
||||||
get_flashinfer_moe_backend,
|
get_flashinfer_moe_backend,
|
||||||
@@ -545,7 +548,7 @@ def make_fp8_moe_kernel(
|
|||||||
experts_cls: type[mk.FusedMoEExperts],
|
experts_cls: type[mk.FusedMoEExperts],
|
||||||
fp8_backend: Fp8MoeBackend,
|
fp8_backend: Fp8MoeBackend,
|
||||||
routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None,
|
routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None,
|
||||||
shared_experts: torch.nn.Module | None = None,
|
shared_experts: SharedExperts | None = None,
|
||||||
) -> mk.FusedMoEKernel:
|
) -> mk.FusedMoEKernel:
|
||||||
# Create Prepare/Finalize.
|
# Create Prepare/Finalize.
|
||||||
prepare_finalize = maybe_make_prepare_finalize(
|
prepare_finalize = maybe_make_prepare_finalize(
|
||||||
@@ -581,12 +584,7 @@ def make_fp8_moe_kernel(
|
|||||||
kernel = mk.FusedMoEKernel(
|
kernel = mk.FusedMoEKernel(
|
||||||
prepare_finalize,
|
prepare_finalize,
|
||||||
experts,
|
experts,
|
||||||
shared_experts=(
|
shared_experts=shared_experts,
|
||||||
shared_experts
|
|
||||||
if moe_config.moe_parallel_config.use_deepep_ll_kernels
|
|
||||||
else None
|
|
||||||
),
|
|
||||||
moe_parallel_config=moe_config.moe_parallel_config,
|
|
||||||
inplace=(
|
inplace=(
|
||||||
not moe_config.disable_inplace
|
not moe_config.disable_inplace
|
||||||
and fp8_backend != Fp8MoeBackend.FLASHINFER_CUTLASS
|
and fp8_backend != Fp8MoeBackend.FLASHINFER_CUTLASS
|
||||||
|
|||||||
@@ -859,7 +859,6 @@ def make_mxfp4_moe_kernel(
|
|||||||
if moe_config.moe_parallel_config.use_deepep_ll_kernels
|
if moe_config.moe_parallel_config.use_deepep_ll_kernels
|
||||||
else None
|
else None
|
||||||
),
|
),
|
||||||
moe_parallel_config=moe_config.moe_parallel_config,
|
|
||||||
inplace=(
|
inplace=(
|
||||||
not moe_config.disable_inplace and mxfp4_backend not in TRTLLM_BACKENDS
|
not moe_config.disable_inplace and mxfp4_backend not in TRTLLM_BACKENDS
|
||||||
),
|
),
|
||||||
|
|||||||
@@ -17,6 +17,9 @@ from vllm.model_executor.layers.fused_moe.config import (
|
|||||||
nvfp4_moe_quant_config,
|
nvfp4_moe_quant_config,
|
||||||
nvfp4_w4a16_moe_quant_config,
|
nvfp4_w4a16_moe_quant_config,
|
||||||
)
|
)
|
||||||
|
from vllm.model_executor.layers.fused_moe.runner.shared_experts import (
|
||||||
|
SharedExperts,
|
||||||
|
)
|
||||||
from vllm.model_executor.layers.quantization.utils.flashinfer_fp4_moe import (
|
from vllm.model_executor.layers.quantization.utils.flashinfer_fp4_moe import (
|
||||||
prepare_nvfp4_moe_layer_for_fi_or_cutlass,
|
prepare_nvfp4_moe_layer_for_fi_or_cutlass,
|
||||||
)
|
)
|
||||||
@@ -386,7 +389,7 @@ def make_nvfp4_moe_kernel(
|
|||||||
moe_config: FusedMoEConfig,
|
moe_config: FusedMoEConfig,
|
||||||
experts_cls: type[mk.FusedMoEExperts],
|
experts_cls: type[mk.FusedMoEExperts],
|
||||||
routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None,
|
routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None,
|
||||||
shared_experts: torch.nn.Module | None = None,
|
shared_experts: SharedExperts | None = None,
|
||||||
) -> mk.FusedMoEKernel:
|
) -> mk.FusedMoEKernel:
|
||||||
# Create Prepare/Finalize.
|
# Create Prepare/Finalize.
|
||||||
prepare_finalize = maybe_make_prepare_finalize(
|
prepare_finalize = maybe_make_prepare_finalize(
|
||||||
@@ -422,12 +425,7 @@ def make_nvfp4_moe_kernel(
|
|||||||
kernel = mk.FusedMoEKernel(
|
kernel = mk.FusedMoEKernel(
|
||||||
prepare_finalize,
|
prepare_finalize,
|
||||||
experts,
|
experts,
|
||||||
shared_experts=(
|
shared_experts=shared_experts,
|
||||||
shared_experts
|
|
||||||
if moe_config.moe_parallel_config.use_deepep_ll_kernels
|
|
||||||
else None
|
|
||||||
),
|
|
||||||
moe_parallel_config=moe_config.moe_parallel_config,
|
|
||||||
inplace=False,
|
inplace=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -18,6 +18,9 @@ from vllm.model_executor.layers.fused_moe.config import (
|
|||||||
FusedMoEConfig,
|
FusedMoEConfig,
|
||||||
FusedMoEQuantConfig,
|
FusedMoEQuantConfig,
|
||||||
)
|
)
|
||||||
|
from vllm.model_executor.layers.fused_moe.runner.shared_experts import (
|
||||||
|
SharedExperts,
|
||||||
|
)
|
||||||
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
|
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
|
||||||
FlashinferMoeBackend,
|
FlashinferMoeBackend,
|
||||||
convert_moe_weights_to_flashinfer_trtllm_block_layout,
|
convert_moe_weights_to_flashinfer_trtllm_block_layout,
|
||||||
@@ -321,7 +324,7 @@ def make_unquantized_moe_kernel(
|
|||||||
backend: UnquantizedMoeBackend,
|
backend: UnquantizedMoeBackend,
|
||||||
experts_cls: type[mk.FusedMoEExperts],
|
experts_cls: type[mk.FusedMoEExperts],
|
||||||
routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None,
|
routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None,
|
||||||
shared_experts: torch.nn.Module | None = None,
|
shared_experts: SharedExperts | None = None,
|
||||||
) -> mk.FusedMoEKernel:
|
) -> mk.FusedMoEKernel:
|
||||||
# Create Prepare/Finalize
|
# Create Prepare/Finalize
|
||||||
is_monolithic = issubclass(experts_cls, mk.FusedMoEExpertsMonolithic)
|
is_monolithic = issubclass(experts_cls, mk.FusedMoEExpertsMonolithic)
|
||||||
@@ -355,12 +358,7 @@ def make_unquantized_moe_kernel(
|
|||||||
kernel = mk.FusedMoEKernel(
|
kernel = mk.FusedMoEKernel(
|
||||||
prepare_finalize,
|
prepare_finalize,
|
||||||
experts,
|
experts,
|
||||||
shared_experts=(
|
shared_experts=shared_experts,
|
||||||
shared_experts
|
|
||||||
if moe_config.moe_parallel_config.use_deepep_ll_kernels
|
|
||||||
else None
|
|
||||||
),
|
|
||||||
moe_parallel_config=moe_config.moe_parallel_config,
|
|
||||||
inplace=(not moe_config.disable_inplace and not is_monolithic),
|
inplace=(not moe_config.disable_inplace and not is_monolithic),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -325,7 +325,7 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalizeModular):
|
|||||||
**(dict(use_nvfp4=True) if use_nvfp4 else dict()),
|
**(dict(use_nvfp4=True) if use_nvfp4 else dict()),
|
||||||
**(
|
**(
|
||||||
dict(x_global_scale=qc_a1_gscale_or_scale)
|
dict(x_global_scale=qc_a1_gscale_or_scale)
|
||||||
if qc_a1_gscale_or_scale is not None
|
if qc_a1_gscale_or_scale is not None and nvfp4_dispatch
|
||||||
else dict()
|
else dict()
|
||||||
),
|
),
|
||||||
async_finish=False,
|
async_finish=False,
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
|
from collections.abc import Callable
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
@@ -13,6 +14,13 @@ class FusedMoERouter(ABC):
|
|||||||
method that is used for routing hidden states based on router logits.
|
method that is used for routing hidden states based on router logits.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def set_capture_fn(
|
||||||
|
self,
|
||||||
|
capture_fn: Callable[[torch.Tensor], None] | None,
|
||||||
|
) -> None:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def routing_method_type(self) -> RoutingMethodType:
|
def routing_method_type(self) -> RoutingMethodType:
|
||||||
|
|||||||
@@ -7,7 +7,6 @@ from typing import TYPE_CHECKING
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
|
||||||
import vllm.envs as envs
|
|
||||||
from vllm.distributed import (
|
from vllm.distributed import (
|
||||||
get_ep_group,
|
get_ep_group,
|
||||||
get_pcp_group,
|
get_pcp_group,
|
||||||
@@ -29,13 +28,15 @@ from vllm.model_executor.layers.fused_moe.router.fused_moe_router import (
|
|||||||
FusedMoERouter,
|
FusedMoERouter,
|
||||||
)
|
)
|
||||||
from vllm.model_executor.layers.fused_moe.runner.moe_runner import MoERunner
|
from vllm.model_executor.layers.fused_moe.runner.moe_runner import MoERunner
|
||||||
|
from vllm.model_executor.layers.fused_moe.runner.shared_experts import (
|
||||||
|
SharedExperts,
|
||||||
|
SharedExpertsOrder,
|
||||||
|
)
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
from vllm.utils.math_utils import cdiv
|
from vllm.utils.math_utils import cdiv
|
||||||
from vllm.utils.torch_utils import (
|
from vllm.utils.torch_utils import (
|
||||||
HAS_OPAQUE_TYPE,
|
HAS_OPAQUE_TYPE,
|
||||||
ModuleName,
|
ModuleName,
|
||||||
aux_stream,
|
|
||||||
current_stream,
|
|
||||||
direct_register_custom_op,
|
direct_register_custom_op,
|
||||||
)
|
)
|
||||||
from vllm.v1.worker.ubatching import dbo_current_ubatch_id
|
from vllm.v1.worker.ubatching import dbo_current_ubatch_id
|
||||||
@@ -74,6 +75,9 @@ def _resolve_layer_name(layer_name: str | ModuleName) -> str:
|
|||||||
return layer_name.value if isinstance(layer_name, ModuleName) else layer_name
|
return layer_name.value if isinstance(layer_name, ModuleName) else layer_name
|
||||||
|
|
||||||
|
|
||||||
|
# Note: _moe_forward and _moe_forward_shared should not contain any
|
||||||
|
# implementation details, They should merely pass along control to
|
||||||
|
# the runner's 'forward_dispatch' method.
|
||||||
def _moe_forward(
|
def _moe_forward(
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
router_logits: torch.Tensor,
|
router_logits: torch.Tensor,
|
||||||
@@ -81,24 +85,12 @@ def _moe_forward(
|
|||||||
layer_name: _layer_name_type,
|
layer_name: _layer_name_type,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
layer = get_layer_from_name(_resolve_layer_name(layer_name))
|
layer = get_layer_from_name(_resolve_layer_name(layer_name))
|
||||||
# TODO(bnell): this can be removed after MK migration is complete.
|
return layer.runner.forward_dispatch(
|
||||||
layer.ensure_moe_quant_config_init()
|
layer,
|
||||||
runner = layer.runner
|
hidden_states,
|
||||||
with runner._sequence_parallel_context():
|
router_logits,
|
||||||
if runner.use_dp_chunking:
|
shared_experts_input,
|
||||||
return runner.forward_impl_chunked(
|
)
|
||||||
layer,
|
|
||||||
hidden_states,
|
|
||||||
router_logits,
|
|
||||||
shared_experts_input,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
return runner.forward_impl(
|
|
||||||
layer,
|
|
||||||
hidden_states,
|
|
||||||
router_logits,
|
|
||||||
shared_experts_input,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def _moe_forward_fake(
|
def _moe_forward_fake(
|
||||||
@@ -117,24 +109,12 @@ def _moe_forward_shared(
|
|||||||
layer_name: _layer_name_type,
|
layer_name: _layer_name_type,
|
||||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||||
layer = get_layer_from_name(_resolve_layer_name(layer_name))
|
layer = get_layer_from_name(_resolve_layer_name(layer_name))
|
||||||
# TODO(bnell): this can be removed after MK migration is complete.
|
return layer.runner.forward_dispatch(
|
||||||
layer.ensure_moe_quant_config_init()
|
layer,
|
||||||
runner = layer.runner
|
hidden_states,
|
||||||
with runner._sequence_parallel_context():
|
router_logits,
|
||||||
if runner.use_dp_chunking:
|
shared_experts_input,
|
||||||
return runner.forward_impl_chunked(
|
)
|
||||||
layer,
|
|
||||||
hidden_states,
|
|
||||||
router_logits,
|
|
||||||
shared_experts_input,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
return runner.forward_impl(
|
|
||||||
layer,
|
|
||||||
hidden_states,
|
|
||||||
router_logits,
|
|
||||||
shared_experts_input,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def _moe_forward_shared_fake(
|
def _moe_forward_shared_fake(
|
||||||
@@ -159,7 +139,7 @@ def _moe_forward_shared_fake(
|
|||||||
direct_register_custom_op(
|
direct_register_custom_op(
|
||||||
op_name="moe_forward",
|
op_name="moe_forward",
|
||||||
op_func=_moe_forward,
|
op_func=_moe_forward,
|
||||||
mutates_args=["hidden_states"],
|
mutates_args=["hidden_states"], # is this still true?
|
||||||
fake_impl=_moe_forward_fake,
|
fake_impl=_moe_forward_fake,
|
||||||
tags=(torch.Tag.needs_fixed_stride_order,),
|
tags=(torch.Tag.needs_fixed_stride_order,),
|
||||||
)
|
)
|
||||||
@@ -168,7 +148,6 @@ direct_register_custom_op(
|
|||||||
direct_register_custom_op(
|
direct_register_custom_op(
|
||||||
op_name="moe_forward_shared",
|
op_name="moe_forward_shared",
|
||||||
op_func=_moe_forward_shared,
|
op_func=_moe_forward_shared,
|
||||||
mutates_args=["hidden_states"],
|
|
||||||
fake_impl=_moe_forward_shared_fake,
|
fake_impl=_moe_forward_shared_fake,
|
||||||
tags=(torch.Tag.needs_fixed_stride_order,),
|
tags=(torch.Tag.needs_fixed_stride_order,),
|
||||||
)
|
)
|
||||||
@@ -213,87 +192,68 @@ class DefaultMoERunner(MoERunner):
|
|||||||
self.router = router
|
self.router = router
|
||||||
self.routed_input_transform = routed_input_transform
|
self.routed_input_transform = routed_input_transform
|
||||||
self.gate = gate
|
self.gate = gate
|
||||||
self.shared_experts = shared_experts
|
|
||||||
self.quant_method = quant_method
|
self.quant_method = quant_method
|
||||||
self.reduce_results = reduce_results
|
self.reduce_results = reduce_results
|
||||||
self.enable_dbo = enable_dbo
|
self.enable_dbo = enable_dbo
|
||||||
|
|
||||||
|
self.shared_experts: SharedExperts | None = None
|
||||||
|
if shared_experts is not None:
|
||||||
|
self.shared_experts = SharedExperts(
|
||||||
|
shared_experts,
|
||||||
|
moe_config=moe_config,
|
||||||
|
# Note: For now we must pass quant_method along to SharedExperts so it
|
||||||
|
# can property determine where the shared experts are supposed to be
|
||||||
|
# called, i.e. by a MK or by the MoERunner.
|
||||||
|
# Once the MK can be created upfront, we can just pass in the proper
|
||||||
|
# flags derived from the quant_method's MK.
|
||||||
|
reduce_results=reduce_results,
|
||||||
|
quant_method=quant_method,
|
||||||
|
enable_dbo=enable_dbo,
|
||||||
|
)
|
||||||
|
|
||||||
# Chunked all2all staging tensor
|
# Chunked all2all staging tensor
|
||||||
# TODO(bnell) rename these?
|
# These need to exist ahead of time due to CUDAgraph construction
|
||||||
|
# needing a fixed buffer address.
|
||||||
|
self.use_dp_chunking = self.moe_config.moe_parallel_config.use_dp_chunking
|
||||||
self.batched_hidden_states: torch.Tensor | None = None
|
self.batched_hidden_states: torch.Tensor | None = None
|
||||||
self.batched_router_logits: torch.Tensor | None = None
|
self.batched_router_logits: torch.Tensor | None = None
|
||||||
self._maybe_init_dp_chunking()
|
self._maybe_init_dp_chunking()
|
||||||
|
|
||||||
# Allow disabling of the separate shared experts stream for
|
|
||||||
# debug purposes.
|
|
||||||
# TODO: Remove this after more extensive testings with TP/DP
|
|
||||||
# and other execution modes
|
|
||||||
self.use_shared_experts_stream = False
|
|
||||||
if envs.VLLM_DISABLE_SHARED_EXPERTS_STREAM:
|
|
||||||
logger.debug_once("Disabling MoE shared_experts cuda stream", scope="local")
|
|
||||||
self.shared_experts_stream = None
|
|
||||||
else:
|
|
||||||
# TODO(rob): enable shared expert overlap with non-cuda-alike.
|
|
||||||
# aux_stream() returns None on non-cuda-alike platforms.
|
|
||||||
self.shared_experts_stream = aux_stream()
|
|
||||||
if self.shared_experts_stream is not None:
|
|
||||||
logger.debug_once(
|
|
||||||
"Enabled separate cuda stream for MoE shared_experts", scope="local"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Needed for string -> FusedMoE layer lookup in custom ops.
|
# Needed for string -> FusedMoE layer lookup in custom ops.
|
||||||
self.layer_name = layer.layer_name
|
self.layer_name = layer.layer_name
|
||||||
|
|
||||||
self.moe_forward = self._select_forward(layer)
|
self.forward_entry, self.forward_impl = self._select_forward(layer)
|
||||||
|
|
||||||
|
def _select_forward(self, layer: torch.nn.Module) -> tuple[Callable, Callable]:
|
||||||
|
# Select implementation based on presence of DP chunking.
|
||||||
|
forward_impl_fn = (
|
||||||
|
self._forward_impl_chunked if self.use_dp_chunking else self._forward_impl
|
||||||
|
)
|
||||||
|
|
||||||
def _select_forward(self, layer: torch.nn.Module) -> Callable:
|
|
||||||
if current_platform.is_tpu() or current_platform.is_cpu():
|
if current_platform.is_tpu() or current_platform.is_cpu():
|
||||||
# TODO: Once the OOM issue for the TPU backend is resolved, we
|
# TODO: Once the OOM issue for the TPU backend is resolved, we
|
||||||
# will switch to using the moe_forward custom op.
|
# will switch to using the moe_forward custom op.
|
||||||
# Note: CPU doesn't require wrapped forward_impl.
|
# Note: CPU doesn't require wrapped forward_impl.
|
||||||
return _moe_forward if self.shared_experts is None else _moe_forward_shared
|
return (
|
||||||
|
_moe_forward if self.shared_experts is None else _moe_forward_shared,
|
||||||
|
forward_impl_fn,
|
||||||
|
)
|
||||||
|
|
||||||
return (
|
return (
|
||||||
torch.ops.vllm.moe_forward
|
torch.ops.vllm.moe_forward
|
||||||
if self.shared_experts is None
|
if self.shared_experts is None
|
||||||
else torch.ops.vllm.moe_forward_shared
|
else torch.ops.vllm.moe_forward_shared,
|
||||||
|
forward_impl_fn,
|
||||||
)
|
)
|
||||||
|
|
||||||
@property
|
# TODO(bnell): temporary hack, do not call this method.
|
||||||
def use_dp_chunking(self) -> bool:
|
def _replace_quant_method(self, quant_method: FusedMoEMethodBase):
|
||||||
return (
|
if self.shared_experts is not None:
|
||||||
self.moe_config.moe_parallel_config.use_deepep_ll_kernels
|
self.shared_experts._quant_method = quant_method
|
||||||
or self.moe_config.moe_parallel_config.use_mori_kernels
|
self.quant_method = quant_method
|
||||||
or self.moe_config.moe_parallel_config.use_fi_nvl_two_sided_kernels
|
|
||||||
or self.moe_config.moe_parallel_config.use_nixl_ep_kernels
|
|
||||||
) and envs.VLLM_ENABLE_MOE_DP_CHUNK
|
|
||||||
|
|
||||||
def _maybe_setup_shared_experts_stream(
|
def is_internal_router(self) -> bool:
|
||||||
self,
|
return self.gate is not None
|
||||||
hidden_states: torch.Tensor,
|
|
||||||
shared_input: torch.Tensor | None,
|
|
||||||
):
|
|
||||||
if self.use_shared_experts_stream:
|
|
||||||
assert self.shared_experts_stream is not None
|
|
||||||
assert self.moe_config.disable_inplace
|
|
||||||
|
|
||||||
shared_experts_input = (
|
|
||||||
shared_input if shared_input is not None else hidden_states
|
|
||||||
)
|
|
||||||
|
|
||||||
# Record that the shared_experts_input will be used in the
|
|
||||||
# shared_experts_stream to avoid gc issue from
|
|
||||||
# deallocation. For more details:
|
|
||||||
# https://docs.pytorch.org/docs/stable/generated/torch.Tensor.record_stream.html # noqa: E501
|
|
||||||
# NOTE: We don't need shared_output.record_stream(current_stream())
|
|
||||||
# because we synch the streams before using shared_output.
|
|
||||||
shared_experts_input.record_stream(self.shared_experts_stream)
|
|
||||||
|
|
||||||
# Mark sync start point for the separate shared experts
|
|
||||||
# stream here since we want to run in parallel with the
|
|
||||||
# router/gate (next op below)
|
|
||||||
assert self.shared_experts_stream is not None
|
|
||||||
self.shared_experts_stream.wait_stream(current_stream())
|
|
||||||
|
|
||||||
def _maybe_init_dp_chunking(self):
|
def _maybe_init_dp_chunking(self):
|
||||||
if not self.use_dp_chunking:
|
if not self.use_dp_chunking:
|
||||||
@@ -325,38 +285,6 @@ class DefaultMoERunner(MoERunner):
|
|||||||
device=device,
|
device=device,
|
||||||
)
|
)
|
||||||
|
|
||||||
@property
|
|
||||||
def has_separate_shared_experts(self) -> bool:
|
|
||||||
return (
|
|
||||||
not self.quant_method.mk_owns_shared_expert
|
|
||||||
and self.shared_experts is not None
|
|
||||||
)
|
|
||||||
|
|
||||||
def _apply_shared_experts(
|
|
||||||
self,
|
|
||||||
hidden_states: torch.Tensor,
|
|
||||||
allow_streaming: bool = False,
|
|
||||||
) -> torch.Tensor | None:
|
|
||||||
shared_output: torch.Tensor | None = None
|
|
||||||
if self.has_separate_shared_experts:
|
|
||||||
assert self.shared_experts is not None
|
|
||||||
|
|
||||||
if self.use_shared_experts_stream and allow_streaming:
|
|
||||||
# Run shared experts in parallel on a separate stream
|
|
||||||
# NOTE: We start the separate stream here and mark the
|
|
||||||
# sync end point immediately after it is done. This is
|
|
||||||
# important to avoid excessive stream allocations by the cuda
|
|
||||||
# graph replay later.
|
|
||||||
with torch.cuda.stream(self.shared_experts_stream):
|
|
||||||
# Note that hidden_states clone() is necessary here to avoid
|
|
||||||
# conflict with the main stream
|
|
||||||
shared_output = self.shared_experts(hidden_states)
|
|
||||||
current_stream().wait_stream(self.shared_experts_stream)
|
|
||||||
else:
|
|
||||||
shared_output = self.shared_experts(hidden_states)
|
|
||||||
|
|
||||||
return shared_output
|
|
||||||
|
|
||||||
def must_reduce_shared_expert_outputs(self) -> bool:
|
def must_reduce_shared_expert_outputs(self) -> bool:
|
||||||
"""
|
"""
|
||||||
The shared_experts are typically computed using the RowParallelLinear
|
The shared_experts are typically computed using the RowParallelLinear
|
||||||
@@ -384,7 +312,9 @@ class DefaultMoERunner(MoERunner):
|
|||||||
else:
|
else:
|
||||||
return tensor_model_parallel_all_reduce(final_hidden_states)
|
return tensor_model_parallel_all_reduce(final_hidden_states)
|
||||||
|
|
||||||
def apply_routed_input_transform(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
def apply_routed_input_transform(
|
||||||
|
self, hidden_states: torch.Tensor
|
||||||
|
) -> tuple[torch.Tensor, torch.Tensor | None]:
|
||||||
"""Apply transform for routed experts (e.g., latent projection).
|
"""Apply transform for routed experts (e.g., latent projection).
|
||||||
|
|
||||||
This is called by FusedMoE.forward_native. The original hidden_states
|
This is called by FusedMoE.forward_native. The original hidden_states
|
||||||
@@ -394,15 +324,22 @@ class DefaultMoERunner(MoERunner):
|
|||||||
TODO: For latent MoE bandwidth optimization, fc2_latent_proj could be
|
TODO: For latent MoE bandwidth optimization, fc2_latent_proj could be
|
||||||
moved inside SharedFusedMoE to all-reduce on the smaller latent
|
moved inside SharedFusedMoE to all-reduce on the smaller latent
|
||||||
dimension.
|
dimension.
|
||||||
|
|
||||||
|
Returns (possibly transformed) hidden states and the input for shared
|
||||||
|
experts (or None if there are no shared experts).
|
||||||
"""
|
"""
|
||||||
if self.routed_input_transform is not None:
|
if self.routed_input_transform is not None:
|
||||||
result = self.routed_input_transform(hidden_states)
|
result = self.routed_input_transform(hidden_states)
|
||||||
# ReplicatedLinear returns (output, extra_bias) tuple.
|
# ReplicatedLinear returns (output, extra_bias) tuple.
|
||||||
# We only need the output tensor; extra_bias is not used here.
|
# We only need the output tensor; extra_bias is not used here.
|
||||||
if isinstance(result, tuple):
|
if isinstance(result, tuple):
|
||||||
return result[0]
|
return result[0], hidden_states
|
||||||
return result
|
return result, hidden_states
|
||||||
return hidden_states
|
|
||||||
|
return (
|
||||||
|
hidden_states,
|
||||||
|
hidden_states if self.shared_experts is not None else None,
|
||||||
|
)
|
||||||
|
|
||||||
def _maybe_reduce_output(
|
def _maybe_reduce_output(
|
||||||
self,
|
self,
|
||||||
@@ -446,13 +383,11 @@ class DefaultMoERunner(MoERunner):
|
|||||||
|
|
||||||
def _maybe_pad_hidden_states(
|
def _maybe_pad_hidden_states(
|
||||||
self,
|
self,
|
||||||
original_hidden_states: torch.Tensor | None,
|
shared_experts_input: torch.Tensor | None,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
) -> tuple[torch.Tensor, list[int]]:
|
) -> tuple[torch.Tensor, list[int]]:
|
||||||
original_hidden_dim = (
|
shared_experts_hidden_dim = (
|
||||||
original_hidden_states.shape[-1]
|
shared_experts_input.shape[-1] if shared_experts_input is not None else 0
|
||||||
if original_hidden_states is not None
|
|
||||||
else 0
|
|
||||||
)
|
)
|
||||||
transformed_hidden_dim = hidden_states.shape[-1]
|
transformed_hidden_dim = hidden_states.shape[-1]
|
||||||
if (
|
if (
|
||||||
@@ -467,29 +402,37 @@ class DefaultMoERunner(MoERunner):
|
|||||||
)
|
)
|
||||||
|
|
||||||
if self.shared_experts is not None:
|
if self.shared_experts is not None:
|
||||||
orig_hidden_dims = [original_hidden_dim, transformed_hidden_dim]
|
orig_hidden_dims = [shared_experts_hidden_dim, transformed_hidden_dim]
|
||||||
else:
|
else:
|
||||||
orig_hidden_dims = [transformed_hidden_dim]
|
orig_hidden_dims = [transformed_hidden_dim]
|
||||||
|
|
||||||
return hidden_states, orig_hidden_dims
|
return hidden_states, orig_hidden_dims
|
||||||
|
|
||||||
|
def _maybe_apply_shared_experts(
|
||||||
|
self,
|
||||||
|
shared_experts_input: torch.Tensor | None,
|
||||||
|
order: SharedExpertsOrder,
|
||||||
|
):
|
||||||
|
if self.shared_experts is not None:
|
||||||
|
assert shared_experts_input is not None
|
||||||
|
self.shared_experts.apply(shared_experts_input, order)
|
||||||
|
|
||||||
def _apply_quant_method(
|
def _apply_quant_method(
|
||||||
self,
|
self,
|
||||||
layer: torch.nn.Module,
|
layer: torch.nn.Module,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
router_logits: torch.Tensor,
|
router_logits: torch.Tensor,
|
||||||
shared_input: torch.Tensor | None,
|
shared_experts_input: torch.Tensor | None,
|
||||||
run_shared_experts_before: bool = True,
|
|
||||||
) -> tuple[torch.Tensor | None, torch.Tensor]:
|
) -> tuple[torch.Tensor | None, torch.Tensor]:
|
||||||
shared_input = shared_input if shared_input is not None else hidden_states
|
|
||||||
shared_output: torch.Tensor | None = None
|
|
||||||
|
|
||||||
# Run this before quant_method to avoid inplace issues.
|
# Run this before quant_method to avoid inplace issues.
|
||||||
if run_shared_experts_before:
|
# TODO(bnell): probably not needed anymore since inplace is
|
||||||
shared_output = self._apply_shared_experts(shared_input, False)
|
# disabled when shared experts are present.
|
||||||
|
self._maybe_apply_shared_experts(
|
||||||
|
shared_experts_input, SharedExpertsOrder.NO_OVERLAP
|
||||||
|
)
|
||||||
|
|
||||||
if self.quant_method.is_monolithic:
|
if self.quant_method.is_monolithic:
|
||||||
result = self.quant_method.apply_monolithic(
|
fused_out = self.quant_method.apply_monolithic(
|
||||||
layer=layer,
|
layer=layer,
|
||||||
x=hidden_states,
|
x=hidden_states,
|
||||||
router_logits=router_logits,
|
router_logits=router_logits,
|
||||||
@@ -500,25 +443,25 @@ class DefaultMoERunner(MoERunner):
|
|||||||
router_logits=router_logits,
|
router_logits=router_logits,
|
||||||
)
|
)
|
||||||
|
|
||||||
result = self.quant_method.apply(
|
# Passing shared_experts_input in case SharedExpertsOrder is
|
||||||
|
# NO_OVERLAP or MK_INTERNAL_OVERLAPPED.
|
||||||
|
fused_out = self.quant_method.apply(
|
||||||
layer=layer,
|
layer=layer,
|
||||||
x=hidden_states,
|
x=hidden_states,
|
||||||
topk_weights=topk_weights,
|
topk_weights=topk_weights,
|
||||||
topk_ids=topk_ids,
|
topk_ids=topk_ids,
|
||||||
shared_experts_input=shared_input,
|
shared_experts_input=shared_experts_input,
|
||||||
)
|
)
|
||||||
|
|
||||||
if isinstance(result, tuple):
|
self._maybe_apply_shared_experts(
|
||||||
assert shared_output is None
|
shared_experts_input,
|
||||||
shared_output, hidden_states = result
|
SharedExpertsOrder.MULTI_STREAM_OVERLAPPED,
|
||||||
else:
|
)
|
||||||
hidden_states = result
|
|
||||||
|
|
||||||
if not run_shared_experts_before and self.has_separate_shared_experts:
|
return (
|
||||||
assert shared_output is None
|
self.shared_experts.output if self.shared_experts is not None else None,
|
||||||
shared_output = self._apply_shared_experts(shared_input, True)
|
fused_out,
|
||||||
|
)
|
||||||
return shared_output, hidden_states
|
|
||||||
|
|
||||||
def _sequence_parallel_context(self):
|
def _sequence_parallel_context(self):
|
||||||
ctx = get_forward_context()
|
ctx = get_forward_context()
|
||||||
@@ -558,18 +501,16 @@ class DefaultMoERunner(MoERunner):
|
|||||||
|
|
||||||
return final_shared_hidden_states, final_fused_hidden_states
|
return final_shared_hidden_states, final_fused_hidden_states
|
||||||
|
|
||||||
def _maybe_gate(
|
def _maybe_sync_shared_experts_stream(
|
||||||
self,
|
self,
|
||||||
hidden_states: torch.Tensor,
|
shared_experts_input: torch.Tensor | None,
|
||||||
router_logits: torch.Tensor,
|
):
|
||||||
) -> torch.Tensor:
|
|
||||||
# If router/gate provided, then apply it here.
|
# If router/gate provided, then apply it here.
|
||||||
# (Note: This code runs only when "overlapped mode" is on to allow
|
# (Note: This code runs only when "overlapped mode" is on to allow
|
||||||
# parallel execution of shared experts with the FusedMoE via
|
# parallel execution of shared experts with the FusedMoE via
|
||||||
# separate cuda stream)
|
# separate cuda stream)
|
||||||
if self.gate is not None:
|
if self.shared_experts is not None:
|
||||||
router_logits, _ = self.gate(hidden_states)
|
self.shared_experts.maybe_sync_shared_experts_stream(shared_experts_input)
|
||||||
return router_logits
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def do_naive_dispatch_combine(self) -> bool:
|
def do_naive_dispatch_combine(self) -> bool:
|
||||||
@@ -624,7 +565,6 @@ class DefaultMoERunner(MoERunner):
|
|||||||
hidden_states,
|
hidden_states,
|
||||||
dim=0,
|
dim=0,
|
||||||
)
|
)
|
||||||
# need RS for shared_output?
|
|
||||||
|
|
||||||
if self.shared_experts is not None:
|
if self.shared_experts is not None:
|
||||||
assert shared_output is not None
|
assert shared_output is not None
|
||||||
@@ -637,30 +577,86 @@ class DefaultMoERunner(MoERunner):
|
|||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
router_logits: torch.Tensor,
|
router_logits: torch.Tensor,
|
||||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
||||||
# For latent MoE: save ORIGINAL hidden_states before transform
|
"""Invoke the fused moe layer.
|
||||||
# (shared_experts need original dimension, routed experts use transformed)
|
|
||||||
if self.shared_experts is not None:
|
Input:
|
||||||
original_hidden_states = hidden_states
|
- hidden_states
|
||||||
else:
|
- router_logits
|
||||||
original_hidden_states = None
|
|
||||||
|
Output:
|
||||||
|
- The new hidden_states.
|
||||||
|
or
|
||||||
|
- A tuple of (shared experts output, new hidden_states).
|
||||||
|
|
||||||
|
Calling sequence
|
||||||
|
- forward
|
||||||
|
- self.forward_entry (_moe_forward or _moe_forward_shared custom op)
|
||||||
|
- forward_dispatch
|
||||||
|
- forward_impl (_forward_impl or _forward_impl_chunked)
|
||||||
|
|
||||||
|
Note: The existence of _moe_forward and _moe_forward_shared custom ops are due
|
||||||
|
to the following reasons:
|
||||||
|
1. the chunking loop in _forward_impl_chunked cannot be compiled by
|
||||||
|
torch.compile
|
||||||
|
2. pytorch cannot handle union types in custom op signatures so _moe_forward
|
||||||
|
and _moe_forward_shared must be split.
|
||||||
|
|
||||||
|
If _forward_impl_chunked can be implemented via torch.scan we can potentially
|
||||||
|
get rid of _moe_forward and _moe_forward_shared and collapse the whole sequence
|
||||||
|
into the 'forward' method.
|
||||||
|
"""
|
||||||
|
|
||||||
# Apply transform for routed experts (e.g., latent projection for latent MoE)
|
# Apply transform for routed experts (e.g., latent projection for latent MoE)
|
||||||
hidden_states = self.apply_routed_input_transform(hidden_states)
|
hidden_states, shared_experts_input = self.apply_routed_input_transform(
|
||||||
|
hidden_states
|
||||||
|
)
|
||||||
|
|
||||||
hidden_states, og_hidden_dims = self._maybe_pad_hidden_states(
|
hidden_states, og_hidden_dims = self._maybe_pad_hidden_states(
|
||||||
original_hidden_states,
|
shared_experts_input,
|
||||||
hidden_states,
|
hidden_states,
|
||||||
)
|
)
|
||||||
|
|
||||||
fused_output = self.moe_forward(
|
fused_output = self.forward_entry(
|
||||||
hidden_states,
|
hidden_states,
|
||||||
router_logits,
|
router_logits,
|
||||||
original_hidden_states,
|
shared_experts_input,
|
||||||
self._encode_layer_name(),
|
self._encode_layer_name(),
|
||||||
)
|
)
|
||||||
|
|
||||||
return self._maybe_reduce_output(fused_output, og_hidden_dims)
|
return self._maybe_reduce_output(fused_output, og_hidden_dims)
|
||||||
|
|
||||||
|
def forward_dispatch(
|
||||||
|
self,
|
||||||
|
layer: torch.nn.Module,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
router_logits: torch.Tensor,
|
||||||
|
shared_experts_input: torch.Tensor | None,
|
||||||
|
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
# TODO(bnell): this can be removed after MK migration is complete.
|
||||||
|
layer.ensure_moe_quant_config_init()
|
||||||
|
|
||||||
|
# Sync aux and main stream for shared expert multi-stream overlap.
|
||||||
|
self._maybe_sync_shared_experts_stream(shared_experts_input)
|
||||||
|
|
||||||
|
# If the Runner holds the gate, apply it after the stream sync,
|
||||||
|
# so it can run overlapped with the
|
||||||
|
# NOTE: in future PR, MoE runner will always hold the gate.
|
||||||
|
if self.gate is not None:
|
||||||
|
router_logits, _ = self.gate(hidden_states)
|
||||||
|
|
||||||
|
self._maybe_apply_shared_experts(
|
||||||
|
shared_experts_input,
|
||||||
|
SharedExpertsOrder.EXTERNAL,
|
||||||
|
)
|
||||||
|
|
||||||
|
with self._sequence_parallel_context():
|
||||||
|
return self.forward_impl(
|
||||||
|
layer,
|
||||||
|
hidden_states,
|
||||||
|
router_logits,
|
||||||
|
shared_experts_input,
|
||||||
|
)
|
||||||
|
|
||||||
def _slice_and_copy_input(
|
def _slice_and_copy_input(
|
||||||
self,
|
self,
|
||||||
out_slice: torch.Tensor,
|
out_slice: torch.Tensor,
|
||||||
@@ -681,17 +677,13 @@ class DefaultMoERunner(MoERunner):
|
|||||||
out_slice.copy_(orig_slice, non_blocking=True)
|
out_slice.copy_(orig_slice, non_blocking=True)
|
||||||
return out_slice
|
return out_slice
|
||||||
|
|
||||||
def forward_impl_chunked(
|
def _forward_impl_chunked(
|
||||||
self,
|
self,
|
||||||
layer: torch.nn.Module,
|
layer: torch.nn.Module,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
router_logits: torch.Tensor,
|
router_logits: torch.Tensor,
|
||||||
shared_input: torch.Tensor | None,
|
shared_experts_input: torch.Tensor | None,
|
||||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
||||||
# Gate overlap not supported when chunking is enabled. Run the
|
|
||||||
# gate first.
|
|
||||||
router_logits = self._maybe_gate(hidden_states, router_logits)
|
|
||||||
|
|
||||||
final_shared_hidden_states, final_fused_hidden_states = (
|
final_shared_hidden_states, final_fused_hidden_states = (
|
||||||
self._allocate_dp_chunking_outputs(hidden_states, router_logits)
|
self._allocate_dp_chunking_outputs(hidden_states, router_logits)
|
||||||
)
|
)
|
||||||
@@ -737,9 +729,9 @@ class DefaultMoERunner(MoERunner):
|
|||||||
chunk_end,
|
chunk_end,
|
||||||
)
|
)
|
||||||
|
|
||||||
shared_input_chunk = (
|
shared_experts_input_chunk = (
|
||||||
shared_input[chunk_start:chunk_end, :]
|
shared_experts_input[chunk_start:chunk_end, :]
|
||||||
if shared_input is not None
|
if shared_experts_input is not None
|
||||||
else None
|
else None
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -747,7 +739,7 @@ class DefaultMoERunner(MoERunner):
|
|||||||
layer=layer,
|
layer=layer,
|
||||||
hidden_states=hidden_states_chunk,
|
hidden_states=hidden_states_chunk,
|
||||||
router_logits=router_logits_chunk,
|
router_logits=router_logits_chunk,
|
||||||
shared_input=shared_input_chunk,
|
shared_experts_input=shared_experts_input_chunk,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Store outputs
|
# Store outputs
|
||||||
@@ -769,40 +761,13 @@ class DefaultMoERunner(MoERunner):
|
|||||||
assert final_shared_hidden_states is not None
|
assert final_shared_hidden_states is not None
|
||||||
return (final_shared_hidden_states, final_fused_hidden_states)
|
return (final_shared_hidden_states, final_fused_hidden_states)
|
||||||
|
|
||||||
def forward_impl(
|
def _forward_impl(
|
||||||
self,
|
self,
|
||||||
layer: torch.nn.Module,
|
layer: torch.nn.Module,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
router_logits: torch.Tensor,
|
router_logits: torch.Tensor,
|
||||||
shared_input: torch.Tensor | None,
|
shared_experts_input: torch.Tensor | None,
|
||||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
||||||
self.use_shared_experts_stream = (
|
|
||||||
current_platform.is_cuda()
|
|
||||||
and self.has_separate_shared_experts
|
|
||||||
and not self.use_dp_chunking
|
|
||||||
and self.shared_experts_stream is not None
|
|
||||||
and (
|
|
||||||
hidden_states.shape[0]
|
|
||||||
<= envs.VLLM_SHARED_EXPERTS_STREAM_TOKEN_THRESHOLD
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
# Check if we need to run shared experts before matrix multiply because
|
|
||||||
# matrix multiply may modify the hidden_states.
|
|
||||||
run_shared_experts_before = (
|
|
||||||
self.has_separate_shared_experts and not self.use_shared_experts_stream
|
|
||||||
)
|
|
||||||
|
|
||||||
# The shared experts stream must be set up before calling the gate so they
|
|
||||||
# can be overlapped.
|
|
||||||
if not run_shared_experts_before:
|
|
||||||
self._maybe_setup_shared_experts_stream(
|
|
||||||
hidden_states,
|
|
||||||
shared_input,
|
|
||||||
)
|
|
||||||
|
|
||||||
router_logits = self._maybe_gate(hidden_states, router_logits)
|
|
||||||
|
|
||||||
# TODO(bnell): parts of the dispatch/combine steps will go away once
|
# TODO(bnell): parts of the dispatch/combine steps will go away once
|
||||||
# #32567 lands and the remaining kernels are made MKs. The PCP
|
# #32567 lands and the remaining kernels are made MKs. The PCP
|
||||||
# code will probably remain
|
# code will probably remain
|
||||||
@@ -816,8 +781,7 @@ class DefaultMoERunner(MoERunner):
|
|||||||
layer=layer,
|
layer=layer,
|
||||||
hidden_states=hidden_states,
|
hidden_states=hidden_states,
|
||||||
router_logits=router_logits,
|
router_logits=router_logits,
|
||||||
shared_input=shared_input,
|
shared_experts_input=shared_experts_input,
|
||||||
run_shared_experts_before=run_shared_experts_before,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
return self._maybe_combine(
|
return self._maybe_combine(
|
||||||
|
|||||||
@@ -32,3 +32,7 @@ class MoERunner(ABC):
|
|||||||
final_hidden_states: torch.Tensor,
|
final_hidden_states: torch.Tensor,
|
||||||
):
|
):
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def is_internal_router(self) -> bool:
|
||||||
|
raise NotImplementedError
|
||||||
|
|||||||
216
vllm/model_executor/layers/fused_moe/runner/shared_experts.py
Normal file
216
vllm/model_executor/layers/fused_moe/runner/shared_experts.py
Normal file
@@ -0,0 +1,216 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
from enum import IntEnum
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
import vllm.envs as envs
|
||||||
|
from vllm.distributed import (
|
||||||
|
get_tensor_model_parallel_world_size,
|
||||||
|
tensor_model_parallel_all_reduce,
|
||||||
|
)
|
||||||
|
from vllm.logger import init_logger
|
||||||
|
from vllm.model_executor.layers.fused_moe.config import (
|
||||||
|
FusedMoEConfig,
|
||||||
|
)
|
||||||
|
from vllm.model_executor.layers.quantization.base_config import (
|
||||||
|
QuantizeMethodBase,
|
||||||
|
)
|
||||||
|
from vllm.platforms import current_platform
|
||||||
|
from vllm.utils.torch_utils import (
|
||||||
|
aux_stream,
|
||||||
|
current_stream,
|
||||||
|
)
|
||||||
|
from vllm.v1.worker.ubatching import (
|
||||||
|
dbo_current_ubatch_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class SharedExpertsOrder(IntEnum):
|
||||||
|
# No shared experts.
|
||||||
|
NONE = (0,)
|
||||||
|
|
||||||
|
# Get rid of this one? combine with BEFORE?
|
||||||
|
# Note: this might be important for torch.compile reasons. Can
|
||||||
|
# get rid of it after _moe_forward is undone.
|
||||||
|
EXTERNAL = (1,)
|
||||||
|
|
||||||
|
# No overlap - defensively called before MK.
|
||||||
|
NO_OVERLAP = (2,)
|
||||||
|
|
||||||
|
# Overlapped with dispatch/combine in DP/EP - called by the MK.
|
||||||
|
MK_INTERNAL_OVERLAPPED = (3,)
|
||||||
|
|
||||||
|
# Overlapped with the gate, router, experts in aux stream.
|
||||||
|
MULTI_STREAM_OVERLAPPED = (4,)
|
||||||
|
|
||||||
|
|
||||||
|
class SharedExperts:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
layer: torch.nn.Module,
|
||||||
|
moe_config: FusedMoEConfig,
|
||||||
|
quant_method: QuantizeMethodBase,
|
||||||
|
reduce_results: bool,
|
||||||
|
enable_dbo: bool,
|
||||||
|
):
|
||||||
|
from vllm.model_executor.layers.fused_moe.fused_moe_method_base import (
|
||||||
|
FusedMoEMethodBase,
|
||||||
|
)
|
||||||
|
|
||||||
|
# quant_method must be a FusedMoEMethodBase but we can't use the type
|
||||||
|
# due to circular imports.
|
||||||
|
assert isinstance(quant_method, FusedMoEMethodBase)
|
||||||
|
|
||||||
|
# The SharedExperts need to handle DBO since they can be called from
|
||||||
|
# an MK's finalize method. We keep a list of outputs indexed by current
|
||||||
|
# DBO ubatch id to handle this case. If DBO is not enabled, the
|
||||||
|
# index is always 0 and the second output list element is ignored.
|
||||||
|
self.enable_dbo = enable_dbo
|
||||||
|
self._output: list[torch.Tensor | None] = [None, None]
|
||||||
|
self._layer = layer
|
||||||
|
self._moe_config = moe_config
|
||||||
|
self._quant_method = quant_method
|
||||||
|
self._reduce_results = reduce_results
|
||||||
|
self._use_dp_chunking = moe_config.moe_parallel_config.use_dp_chunking
|
||||||
|
|
||||||
|
# Allow disabling of the separate shared experts stream for
|
||||||
|
# debug purposes.
|
||||||
|
# TODO: Remove this after more extensive testings with TP/DP
|
||||||
|
# and other execution modes
|
||||||
|
if envs.VLLM_DISABLE_SHARED_EXPERTS_STREAM:
|
||||||
|
logger.debug_once("Disabling MoE shared_experts cuda stream", scope="local")
|
||||||
|
self._stream = None
|
||||||
|
else:
|
||||||
|
# TODO(rob): enable shared expert overlap with non-cuda-alike.
|
||||||
|
# aux_stream() returns None on non-cuda-alike platforms.
|
||||||
|
self._stream = aux_stream()
|
||||||
|
if self._stream is not None:
|
||||||
|
logger.debug_once(
|
||||||
|
"Enabled separate cuda stream for MoE shared_experts", scope="local"
|
||||||
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _has_external_experts(self) -> bool:
|
||||||
|
# Disable shared expert overlap if:
|
||||||
|
# - we are using eplb with non-default backend, because of correctness issues
|
||||||
|
# - we are using flashinfer with DP, since there nothing to gain
|
||||||
|
backend = self._moe_config.moe_parallel_config.all2all_backend
|
||||||
|
return not (
|
||||||
|
(
|
||||||
|
self._moe_config.moe_parallel_config.enable_eplb
|
||||||
|
and backend != "allgather_reducescatter"
|
||||||
|
)
|
||||||
|
or self._moe_config.moe_parallel_config.use_fi_nvl_two_sided_kernels
|
||||||
|
)
|
||||||
|
|
||||||
|
def _determine_shared_experts_order(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
) -> SharedExpertsOrder:
|
||||||
|
if self._has_external_experts and not self._use_dp_chunking:
|
||||||
|
return SharedExpertsOrder.EXTERNAL
|
||||||
|
|
||||||
|
if self._quant_method.mk_owns_shared_expert:
|
||||||
|
return SharedExpertsOrder.MK_INTERNAL_OVERLAPPED
|
||||||
|
|
||||||
|
should_run_shared_in_aux_stream = (
|
||||||
|
current_platform.is_cuda()
|
||||||
|
and not self._use_dp_chunking
|
||||||
|
and self._stream is not None
|
||||||
|
and hidden_states.shape[0]
|
||||||
|
<= envs.VLLM_SHARED_EXPERTS_STREAM_TOKEN_THRESHOLD
|
||||||
|
)
|
||||||
|
|
||||||
|
if should_run_shared_in_aux_stream:
|
||||||
|
return SharedExpertsOrder.MULTI_STREAM_OVERLAPPED
|
||||||
|
else:
|
||||||
|
return SharedExpertsOrder.NO_OVERLAP
|
||||||
|
|
||||||
|
def maybe_sync_shared_experts_stream(
|
||||||
|
self,
|
||||||
|
shared_experts_input: torch.Tensor,
|
||||||
|
):
|
||||||
|
experts_order = self._determine_shared_experts_order(shared_experts_input)
|
||||||
|
|
||||||
|
if experts_order == SharedExpertsOrder.MULTI_STREAM_OVERLAPPED:
|
||||||
|
assert self._stream is not None
|
||||||
|
assert self._moe_config.disable_inplace
|
||||||
|
|
||||||
|
# Record that the clone will be used by shared_experts_stream
|
||||||
|
# to avoid gc issue from deallocation of hidden_states_clone
|
||||||
|
# For more details: https://docs.pytorch.org/docs/stable/generated/torch.Tensor.record_stream.html # noqa: E501
|
||||||
|
# NOTE: We don't need shared_output.record_stream(current_stream())
|
||||||
|
# because we synch the streams before using shared_output.
|
||||||
|
shared_experts_input.record_stream(self._stream)
|
||||||
|
|
||||||
|
# Mark sync start point for the aux stream since we will
|
||||||
|
# run in parallel with router/gate.
|
||||||
|
self._stream.wait_stream(current_stream())
|
||||||
|
|
||||||
|
def _run_in_aux_stream(
|
||||||
|
self,
|
||||||
|
shared_experts_input: torch.Tensor,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
# TODO: assert that maybe_sync_shared_experts_stream has been called.
|
||||||
|
|
||||||
|
# Run shared experts in parallel on a separate stream.
|
||||||
|
with torch.cuda.stream(self._stream):
|
||||||
|
output = self._layer(shared_experts_input)
|
||||||
|
current_stream().wait_stream(self._stream)
|
||||||
|
|
||||||
|
return output
|
||||||
|
|
||||||
|
def _maybe_reduce_shared_out(self, shared_out: torch.Tensor) -> torch.Tensor:
|
||||||
|
# Reduce shared expert outputs if necessary, since the MLP
|
||||||
|
# should have been created with reduce_results=False.
|
||||||
|
if (
|
||||||
|
self._reduce_results
|
||||||
|
and self._quant_method.moe_kernel is not None
|
||||||
|
and self._quant_method.moe_kernel.output_is_reduced()
|
||||||
|
and get_tensor_model_parallel_world_size() > 1
|
||||||
|
):
|
||||||
|
shared_out = tensor_model_parallel_all_reduce(shared_out)
|
||||||
|
return shared_out
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _output_idx(self) -> int:
|
||||||
|
return dbo_current_ubatch_id() if self.enable_dbo else 0
|
||||||
|
|
||||||
|
@property
|
||||||
|
def output(self) -> torch.Tensor:
|
||||||
|
assert self._output[self._output_idx] is not None
|
||||||
|
output = self._output[self._output_idx]
|
||||||
|
self._output[self._output_idx] = None
|
||||||
|
return output
|
||||||
|
|
||||||
|
def apply(
|
||||||
|
self,
|
||||||
|
shared_experts_input: torch.Tensor,
|
||||||
|
order: SharedExpertsOrder,
|
||||||
|
):
|
||||||
|
experts_order = self._determine_shared_experts_order(shared_experts_input)
|
||||||
|
|
||||||
|
if order != experts_order:
|
||||||
|
return None
|
||||||
|
|
||||||
|
assert self._output[self._output_idx] is None
|
||||||
|
|
||||||
|
if order == SharedExpertsOrder.MULTI_STREAM_OVERLAPPED:
|
||||||
|
self._output[self._output_idx] = self._run_in_aux_stream(
|
||||||
|
shared_experts_input
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self._output[self._output_idx] = self._layer(shared_experts_input)
|
||||||
|
|
||||||
|
if order == SharedExpertsOrder.EXTERNAL:
|
||||||
|
# TODO: figure out how to combine this with maybe_reduce_output?
|
||||||
|
# or get rid of it completely.
|
||||||
|
assert self._output[self._output_idx] is not None
|
||||||
|
self._output[self._output_idx] = self._maybe_reduce_shared_out(
|
||||||
|
self._output[self._output_idx]
|
||||||
|
)
|
||||||
|
|
||||||
|
assert self._output[self._output_idx] is not None
|
||||||
@@ -3,14 +3,10 @@
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from vllm.distributed import (
|
|
||||||
get_tensor_model_parallel_world_size,
|
|
||||||
tensor_model_parallel_all_reduce,
|
|
||||||
)
|
|
||||||
from vllm.model_executor.layers.fused_moe.layer import FusedMoE
|
from vllm.model_executor.layers.fused_moe.layer import FusedMoE
|
||||||
|
|
||||||
|
|
||||||
# TODO(bnell): Add shared + fused combo function? e.g. +
|
# TODO(bnell): Remove this entirely
|
||||||
class SharedFusedMoE(FusedMoE):
|
class SharedFusedMoE(FusedMoE):
|
||||||
"""
|
"""
|
||||||
A FusedMoE operation that also computes the results of shared experts.
|
A FusedMoE operation that also computes the results of shared experts.
|
||||||
@@ -23,36 +19,11 @@ class SharedFusedMoE(FusedMoE):
|
|||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
router_logits: torch.Tensor,
|
router_logits: torch.Tensor,
|
||||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||||
if not self.use_overlapped:
|
result = super().forward(
|
||||||
if self._shared_experts is not None:
|
hidden_states=hidden_states,
|
||||||
shared_out = self._shared_experts(hidden_states)
|
router_logits=router_logits,
|
||||||
|
)
|
||||||
# Reduce shared expert outputs if necessary, since the MLP
|
if self.shared_experts is None:
|
||||||
# should have been created with reduce_results=False.
|
return None, result
|
||||||
if (
|
|
||||||
self.reduce_results
|
|
||||||
and get_tensor_model_parallel_world_size() > 1
|
|
||||||
and self.must_reduce_shared_expert_outputs()
|
|
||||||
):
|
|
||||||
shared_out = tensor_model_parallel_all_reduce(shared_out)
|
|
||||||
else:
|
|
||||||
shared_out = None
|
|
||||||
|
|
||||||
fused_out = super().forward(
|
|
||||||
hidden_states=hidden_states,
|
|
||||||
router_logits=router_logits,
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
shared_out, fused_out = super().forward(
|
return result
|
||||||
hidden_states=hidden_states,
|
|
||||||
router_logits=router_logits,
|
|
||||||
)
|
|
||||||
# ensure early TP reduction of shared expert outputs when required
|
|
||||||
if (
|
|
||||||
shared_out is not None
|
|
||||||
and self.reduce_results
|
|
||||||
and get_tensor_model_parallel_world_size() > 1
|
|
||||||
and self.must_reduce_shared_expert_outputs()
|
|
||||||
):
|
|
||||||
shared_out = tensor_model_parallel_all_reduce(shared_out)
|
|
||||||
return shared_out, fused_out
|
|
||||||
|
|||||||
@@ -245,7 +245,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
|||||||
topk_weights: torch.Tensor,
|
topk_weights: torch.Tensor,
|
||||||
topk_ids: torch.Tensor,
|
topk_ids: torch.Tensor,
|
||||||
shared_experts_input: torch.Tensor | None,
|
shared_experts_input: torch.Tensor | None,
|
||||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
) -> torch.Tensor:
|
||||||
return self.forward(
|
return self.forward(
|
||||||
layer=layer,
|
layer=layer,
|
||||||
x=x,
|
x=x,
|
||||||
@@ -261,7 +261,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
|||||||
topk_weights: torch.Tensor,
|
topk_weights: torch.Tensor,
|
||||||
topk_ids: torch.Tensor,
|
topk_ids: torch.Tensor,
|
||||||
shared_experts_input: torch.Tensor | None,
|
shared_experts_input: torch.Tensor | None,
|
||||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
) -> torch.Tensor:
|
||||||
assert self.moe_kernel is not None
|
assert self.moe_kernel is not None
|
||||||
return self.moe_kernel.apply(
|
return self.moe_kernel.apply(
|
||||||
hidden_states=x,
|
hidden_states=x,
|
||||||
@@ -283,7 +283,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
|||||||
topk_weights: torch.Tensor,
|
topk_weights: torch.Tensor,
|
||||||
topk_ids: torch.Tensor,
|
topk_ids: torch.Tensor,
|
||||||
shared_experts_input: torch.Tensor | None,
|
shared_experts_input: torch.Tensor | None,
|
||||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
) -> torch.Tensor:
|
||||||
return self.forward_native(
|
return self.forward_native(
|
||||||
layer, x, topk_weights, topk_ids, shared_experts_input
|
layer, x, topk_weights, topk_ids, shared_experts_input
|
||||||
)
|
)
|
||||||
@@ -293,7 +293,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
|||||||
layer: "FusedMoE", # type: ignore[name-defined] # noqa: F821
|
layer: "FusedMoE", # type: ignore[name-defined] # noqa: F821
|
||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
router_logits: torch.Tensor,
|
router_logits: torch.Tensor,
|
||||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
) -> torch.Tensor:
|
||||||
assert self.is_monolithic
|
assert self.is_monolithic
|
||||||
if self.unquantized_backend == UnquantizedMoeBackend.CPU:
|
if self.unquantized_backend == UnquantizedMoeBackend.CPU:
|
||||||
assert self.moe_kernel is None
|
assert self.moe_kernel is None
|
||||||
|
|||||||
@@ -811,7 +811,7 @@ class AWQMarlinMoEMethod(FusedMoEMethodBase):
|
|||||||
topk_weights: torch.Tensor,
|
topk_weights: torch.Tensor,
|
||||||
topk_ids: torch.Tensor,
|
topk_ids: torch.Tensor,
|
||||||
shared_experts_input: torch.Tensor | None,
|
shared_experts_input: torch.Tensor | None,
|
||||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
) -> torch.Tensor:
|
||||||
return fused_marlin_moe(
|
return fused_marlin_moe(
|
||||||
x,
|
x,
|
||||||
layer.w13_qweight,
|
layer.w13_qweight,
|
||||||
|
|||||||
@@ -483,7 +483,7 @@ class BitsAndBytesMoEMethod(FusedMoEMethodBase):
|
|||||||
topk_weights: torch.Tensor,
|
topk_weights: torch.Tensor,
|
||||||
topk_ids: torch.Tensor,
|
topk_ids: torch.Tensor,
|
||||||
shared_experts_input: torch.Tensor | None,
|
shared_experts_input: torch.Tensor | None,
|
||||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
) -> torch.Tensor:
|
||||||
from vllm.model_executor.layers.fused_moe import fused_experts
|
from vllm.model_executor.layers.fused_moe import fused_experts
|
||||||
|
|
||||||
# TODO(bnell): Do these need to be called on the hot path?
|
# TODO(bnell): Do these need to be called on the hot path?
|
||||||
|
|||||||
@@ -355,7 +355,7 @@ class CompressedTensorsW4A4Mxfp4MoEMethod(CompressedTensorsMoEMethod):
|
|||||||
topk_weights: torch.Tensor,
|
topk_weights: torch.Tensor,
|
||||||
topk_ids: torch.Tensor,
|
topk_ids: torch.Tensor,
|
||||||
shared_experts_input: torch.Tensor | None,
|
shared_experts_input: torch.Tensor | None,
|
||||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
) -> torch.Tensor:
|
||||||
assert self.moe_kernel is not None
|
assert self.moe_kernel is not None
|
||||||
return self.moe_kernel.apply(
|
return self.moe_kernel.apply(
|
||||||
x,
|
x,
|
||||||
@@ -603,7 +603,7 @@ class CompressedTensorsW4A4Nvfp4MoEMethod(CompressedTensorsMoEMethod):
|
|||||||
layer: FusedMoE,
|
layer: FusedMoE,
|
||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
router_logits: torch.Tensor,
|
router_logits: torch.Tensor,
|
||||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
) -> torch.Tensor:
|
||||||
assert self.is_monolithic
|
assert self.is_monolithic
|
||||||
assert self.moe_kernel is not None
|
assert self.moe_kernel is not None
|
||||||
return self.moe_kernel.apply_monolithic(
|
return self.moe_kernel.apply_monolithic(
|
||||||
@@ -628,7 +628,7 @@ class CompressedTensorsW4A4Nvfp4MoEMethod(CompressedTensorsMoEMethod):
|
|||||||
topk_weights: torch.Tensor,
|
topk_weights: torch.Tensor,
|
||||||
topk_ids: torch.Tensor,
|
topk_ids: torch.Tensor,
|
||||||
shared_experts_input: torch.Tensor | None,
|
shared_experts_input: torch.Tensor | None,
|
||||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
) -> torch.Tensor:
|
||||||
assert self.moe_kernel is not None
|
assert self.moe_kernel is not None
|
||||||
return self.moe_kernel.apply(
|
return self.moe_kernel.apply(
|
||||||
x,
|
x,
|
||||||
@@ -963,7 +963,7 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
|
|||||||
layer: FusedMoE,
|
layer: FusedMoE,
|
||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
router_logits: torch.Tensor,
|
router_logits: torch.Tensor,
|
||||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
) -> torch.Tensor:
|
||||||
assert self.moe_kernel is not None
|
assert self.moe_kernel is not None
|
||||||
return self.moe_kernel.apply_monolithic(
|
return self.moe_kernel.apply_monolithic(
|
||||||
x,
|
x,
|
||||||
@@ -987,7 +987,7 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
|
|||||||
topk_weights: torch.Tensor,
|
topk_weights: torch.Tensor,
|
||||||
topk_ids: torch.Tensor,
|
topk_ids: torch.Tensor,
|
||||||
shared_experts_input: torch.Tensor | None,
|
shared_experts_input: torch.Tensor | None,
|
||||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
) -> torch.Tensor:
|
||||||
assert not self.is_monolithic
|
assert not self.is_monolithic
|
||||||
assert self.moe_kernel is not None
|
assert self.moe_kernel is not None
|
||||||
return self.moe_kernel.apply(
|
return self.moe_kernel.apply(
|
||||||
@@ -1127,7 +1127,7 @@ class CompressedTensorsW8A8Int8MoEMethod(CompressedTensorsMoEMethod):
|
|||||||
topk_weights: torch.Tensor,
|
topk_weights: torch.Tensor,
|
||||||
topk_ids: torch.Tensor,
|
topk_ids: torch.Tensor,
|
||||||
shared_experts_input: torch.Tensor | None,
|
shared_experts_input: torch.Tensor | None,
|
||||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
) -> torch.Tensor:
|
||||||
from vllm.model_executor.layers.fused_moe import fused_experts
|
from vllm.model_executor.layers.fused_moe import fused_experts
|
||||||
|
|
||||||
return fused_experts(
|
return fused_experts(
|
||||||
@@ -1611,7 +1611,7 @@ class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod):
|
|||||||
layer: FusedMoE,
|
layer: FusedMoE,
|
||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
router_logits: torch.Tensor,
|
router_logits: torch.Tensor,
|
||||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
) -> torch.Tensor:
|
||||||
assert self.kernel_backend == "Flashinfer"
|
assert self.kernel_backend == "Flashinfer"
|
||||||
return flashinfer_trtllm_mxint4_moe(
|
return flashinfer_trtllm_mxint4_moe(
|
||||||
x=x,
|
x=x,
|
||||||
@@ -1638,7 +1638,7 @@ class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod):
|
|||||||
topk_weights: torch.Tensor,
|
topk_weights: torch.Tensor,
|
||||||
topk_ids: torch.Tensor,
|
topk_ids: torch.Tensor,
|
||||||
shared_experts_input: torch.Tensor | None,
|
shared_experts_input: torch.Tensor | None,
|
||||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
) -> torch.Tensor:
|
||||||
assert self.kernel_backend == "Marlin"
|
assert self.kernel_backend == "Marlin"
|
||||||
return fused_marlin_moe(
|
return fused_marlin_moe(
|
||||||
x,
|
x,
|
||||||
@@ -1887,7 +1887,7 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
|
|||||||
topk_weights: torch.Tensor,
|
topk_weights: torch.Tensor,
|
||||||
topk_ids: torch.Tensor,
|
topk_ids: torch.Tensor,
|
||||||
shared_experts_input: torch.Tensor | None,
|
shared_experts_input: torch.Tensor | None,
|
||||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
) -> torch.Tensor:
|
||||||
from vllm.model_executor.layers.fused_moe import fused_experts
|
from vllm.model_executor.layers.fused_moe import fused_experts
|
||||||
|
|
||||||
return fused_experts(
|
return fused_experts(
|
||||||
@@ -2502,7 +2502,7 @@ class CompressedTensorsW4A8Fp8MoEMethod(CompressedTensorsMoEMethod):
|
|||||||
topk_weights: torch.Tensor,
|
topk_weights: torch.Tensor,
|
||||||
topk_ids: torch.Tensor,
|
topk_ids: torch.Tensor,
|
||||||
shared_experts_input: torch.Tensor | None,
|
shared_experts_input: torch.Tensor | None,
|
||||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
) -> torch.Tensor:
|
||||||
if layer.enable_eplb:
|
if layer.enable_eplb:
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
"EPLB not supported for `CompressedTensorsW4A8Fp8MoEMethod` yet."
|
"EPLB not supported for `CompressedTensorsW4A8Fp8MoEMethod` yet."
|
||||||
|
|||||||
@@ -141,7 +141,7 @@ class ExpertsInt8MoEMethod(FusedMoEMethodBase):
|
|||||||
topk_weights: torch.Tensor,
|
topk_weights: torch.Tensor,
|
||||||
topk_ids: torch.Tensor,
|
topk_ids: torch.Tensor,
|
||||||
shared_experts_input: torch.Tensor | None,
|
shared_experts_input: torch.Tensor | None,
|
||||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
) -> torch.Tensor:
|
||||||
from vllm.model_executor.layers.fused_moe import fused_experts
|
from vllm.model_executor.layers.fused_moe import fused_experts
|
||||||
|
|
||||||
return fused_experts(
|
return fused_experts(
|
||||||
|
|||||||
@@ -877,7 +877,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
|||||||
layer: FusedMoE,
|
layer: FusedMoE,
|
||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
router_logits: torch.Tensor,
|
router_logits: torch.Tensor,
|
||||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
) -> torch.Tensor:
|
||||||
assert self.is_monolithic
|
assert self.is_monolithic
|
||||||
assert self.moe_kernel is not None
|
assert self.moe_kernel is not None
|
||||||
return self.moe_kernel.apply_monolithic(
|
return self.moe_kernel.apply_monolithic(
|
||||||
@@ -902,7 +902,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
|||||||
topk_weights: torch.Tensor,
|
topk_weights: torch.Tensor,
|
||||||
topk_ids: torch.Tensor,
|
topk_ids: torch.Tensor,
|
||||||
shared_experts_input: torch.Tensor | None,
|
shared_experts_input: torch.Tensor | None,
|
||||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
) -> torch.Tensor:
|
||||||
assert not self.is_monolithic
|
assert not self.is_monolithic
|
||||||
assert self.moe_kernel is not None
|
assert self.moe_kernel is not None
|
||||||
return self.moe_kernel.apply(
|
return self.moe_kernel.apply(
|
||||||
|
|||||||
@@ -650,7 +650,7 @@ class GGUFMoEMethod(FusedMoEMethodBase):
|
|||||||
topk_weights: torch.Tensor,
|
topk_weights: torch.Tensor,
|
||||||
topk_ids: torch.Tensor,
|
topk_ids: torch.Tensor,
|
||||||
shared_experts_input: torch.Tensor | None,
|
shared_experts_input: torch.Tensor | None,
|
||||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
) -> torch.Tensor:
|
||||||
if layer.apply_router_weight_on_input:
|
if layer.apply_router_weight_on_input:
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
"Apply router weight on input is not supported for"
|
"Apply router weight on input is not supported for"
|
||||||
|
|||||||
@@ -907,7 +907,7 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
|
|||||||
topk_weights: torch.Tensor,
|
topk_weights: torch.Tensor,
|
||||||
topk_ids: torch.Tensor,
|
topk_ids: torch.Tensor,
|
||||||
shared_experts_input: torch.Tensor | None,
|
shared_experts_input: torch.Tensor | None,
|
||||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
) -> torch.Tensor:
|
||||||
return fused_marlin_moe(
|
return fused_marlin_moe(
|
||||||
x,
|
x,
|
||||||
layer.w13_qweight,
|
layer.w13_qweight,
|
||||||
|
|||||||
@@ -935,7 +935,7 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
|
|||||||
layer: FusedMoE,
|
layer: FusedMoE,
|
||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
router_logits: torch.Tensor,
|
router_logits: torch.Tensor,
|
||||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
) -> torch.Tensor:
|
||||||
assert self.is_monolithic
|
assert self.is_monolithic
|
||||||
assert self.moe_kernel is not None
|
assert self.moe_kernel is not None
|
||||||
return self.moe_kernel.apply_monolithic(
|
return self.moe_kernel.apply_monolithic(
|
||||||
@@ -960,7 +960,7 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
|
|||||||
topk_weights: torch.Tensor,
|
topk_weights: torch.Tensor,
|
||||||
topk_ids: torch.Tensor,
|
topk_ids: torch.Tensor,
|
||||||
shared_experts_input: torch.Tensor | None,
|
shared_experts_input: torch.Tensor | None,
|
||||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
) -> torch.Tensor:
|
||||||
assert not self.is_monolithic
|
assert not self.is_monolithic
|
||||||
assert self.moe_kernel is not None
|
assert self.moe_kernel is not None
|
||||||
return self.moe_kernel.apply(
|
return self.moe_kernel.apply(
|
||||||
@@ -1419,7 +1419,7 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
|
|||||||
layer: FusedMoE,
|
layer: FusedMoE,
|
||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
router_logits: torch.Tensor,
|
router_logits: torch.Tensor,
|
||||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
) -> torch.Tensor:
|
||||||
assert self.is_monolithic
|
assert self.is_monolithic
|
||||||
assert self.moe_kernel is not None
|
assert self.moe_kernel is not None
|
||||||
return self.moe_kernel.apply_monolithic(
|
return self.moe_kernel.apply_monolithic(
|
||||||
@@ -1444,7 +1444,7 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
|
|||||||
topk_weights: torch.Tensor,
|
topk_weights: torch.Tensor,
|
||||||
topk_ids: torch.Tensor,
|
topk_ids: torch.Tensor,
|
||||||
shared_experts_input: torch.Tensor | None,
|
shared_experts_input: torch.Tensor | None,
|
||||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
) -> torch.Tensor:
|
||||||
assert not self.is_monolithic
|
assert not self.is_monolithic
|
||||||
assert self.moe_kernel is not None
|
assert self.moe_kernel is not None
|
||||||
return self.moe_kernel.apply(
|
return self.moe_kernel.apply(
|
||||||
|
|||||||
@@ -369,7 +369,7 @@ class MoeWNA16Method(FusedMoEMethodBase):
|
|||||||
topk_weights: torch.Tensor,
|
topk_weights: torch.Tensor,
|
||||||
topk_ids: torch.Tensor,
|
topk_ids: torch.Tensor,
|
||||||
shared_experts_input: torch.Tensor | None,
|
shared_experts_input: torch.Tensor | None,
|
||||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
) -> torch.Tensor:
|
||||||
from vllm.model_executor.layers.fused_moe import fused_experts
|
from vllm.model_executor.layers.fused_moe import fused_experts
|
||||||
|
|
||||||
assert layer.activation == MoEActivation.SILU, (
|
assert layer.activation == MoEActivation.SILU, (
|
||||||
|
|||||||
@@ -377,7 +377,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
|||||||
topk_weights: torch.Tensor,
|
topk_weights: torch.Tensor,
|
||||||
topk_ids: torch.Tensor,
|
topk_ids: torch.Tensor,
|
||||||
shared_experts_input: torch.Tensor | None,
|
shared_experts_input: torch.Tensor | None,
|
||||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
) -> torch.Tensor:
|
||||||
assert not self.is_monolithic
|
assert not self.is_monolithic
|
||||||
assert self.moe_kernel is not None
|
assert self.moe_kernel is not None
|
||||||
return self.moe_kernel.apply(
|
return self.moe_kernel.apply(
|
||||||
@@ -398,7 +398,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
|||||||
layer: FusedMoE,
|
layer: FusedMoE,
|
||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
router_logits: torch.Tensor,
|
router_logits: torch.Tensor,
|
||||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
) -> torch.Tensor:
|
||||||
assert self.is_monolithic
|
assert self.is_monolithic
|
||||||
assert self.moe_kernel is not None
|
assert self.moe_kernel is not None
|
||||||
return self.moe_kernel.apply_monolithic(
|
return self.moe_kernel.apply_monolithic(
|
||||||
|
|||||||
@@ -444,7 +444,7 @@ class QuarkW8A8Fp8MoEMethod(QuarkMoEMethod):
|
|||||||
topk_weights: torch.Tensor,
|
topk_weights: torch.Tensor,
|
||||||
topk_ids: torch.Tensor,
|
topk_ids: torch.Tensor,
|
||||||
shared_experts_input: torch.Tensor | None,
|
shared_experts_input: torch.Tensor | None,
|
||||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
) -> torch.Tensor:
|
||||||
if self.rocm_aiter_moe_enabled:
|
if self.rocm_aiter_moe_enabled:
|
||||||
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
|
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
|
||||||
rocm_aiter_fused_experts,
|
rocm_aiter_fused_experts,
|
||||||
@@ -634,7 +634,7 @@ class QuarkW4A8Fp8MoEMethod(QuarkMoEMethod):
|
|||||||
topk_weights: torch.Tensor,
|
topk_weights: torch.Tensor,
|
||||||
topk_ids: torch.Tensor,
|
topk_ids: torch.Tensor,
|
||||||
shared_experts_input: torch.Tensor | None,
|
shared_experts_input: torch.Tensor | None,
|
||||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
) -> torch.Tensor:
|
||||||
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
|
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
|
||||||
rocm_aiter_fused_experts,
|
rocm_aiter_fused_experts,
|
||||||
)
|
)
|
||||||
@@ -1027,7 +1027,7 @@ class QuarkOCP_MX_MoEMethod(QuarkMoEMethod):
|
|||||||
topk_weights: torch.Tensor,
|
topk_weights: torch.Tensor,
|
||||||
topk_ids: torch.Tensor,
|
topk_ids: torch.Tensor,
|
||||||
shared_experts_input: torch.Tensor | None,
|
shared_experts_input: torch.Tensor | None,
|
||||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
) -> torch.Tensor:
|
||||||
if not self.emulate:
|
if not self.emulate:
|
||||||
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
|
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
|
||||||
rocm_aiter_fused_experts,
|
rocm_aiter_fused_experts,
|
||||||
|
|||||||
@@ -94,6 +94,8 @@ def transformers_moe_forward(
|
|||||||
self = forward_context.no_compile_layers[layer_name]
|
self = forward_context.no_compile_layers[layer_name]
|
||||||
self._topk_ids = topk_ids
|
self._topk_ids = topk_ids
|
||||||
# Clone hidden_states because it will be mutated in-place in FusedMoE
|
# Clone hidden_states because it will be mutated in-place in FusedMoE
|
||||||
|
# TODO(bnell): figure out a way to avoid calling runner directly.
|
||||||
|
# it is a hack that the weight are being passed via logits.
|
||||||
return self.runner.forward(hidden_states.clone(), topk_weights)
|
return self.runner.forward(hidden_states.clone(), topk_weights)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user