[Kernels] Split up fused_moe/layer.py, isolate more modular kernel code (#28064)
This commit is contained in:
@@ -6,6 +6,10 @@ import torch
|
|||||||
|
|
||||||
# Fused experts and PrepareFinalize imports
|
# Fused experts and PrepareFinalize imports
|
||||||
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
||||||
|
from vllm.model_executor.layers.fused_moe import TritonExperts
|
||||||
|
from vllm.model_executor.layers.fused_moe.all2all_utils import (
|
||||||
|
maybe_make_prepare_finalize,
|
||||||
|
)
|
||||||
from vllm.model_executor.layers.fused_moe.batched_deep_gemm_moe import (
|
from vllm.model_executor.layers.fused_moe.batched_deep_gemm_moe import (
|
||||||
BatchedDeepGemmExperts,
|
BatchedDeepGemmExperts,
|
||||||
)
|
)
|
||||||
@@ -21,7 +25,6 @@ from vllm.model_executor.layers.fused_moe.fused_batched_moe import (
|
|||||||
BatchedTritonExperts,
|
BatchedTritonExperts,
|
||||||
NaiveBatchedExperts,
|
NaiveBatchedExperts,
|
||||||
)
|
)
|
||||||
from vllm.model_executor.layers.fused_moe.layer import FusedMoEMethodBase, TritonExperts
|
|
||||||
from vllm.model_executor.layers.fused_moe.prepare_finalize import (
|
from vllm.model_executor.layers.fused_moe.prepare_finalize import (
|
||||||
MoEPrepareAndFinalizeNoEP,
|
MoEPrepareAndFinalizeNoEP,
|
||||||
)
|
)
|
||||||
@@ -399,9 +402,7 @@ def make_prepare_finalize(
|
|||||||
quant_config: FusedMoEQuantConfig,
|
quant_config: FusedMoEQuantConfig,
|
||||||
) -> mk.FusedMoEPrepareAndFinalize:
|
) -> mk.FusedMoEPrepareAndFinalize:
|
||||||
if backend != "naive" and backend is not None:
|
if backend != "naive" and backend is not None:
|
||||||
prepare_finalize = FusedMoEMethodBase._maybe_make_prepare_finalize(
|
prepare_finalize = maybe_make_prepare_finalize(moe, quant_config)
|
||||||
moe, quant_config
|
|
||||||
)
|
|
||||||
assert prepare_finalize is not None
|
assert prepare_finalize is not None
|
||||||
return prepare_finalize
|
return prepare_finalize
|
||||||
elif prepare_finalize_type == FlashInferCutlassMoEPrepareAndFinalize:
|
elif prepare_finalize_type == FlashInferCutlassMoEPrepareAndFinalize:
|
||||||
|
|||||||
@@ -25,7 +25,9 @@ from vllm.model_executor.layers.fused_moe.fused_moe import (
|
|||||||
modular_triton_fused_moe,
|
modular_triton_fused_moe,
|
||||||
try_get_optimal_moe_config,
|
try_get_optimal_moe_config,
|
||||||
)
|
)
|
||||||
from vllm.model_executor.layers.fused_moe.layer import FusedMoEModularMethod
|
from vllm.model_executor.layers.fused_moe.fused_moe_modular_method import (
|
||||||
|
FusedMoEModularMethod,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class FusedMoEWithLoRA(BaseLayerWithLoRA):
|
class FusedMoEWithLoRA(BaseLayerWithLoRA):
|
||||||
|
|||||||
@@ -5,9 +5,11 @@ from contextlib import contextmanager
|
|||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from vllm.model_executor.layers.fused_moe.config import FusedMoEConfig
|
from vllm.model_executor.layers.fused_moe.config import FusedMoEConfig
|
||||||
|
from vllm.model_executor.layers.fused_moe.fused_moe_method_base import (
|
||||||
|
FusedMoEMethodBase,
|
||||||
|
)
|
||||||
from vllm.model_executor.layers.fused_moe.layer import (
|
from vllm.model_executor.layers.fused_moe.layer import (
|
||||||
FusedMoE,
|
FusedMoE,
|
||||||
FusedMoEMethodBase,
|
|
||||||
FusedMoeWeightScaleSupported,
|
FusedMoeWeightScaleSupported,
|
||||||
)
|
)
|
||||||
from vllm.model_executor.layers.fused_moe.modular_kernel import (
|
from vllm.model_executor.layers.fused_moe.modular_kernel import (
|
||||||
|
|||||||
160
vllm/model_executor/layers/fused_moe/all2all_utils.py
Normal file
160
vllm/model_executor/layers/fused_moe/all2all_utils.py
Normal file
@@ -0,0 +1,160 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from vllm.distributed import (
|
||||||
|
get_ep_group,
|
||||||
|
)
|
||||||
|
from vllm.model_executor.layers.fused_moe.config import (
|
||||||
|
FusedMoEConfig,
|
||||||
|
FusedMoEParallelConfig,
|
||||||
|
FusedMoEQuantConfig,
|
||||||
|
)
|
||||||
|
from vllm.model_executor.layers.fused_moe.modular_kernel import (
|
||||||
|
FusedMoEPrepareAndFinalize,
|
||||||
|
)
|
||||||
|
from vllm.platforms import current_platform
|
||||||
|
from vllm.utils.import_utils import has_deep_ep, has_pplx
|
||||||
|
|
||||||
|
if current_platform.is_cuda_alike():
|
||||||
|
if has_pplx():
|
||||||
|
from .pplx_prepare_finalize import (
|
||||||
|
PplxPrepareAndFinalize,
|
||||||
|
pplx_hidden_dim_scale_bytes,
|
||||||
|
)
|
||||||
|
if has_deep_ep():
|
||||||
|
from .deepep_ht_prepare_finalize import DeepEPHTPrepareAndFinalize
|
||||||
|
from .deepep_ll_prepare_finalize import (
|
||||||
|
DEEPEP_QUANT_BLOCK_SHAPE,
|
||||||
|
DeepEPLLPrepareAndFinalize,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def maybe_roundup_layer_hidden_size(
|
||||||
|
hidden_size: int,
|
||||||
|
act_dtype: torch.dtype,
|
||||||
|
moe_parallel_config: FusedMoEParallelConfig,
|
||||||
|
) -> int:
|
||||||
|
"""
|
||||||
|
Given layer hidden size and MoE configurations, round up hidden_size
|
||||||
|
if necessary.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
hidden_size: Layer hidden-size
|
||||||
|
act_dtype: Data type of the layer activations.
|
||||||
|
moe_parallel_config: Fused MoE parallelization strategy configuration.
|
||||||
|
|
||||||
|
Return:
|
||||||
|
Rounded up hidden_size if rounding up is required based on the configs
|
||||||
|
and all2all backend.
|
||||||
|
Original hidden size otherwise.
|
||||||
|
"""
|
||||||
|
if moe_parallel_config.use_deepep_ht_kernels:
|
||||||
|
hidden_size = DeepEPHTPrepareAndFinalize.maybe_roundup_layer_hidden_size(
|
||||||
|
hidden_size, act_dtype
|
||||||
|
)
|
||||||
|
|
||||||
|
if moe_parallel_config.use_deepep_ll_kernels:
|
||||||
|
hidden_size = DeepEPLLPrepareAndFinalize.maybe_roundup_layer_hidden_size(
|
||||||
|
hidden_size
|
||||||
|
)
|
||||||
|
|
||||||
|
return hidden_size
|
||||||
|
|
||||||
|
|
||||||
|
def maybe_make_prepare_finalize(
|
||||||
|
moe: FusedMoEConfig,
|
||||||
|
quant_config: FusedMoEQuantConfig | None,
|
||||||
|
) -> FusedMoEPrepareAndFinalize | None:
|
||||||
|
if not moe.moe_parallel_config.use_all2all_kernels:
|
||||||
|
return None
|
||||||
|
|
||||||
|
all2all_manager = get_ep_group().device_communicator.all2all_manager
|
||||||
|
assert all2all_manager is not None
|
||||||
|
|
||||||
|
prepare_finalize: FusedMoEPrepareAndFinalize | None = None
|
||||||
|
|
||||||
|
# TODO: could allow this now
|
||||||
|
assert not moe.use_flashinfer_cutlass_kernels, "Must be created in modelopt.py"
|
||||||
|
|
||||||
|
if moe.use_pplx_kernels:
|
||||||
|
assert quant_config is not None
|
||||||
|
|
||||||
|
hidden_dim_bytes, hidden_scale_bytes = pplx_hidden_dim_scale_bytes(
|
||||||
|
moe.max_num_tokens,
|
||||||
|
moe.hidden_dim,
|
||||||
|
moe.in_dtype,
|
||||||
|
quant_config.quant_dtype,
|
||||||
|
per_act_token_quant=quant_config.per_act_token_quant,
|
||||||
|
block_shape=quant_config.block_shape,
|
||||||
|
)
|
||||||
|
|
||||||
|
all_to_all_args = dict(
|
||||||
|
max_num_tokens=moe.max_num_tokens,
|
||||||
|
num_experts=moe.num_experts,
|
||||||
|
experts_per_token=moe.experts_per_token, # topk
|
||||||
|
rank=all2all_manager.rank,
|
||||||
|
world_size=all2all_manager.world_size,
|
||||||
|
# dp_size actually means tp_size, bug in pplx kernels
|
||||||
|
dp_size=all2all_manager.tp_group.world_size,
|
||||||
|
hidden_dim=moe.hidden_dim,
|
||||||
|
hidden_dim_bytes=hidden_dim_bytes,
|
||||||
|
hidden_dim_scale_bytes=hidden_scale_bytes,
|
||||||
|
)
|
||||||
|
|
||||||
|
num_dispatchers = (
|
||||||
|
all2all_manager.world_size // all2all_manager.tp_group.world_size
|
||||||
|
)
|
||||||
|
|
||||||
|
# Intranode pplx a2a takes a group name while internode does not.
|
||||||
|
if not all2all_manager.internode:
|
||||||
|
all_to_all_args["group_name"] = all2all_manager.cpu_group.group_name
|
||||||
|
|
||||||
|
handle = all2all_manager.get_handle(all_to_all_args)
|
||||||
|
|
||||||
|
prepare_finalize = PplxPrepareAndFinalize(
|
||||||
|
handle,
|
||||||
|
max_num_tokens=moe.max_num_tokens,
|
||||||
|
num_local_experts=moe.num_local_experts,
|
||||||
|
num_dispatchers=num_dispatchers,
|
||||||
|
)
|
||||||
|
elif moe.use_deepep_ht_kernels:
|
||||||
|
assert moe.dp_size == all2all_manager.dp_world_size
|
||||||
|
|
||||||
|
all_to_all_args = dict()
|
||||||
|
handle = all2all_manager.get_handle(all_to_all_args)
|
||||||
|
prepare_finalize = DeepEPHTPrepareAndFinalize(
|
||||||
|
handle,
|
||||||
|
num_dispatchers=all2all_manager.world_size,
|
||||||
|
dp_size=all2all_manager.dp_world_size,
|
||||||
|
rank_expert_offset=all2all_manager.rank * moe.num_local_experts,
|
||||||
|
)
|
||||||
|
|
||||||
|
elif moe.use_deepep_ll_kernels:
|
||||||
|
assert quant_config is not None
|
||||||
|
all_to_all_args = dict(
|
||||||
|
max_num_tokens_per_dp_rank=moe.max_num_tokens,
|
||||||
|
token_hidden_size=moe.hidden_dim,
|
||||||
|
num_ep_ranks=all2all_manager.world_size,
|
||||||
|
num_global_experts=moe.num_experts,
|
||||||
|
num_local_experts=moe.num_experts // all2all_manager.world_size,
|
||||||
|
)
|
||||||
|
handle = all2all_manager.get_handle(all_to_all_args)
|
||||||
|
|
||||||
|
# Note: We may want to use FP8 dispatch just to reduce
|
||||||
|
# data movement.
|
||||||
|
use_fp8_dispatch = (
|
||||||
|
quant_config.quant_dtype == current_platform.fp8_dtype()
|
||||||
|
and quant_config.block_shape == DEEPEP_QUANT_BLOCK_SHAPE
|
||||||
|
)
|
||||||
|
|
||||||
|
prepare_finalize = DeepEPLLPrepareAndFinalize(
|
||||||
|
handle,
|
||||||
|
max_tokens_per_rank=moe.max_num_tokens,
|
||||||
|
num_dispatchers=all2all_manager.world_size,
|
||||||
|
use_fp8_dispatch=use_fp8_dispatch,
|
||||||
|
)
|
||||||
|
|
||||||
|
return prepare_finalize
|
||||||
112
vllm/model_executor/layers/fused_moe/fused_moe_method_base.py
Normal file
112
vllm/model_executor/layers/fused_moe/fused_moe_method_base.py
Normal file
@@ -0,0 +1,112 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
|
from abc import abstractmethod
|
||||||
|
from collections.abc import Callable
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from vllm.logger import init_logger
|
||||||
|
from vllm.model_executor.layers.fused_moe.config import (
|
||||||
|
FusedMoEConfig,
|
||||||
|
FusedMoEQuantConfig,
|
||||||
|
)
|
||||||
|
from vllm.model_executor.layers.fused_moe.modular_kernel import (
|
||||||
|
FusedMoEPermuteExpertsUnpermute,
|
||||||
|
FusedMoEPrepareAndFinalize,
|
||||||
|
)
|
||||||
|
from vllm.model_executor.layers.quantization.base_config import (
|
||||||
|
QuantizeMethodBase,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class FusedMoEMethodBase(QuantizeMethodBase):
|
||||||
|
def __init__(self, moe: FusedMoEConfig):
|
||||||
|
super().__init__()
|
||||||
|
self.moe: FusedMoEConfig = moe
|
||||||
|
self.moe_quant_config: FusedMoEQuantConfig | None = None
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def create_weights(
|
||||||
|
self,
|
||||||
|
layer: torch.nn.Module,
|
||||||
|
num_experts: int,
|
||||||
|
hidden_size: int,
|
||||||
|
intermediate_size_per_partition: int,
|
||||||
|
params_dtype: torch.dtype,
|
||||||
|
**extra_weight_attrs,
|
||||||
|
):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def uses_weight_scale_2_pattern(self) -> bool:
|
||||||
|
"""
|
||||||
|
Returns True if this quantization method uses 'weight_scale_2' pattern
|
||||||
|
for per-tensor weight scales (e.g., FP4 variants), False otherwise.
|
||||||
|
|
||||||
|
This method should be overridden by subclasses that use the
|
||||||
|
'weight_scale_2' pattern instead of the standard 'weight_scale' pattern.
|
||||||
|
"""
|
||||||
|
return False
|
||||||
|
|
||||||
|
def maybe_make_prepare_finalize(self) -> FusedMoEPrepareAndFinalize | None:
|
||||||
|
from .all2all_utils import maybe_make_prepare_finalize
|
||||||
|
|
||||||
|
return maybe_make_prepare_finalize(self.moe, self.moe_quant_config)
|
||||||
|
|
||||||
|
def select_gemm_impl(
|
||||||
|
self,
|
||||||
|
prepare_finalize: FusedMoEPrepareAndFinalize,
|
||||||
|
layer: torch.nn.Module,
|
||||||
|
) -> FusedMoEPermuteExpertsUnpermute:
|
||||||
|
# based on the all2all implementation, select the appropriate
|
||||||
|
# gemm implementation
|
||||||
|
raise NotImplementedError(
|
||||||
|
f"{self.__class__.__name__} must select appropriate gemm "
|
||||||
|
"implementation based on the prepare_finalize"
|
||||||
|
)
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_fused_moe_quant_config(
|
||||||
|
self, layer: torch.nn.Module
|
||||||
|
) -> FusedMoEQuantConfig | None:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@property
|
||||||
|
def topk_indices_dtype(self) -> torch.dtype | None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def supports_eplb(self) -> bool:
|
||||||
|
return False
|
||||||
|
|
||||||
|
@property
|
||||||
|
def allow_inplace(self) -> bool:
|
||||||
|
return False
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def apply(
|
||||||
|
self,
|
||||||
|
layer: torch.nn.Module,
|
||||||
|
x: torch.Tensor,
|
||||||
|
router_logits: torch.Tensor,
|
||||||
|
top_k: int,
|
||||||
|
renormalize: bool,
|
||||||
|
use_grouped_topk: bool = False,
|
||||||
|
topk_group: int | None = None,
|
||||||
|
num_expert_group: int | None = None,
|
||||||
|
global_num_experts: int = -1,
|
||||||
|
expert_map: torch.Tensor | None = None,
|
||||||
|
custom_routing_function: Callable | None = None,
|
||||||
|
scoring_func: str = "softmax",
|
||||||
|
routed_scaling_factor: float = 1.0,
|
||||||
|
e_score_correction_bias: torch.Tensor | None = None,
|
||||||
|
apply_router_weight_on_input: bool = False,
|
||||||
|
activation: str = "silu",
|
||||||
|
enable_eplb: bool = False,
|
||||||
|
expert_load_view: torch.Tensor | None = None,
|
||||||
|
logical_to_physical_map: torch.Tensor | None = None,
|
||||||
|
logical_replica_count: torch.Tensor | None = None,
|
||||||
|
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
raise NotImplementedError
|
||||||
164
vllm/model_executor/layers/fused_moe/fused_moe_modular_method.py
Normal file
164
vllm/model_executor/layers/fused_moe/fused_moe_modular_method.py
Normal file
@@ -0,0 +1,164 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
|
from collections.abc import Callable
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from vllm.logger import init_logger
|
||||||
|
from vllm.model_executor.custom_op import CustomOp
|
||||||
|
from vllm.model_executor.layers.fused_moe.config import (
|
||||||
|
FusedMoEQuantConfig,
|
||||||
|
)
|
||||||
|
from vllm.model_executor.layers.fused_moe.fused_moe_method_base import (
|
||||||
|
FusedMoEMethodBase,
|
||||||
|
)
|
||||||
|
from vllm.model_executor.layers.fused_moe.modular_kernel import (
|
||||||
|
FusedMoEModularKernel,
|
||||||
|
FusedMoEPrepareAndFinalize,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@CustomOp.register("modular_fused_moe")
|
||||||
|
class FusedMoEModularMethod(FusedMoEMethodBase, CustomOp):
|
||||||
|
def __init__(
|
||||||
|
self, old_quant_method: FusedMoEMethodBase, experts: FusedMoEModularKernel
|
||||||
|
):
|
||||||
|
super().__init__(old_quant_method.moe)
|
||||||
|
self.moe_quant_config = old_quant_method.moe_quant_config
|
||||||
|
self.fused_experts = experts
|
||||||
|
self.disable_expert_map = getattr(
|
||||||
|
old_quant_method,
|
||||||
|
"disable_expert_map",
|
||||||
|
not self.fused_experts.supports_expert_map(),
|
||||||
|
)
|
||||||
|
self.old_quant_method = old_quant_method
|
||||||
|
logger.debug("Swapping out %s", self.old_quant_method.__class__.__name__)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def make(
|
||||||
|
moe_layer: torch.nn.Module,
|
||||||
|
old_quant_method: FusedMoEMethodBase,
|
||||||
|
prepare_finalize: FusedMoEPrepareAndFinalize,
|
||||||
|
shared_experts: torch.nn.Module | None,
|
||||||
|
) -> "FusedMoEModularMethod":
|
||||||
|
return FusedMoEModularMethod(
|
||||||
|
old_quant_method,
|
||||||
|
FusedMoEModularKernel(
|
||||||
|
prepare_finalize,
|
||||||
|
old_quant_method.select_gemm_impl(prepare_finalize, moe_layer),
|
||||||
|
shared_experts,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def topk_indices_dtype(self) -> torch.dtype | None:
|
||||||
|
return self.fused_experts.prepare_finalize.topk_indices_dtype()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def supports_eplb(self) -> bool:
|
||||||
|
return self.old_quant_method.supports_eplb
|
||||||
|
|
||||||
|
@property
|
||||||
|
def allow_inplace(self) -> bool:
|
||||||
|
return self.old_quant_method.allow_inplace
|
||||||
|
|
||||||
|
def create_weights(
|
||||||
|
self,
|
||||||
|
layer: torch.nn.Module,
|
||||||
|
num_experts: int,
|
||||||
|
hidden_size: int,
|
||||||
|
intermediate_size_per_partition: int,
|
||||||
|
params_dtype: torch.dtype,
|
||||||
|
**extra_weight_attrs,
|
||||||
|
):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def get_fused_moe_quant_config(
|
||||||
|
self, layer: torch.nn.Module
|
||||||
|
) -> FusedMoEQuantConfig | None:
|
||||||
|
return self.moe_quant_config
|
||||||
|
|
||||||
|
def apply(
|
||||||
|
self,
|
||||||
|
layer: torch.nn.Module,
|
||||||
|
x: torch.Tensor,
|
||||||
|
router_logits: torch.Tensor,
|
||||||
|
top_k: int,
|
||||||
|
renormalize: bool,
|
||||||
|
use_grouped_topk: bool = False,
|
||||||
|
topk_group: int | None = None,
|
||||||
|
num_expert_group: int | None = None,
|
||||||
|
global_num_experts: int = -1,
|
||||||
|
expert_map: torch.Tensor | None = None,
|
||||||
|
custom_routing_function: Callable | None = None,
|
||||||
|
scoring_func: str = "softmax",
|
||||||
|
routed_scaling_factor: float = 1.0,
|
||||||
|
e_score_correction_bias: torch.Tensor | None = None,
|
||||||
|
apply_router_weight_on_input: bool = False,
|
||||||
|
activation: str = "silu",
|
||||||
|
enable_eplb: bool = False,
|
||||||
|
expert_load_view: torch.Tensor | None = None,
|
||||||
|
logical_to_physical_map: torch.Tensor | None = None,
|
||||||
|
logical_replica_count: torch.Tensor | None = None,
|
||||||
|
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
# Is getattr needed?
|
||||||
|
zero_expert_num = getattr(layer, "zero_expert_num", 0)
|
||||||
|
zero_expert_type = getattr(layer, "zero_expert_type", None)
|
||||||
|
|
||||||
|
if enable_eplb:
|
||||||
|
if self.supports_eplb:
|
||||||
|
assert expert_load_view is not None
|
||||||
|
assert logical_to_physical_map is not None
|
||||||
|
assert logical_replica_count is not None
|
||||||
|
else:
|
||||||
|
raise NotImplementedError(
|
||||||
|
"EPLB is not supported for "
|
||||||
|
f"{self.old_quant_method.__class__.__name__}."
|
||||||
|
)
|
||||||
|
|
||||||
|
topk_weights, topk_ids, zero_expert_result = layer.select_experts(
|
||||||
|
hidden_states=x,
|
||||||
|
router_logits=router_logits,
|
||||||
|
use_grouped_topk=use_grouped_topk,
|
||||||
|
top_k=top_k,
|
||||||
|
renormalize=renormalize,
|
||||||
|
topk_group=topk_group,
|
||||||
|
num_expert_group=num_expert_group,
|
||||||
|
custom_routing_function=custom_routing_function,
|
||||||
|
scoring_func=scoring_func,
|
||||||
|
routed_scaling_factor=routed_scaling_factor,
|
||||||
|
e_score_correction_bias=e_score_correction_bias,
|
||||||
|
indices_type=self.topk_indices_dtype,
|
||||||
|
enable_eplb=enable_eplb,
|
||||||
|
expert_map=expert_map,
|
||||||
|
expert_load_view=expert_load_view,
|
||||||
|
logical_to_physical_map=logical_to_physical_map,
|
||||||
|
logical_replica_count=logical_replica_count,
|
||||||
|
global_num_experts=global_num_experts,
|
||||||
|
zero_expert_num=zero_expert_num,
|
||||||
|
zero_expert_type=zero_expert_type,
|
||||||
|
)
|
||||||
|
|
||||||
|
result = self.fused_experts(
|
||||||
|
hidden_states=x,
|
||||||
|
w1=layer.w13_weight,
|
||||||
|
w2=layer.w2_weight,
|
||||||
|
topk_weights=topk_weights,
|
||||||
|
topk_ids=topk_ids,
|
||||||
|
inplace=self.allow_inplace,
|
||||||
|
activation=activation,
|
||||||
|
global_num_experts=global_num_experts,
|
||||||
|
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||||
|
expert_map=None if self.disable_expert_map else expert_map,
|
||||||
|
)
|
||||||
|
|
||||||
|
if zero_expert_num != 0 and zero_expert_type is not None:
|
||||||
|
assert not isinstance(result, tuple), (
|
||||||
|
"Shared + zero experts are mutually exclusive not yet supported"
|
||||||
|
)
|
||||||
|
return result, zero_expert_result
|
||||||
|
else:
|
||||||
|
return result
|
||||||
File diff suppressed because it is too large
Load Diff
@@ -38,7 +38,7 @@ class SharedFusedMoE(FusedMoE):
|
|||||||
and not (
|
and not (
|
||||||
# TODO(wentao): find the root cause and remove this condition
|
# TODO(wentao): find the root cause and remove this condition
|
||||||
self.enable_eplb
|
self.enable_eplb
|
||||||
or (self.use_flashinfer_cutlass_kernels and self.dp_size > 1)
|
or (self.moe_config.use_flashinfer_cutlass_kernels and self.dp_size > 1)
|
||||||
)
|
)
|
||||||
and self._shared_experts is not None
|
and self._shared_experts is not None
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -0,0 +1,578 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
|
from collections.abc import Callable
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
import vllm.envs as envs
|
||||||
|
from vllm._aiter_ops import rocm_aiter_ops
|
||||||
|
from vllm.logger import init_logger
|
||||||
|
from vllm.model_executor.custom_op import CustomOp
|
||||||
|
from vllm.model_executor.layers.fused_moe.config import (
|
||||||
|
FUSED_MOE_UNQUANTIZED_CONFIG,
|
||||||
|
FusedMoEConfig,
|
||||||
|
FusedMoEQuantConfig,
|
||||||
|
biased_moe_quant_config,
|
||||||
|
)
|
||||||
|
from vllm.model_executor.layers.fused_moe.fused_moe_method_base import (
|
||||||
|
FusedMoEMethodBase,
|
||||||
|
)
|
||||||
|
from vllm.model_executor.layers.fused_moe.modular_kernel import (
|
||||||
|
FusedMoEActivationFormat,
|
||||||
|
FusedMoEPermuteExpertsUnpermute,
|
||||||
|
FusedMoEPrepareAndFinalize,
|
||||||
|
)
|
||||||
|
from vllm.model_executor.utils import set_weight_attrs
|
||||||
|
from vllm.platforms import current_platform
|
||||||
|
from vllm.platforms.interface import CpuArchEnum
|
||||||
|
from vllm.utils.flashinfer import has_flashinfer_cutlass_fused_moe
|
||||||
|
|
||||||
|
if current_platform.is_cuda_alike():
|
||||||
|
from .fused_batched_moe import BatchedTritonExperts
|
||||||
|
from .fused_moe import TritonExperts, fused_experts
|
||||||
|
else:
|
||||||
|
fused_experts = None # type: ignore
|
||||||
|
|
||||||
|
if current_platform.is_tpu():
|
||||||
|
from .moe_pallas import fused_moe as fused_moe_pallas
|
||||||
|
else:
|
||||||
|
fused_moe_pallas = None # type: ignore
|
||||||
|
|
||||||
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@CustomOp.register("unquantized_fused_moe")
|
||||||
|
class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
||||||
|
"""MoE method without quantization."""
|
||||||
|
|
||||||
|
def __init__(self, moe: FusedMoEConfig):
|
||||||
|
super().__init__(moe)
|
||||||
|
|
||||||
|
self.rocm_aiter_moe_enabled = rocm_aiter_ops.is_fused_moe_enabled()
|
||||||
|
if self.rocm_aiter_moe_enabled:
|
||||||
|
from .rocm_aiter_fused_moe import rocm_aiter_fused_experts
|
||||||
|
|
||||||
|
self.rocm_aiter_fused_experts = rocm_aiter_fused_experts
|
||||||
|
else:
|
||||||
|
self.rocm_aiter_fused_experts = None # type: ignore
|
||||||
|
|
||||||
|
# FlashInfer CUTLASS MoE is only supported on Hopper and later GPUS
|
||||||
|
self.flashinfer_cutlass_moe_enabled = (
|
||||||
|
has_flashinfer_cutlass_fused_moe()
|
||||||
|
and envs.VLLM_USE_FLASHINFER_MOE_FP16
|
||||||
|
and self.moe.moe_parallel_config.use_ep
|
||||||
|
and self.moe.moe_parallel_config.dp_size == 1
|
||||||
|
and current_platform.get_device_capability()[0] >= 9
|
||||||
|
)
|
||||||
|
if self.flashinfer_cutlass_moe_enabled:
|
||||||
|
logger.info_once(
|
||||||
|
"Enabling FlashInfer CUTLASS MoE for UnquantizedFusedMoEMethod"
|
||||||
|
)
|
||||||
|
from functools import partial
|
||||||
|
|
||||||
|
from .flashinfer_cutlass_moe import flashinfer_cutlass_moe
|
||||||
|
|
||||||
|
self.flashinfer_cutlass_moe = partial(
|
||||||
|
flashinfer_cutlass_moe,
|
||||||
|
quant_config=FUSED_MOE_UNQUANTIZED_CONFIG,
|
||||||
|
tp_rank=self.moe.moe_parallel_config.tp_rank,
|
||||||
|
tp_size=self.moe.moe_parallel_config.tp_size,
|
||||||
|
ep_rank=self.moe.moe_parallel_config.ep_rank,
|
||||||
|
ep_size=self.moe.moe_parallel_config.ep_size,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
if (
|
||||||
|
self.moe.moe_parallel_config.use_ep
|
||||||
|
and self.moe.moe_parallel_config.dp_size == 1
|
||||||
|
):
|
||||||
|
logger.info_once(
|
||||||
|
"FlashInfer CUTLASS MoE is available for EP"
|
||||||
|
" but not enabled, consider setting"
|
||||||
|
" VLLM_USE_FLASHINFER_MOE_FP16=1 to enable it.",
|
||||||
|
scope="local",
|
||||||
|
)
|
||||||
|
elif self.moe.moe_parallel_config.dp_size > 1:
|
||||||
|
logger.info_once(
|
||||||
|
"FlashInfer CUTLASS MoE is currently not available for DP.",
|
||||||
|
scope="local",
|
||||||
|
)
|
||||||
|
self.flashinfer_cutlass_moe = None # type: ignore
|
||||||
|
|
||||||
|
@property
|
||||||
|
def supports_eplb(self) -> bool:
|
||||||
|
return True
|
||||||
|
|
||||||
|
@property
|
||||||
|
def allow_inplace(self) -> bool:
|
||||||
|
return True
|
||||||
|
|
||||||
|
def maybe_make_prepare_finalize(self) -> FusedMoEPrepareAndFinalize | None:
|
||||||
|
if self.rocm_aiter_moe_enabled:
|
||||||
|
return None
|
||||||
|
else:
|
||||||
|
return super().maybe_make_prepare_finalize()
|
||||||
|
|
||||||
|
def select_gemm_impl(
|
||||||
|
self,
|
||||||
|
prepare_finalize: FusedMoEPrepareAndFinalize,
|
||||||
|
layer: torch.nn.Module,
|
||||||
|
) -> FusedMoEPermuteExpertsUnpermute:
|
||||||
|
assert self.moe_quant_config is not None
|
||||||
|
if (
|
||||||
|
prepare_finalize.activation_format
|
||||||
|
== FusedMoEActivationFormat.BatchedExperts
|
||||||
|
):
|
||||||
|
logger.debug("BatchedTritonExperts %s", self.moe)
|
||||||
|
return BatchedTritonExperts(
|
||||||
|
max_num_tokens=self.moe.max_num_tokens,
|
||||||
|
num_dispatchers=prepare_finalize.num_dispatchers(),
|
||||||
|
quant_config=self.moe_quant_config,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logger.debug("TritonExperts %s", self.moe)
|
||||||
|
return TritonExperts(self.moe_quant_config)
|
||||||
|
|
||||||
|
def create_weights(
|
||||||
|
self,
|
||||||
|
layer: torch.nn.Module,
|
||||||
|
num_experts: int,
|
||||||
|
hidden_size: int,
|
||||||
|
intermediate_size_per_partition: int,
|
||||||
|
params_dtype: torch.dtype,
|
||||||
|
**extra_weight_attrs,
|
||||||
|
):
|
||||||
|
if self.moe.is_act_and_mul:
|
||||||
|
w13_up_dim = 2 * intermediate_size_per_partition
|
||||||
|
else:
|
||||||
|
w13_up_dim = intermediate_size_per_partition
|
||||||
|
# Fused gate_up_proj (column parallel)
|
||||||
|
w13_weight = torch.nn.Parameter(
|
||||||
|
torch.empty(
|
||||||
|
num_experts,
|
||||||
|
w13_up_dim,
|
||||||
|
hidden_size,
|
||||||
|
dtype=params_dtype,
|
||||||
|
),
|
||||||
|
requires_grad=False,
|
||||||
|
)
|
||||||
|
layer.register_parameter("w13_weight", w13_weight)
|
||||||
|
set_weight_attrs(w13_weight, extra_weight_attrs)
|
||||||
|
if self.moe.has_bias:
|
||||||
|
w13_bias = torch.nn.Parameter(
|
||||||
|
torch.zeros(num_experts, w13_up_dim, dtype=params_dtype),
|
||||||
|
requires_grad=False,
|
||||||
|
)
|
||||||
|
layer.register_parameter("w13_bias", w13_bias)
|
||||||
|
set_weight_attrs(w13_bias, extra_weight_attrs)
|
||||||
|
# down_proj (row parallel)
|
||||||
|
w2_weight = torch.nn.Parameter(
|
||||||
|
torch.empty(
|
||||||
|
num_experts,
|
||||||
|
hidden_size,
|
||||||
|
intermediate_size_per_partition,
|
||||||
|
dtype=params_dtype,
|
||||||
|
),
|
||||||
|
requires_grad=False,
|
||||||
|
)
|
||||||
|
layer.register_parameter("w2_weight", w2_weight)
|
||||||
|
set_weight_attrs(w2_weight, extra_weight_attrs)
|
||||||
|
if self.moe.has_bias:
|
||||||
|
w2_bias = torch.nn.Parameter(
|
||||||
|
torch.zeros(num_experts, hidden_size, dtype=params_dtype),
|
||||||
|
requires_grad=False,
|
||||||
|
)
|
||||||
|
layer.register_parameter("w2_bias", w2_bias)
|
||||||
|
set_weight_attrs(w2_bias, extra_weight_attrs)
|
||||||
|
|
||||||
|
def _maybe_pad_weight(self, weight: torch.Tensor) -> torch.Tensor:
|
||||||
|
# Pad the weight tensor. This is an optimization on ROCm platform, which
|
||||||
|
# can benefit from tensors located far enough from one another in memory
|
||||||
|
if (
|
||||||
|
envs.VLLM_ROCM_MOE_PADDING
|
||||||
|
and current_platform.is_rocm()
|
||||||
|
and weight.stride(-1) == 1
|
||||||
|
and (weight.stride(-2) * weight.element_size()) % 512 == 0
|
||||||
|
):
|
||||||
|
num_pad = 256 // weight.element_size()
|
||||||
|
weight = F.pad(weight, (0, num_pad), "constant", 0)[..., :-num_pad]
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
return weight
|
||||||
|
|
||||||
|
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||||
|
super().process_weights_after_loading(layer)
|
||||||
|
|
||||||
|
# Padding the weight for better performance on ROCm
|
||||||
|
layer.w13_weight.data = self._maybe_pad_weight(layer.w13_weight.data)
|
||||||
|
layer.w2_weight.data = self._maybe_pad_weight(layer.w2_weight.data)
|
||||||
|
|
||||||
|
if self.rocm_aiter_moe_enabled:
|
||||||
|
shuffled_w13, shuffled_w2 = rocm_aiter_ops.shuffle_weights(
|
||||||
|
layer.w13_weight.data, layer.w2_weight.data
|
||||||
|
)
|
||||||
|
|
||||||
|
layer.w13_weight.data = shuffled_w13
|
||||||
|
layer.w2_weight.data = shuffled_w2
|
||||||
|
|
||||||
|
if self.flashinfer_cutlass_moe_enabled:
|
||||||
|
# Swap halves to arrange as [w3; w1] (kernel expectation)
|
||||||
|
w1_w, w3_w = torch.chunk(layer.w13_weight.data, 2, dim=1)
|
||||||
|
w13_weight_swapped = torch.cat([w3_w, w1_w], dim=1)
|
||||||
|
layer.w13_weight.data = w13_weight_swapped.contiguous()
|
||||||
|
|
||||||
|
if current_platform.is_xpu():
|
||||||
|
import intel_extension_for_pytorch as ipex
|
||||||
|
|
||||||
|
ep_rank_start = self.moe.ep_rank * self.moe.num_local_experts
|
||||||
|
layer.ipex_fusion = ipex.llm.modules.GatedMLPMOE(
|
||||||
|
layer.w13_weight,
|
||||||
|
layer.w2_weight,
|
||||||
|
use_prepack=True,
|
||||||
|
experts_start_id=ep_rank_start,
|
||||||
|
)
|
||||||
|
elif current_platform.is_cpu():
|
||||||
|
from vllm.model_executor.layers.fused_moe import cpu_fused_moe
|
||||||
|
|
||||||
|
if current_platform.get_cpu_architecture() == CpuArchEnum.X86:
|
||||||
|
from vllm.model_executor.layers.utils import check_cpu_sgl_kernel
|
||||||
|
|
||||||
|
dtype_w13 = layer.w13_weight.dtype
|
||||||
|
_, n_w13, k_w13 = layer.w13_weight.size()
|
||||||
|
dtype_w2 = layer.w2_weight.dtype
|
||||||
|
_, n_w2, k_w2 = layer.w2_weight.size()
|
||||||
|
if (
|
||||||
|
envs.VLLM_CPU_SGL_KERNEL
|
||||||
|
and check_cpu_sgl_kernel(n_w13, k_w13, dtype_w13)
|
||||||
|
and check_cpu_sgl_kernel(n_w2, k_w2, dtype_w2)
|
||||||
|
):
|
||||||
|
packed_w13_weight = torch.ops._C.convert_weight_packed(
|
||||||
|
layer.w13_weight
|
||||||
|
)
|
||||||
|
assert packed_w13_weight.size() == layer.w13_weight.size()
|
||||||
|
layer.w13_weight.copy_(packed_w13_weight)
|
||||||
|
del packed_w13_weight
|
||||||
|
packed_w2_weight = torch.ops._C.convert_weight_packed(
|
||||||
|
layer.w2_weight
|
||||||
|
)
|
||||||
|
assert packed_w2_weight.size() == layer.w2_weight.size()
|
||||||
|
layer.w2_weight.copy_(packed_w2_weight)
|
||||||
|
layer.cpu_fused_moe = cpu_fused_moe.SGLFusedMOE(layer)
|
||||||
|
else:
|
||||||
|
layer.cpu_fused_moe = cpu_fused_moe.IPEXFusedMOE(layer)
|
||||||
|
else:
|
||||||
|
layer.cpu_fused_moe = cpu_fused_moe.CPUFusedMOE(layer)
|
||||||
|
|
||||||
|
def apply(
|
||||||
|
self,
|
||||||
|
layer: torch.nn.Module,
|
||||||
|
x: torch.Tensor,
|
||||||
|
router_logits: torch.Tensor,
|
||||||
|
top_k: int,
|
||||||
|
renormalize: bool,
|
||||||
|
use_grouped_topk: bool = False,
|
||||||
|
topk_group: int | None = None,
|
||||||
|
num_expert_group: int | None = None,
|
||||||
|
global_num_experts: int = -1,
|
||||||
|
expert_map: torch.Tensor | None = None,
|
||||||
|
custom_routing_function: Callable | None = None,
|
||||||
|
scoring_func: str = "softmax",
|
||||||
|
routed_scaling_factor: float = 1.0,
|
||||||
|
e_score_correction_bias: torch.Tensor | None = None,
|
||||||
|
apply_router_weight_on_input: bool = False,
|
||||||
|
activation: str = "silu",
|
||||||
|
enable_eplb: bool = False,
|
||||||
|
expert_load_view: torch.Tensor | None = None,
|
||||||
|
logical_to_physical_map: torch.Tensor | None = None,
|
||||||
|
logical_replica_count: torch.Tensor | None = None,
|
||||||
|
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
if enable_eplb:
|
||||||
|
assert expert_load_view is not None
|
||||||
|
assert logical_to_physical_map is not None
|
||||||
|
assert logical_replica_count is not None
|
||||||
|
|
||||||
|
return self.forward(
|
||||||
|
x=x,
|
||||||
|
layer=layer,
|
||||||
|
router_logits=router_logits,
|
||||||
|
top_k=top_k,
|
||||||
|
renormalize=renormalize,
|
||||||
|
use_grouped_topk=use_grouped_topk,
|
||||||
|
topk_group=topk_group,
|
||||||
|
num_expert_group=num_expert_group,
|
||||||
|
global_num_experts=global_num_experts,
|
||||||
|
expert_map=expert_map,
|
||||||
|
custom_routing_function=custom_routing_function,
|
||||||
|
scoring_func=scoring_func,
|
||||||
|
routed_scaling_factor=routed_scaling_factor,
|
||||||
|
e_score_correction_bias=e_score_correction_bias,
|
||||||
|
activation=activation,
|
||||||
|
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||||
|
enable_eplb=enable_eplb,
|
||||||
|
expert_load_view=expert_load_view,
|
||||||
|
logical_to_physical_map=logical_to_physical_map,
|
||||||
|
logical_replica_count=logical_replica_count,
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_fused_moe_quant_config(
|
||||||
|
self, layer: torch.nn.Module
|
||||||
|
) -> FusedMoEQuantConfig | None:
|
||||||
|
if self.moe.has_bias:
|
||||||
|
return biased_moe_quant_config(
|
||||||
|
layer.w13_bias,
|
||||||
|
layer.w2_bias,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return FUSED_MOE_UNQUANTIZED_CONFIG
|
||||||
|
|
||||||
|
def forward_cuda(
|
||||||
|
self,
|
||||||
|
layer: torch.nn.Module,
|
||||||
|
x: torch.Tensor,
|
||||||
|
use_grouped_topk: bool,
|
||||||
|
top_k: int,
|
||||||
|
router_logits: torch.Tensor,
|
||||||
|
renormalize: bool,
|
||||||
|
topk_group: int | None = None,
|
||||||
|
num_expert_group: int | None = None,
|
||||||
|
global_num_experts: int = -1,
|
||||||
|
expert_map: torch.Tensor | None = None,
|
||||||
|
custom_routing_function: Callable | None = None,
|
||||||
|
scoring_func: str = "softmax",
|
||||||
|
routed_scaling_factor: float = 1.0,
|
||||||
|
e_score_correction_bias: torch.Tensor | None = None,
|
||||||
|
apply_router_weight_on_input: bool = False,
|
||||||
|
activation: str = "silu",
|
||||||
|
enable_eplb: bool = False,
|
||||||
|
expert_load_view: torch.Tensor | None = None,
|
||||||
|
logical_to_physical_map: torch.Tensor | None = None,
|
||||||
|
logical_replica_count: torch.Tensor | None = None,
|
||||||
|
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
zero_expert_num = getattr(layer, "zero_expert_num", 0)
|
||||||
|
zero_expert_type = getattr(layer, "zero_expert_type", None)
|
||||||
|
|
||||||
|
topk_weights, topk_ids, zero_expert_result = layer.select_experts(
|
||||||
|
hidden_states=x,
|
||||||
|
router_logits=router_logits,
|
||||||
|
use_grouped_topk=use_grouped_topk,
|
||||||
|
top_k=top_k,
|
||||||
|
renormalize=renormalize,
|
||||||
|
topk_group=topk_group,
|
||||||
|
num_expert_group=num_expert_group,
|
||||||
|
custom_routing_function=custom_routing_function,
|
||||||
|
scoring_func=scoring_func,
|
||||||
|
routed_scaling_factor=routed_scaling_factor,
|
||||||
|
e_score_correction_bias=e_score_correction_bias,
|
||||||
|
indices_type=self.topk_indices_dtype,
|
||||||
|
enable_eplb=enable_eplb,
|
||||||
|
expert_map=expert_map,
|
||||||
|
expert_load_view=expert_load_view,
|
||||||
|
logical_to_physical_map=logical_to_physical_map,
|
||||||
|
logical_replica_count=logical_replica_count,
|
||||||
|
global_num_experts=global_num_experts,
|
||||||
|
zero_expert_num=zero_expert_num,
|
||||||
|
zero_expert_type=zero_expert_type,
|
||||||
|
num_fused_shared_experts=layer.num_fused_shared_experts,
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.rocm_aiter_moe_enabled:
|
||||||
|
result = self.rocm_aiter_fused_experts(
|
||||||
|
hidden_states=x,
|
||||||
|
w1=layer.w13_weight,
|
||||||
|
w2=layer.w2_weight,
|
||||||
|
topk_weights=topk_weights,
|
||||||
|
topk_ids=topk_ids,
|
||||||
|
expert_map=expert_map,
|
||||||
|
activation=activation,
|
||||||
|
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||||
|
)
|
||||||
|
elif self.flashinfer_cutlass_moe_enabled:
|
||||||
|
return self.flashinfer_cutlass_moe(
|
||||||
|
hidden_states=x,
|
||||||
|
w1=layer.w13_weight,
|
||||||
|
w2=layer.w2_weight,
|
||||||
|
topk_weights=topk_weights,
|
||||||
|
topk_ids=topk_ids,
|
||||||
|
activation=activation,
|
||||||
|
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
result = fused_experts(
|
||||||
|
hidden_states=x,
|
||||||
|
w1=layer.w13_weight,
|
||||||
|
w2=layer.w2_weight,
|
||||||
|
topk_weights=topk_weights,
|
||||||
|
topk_ids=topk_ids,
|
||||||
|
inplace=True,
|
||||||
|
activation=activation,
|
||||||
|
quant_config=self.moe_quant_config,
|
||||||
|
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||||
|
global_num_experts=global_num_experts,
|
||||||
|
expert_map=expert_map,
|
||||||
|
)
|
||||||
|
|
||||||
|
if zero_expert_num != 0 and zero_expert_type is not None:
|
||||||
|
assert not isinstance(result, tuple), (
|
||||||
|
"Shared + zero experts are mutually exclusive not yet supported"
|
||||||
|
)
|
||||||
|
return result, zero_expert_result
|
||||||
|
else:
|
||||||
|
return result
|
||||||
|
|
||||||
|
def forward_cpu(
|
||||||
|
self,
|
||||||
|
layer: torch.nn.Module,
|
||||||
|
x: torch.Tensor,
|
||||||
|
use_grouped_topk: bool,
|
||||||
|
top_k: int,
|
||||||
|
router_logits: torch.Tensor,
|
||||||
|
renormalize: bool,
|
||||||
|
topk_group: int | None = None,
|
||||||
|
num_expert_group: int | None = None,
|
||||||
|
global_num_experts: int = -1,
|
||||||
|
expert_map: torch.Tensor | None = None,
|
||||||
|
custom_routing_function: Callable | None = None,
|
||||||
|
scoring_func: str = "softmax",
|
||||||
|
routed_scaling_factor: float = 1.0,
|
||||||
|
e_score_correction_bias: torch.Tensor | None = None,
|
||||||
|
apply_router_weight_on_input: bool = False,
|
||||||
|
activation: str = "silu",
|
||||||
|
enable_eplb: bool = False,
|
||||||
|
expert_load_view: torch.Tensor | None = None,
|
||||||
|
logical_to_physical_map: torch.Tensor | None = None,
|
||||||
|
logical_replica_count: torch.Tensor | None = None,
|
||||||
|
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
if (
|
||||||
|
enable_eplb is not False
|
||||||
|
or expert_load_view is not None
|
||||||
|
or logical_to_physical_map is not None
|
||||||
|
or logical_replica_count is not None
|
||||||
|
):
|
||||||
|
raise NotImplementedError("Expert load balancing is not supported for CPU.")
|
||||||
|
return layer.cpu_fused_moe(
|
||||||
|
layer,
|
||||||
|
x,
|
||||||
|
use_grouped_topk,
|
||||||
|
top_k,
|
||||||
|
router_logits,
|
||||||
|
renormalize,
|
||||||
|
topk_group,
|
||||||
|
num_expert_group,
|
||||||
|
global_num_experts,
|
||||||
|
expert_map,
|
||||||
|
custom_routing_function,
|
||||||
|
scoring_func,
|
||||||
|
routed_scaling_factor,
|
||||||
|
e_score_correction_bias,
|
||||||
|
apply_router_weight_on_input,
|
||||||
|
activation,
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward_xpu(
|
||||||
|
self,
|
||||||
|
layer: torch.nn.Module,
|
||||||
|
x: torch.Tensor,
|
||||||
|
use_grouped_topk: bool,
|
||||||
|
top_k: int,
|
||||||
|
router_logits: torch.Tensor,
|
||||||
|
renormalize: bool,
|
||||||
|
topk_group: int | None = None,
|
||||||
|
num_expert_group: int | None = None,
|
||||||
|
global_num_experts: int = -1,
|
||||||
|
expert_map: torch.Tensor | None = None,
|
||||||
|
custom_routing_function: Callable | None = None,
|
||||||
|
scoring_func: str = "softmax",
|
||||||
|
routed_scaling_factor: float = 1.0,
|
||||||
|
e_score_correction_bias: torch.Tensor | None = None,
|
||||||
|
apply_router_weight_on_input: bool = False,
|
||||||
|
activation: str = "silu",
|
||||||
|
enable_eplb: bool = False,
|
||||||
|
expert_load_view: torch.Tensor | None = None,
|
||||||
|
logical_to_physical_map: torch.Tensor | None = None,
|
||||||
|
logical_replica_count: torch.Tensor | None = None,
|
||||||
|
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
if (
|
||||||
|
enable_eplb is not False
|
||||||
|
or expert_load_view is not None
|
||||||
|
or logical_to_physical_map is not None
|
||||||
|
or logical_replica_count is not None
|
||||||
|
):
|
||||||
|
raise NotImplementedError("Expert load balancing is not supported for XPU.")
|
||||||
|
return layer.ipex_fusion(
|
||||||
|
x,
|
||||||
|
use_grouped_topk,
|
||||||
|
top_k,
|
||||||
|
router_logits,
|
||||||
|
renormalize,
|
||||||
|
topk_group,
|
||||||
|
num_expert_group,
|
||||||
|
custom_routing_function=custom_routing_function,
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward_tpu(
|
||||||
|
self,
|
||||||
|
layer: torch.nn.Module,
|
||||||
|
x: torch.Tensor,
|
||||||
|
use_grouped_topk: bool,
|
||||||
|
top_k: int,
|
||||||
|
router_logits: torch.Tensor,
|
||||||
|
renormalize: bool,
|
||||||
|
topk_group: int | None = None,
|
||||||
|
num_expert_group: int | None = None,
|
||||||
|
global_num_experts: int = -1,
|
||||||
|
expert_map: torch.Tensor | None = None,
|
||||||
|
custom_routing_function: Callable | None = None,
|
||||||
|
scoring_func: str = "softmax",
|
||||||
|
routed_scaling_factor: float = 1.0,
|
||||||
|
e_score_correction_bias: torch.Tensor | None = None,
|
||||||
|
apply_router_weight_on_input: bool = False,
|
||||||
|
activation: str = "silu",
|
||||||
|
enable_eplb: bool = False,
|
||||||
|
expert_load_view: torch.Tensor | None = None,
|
||||||
|
logical_to_physical_map: torch.Tensor | None = None,
|
||||||
|
logical_replica_count: torch.Tensor | None = None,
|
||||||
|
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
assert not use_grouped_topk
|
||||||
|
assert num_expert_group is None
|
||||||
|
assert topk_group is None
|
||||||
|
assert custom_routing_function is None
|
||||||
|
assert apply_router_weight_on_input is False
|
||||||
|
if scoring_func != "softmax":
|
||||||
|
raise NotImplementedError(
|
||||||
|
"Only softmax scoring function is supported for TPU."
|
||||||
|
)
|
||||||
|
if e_score_correction_bias is not None:
|
||||||
|
raise NotImplementedError(
|
||||||
|
"Expert score correction bias is not supported for TPU."
|
||||||
|
)
|
||||||
|
assert activation == "silu", f"{activation} is not supported for TPU."
|
||||||
|
assert routed_scaling_factor == 1.0, (
|
||||||
|
f"routed_scaling_factor {routed_scaling_factor} is not supported for TPU."
|
||||||
|
)
|
||||||
|
if (
|
||||||
|
enable_eplb is not False
|
||||||
|
or expert_load_view is not None
|
||||||
|
or logical_to_physical_map is not None
|
||||||
|
or logical_replica_count is not None
|
||||||
|
):
|
||||||
|
raise NotImplementedError("Expert load balancing is not supported for TPU.")
|
||||||
|
return fused_moe_pallas(
|
||||||
|
hidden_states=x,
|
||||||
|
w1=layer.w13_weight,
|
||||||
|
w2=layer.w2_weight,
|
||||||
|
topk=top_k,
|
||||||
|
gating_output=router_logits,
|
||||||
|
global_num_experts=global_num_experts,
|
||||||
|
expert_map=expert_map,
|
||||||
|
renormalize=renormalize,
|
||||||
|
)
|
||||||
|
|
||||||
|
if current_platform.is_tpu():
|
||||||
|
forward_native = forward_tpu
|
||||||
|
elif current_platform.is_cpu():
|
||||||
|
forward_native = forward_cpu
|
||||||
|
elif current_platform.is_xpu():
|
||||||
|
forward_native = forward_xpu
|
||||||
|
else:
|
||||||
|
forward_native = forward_cuda
|
||||||
@@ -741,15 +741,10 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
|||||||
weight_scale=w2_scale, flex_ctx=FlexCtx(rhs_data=w2_flex)
|
weight_scale=w2_scale, flex_ctx=FlexCtx(rhs_data=w2_flex)
|
||||||
)
|
)
|
||||||
|
|
||||||
self.w13_weight_triton_tensor = w13_weight
|
self.w13_weight = w13_weight
|
||||||
self.w2_weight_triton_tensor = w2_weight
|
self.w2_weight = w2_weight
|
||||||
|
layer.w13_weight = w13_weight
|
||||||
# need to delete the original weights to save memory on single GPU
|
layer.w2_weight = w2_weight
|
||||||
del layer.w13_weight
|
|
||||||
del layer.w2_weight
|
|
||||||
layer.w13_weight = None
|
|
||||||
layer.w2_weight = None
|
|
||||||
torch.cuda.empty_cache()
|
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unsupported backend: {self.mxfp4_backend}")
|
raise ValueError(f"Unsupported backend: {self.mxfp4_backend}")
|
||||||
|
|
||||||
@@ -824,18 +819,6 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
|||||||
"EP batched experts format"
|
"EP batched experts format"
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
layer.w13_weight = (
|
|
||||||
self.w13_weight_triton_tensor
|
|
||||||
if layer.w13_weight is None
|
|
||||||
else layer.w13_weight
|
|
||||||
)
|
|
||||||
layer.w2_weight = (
|
|
||||||
self.w2_weight_triton_tensor
|
|
||||||
if layer.w2_weight is None
|
|
||||||
else layer.w2_weight
|
|
||||||
)
|
|
||||||
assert all([w is not None for w in [layer.w13_weight, layer.w2_weight]])
|
|
||||||
|
|
||||||
assert self.moe_quant_config is not None
|
assert self.moe_quant_config is not None
|
||||||
if (
|
if (
|
||||||
self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_MXFP8_TRTLLM
|
self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_MXFP8_TRTLLM
|
||||||
@@ -1070,8 +1053,8 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
|||||||
|
|
||||||
return triton_kernel_moe_forward(
|
return triton_kernel_moe_forward(
|
||||||
hidden_states=x,
|
hidden_states=x,
|
||||||
w1=self.w13_weight_triton_tensor,
|
w1=self.w13_weight,
|
||||||
w2=self.w2_weight_triton_tensor,
|
w2=self.w2_weight,
|
||||||
gating_output=router_logits,
|
gating_output=router_logits,
|
||||||
topk=top_k,
|
topk=top_k,
|
||||||
renormalize=renormalize,
|
renormalize=renormalize,
|
||||||
|
|||||||
Reference in New Issue
Block a user