[Kernels] MoE refactor (#19636)
Signed-off-by: Bill Nell <bnell@redhat.com> Signed-off-by: ElizaWszola <ewszola@redhat.com> Co-authored-by: ElizaWszola <ewszola@redhat.com>
This commit is contained in:
@@ -3,27 +3,30 @@
|
||||
|
||||
from abc import abstractmethod
|
||||
from collections.abc import Iterable
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from typing import Callable, Literal, Optional, Union, overload
|
||||
from typing import Callable, Literal, Optional, overload
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from compressed_tensors.quantization import (QuantizationArgs,
|
||||
QuantizationStrategy,
|
||||
QuantizationType)
|
||||
from torch.nn.parameter import UninitializedParameter
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.config import ParallelConfig, get_current_vllm_config
|
||||
from vllm.config import get_current_vllm_config
|
||||
from vllm.distributed import (get_dp_group, get_ep_group,
|
||||
get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size,
|
||||
get_world_group,
|
||||
tensor_model_parallel_all_reduce)
|
||||
from vllm.distributed.eplb.eplb_state import EplbState
|
||||
from vllm.forward_context import ForwardContext, get_forward_context
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.custom_op import CustomOp
|
||||
# yapf: disable
|
||||
from vllm.model_executor.layers.fused_moe.config import (
|
||||
FusedMoEConfig, FusedMoEParallelConfig)
|
||||
# yapf: enable
|
||||
from vllm.model_executor.layers.fused_moe.modular_kernel import (
|
||||
FusedMoEActivationFormat, FusedMoEModularKernel,
|
||||
FusedMoEPermuteExpertsUnpermute, FusedMoEPrepareAndFinalize)
|
||||
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
|
||||
is_rocm_aiter_moe_enabled)
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
@@ -36,14 +39,12 @@ from vllm.utils import direct_register_custom_op, has_deep_ep, has_pplx
|
||||
if current_platform.is_cuda_alike():
|
||||
from .fused_batched_moe import BatchedTritonExperts
|
||||
from .fused_moe import TritonExperts, fused_experts
|
||||
from .modular_kernel import (FusedMoEModularKernel,
|
||||
FusedMoEPermuteExpertsUnpermute,
|
||||
FusedMoEPrepareAndFinalize)
|
||||
if has_pplx():
|
||||
from .pplx_prepare_finalize import PplxPrepareAndFinalize
|
||||
from .pplx_prepare_finalize import (PplxPrepareAndFinalize,
|
||||
pplx_hidden_dim_scale_bytes)
|
||||
if has_deep_ep():
|
||||
from .deepep_ht_prepare_finalize import DeepEPHTPrepareAndFinalize
|
||||
from .deepep_ll_prepare_finalize import (DEEPEP_QUANT_BLOCK_SIZE,
|
||||
from .deepep_ll_prepare_finalize import (DEEPEP_QUANT_BLOCK_SHAPE,
|
||||
DeepEPLLPrepareAndFinalize)
|
||||
else:
|
||||
fused_experts = None # type: ignore
|
||||
@@ -60,209 +61,10 @@ if current_platform.is_tpu():
|
||||
from .moe_pallas import fused_moe as fused_moe_pallas
|
||||
else:
|
||||
fused_moe_pallas = None # type: ignore
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class FusedMoEParallelConfig:
|
||||
tp_size: int
|
||||
dp_size: int
|
||||
ep_size: int
|
||||
tp_rank: int
|
||||
dp_rank: int
|
||||
ep_rank: int
|
||||
|
||||
use_ep: bool # whether to use EP or not
|
||||
|
||||
@property
|
||||
def use_all2all_kernels(self):
|
||||
return self.dp_size > 1 and self.use_ep
|
||||
|
||||
@property
|
||||
def use_pplx_kernels(self):
|
||||
return (self.use_all2all_kernels
|
||||
and envs.VLLM_ALL2ALL_BACKEND == "pplx")
|
||||
|
||||
@property
|
||||
def use_deepep_ht_kernels(self):
|
||||
return (self.use_all2all_kernels
|
||||
and envs.VLLM_ALL2ALL_BACKEND == "deepep_high_throughput")
|
||||
|
||||
@property
|
||||
def use_deepep_ll_kernels(self):
|
||||
return (self.use_all2all_kernels
|
||||
and envs.VLLM_ALL2ALL_BACKEND == "deepep_low_latency")
|
||||
|
||||
@staticmethod
|
||||
def make(tp_size_: int, dp_size_: int,
|
||||
vllm_parallel_config: ParallelConfig) -> "FusedMoEParallelConfig":
|
||||
"""
|
||||
Determine MoE parallel configuration. Based on the input tp_size_,
|
||||
dp_size_, ep_size_ and vllm's parallel config, determine what
|
||||
level's of parallelism to use in the fused moe layer.
|
||||
|
||||
Args:
|
||||
tp_size_ (int): tp_size passed into the FusedMoE constructor.
|
||||
dp_size_ (int): dp_size passed into the FusedMoE constructor.
|
||||
ep_size_ (int): ep_size passed into the FusedMoE constructor.
|
||||
vllm_parallel_config (ParallelConfig): vllm's parallel config
|
||||
object.
|
||||
|
||||
Examples:
|
||||
When there is no parallelism requested, i.e. tp_size_ = dp_size_ = 1,
|
||||
we simply return the sizes unaltered and the ranks set to 0.
|
||||
|
||||
Expert Parallelism is considered only when either dp_size_ or tp_size_
|
||||
is non trivial.
|
||||
|
||||
When TP = 2, DP = 1 and EP = False, the configuration on different
|
||||
devices,
|
||||
- device 0 : TP = {2, 0} DP = {1, 0} EP = {1, 0} //
|
||||
legend : {size, rank}
|
||||
- device 1 : TP = {2, 1} DP = {1, 0} EP = {1, 0}
|
||||
- Comment : Tensors are sharded across 2 devices.
|
||||
|
||||
When TP = 1, DP = 2 and EP = False, the configuration on different
|
||||
devices,
|
||||
- device 0 : TP = {2, 0} DP = {2, 0} EP = {1, 0}
|
||||
- device 1 : TP = {2, 1} DP = {2, 1} EP = {1, 0}
|
||||
- Comment: There are 2 engine instances and the tensors are sharded
|
||||
across 2 decvices.
|
||||
|
||||
When TP = 2, DP = 2 and EP = False, the configuration on different
|
||||
devices,
|
||||
- device 0: TP = {4, 0} DP = {2, 0} EP = {1, 0}
|
||||
- device 1: TP = {4, 1} DP = {2, 0} EP = {1, 0}
|
||||
- device 2: TP = {4, 2} DP = {2, 1} EP = {1, 0}
|
||||
- device 3: TP = {4, 3} DP = {2, 1} EP = {1, 0}
|
||||
- Comment: There are 2 engine instances and the tensors are sharded
|
||||
across 4 devices.
|
||||
|
||||
When, TP = 2, DP = 1 and EP = True, the configuration on different
|
||||
devices,
|
||||
- device 0: TP = {1, 0} DP = {1, 0} EP = {2, 0}
|
||||
- device 1: TP = {1, 0} DP = {1, 0} EP = {2, 1}
|
||||
- Comment: The experts are split between the 2 devices.
|
||||
|
||||
When, TP = 1, DP = 2 and EP = True, the configuration on different
|
||||
devices,
|
||||
- device 0: TP = {1, 0} DP = {2, 0} EP = {2, 0}
|
||||
- device 1: TP = {1, 0} DP = {2, 1} EP = {2, 1}
|
||||
- Comment: There are 2 engine instances and the experts are split
|
||||
between the 2 devices.
|
||||
|
||||
When TP = 2, DP = 2 and EP = True, the configuration on different
|
||||
devices,
|
||||
- device 0: TP = {1, 0} DP = {2, 0} EP = {4, 0}
|
||||
- device 1: TP = {1, 0} DP = {2, 0} EP = {4, 1}
|
||||
- device 2: TP = {1, 0} DP = {2, 1} EP = {4, 2}
|
||||
- device 3: TP = {1, 0} DP = {2, 1} EP = {4, 3}
|
||||
- Comment: There are 2 engine instances and the experts are split
|
||||
between the 4 devices.
|
||||
"""
|
||||
|
||||
def flatten_tp_across_dp(dp_rank: int):
|
||||
tp_rank = 0 if tp_size_ == 1 else get_tensor_model_parallel_rank()
|
||||
# There are actually dp_size_ * tp_size_ devices. Update tp_size
|
||||
# and tp_rank so we shard across all devices.
|
||||
tp_size = dp_size_ * tp_size_
|
||||
tp_rank = dp_rank * tp_size_ + tp_rank
|
||||
return tp_size, tp_rank
|
||||
|
||||
use_ep = (dp_size_ * tp_size_ > 1
|
||||
and vllm_parallel_config.enable_expert_parallel)
|
||||
|
||||
dp_size = dp_size_
|
||||
dp_rank = get_dp_group().rank_in_group if dp_size > 1 else 0
|
||||
tp_size, tp_rank = flatten_tp_across_dp(dp_rank)
|
||||
|
||||
if not use_ep:
|
||||
return FusedMoEParallelConfig(tp_size=tp_size,
|
||||
tp_rank=tp_rank,
|
||||
dp_size=dp_size,
|
||||
dp_rank=dp_rank,
|
||||
ep_size=1,
|
||||
ep_rank=0,
|
||||
use_ep=False)
|
||||
# DP + EP / TP + EP / DP + TP + EP
|
||||
assert use_ep
|
||||
# In EP, each device owns a set of experts fully. There is no tensor
|
||||
# parallel update tp_size, tp_rank, ep_size and ep_rank to reflect that.
|
||||
ep_size = tp_size
|
||||
ep_rank = tp_rank
|
||||
return FusedMoEParallelConfig(tp_size=1,
|
||||
tp_rank=0,
|
||||
dp_size=dp_size,
|
||||
dp_rank=dp_rank,
|
||||
ep_size=ep_size,
|
||||
ep_rank=ep_rank,
|
||||
use_ep=True)
|
||||
|
||||
|
||||
# Adapted from pplx-kernels tests/all_to_all_utils.py
|
||||
@dataclass
|
||||
class MoEConfig:
|
||||
num_experts: int
|
||||
experts_per_token: int
|
||||
hidden_dim: int
|
||||
|
||||
num_local_experts: int
|
||||
moe_parallel_config: FusedMoEParallelConfig
|
||||
|
||||
in_dtype: torch.dtype # The activation type.
|
||||
quant_dtype: torch.dtype = None
|
||||
|
||||
# TODO: add more quantization params, blocked, per-token, etc.
|
||||
block_size: int = 128
|
||||
|
||||
max_num_tokens: int = envs.VLLM_MOE_DP_CHUNK_SIZE
|
||||
|
||||
def __post_init__(self):
|
||||
if self.dp_size > 1:
|
||||
logger.debug("Using MOEConfig::max_num_tokens=%d",
|
||||
self.max_num_tokens)
|
||||
|
||||
@property
|
||||
def tp_size(self):
|
||||
return self.moe_parallel_config.tp_size
|
||||
|
||||
@property
|
||||
def dp_size(self):
|
||||
return self.moe_parallel_config.dp_size
|
||||
|
||||
@property
|
||||
def ep_size(self):
|
||||
return self.moe_parallel_config.ep_size
|
||||
|
||||
@property
|
||||
def tp_rank(self):
|
||||
return self.moe_parallel_config.tp_rank
|
||||
|
||||
@property
|
||||
def dp_rank(self):
|
||||
return self.moe_parallel_config.dp_rank
|
||||
|
||||
@property
|
||||
def ep_rank(self):
|
||||
return self.moe_parallel_config.ep_rank
|
||||
|
||||
@property
|
||||
def use_ep(self):
|
||||
return self.moe_parallel_config.use_ep
|
||||
|
||||
@property
|
||||
def use_pplx_kernels(self):
|
||||
return self.moe_parallel_config.use_pplx_kernels
|
||||
|
||||
@property
|
||||
def use_deepep_ht_kernels(self):
|
||||
return self.moe_parallel_config.use_deepep_ht_kernels
|
||||
|
||||
@property
|
||||
def use_deepep_ll_kernels(self):
|
||||
return self.moe_parallel_config.use_deepep_ll_kernels
|
||||
|
||||
|
||||
class FusedMoeWeightScaleSupported(Enum):
|
||||
TENSOR = "tensor"
|
||||
CHANNEL = "channel"
|
||||
@@ -270,21 +72,9 @@ class FusedMoeWeightScaleSupported(Enum):
|
||||
BLOCK = "block"
|
||||
|
||||
|
||||
def get_quant_config_input_activations(
|
||||
quant_config: Optional[QuantizationConfig]
|
||||
) -> Optional[QuantizationArgs]:
|
||||
if (quant_config is not None and hasattr(quant_config, 'target_scheme_map')
|
||||
and "Linear" in quant_config.target_scheme_map and
|
||||
"input_activations" in quant_config.target_scheme_map["Linear"]):
|
||||
return quant_config.target_scheme_map["Linear"].get(
|
||||
"input_activations")
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
class FusedMoEMethodBase(QuantizeMethodBase):
|
||||
|
||||
moe: MoEConfig
|
||||
moe: FusedMoEConfig
|
||||
|
||||
@abstractmethod
|
||||
def create_weights(self, layer: torch.nn.Module, num_experts: int,
|
||||
@@ -292,23 +82,25 @@ class FusedMoEMethodBase(QuantizeMethodBase):
|
||||
params_dtype: torch.dtype, **extra_weight_attrs):
|
||||
raise NotImplementedError
|
||||
|
||||
def init_prepare_finalize(self, moe: MoEConfig,
|
||||
def init_prepare_finalize(self, moe: FusedMoEConfig,
|
||||
quant_config: Optional[QuantizationConfig]):
|
||||
all2all_manager = get_ep_group().device_communicator.all2all_manager
|
||||
assert all2all_manager is not None
|
||||
|
||||
self.moe = moe
|
||||
quant_dtype = None
|
||||
act_quant_block_size = None
|
||||
from vllm.model_executor.layers.quantization.fp8 import Fp8Config
|
||||
if isinstance(quant_config, Fp8Config):
|
||||
act_quant_block_size = quant_config.weight_block_size
|
||||
quant_dtype = torch.float8_e4m3fn
|
||||
|
||||
prepare_finalize: Optional[Union[PplxPrepareAndFinalize,
|
||||
DeepEPHTPrepareAndFinalize,
|
||||
DeepEPLLPrepareAndFinalize]] = None
|
||||
prepare_finalize: Optional[FusedMoEPrepareAndFinalize] = None
|
||||
|
||||
if moe.use_pplx_kernels:
|
||||
hidden_dim_bytes, hidden_scale_bytes = pplx_hidden_dim_scale_bytes(
|
||||
moe.max_num_tokens,
|
||||
moe.hidden_dim,
|
||||
moe.in_dtype,
|
||||
moe.quant_dtype,
|
||||
per_act_token_quant=moe.per_act_token_quant,
|
||||
block_shape=moe.block_shape,
|
||||
)
|
||||
|
||||
all_to_all_args = dict(
|
||||
max_num_tokens=moe.max_num_tokens,
|
||||
num_experts=moe.num_experts,
|
||||
@@ -318,14 +110,8 @@ class FusedMoEMethodBase(QuantizeMethodBase):
|
||||
# dp_size actually means tp_size, bug in pplx kernels
|
||||
dp_size=all2all_manager.tp_group.world_size,
|
||||
hidden_dim=moe.hidden_dim,
|
||||
hidden_dim_bytes=moe.hidden_dim * moe.quant_dtype.itemsize,
|
||||
# For blocked per token: set to
|
||||
# ceil_div(hidden_dim, block_size) * sizeof(float32)
|
||||
# For per-token: set to sizeof(float32)
|
||||
hidden_dim_scale_bytes=(
|
||||
0 if moe.quant_dtype.itemsize != 1 else
|
||||
((moe.hidden_dim + moe.block_size - 1) // moe.block_size *
|
||||
torch.float32.itemsize)),
|
||||
hidden_dim_bytes=hidden_dim_bytes,
|
||||
hidden_dim_scale_bytes=hidden_scale_bytes,
|
||||
)
|
||||
|
||||
# Intranode pplx a2a takes a group name while internode does not.
|
||||
@@ -335,9 +121,6 @@ class FusedMoEMethodBase(QuantizeMethodBase):
|
||||
|
||||
handle = all2all_manager.get_handle(all_to_all_args)
|
||||
|
||||
input_activations = get_quant_config_input_activations(
|
||||
quant_config)
|
||||
|
||||
prepare_finalize = PplxPrepareAndFinalize(
|
||||
handle,
|
||||
max_num_tokens=moe.max_num_tokens,
|
||||
@@ -345,10 +128,6 @@ class FusedMoEMethodBase(QuantizeMethodBase):
|
||||
rank=all2all_manager.rank,
|
||||
# dp_size actually means tp_size, bug in pplx kernels
|
||||
dp_size=all2all_manager.tp_group.world_size,
|
||||
quant_dtype=moe.quant_dtype,
|
||||
per_act_token=(input_activations.strategy
|
||||
== QuantizationStrategy.TOKEN
|
||||
if input_activations is not None else False),
|
||||
)
|
||||
elif moe.use_deepep_ht_kernels:
|
||||
assert moe.dp_size == all2all_manager.dp_world_size
|
||||
@@ -362,8 +141,6 @@ class FusedMoEMethodBase(QuantizeMethodBase):
|
||||
dp_size=all2all_manager.dp_world_size,
|
||||
rank_expert_offset=all2all_manager.rank *
|
||||
moe.num_local_experts,
|
||||
quant_dtype=quant_dtype,
|
||||
block_shape=act_quant_block_size,
|
||||
)
|
||||
|
||||
elif moe.use_deepep_ll_kernels:
|
||||
@@ -380,25 +157,25 @@ class FusedMoEMethodBase(QuantizeMethodBase):
|
||||
|
||||
# Note : We may want to use FP8 dispatch even otherwise just to
|
||||
# reduce datamovement
|
||||
assert act_quant_block_size is not None
|
||||
use_fp8_dispatch = (quant_dtype == current_platform.fp8_dtype()
|
||||
and act_quant_block_size[1]
|
||||
== DEEPEP_QUANT_BLOCK_SIZE)
|
||||
use_fp8_dispatch = (moe.quant_config is not None
|
||||
and moe.quant_config.quant_dtype
|
||||
== current_platform.fp8_dtype()
|
||||
and moe.quant_config.block_shape
|
||||
== DEEPEP_QUANT_BLOCK_SHAPE)
|
||||
|
||||
# Note (varun): Whether to use FP8 dispatch or not needs some
|
||||
# profiling. Turning it off for now.
|
||||
prepare_finalize = DeepEPLLPrepareAndFinalize(
|
||||
handle,
|
||||
max_tokens_per_rank=moe.max_num_tokens,
|
||||
world_size=all2all_manager.world_size,
|
||||
dp_size=all2all_manager.dp_world_size,
|
||||
max_tokens_per_rank=moe.max_num_tokens,
|
||||
quant_dtype=quant_dtype,
|
||||
block_shape=act_quant_block_size,
|
||||
use_fp8_dispatch=use_fp8_dispatch,
|
||||
)
|
||||
|
||||
self.topk_indices_dtype = None
|
||||
if prepare_finalize is not None:
|
||||
logger.debug("%s", prepare_finalize.__class__.__name__)
|
||||
self.topk_indices_dtype = prepare_finalize.topk_indices_dtype()
|
||||
experts = self.select_gemm_impl(prepare_finalize, moe)
|
||||
self.fused_experts = FusedMoEModularKernel(
|
||||
@@ -407,13 +184,15 @@ class FusedMoEMethodBase(QuantizeMethodBase):
|
||||
)
|
||||
|
||||
def select_gemm_impl(
|
||||
self, prepare_finalize: FusedMoEPrepareAndFinalize,
|
||||
moe: Optional[MoEConfig]) -> FusedMoEPermuteExpertsUnpermute:
|
||||
self,
|
||||
prepare_finalize: FusedMoEPrepareAndFinalize,
|
||||
moe: FusedMoEConfig,
|
||||
) -> FusedMoEPermuteExpertsUnpermute:
|
||||
# based on the all2all implementation, select the appropriate
|
||||
# gemm implementation
|
||||
raise NotImplementedError(
|
||||
"Subclass must select appropriate gemm implementation"
|
||||
" based on the prepare_finalize")
|
||||
f"{self.__class__.__name__} must select appropriate gemm "
|
||||
"implementation based on the prepare_finalize")
|
||||
|
||||
@abstractmethod
|
||||
def apply(
|
||||
@@ -445,7 +224,7 @@ class FusedMoEMethodBase(QuantizeMethodBase):
|
||||
class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
||||
"""MoE method without quantization."""
|
||||
|
||||
def __init__(self, moe: MoEConfig):
|
||||
def __init__(self, moe: FusedMoEConfig):
|
||||
super().__init__()
|
||||
self.fused_experts = fused_experts # type: ignore
|
||||
self.topk_indices_dtype = None
|
||||
@@ -458,44 +237,30 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
||||
else:
|
||||
self.rocm_aiter_fused_experts = None # type: ignore
|
||||
|
||||
def select_gemm_impl(self, prepare_finalize: FusedMoEPrepareAndFinalize,
|
||||
moe: Optional[MoEConfig]):
|
||||
def select_gemm_impl(
|
||||
self,
|
||||
prepare_finalize: FusedMoEPrepareAndFinalize,
|
||||
moe: FusedMoEConfig,
|
||||
) -> FusedMoEPermuteExpertsUnpermute:
|
||||
|
||||
assert self.fused_experts == fused_experts
|
||||
|
||||
all2all_manager = get_ep_group().device_communicator.all2all_manager
|
||||
assert all2all_manager is not None
|
||||
|
||||
experts: Optional[FusedMoEPermuteExpertsUnpermute] = None
|
||||
|
||||
use_batched_experts = prepare_finalize.max_num_tokens_per_rank(
|
||||
) is not None
|
||||
if use_batched_experts:
|
||||
if (prepare_finalize.activation_format ==
|
||||
FusedMoEActivationFormat.BatchedExperts):
|
||||
logger.debug("BatchedTritonExperts %s", self.moe)
|
||||
assert self.moe.dp_size == all2all_manager.dp_world_size
|
||||
experts = BatchedTritonExperts(
|
||||
return BatchedTritonExperts(
|
||||
max_num_tokens=self.moe.max_num_tokens,
|
||||
world_size=all2all_manager.world_size,
|
||||
# dp_size actually means tp_size, bug in pplx kernels
|
||||
dp_size=all2all_manager.tp_group.world_size,
|
||||
use_fp8_w8a8=False,
|
||||
use_int8_w8a8=False,
|
||||
use_int8_w8a16=False,
|
||||
use_int4_w4a16=False,
|
||||
block_shape=None,
|
||||
per_channel_quant=False,
|
||||
)
|
||||
else:
|
||||
logger.debug("TritonExperts %s", self.moe)
|
||||
experts = TritonExperts(
|
||||
use_fp8_w8a8=False,
|
||||
use_int8_w8a8=False,
|
||||
use_int8_w8a16=False,
|
||||
use_int4_w4a16=False,
|
||||
block_shape=None,
|
||||
per_channel_quant=False,
|
||||
)
|
||||
return experts
|
||||
return TritonExperts()
|
||||
|
||||
def create_weights(self, layer: torch.nn.Module, num_experts: int,
|
||||
hidden_size: int, intermediate_size_per_partition: int,
|
||||
@@ -883,13 +648,18 @@ class FusedMoE(torch.nn.Module):
|
||||
params_dtype = torch.get_default_dtype()
|
||||
self.params_dtype = params_dtype
|
||||
|
||||
tp_size_ = (tp_size if tp_size is not None else
|
||||
get_tensor_model_parallel_world_size())
|
||||
dp_size_ = (dp_size
|
||||
if dp_size is not None else get_dp_group().world_size)
|
||||
world_size_ = get_world_group().world_size
|
||||
|
||||
vllm_config = get_current_vllm_config()
|
||||
self.moe_parallel_config: FusedMoEParallelConfig = (
|
||||
FusedMoEParallelConfig.make(
|
||||
tp_size_=(tp_size if tp_size is not None else
|
||||
get_tensor_model_parallel_world_size()),
|
||||
dp_size_=(dp_size if dp_size is not None else
|
||||
get_dp_group().world_size),
|
||||
tp_size_=tp_size_,
|
||||
dp_size_=dp_size_,
|
||||
world_size_=world_size_,
|
||||
vllm_parallel_config=vllm_config.parallel_config))
|
||||
|
||||
self.global_num_experts = num_experts + num_redundant_experts
|
||||
@@ -948,25 +718,22 @@ class FusedMoE(torch.nn.Module):
|
||||
from vllm_hpu_extension.ops import DynamicFusedMOE
|
||||
self.hpu_fused_moe = DynamicFusedMOE(self.global_num_experts)
|
||||
|
||||
# Only support float8 for now.
|
||||
quant_dtype = params_dtype
|
||||
if quant_config is not None:
|
||||
input_activations = get_quant_config_input_activations(
|
||||
quant_config)
|
||||
if (input_activations is not None
|
||||
and input_activations.num_bits == 8
|
||||
and input_activations.type == QuantizationType.FLOAT):
|
||||
quant_dtype = torch.float8_e4m3fn
|
||||
if vllm_config.model_config is not None:
|
||||
model_dtype = vllm_config.model_config.dtype
|
||||
else:
|
||||
# TODO (bnell): This is a hack to get test_mixtral_moe to work
|
||||
# since model_config is not set in the pytest test.
|
||||
model_dtype = params_dtype
|
||||
|
||||
moe = MoEConfig(
|
||||
moe = FusedMoEConfig.make(
|
||||
num_experts=self.global_num_experts,
|
||||
experts_per_token=top_k,
|
||||
hidden_dim=hidden_size,
|
||||
num_local_experts=self.local_num_experts,
|
||||
moe_parallel_config=self.moe_parallel_config,
|
||||
in_dtype=params_dtype,
|
||||
quant_dtype=quant_dtype,
|
||||
in_dtype=model_dtype,
|
||||
max_num_tokens=envs.VLLM_MOE_DP_CHUNK_SIZE,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
self.moe_config = moe
|
||||
self.quant_config = quant_config
|
||||
@@ -1017,16 +784,15 @@ class FusedMoE(torch.nn.Module):
|
||||
self.batched_router_logits: Optional[torch.Tensor] = None
|
||||
if (self.moe_parallel_config.use_pplx_kernels
|
||||
or self.moe_parallel_config.use_deepep_ll_kernels):
|
||||
act_dtype = vllm_config.model_config.dtype
|
||||
self.batched_hidden_states = torch.zeros(
|
||||
(envs.VLLM_MOE_DP_CHUNK_SIZE, self.hidden_size),
|
||||
dtype=act_dtype,
|
||||
(moe.max_num_tokens, self.hidden_size),
|
||||
dtype=moe.in_dtype,
|
||||
device=torch.cuda.current_device())
|
||||
|
||||
# Note here we use `num_experts` which is logical expert count
|
||||
self.batched_router_logits = torch.zeros(
|
||||
(envs.VLLM_MOE_DP_CHUNK_SIZE, num_experts),
|
||||
dtype=act_dtype,
|
||||
(moe.max_num_tokens, num_experts),
|
||||
dtype=moe.in_dtype,
|
||||
device=torch.cuda.current_device())
|
||||
|
||||
@property
|
||||
@@ -1588,7 +1354,7 @@ class FusedMoE(torch.nn.Module):
|
||||
|
||||
assert (self.batched_hidden_states.size(0) # type: ignore
|
||||
>= chunk_size)
|
||||
assert (self.batched_router_logits.size(0) # type: ignore
|
||||
assert (self.batched_router_logits.size(0) # type: ignore
|
||||
>= chunk_size)
|
||||
staged_hidden_states = self.batched_hidden_states[:
|
||||
chunk_size, :] # type: ignore
|
||||
|
||||
Reference in New Issue
Block a user