[Kernels] Overlap shared experts with send/recv (#23273)
Signed-off-by: Bill Nell <bnell@redhat.com>
This commit is contained in:
@@ -1,6 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from typing import Optional
|
||||
from typing import Callable, Optional, Union
|
||||
|
||||
import deep_ep
|
||||
import torch
|
||||
@@ -25,6 +25,8 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
||||
self.num_dispatchers_ = num_dispatchers
|
||||
self.dp_size = dp_size
|
||||
self.rank_expert_offset = rank_expert_offset
|
||||
self.async_prepare = True
|
||||
|
||||
# The dispatch function returns a handle that the combine function
|
||||
# requires. We store the handle here so it is available to the
|
||||
# combine function.
|
||||
@@ -56,10 +58,16 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
||||
return None
|
||||
return deep_ep.Buffer.get_combine_config(self.dp_size)
|
||||
|
||||
def _do_dispatch(self, tokens: torch.Tensor,
|
||||
token_scales: Optional[torch.Tensor],
|
||||
rank_topk_ids: torch.Tensor,
|
||||
rank_topk_weights: torch.Tensor, num_experts: int):
|
||||
def _do_dispatch(
|
||||
self,
|
||||
tokens: torch.Tensor,
|
||||
token_scales: Optional[torch.Tensor],
|
||||
rank_topk_ids: torch.Tensor,
|
||||
rank_topk_weights: torch.Tensor,
|
||||
num_experts: int,
|
||||
a1_scale: Optional[torch.Tensor],
|
||||
quant_config: FusedMoEQuantConfig,
|
||||
) -> Callable:
|
||||
|
||||
has_scales = token_scales is not None
|
||||
|
||||
@@ -93,9 +101,36 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
||||
expert_alignment=1,
|
||||
config=self._get_dispatch_config(),
|
||||
previous_event=None,
|
||||
async_finish=False,
|
||||
async_finish=self.async_prepare,
|
||||
allocate_on_comm_stream=False)
|
||||
|
||||
return lambda: self._receiver(
|
||||
event,
|
||||
has_scales,
|
||||
token_data,
|
||||
expert_topk_ids,
|
||||
num_experts,
|
||||
expert_num_tokens_per_expert_list,
|
||||
expert_topk_weights,
|
||||
a1_scale,
|
||||
quant_config,
|
||||
)
|
||||
|
||||
def _receiver(
|
||||
self,
|
||||
event: deep_ep.EventOverlap,
|
||||
has_scales: bool,
|
||||
token_data: Union[tuple[torch.Tensor, torch.Tensor], torch.Tensor],
|
||||
expert_topk_ids: Optional[torch.Tensor],
|
||||
num_experts: int,
|
||||
expert_num_tokens_per_expert_list: list[int],
|
||||
expert_topk_weights: Optional[torch.Tensor],
|
||||
a1_scale: Optional[torch.Tensor],
|
||||
quant_config: FusedMoEQuantConfig,
|
||||
) -> mk.PrepareResultType:
|
||||
if self.async_prepare:
|
||||
event.current_stream_wait()
|
||||
|
||||
if has_scales:
|
||||
expert_x, expert_x_scale = token_data
|
||||
else:
|
||||
@@ -112,6 +147,7 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
||||
# DeepEP's topk_ids output refers to the local experts directly. Offset
|
||||
# the topk_ids to move it back to the global experts space so it aligns
|
||||
# with existing vLLM interfaces.
|
||||
assert expert_topk_ids is not None
|
||||
expert_topk_ids = torch.where(
|
||||
expert_topk_ids == -1,
|
||||
num_experts - 1 if self.rank_expert_offset == 0 else 0,
|
||||
@@ -123,10 +159,28 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
||||
expert_tokens_meta = mk.ExpertTokensMetadata.make_from_list(
|
||||
expert_num_tokens_per_expert_list, device=expert_x.device)
|
||||
|
||||
# Dispatch and Quant
|
||||
# DeepEP kernels only support dispatching block-quantized
|
||||
# activation scales.
|
||||
# Dispatch in bfloat16 and quantize afterwards
|
||||
if not quant_config.is_block_quantized:
|
||||
# Quantize after dispatch.
|
||||
expert_x_scale = None
|
||||
if expert_x.numel() != 0:
|
||||
expert_x, expert_x_scale = moe_kernel_quantize_input(
|
||||
expert_x,
|
||||
a1_scale,
|
||||
quant_dtype=quant_config.quant_dtype,
|
||||
per_act_token_quant=False,
|
||||
block_shape=quant_config.block_shape)
|
||||
|
||||
return (expert_x, expert_x_scale, expert_tokens_meta, expert_topk_ids,
|
||||
expert_topk_weights)
|
||||
|
||||
def prepare(
|
||||
def supports_async(self) -> bool:
|
||||
return True
|
||||
|
||||
def prepare_async(
|
||||
self,
|
||||
a1: torch.Tensor,
|
||||
a1_scale: Optional[torch.Tensor],
|
||||
@@ -137,9 +191,7 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
||||
expert_map: Optional[torch.Tensor],
|
||||
apply_router_weight_on_input: bool,
|
||||
quant_config: FusedMoEQuantConfig,
|
||||
) -> tuple[torch.Tensor, Optional[torch.Tensor],
|
||||
Optional[mk.ExpertTokensMetadata], Optional[torch.Tensor],
|
||||
Optional[torch.Tensor]]:
|
||||
) -> Callable:
|
||||
|
||||
if apply_router_weight_on_input:
|
||||
topk = topk_ids.size(1)
|
||||
@@ -159,37 +211,37 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
||||
)
|
||||
if a1q_scale is not None and a1q_scale.numel() == 1:
|
||||
a1q_scale = a1q_scale.view(1, 1)
|
||||
(expert_x, expert_x_scale, expert_tokens_meta, expert_topk_ids,
|
||||
expert_topk_weights) = self._do_dispatch(
|
||||
tokens=a1q,
|
||||
token_scales=a1q_scale,
|
||||
rank_topk_ids=topk_ids,
|
||||
rank_topk_weights=topk_weights,
|
||||
num_experts=num_experts)
|
||||
a1_post_scale = None
|
||||
else:
|
||||
# Dispatch and Quant
|
||||
# DeepEP kernels only support dispatching block-quantized
|
||||
# activation scales.
|
||||
# Dispatch in bfloat16
|
||||
(expert_x, _, expert_tokens_meta, expert_topk_ids,
|
||||
expert_topk_weights) = self._do_dispatch(
|
||||
tokens=a1,
|
||||
token_scales=None,
|
||||
rank_topk_ids=topk_ids,
|
||||
rank_topk_weights=topk_weights,
|
||||
num_experts=num_experts)
|
||||
# Quantize after dispatch.
|
||||
expert_x_scale = None
|
||||
if expert_x.numel() != 0:
|
||||
expert_x, expert_x_scale = moe_kernel_quantize_input(
|
||||
expert_x,
|
||||
a1_scale,
|
||||
quant_dtype=quant_config.quant_dtype,
|
||||
per_act_token_quant=False,
|
||||
block_shape=quant_config.block_shape)
|
||||
a1q = a1
|
||||
a1q_scale = None
|
||||
a1_post_scale = a1_scale
|
||||
|
||||
return (expert_x, expert_x_scale, expert_tokens_meta, expert_topk_ids,
|
||||
expert_topk_weights)
|
||||
return self._do_dispatch(tokens=a1q,
|
||||
token_scales=a1q_scale,
|
||||
rank_topk_ids=topk_ids,
|
||||
rank_topk_weights=topk_weights,
|
||||
num_experts=num_experts,
|
||||
a1_scale=a1_post_scale,
|
||||
quant_config=quant_config)
|
||||
|
||||
def prepare(
|
||||
self,
|
||||
a1: torch.Tensor,
|
||||
a1_scale: Optional[torch.Tensor],
|
||||
a2_scale: Optional[torch.Tensor],
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
num_experts: int,
|
||||
expert_map: Optional[torch.Tensor],
|
||||
apply_router_weight_on_input: bool,
|
||||
quant_config: FusedMoEQuantConfig,
|
||||
) -> mk.PrepareResultType:
|
||||
receiver = self.prepare_async(a1, a1_scale, a2_scale, topk_weights,
|
||||
topk_ids, num_experts, expert_map,
|
||||
apply_router_weight_on_input,
|
||||
quant_config)
|
||||
return receiver()
|
||||
|
||||
def finalize(
|
||||
self,
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from typing import Optional, Union
|
||||
from typing import Callable, Optional, Union
|
||||
|
||||
import deep_ep
|
||||
import torch
|
||||
@@ -75,7 +75,6 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
||||
self,
|
||||
x: Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]],
|
||||
a1_scale: Optional[torch.Tensor],
|
||||
a2_scale: Optional[torch.Tensor],
|
||||
a1_dtype: torch.dtype,
|
||||
quant_dtype: Union[torch.dtype, str, None],
|
||||
per_act_token_quant: bool,
|
||||
@@ -110,7 +109,10 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
||||
|
||||
return x, x_scales
|
||||
|
||||
def prepare(
|
||||
def supports_async(self) -> bool:
|
||||
return True
|
||||
|
||||
def prepare_async(
|
||||
self,
|
||||
a1: torch.Tensor,
|
||||
a1_scale: Optional[torch.Tensor],
|
||||
@@ -121,9 +123,7 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
||||
expert_map: Optional[torch.Tensor],
|
||||
apply_router_weight_on_input: bool,
|
||||
quant_config: FusedMoEQuantConfig,
|
||||
) -> tuple[torch.Tensor, Optional[torch.Tensor],
|
||||
Optional[mk.ExpertTokensMetadata], Optional[torch.Tensor],
|
||||
Optional[torch.Tensor]]:
|
||||
) -> mk.ReceiverType:
|
||||
|
||||
hidden_size = a1.size(1)
|
||||
assert hidden_size in self.SUPPORTED_HIDDEN_SIZES, \
|
||||
@@ -155,16 +155,48 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
||||
num_experts,
|
||||
use_fp8=self.use_fp8_dispatch,
|
||||
async_finish=False,
|
||||
return_recv_hook=False)
|
||||
return_recv_hook=True)
|
||||
|
||||
return lambda: self._receiver(hook, expert_x, expert_num_tokens,
|
||||
a1_scale, a1.dtype, quant_config)
|
||||
|
||||
def _receiver(
|
||||
self,
|
||||
hook: Callable,
|
||||
expert_x: Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]],
|
||||
expert_num_tokens: torch.Tensor,
|
||||
a1_scale,
|
||||
a1_dtype,
|
||||
quant_config: FusedMoEQuantConfig,
|
||||
) -> mk.PrepareResultType:
|
||||
hook()
|
||||
|
||||
expert_x, expert_x_scale = self._do_quant(
|
||||
expert_x, a1_scale, a2_scale, a1.dtype, quant_config.quant_dtype,
|
||||
expert_x, a1_scale, a1_dtype, quant_config.quant_dtype,
|
||||
quant_config.per_act_token_quant, quant_config.block_shape)
|
||||
|
||||
expert_tokens_meta = mk.ExpertTokensMetadata(
|
||||
expert_num_tokens=expert_num_tokens, expert_num_tokens_cpu=None)
|
||||
|
||||
return (expert_x, expert_x_scale, expert_tokens_meta, None, None)
|
||||
return expert_x, expert_x_scale, expert_tokens_meta, None, None
|
||||
|
||||
def prepare(
|
||||
self,
|
||||
a1: torch.Tensor,
|
||||
a1_scale: Optional[torch.Tensor],
|
||||
a2_scale: Optional[torch.Tensor],
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
num_experts: int,
|
||||
expert_map: Optional[torch.Tensor],
|
||||
apply_router_weight_on_input: bool,
|
||||
quant_config: FusedMoEQuantConfig,
|
||||
) -> mk.PrepareResultType:
|
||||
receiver = self.prepare_async(a1, a1_scale, a2_scale, topk_weights,
|
||||
topk_ids, num_experts, expert_map,
|
||||
apply_router_weight_on_input,
|
||||
quant_config)
|
||||
return receiver()
|
||||
|
||||
def finalize(
|
||||
self,
|
||||
|
||||
@@ -56,9 +56,7 @@ class FlashInferCutlassMoEPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
||||
apply_router_weight_on_input: bool,
|
||||
# TODO(bnell): use quant_config + scales instead of ctor args
|
||||
quant_config: FusedMoEQuantConfig,
|
||||
) -> tuple[torch.Tensor, Optional[torch.Tensor],
|
||||
Optional[mk.ExpertTokensMetadata], Optional[torch.Tensor],
|
||||
Optional[torch.Tensor]]:
|
||||
) -> mk.PrepareResultType:
|
||||
|
||||
if apply_router_weight_on_input:
|
||||
topk = topk_ids.size(1)
|
||||
|
||||
@@ -506,9 +506,7 @@ class BatchedPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
||||
expert_map: Optional[torch.Tensor],
|
||||
apply_router_weight_on_input: bool,
|
||||
quant_config: FusedMoEQuantConfig,
|
||||
) -> tuple[torch.Tensor, Optional[torch.Tensor],
|
||||
Optional[mk.ExpertTokensMetadata], Optional[torch.Tensor],
|
||||
Optional[torch.Tensor]]:
|
||||
) -> mk.PrepareResultType:
|
||||
assert a1.dim() == 2
|
||||
assert topk_ids.dim() == 2
|
||||
assert topk_ids.size(0) == a1.size(0)
|
||||
|
||||
@@ -4,7 +4,7 @@
|
||||
from abc import abstractmethod
|
||||
from collections.abc import Iterable
|
||||
from enum import Enum
|
||||
from typing import Callable, Literal, Optional, overload
|
||||
from typing import Callable, Literal, Optional, Union, overload
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
@@ -215,6 +215,7 @@ class FusedMoEMethodBase(QuantizeMethodBase):
|
||||
self.fused_experts = FusedMoEModularKernel(
|
||||
prepare_finalize,
|
||||
experts,
|
||||
layer.shared_experts,
|
||||
)
|
||||
|
||||
def select_gemm_impl(
|
||||
@@ -252,7 +253,7 @@ class FusedMoEMethodBase(QuantizeMethodBase):
|
||||
expert_load_view: Optional[torch.Tensor] = None,
|
||||
logical_to_physical_map: Optional[torch.Tensor] = None,
|
||||
logical_replica_count: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
@@ -409,7 +410,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
||||
expert_load_view: Optional[torch.Tensor] = None,
|
||||
logical_to_physical_map: Optional[torch.Tensor] = None,
|
||||
logical_replica_count: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
|
||||
if enable_eplb:
|
||||
assert expert_load_view is not None
|
||||
assert logical_to_physical_map is not None
|
||||
@@ -461,7 +462,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
||||
expert_load_view: Optional[torch.Tensor] = None,
|
||||
logical_to_physical_map: Optional[torch.Tensor] = None,
|
||||
logical_replica_count: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
|
||||
|
||||
topk_weights, topk_ids = FusedMoE.select_experts(
|
||||
hidden_states=x,
|
||||
@@ -547,7 +548,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
||||
expert_load_view: Optional[torch.Tensor] = None,
|
||||
logical_to_physical_map: Optional[torch.Tensor] = None,
|
||||
logical_replica_count: Optional[torch.Tensor] = None,
|
||||
):
|
||||
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
|
||||
if enable_eplb is not False or expert_load_view is not None or \
|
||||
logical_to_physical_map is not None or \
|
||||
logical_replica_count is not None:
|
||||
@@ -594,7 +595,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
||||
expert_load_view: Optional[torch.Tensor] = None,
|
||||
logical_to_physical_map: Optional[torch.Tensor] = None,
|
||||
logical_replica_count: Optional[torch.Tensor] = None,
|
||||
):
|
||||
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
|
||||
if enable_eplb is not False or expert_load_view is not None or \
|
||||
logical_to_physical_map is not None or \
|
||||
logical_replica_count is not None:
|
||||
@@ -633,7 +634,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
||||
expert_load_view: Optional[torch.Tensor] = None,
|
||||
logical_to_physical_map: Optional[torch.Tensor] = None,
|
||||
logical_replica_count: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
|
||||
assert not use_grouped_topk
|
||||
assert num_expert_group is None
|
||||
assert topk_group is None
|
||||
@@ -948,6 +949,10 @@ class FusedMoE(CustomOp):
|
||||
dtype=moe.in_dtype,
|
||||
device=torch.cuda.current_device())
|
||||
|
||||
@property
|
||||
def shared_experts(self) -> Optional[torch.nn.Module]:
|
||||
return None
|
||||
|
||||
@property
|
||||
def tp_size(self):
|
||||
return self.moe_parallel_config.tp_size
|
||||
@@ -1400,6 +1405,7 @@ class FusedMoE(CustomOp):
|
||||
return [
|
||||
weight.view(self.local_num_experts, -1) for name, weight in weights
|
||||
if name not in NON_EXPERT_WEIGHTS
|
||||
and not name.startswith("_shared_experts.")
|
||||
]
|
||||
|
||||
def set_eplb_state(
|
||||
@@ -1582,25 +1588,45 @@ class FusedMoE(CustomOp):
|
||||
else:
|
||||
return tensor_model_parallel_all_reduce(final_hidden_states)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor):
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
|
||||
og_hidden_states = hidden_states.shape[-1]
|
||||
if self.hidden_size != og_hidden_states:
|
||||
hidden_states = F.pad(hidden_states,
|
||||
(0, self.hidden_size - og_hidden_states),
|
||||
mode='constant',
|
||||
value=0.0)
|
||||
# TODO: Once the OOM issue for the TPU backend is resolved, we will
|
||||
# switch to using the moe_forward custom op.
|
||||
if current_platform.is_tpu():
|
||||
return self.forward_impl(hidden_states, router_logits)
|
||||
else:
|
||||
return torch.ops.vllm.moe_forward(
|
||||
hidden_states, router_logits,
|
||||
self.layer_name)[..., :og_hidden_states]
|
||||
|
||||
def forward_impl_chunked(self, full_hidden_states: torch.Tensor,
|
||||
full_router_logits: torch.Tensor):
|
||||
if self.shared_experts is None:
|
||||
if current_platform.is_tpu():
|
||||
# TODO: Once the OOM issue for the TPU backend is resolved, we
|
||||
# will switch to using the moe_forward custom op.
|
||||
fused_output = self.forward_impl(hidden_states, router_logits)
|
||||
assert not isinstance(fused_output, tuple)
|
||||
else:
|
||||
fused_output = torch.ops.vllm.moe_forward(
|
||||
hidden_states, router_logits, self.layer_name)
|
||||
return fused_output[..., :og_hidden_states]
|
||||
else:
|
||||
if current_platform.is_tpu():
|
||||
# TODO: Once the OOM issue for the TPU backend is resolved, we
|
||||
# will switch to using the moe_forward custom op.
|
||||
shared_output, fused_output = self.forward_impl(
|
||||
hidden_states, router_logits)
|
||||
else:
|
||||
shared_output, fused_output = torch.ops.vllm.moe_forward_shared(
|
||||
hidden_states, router_logits, self.layer_name)
|
||||
return (shared_output[..., :og_hidden_states],
|
||||
fused_output[..., :og_hidden_states])
|
||||
|
||||
def forward_impl_chunked(
|
||||
self,
|
||||
full_hidden_states: torch.Tensor,
|
||||
full_router_logits: torch.Tensor,
|
||||
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
|
||||
assert self.batched_hidden_states is not None
|
||||
assert self.batched_router_logits is not None
|
||||
assert self.batched_hidden_states.dtype == full_hidden_states.dtype
|
||||
@@ -1611,7 +1637,10 @@ class FusedMoE(CustomOp):
|
||||
assert (
|
||||
self.batched_router_logits.size(-1) == full_router_logits.size(-1))
|
||||
|
||||
full_final_hidden_states = torch.empty_like(full_hidden_states)
|
||||
full_fused_final_hidden_states = torch.empty_like(full_hidden_states)
|
||||
if self.shared_experts is not None:
|
||||
full_shared_final_hidden_states = torch.empty_like(
|
||||
full_hidden_states)
|
||||
|
||||
def process_chunk(chunk_start, chunk_end, skip_result_store=False):
|
||||
chunk_size = chunk_end - chunk_start
|
||||
@@ -1652,9 +1681,21 @@ class FusedMoE(CustomOp):
|
||||
logical_replica_count=self.logical_replica_count,
|
||||
)
|
||||
|
||||
assert self.shared_experts is None or isinstance(
|
||||
final_hidden_states, tuple)
|
||||
|
||||
if not skip_result_store:
|
||||
full_final_hidden_states[chunk_start:chunk_end, :].copy_(
|
||||
final_hidden_states, non_blocking=True)
|
||||
if self.shared_experts is None:
|
||||
full_fused_final_hidden_states[
|
||||
chunk_start:chunk_end, :].copy_(final_hidden_states,
|
||||
non_blocking=True)
|
||||
else:
|
||||
full_shared_final_hidden_states[
|
||||
chunk_start:chunk_end, :].copy_(final_hidden_states[0],
|
||||
non_blocking=True)
|
||||
full_fused_final_hidden_states[
|
||||
chunk_start:chunk_end, :].copy_(final_hidden_states[1],
|
||||
non_blocking=True)
|
||||
|
||||
ctx = get_forward_context()
|
||||
# flashinfer_cutlass_kernels can handle: optional DP + TP/EP
|
||||
@@ -1675,10 +1716,17 @@ class FusedMoE(CustomOp):
|
||||
chunk_end,
|
||||
skip_result_store=chunk_start_ >= num_tokens)
|
||||
|
||||
return full_final_hidden_states
|
||||
if self.shared_experts is None:
|
||||
return full_fused_final_hidden_states
|
||||
else:
|
||||
return (full_shared_final_hidden_states,
|
||||
full_fused_final_hidden_states)
|
||||
|
||||
def forward_impl(self, hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor):
|
||||
def forward_impl(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
|
||||
assert self.quant_method is not None
|
||||
# Route to the chunked forward path using the FlashInfer Cutlass kernel
|
||||
# only when data parallelism (DP) is enabled.
|
||||
@@ -1698,6 +1746,15 @@ class FusedMoE(CustomOp):
|
||||
hidden_states, router_logits = get_ep_group().dispatch(
|
||||
hidden_states, router_logits)
|
||||
|
||||
# If there are shared experts but we are not using a modular kernel, the
|
||||
# shared experts must be called here
|
||||
if (not isinstance(self.quant_method.fused_experts,
|
||||
FusedMoEModularKernel)
|
||||
and self.shared_experts is not None):
|
||||
shared_output = self.shared_experts(hidden_states)
|
||||
else:
|
||||
shared_output = None
|
||||
|
||||
# Matrix multiply.
|
||||
final_hidden_states = self.quant_method.apply(
|
||||
layer=self,
|
||||
@@ -1722,14 +1779,30 @@ class FusedMoE(CustomOp):
|
||||
logical_replica_count=self.logical_replica_count,
|
||||
)
|
||||
|
||||
if do_naive_dispatch_combine:
|
||||
final_hidden_states = get_ep_group().combine(final_hidden_states)
|
||||
if self.reduce_results and (self.tp_size > 1 or self.ep_size > 1):
|
||||
# Default set to False. (May have to add shared expert outputs.
|
||||
final_hidden_states = self.maybe_all_reduce_tensor_model_parallel(
|
||||
final_hidden_states)
|
||||
if shared_output is not None:
|
||||
assert not isinstance(final_hidden_states, tuple)
|
||||
assert self.shared_experts is not None
|
||||
final_hidden_states = (
|
||||
shared_output,
|
||||
final_hidden_states,
|
||||
)
|
||||
|
||||
return final_hidden_states
|
||||
def reduce_output(states: torch.Tensor) -> torch.Tensor:
|
||||
if do_naive_dispatch_combine:
|
||||
states = get_ep_group().combine(states)
|
||||
|
||||
if self.reduce_results and (self.tp_size > 1 or self.ep_size > 1):
|
||||
states = self.maybe_all_reduce_tensor_model_parallel(states)
|
||||
|
||||
return states
|
||||
|
||||
if self.shared_experts is None:
|
||||
return reduce_output(final_hidden_states)
|
||||
else:
|
||||
return (
|
||||
reduce_output(final_hidden_states[0]),
|
||||
reduce_output(final_hidden_states[1]),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def make_expert_params_mapping(
|
||||
@@ -1784,17 +1857,22 @@ class FusedMoE(CustomOp):
|
||||
return s
|
||||
|
||||
|
||||
def moe_forward(hidden_states: torch.Tensor, router_logits: torch.Tensor,
|
||||
layer_name: str) -> torch.Tensor:
|
||||
def moe_forward(
|
||||
hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
layer_name: str,
|
||||
) -> torch.Tensor:
|
||||
forward_context: ForwardContext = get_forward_context()
|
||||
self = forward_context.no_compile_layers[layer_name]
|
||||
assert self.quant_method is not None
|
||||
|
||||
assert self.shared_experts is None
|
||||
return self.forward_impl(hidden_states, router_logits)
|
||||
|
||||
|
||||
def moe_forward_fake(hidden_states: torch.Tensor, router_logits: torch.Tensor,
|
||||
layer_name: str) -> torch.Tensor:
|
||||
def moe_forward_fake(
|
||||
hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
layer_name: str,
|
||||
) -> torch.Tensor:
|
||||
return torch.empty_like(hidden_states)
|
||||
|
||||
|
||||
@@ -1807,6 +1885,37 @@ direct_register_custom_op(
|
||||
tags=(torch.Tag.needs_fixed_stride_order, ),
|
||||
)
|
||||
|
||||
|
||||
def moe_forward_shared(
|
||||
hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
layer_name: str,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
forward_context: ForwardContext = get_forward_context()
|
||||
self = forward_context.no_compile_layers[layer_name]
|
||||
assert self.shared_experts is not None
|
||||
return self.forward_impl(hidden_states, router_logits)
|
||||
|
||||
|
||||
def moe_forward_shared_fake(
|
||||
hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
layer_name: str,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
shared_out = torch.empty_like(hidden_states)
|
||||
fused_out = torch.empty_like(hidden_states)
|
||||
return shared_out, fused_out
|
||||
|
||||
|
||||
direct_register_custom_op(
|
||||
op_name="moe_forward_shared",
|
||||
op_func=moe_forward_shared,
|
||||
mutates_args=["hidden_states"],
|
||||
fake_impl=moe_forward_shared_fake,
|
||||
dispatch_key=current_platform.dispatch_key,
|
||||
tags=(torch.Tag.needs_fixed_stride_order, ),
|
||||
)
|
||||
|
||||
# Mark the FusedMoE weight_loader as supporting MoE-specific parameters
|
||||
# to avoid expensive runtime reflection in model loading code
|
||||
FusedMoE.weight_loader.supports_moe_loading = True # type: ignore[attr-defined]
|
||||
|
||||
@@ -4,7 +4,7 @@ from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from math import prod
|
||||
from typing import Optional, final
|
||||
from typing import Callable, Optional, Union, final
|
||||
|
||||
import torch
|
||||
|
||||
@@ -141,6 +141,29 @@ class TopKWeightAndReduce(ABC):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
#
|
||||
# PrepareResultType is a tuple of:
|
||||
# - quantized + dispatched a.
|
||||
# - quantized + dispatched a1_scales.
|
||||
# - Optional ExpertTokensMetadata containing gpu/cpu tensors
|
||||
# as big as the number of local experts with the information about the
|
||||
# number of tokens assigned to each local expert.
|
||||
# - Optional dispatched expert topk IDs
|
||||
# - Optional dispatched expert topk weight
|
||||
#
|
||||
# See `prepare` method below.
|
||||
#
|
||||
PrepareResultType = tuple[
|
||||
torch.Tensor,
|
||||
Optional[torch.Tensor],
|
||||
Optional[ExpertTokensMetadata],
|
||||
Optional[torch.Tensor],
|
||||
Optional[torch.Tensor],
|
||||
]
|
||||
|
||||
ReceiverType = Callable[[], PrepareResultType]
|
||||
|
||||
|
||||
# TODO: pass FusedMoEParallelConfig in as ctor parameter?
|
||||
class FusedMoEPrepareAndFinalize(ABC):
|
||||
"""
|
||||
@@ -160,16 +183,9 @@ class FusedMoEPrepareAndFinalize(ABC):
|
||||
expert_map: Optional[torch.Tensor],
|
||||
apply_router_weight_on_input: bool,
|
||||
quant_config: FusedMoEQuantConfig,
|
||||
) -> tuple[
|
||||
torch.Tensor,
|
||||
Optional[torch.Tensor],
|
||||
Optional[ExpertTokensMetadata],
|
||||
Optional[torch.Tensor],
|
||||
Optional[torch.Tensor],
|
||||
]:
|
||||
) -> PrepareResultType:
|
||||
"""
|
||||
Perform any quantization (and/or) dispatching needed
|
||||
for this kernel.
|
||||
Perform any quantization (and/or) dispatching needed for this kernel.
|
||||
- a1: The (unquantized) input to the MoE layer.
|
||||
- a1_scale: Optional scales for a1
|
||||
- a2_scale: Optional scales for the second MoE gemm. Required to make
|
||||
@@ -193,6 +209,51 @@ class FusedMoEPrepareAndFinalize(ABC):
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def supports_async(self) -> bool:
|
||||
"""
|
||||
Indicates whether or not this class implements prepare_async.
|
||||
"""
|
||||
return False
|
||||
|
||||
def prepare_async(
|
||||
self,
|
||||
a1: torch.Tensor,
|
||||
a1_scale: Optional[torch.Tensor],
|
||||
a2_scale: Optional[torch.Tensor],
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
num_experts: int,
|
||||
expert_map: Optional[torch.Tensor],
|
||||
apply_router_weight_on_input: bool,
|
||||
quant_config: FusedMoEQuantConfig,
|
||||
) -> ReceiverType:
|
||||
"""
|
||||
Perform any quantization (and/or) dispatching needed for this kernel
|
||||
but do not wait for results from other workers.
|
||||
- a1: The (unquantized) input to the MoE layer.
|
||||
- a1_scale: Optional scales for a1
|
||||
- a2_scale: Optional scales for the second MoE gemm. Required to make
|
||||
sure the quantization is consistent for both gemms.
|
||||
- topk_ids: The topk ids.
|
||||
- topk_weights: The topk weights.
|
||||
- num_experts: The total number of experts in the global expert space.
|
||||
- expert_map: A tensor mapping expert indices from the global expert
|
||||
space to the local expert space of the expert parallel shard.
|
||||
- apply_router_weight_on_input: When True, apply the weights to the
|
||||
activations, before quantization + dispatching.
|
||||
|
||||
Returns a callback that when invoked waits for results from other
|
||||
workers and has the same return signature as `prepare`, e.g.
|
||||
|
||||
receiver = obj.prepare_async(...)
|
||||
a, a_scales, expert_meta, topk_ids, topk_weights = receiver()
|
||||
|
||||
is equivalent to:
|
||||
|
||||
a, a_scales, expert_meta, topk_ids, topk_weights = obj.prepare(...)
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def finalize(
|
||||
self,
|
||||
@@ -453,10 +514,12 @@ class FusedMoEModularKernel(torch.nn.Module):
|
||||
self,
|
||||
prepare_finalize: FusedMoEPrepareAndFinalize,
|
||||
fused_experts: FusedMoEPermuteExpertsUnpermute,
|
||||
shared_experts: Optional[torch.nn.Module] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.prepare_finalize = prepare_finalize
|
||||
self.fused_experts = fused_experts
|
||||
self.shared_experts = shared_experts
|
||||
assert prepare_finalize.activation_format == \
|
||||
fused_experts.activation_formats[0], (
|
||||
f"{prepare_finalize.__class__.__name__}."
|
||||
@@ -692,7 +755,7 @@ class FusedMoEModularKernel(torch.nn.Module):
|
||||
a1_scale: Optional[torch.Tensor] = None,
|
||||
a2_scale: Optional[torch.Tensor] = None,
|
||||
apply_router_weight_on_input: bool = False,
|
||||
) -> torch.Tensor:
|
||||
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
|
||||
"""
|
||||
This function computes a Mixture of Experts (MoE) layer using two sets
|
||||
of weights, w1 and w2, and top-k gating mechanism.
|
||||
@@ -736,18 +799,46 @@ class FusedMoEModularKernel(torch.nn.Module):
|
||||
if global_num_experts == -1:
|
||||
global_num_experts = local_num_experts
|
||||
|
||||
(a1q, a1q_scale, expert_tokens_meta, _expert_topk_ids,
|
||||
_expert_topk_weights) = self.prepare_finalize.prepare(
|
||||
a1,
|
||||
a1_scale,
|
||||
a2_scale,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
global_num_experts,
|
||||
expert_map,
|
||||
apply_router_weight_on_input,
|
||||
self.fused_experts.quant_config,
|
||||
)
|
||||
shared_output: torch.Tensor
|
||||
|
||||
if (not self.prepare_finalize.supports_async()
|
||||
or self.shared_experts is None):
|
||||
|
||||
# Run shared experts serially with dispatch.
|
||||
if self.shared_experts is not None:
|
||||
shared_output = self.shared_experts(a1)
|
||||
|
||||
(a1q, a1q_scale, expert_tokens_meta, _expert_topk_ids,
|
||||
_expert_topk_weights) = self.prepare_finalize.prepare(
|
||||
a1,
|
||||
a1_scale,
|
||||
a2_scale,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
global_num_experts,
|
||||
expert_map,
|
||||
apply_router_weight_on_input,
|
||||
self.fused_experts.quant_config,
|
||||
)
|
||||
else:
|
||||
# Overlap shared expert compute with all2all dispatch.
|
||||
receiver = self.prepare_finalize.prepare_async(
|
||||
a1,
|
||||
a1_scale,
|
||||
a2_scale,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
global_num_experts,
|
||||
expert_map,
|
||||
apply_router_weight_on_input,
|
||||
self.fused_experts.quant_config,
|
||||
)
|
||||
|
||||
assert self.shared_experts is not None
|
||||
shared_output = self.shared_experts(a1)
|
||||
|
||||
(a1q, a1q_scale, expert_tokens_meta, _expert_topk_ids,
|
||||
_expert_topk_weights) = receiver()
|
||||
|
||||
# Maybe prepare gathered topk_ids and topk_weights from other EP ranks.
|
||||
topk_ids = topk_ids if _expert_topk_ids is None else _expert_topk_ids
|
||||
@@ -795,4 +886,7 @@ class FusedMoEModularKernel(torch.nn.Module):
|
||||
self.fused_experts.finalize_weight_and_reduce_impl(),
|
||||
)
|
||||
|
||||
return output
|
||||
if self.shared_experts is None:
|
||||
return output
|
||||
else:
|
||||
return shared_output, output
|
||||
|
||||
@@ -84,12 +84,15 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
||||
return self.max_num_tokens
|
||||
|
||||
def topk_indices_dtype(self) -> Optional[torch.dtype]:
|
||||
return torch.int32
|
||||
return torch.uint32
|
||||
|
||||
def num_dispatchers(self) -> int:
|
||||
return self.num_dispatchers_
|
||||
|
||||
def prepare(
|
||||
def supports_async(self) -> bool:
|
||||
return True
|
||||
|
||||
def prepare_async(
|
||||
self,
|
||||
a1: torch.Tensor,
|
||||
a1_scale: Optional[torch.Tensor],
|
||||
@@ -100,9 +103,7 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
||||
expert_map: Optional[torch.Tensor],
|
||||
apply_router_weight_on_input: bool,
|
||||
quant_config: FusedMoEQuantConfig,
|
||||
) -> tuple[torch.Tensor, Optional[torch.Tensor],
|
||||
Optional[mk.ExpertTokensMetadata], Optional[torch.Tensor],
|
||||
Optional[torch.Tensor]]:
|
||||
) -> mk.ReceiverType:
|
||||
num_tokens = a1.size(0) # M
|
||||
hidden_dim = a1.size(-1) # K
|
||||
|
||||
@@ -138,6 +139,8 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
||||
_validate_scale_shape(a1q, a1q_scale, quant_config.per_act_token_quant,
|
||||
quant_config.block_shape)
|
||||
|
||||
orig_a_scale_block_shape: Optional[int] = None
|
||||
|
||||
if a1q_scale is not None:
|
||||
scalar_scales = a1q_scale.numel() == 1
|
||||
|
||||
@@ -205,8 +208,45 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
||||
out_expert_x_scale=expert_x_scale,
|
||||
dp_x=a1q,
|
||||
dp_x_scale=a1q_scale,
|
||||
indices=topk_ids.view(dtype=torch.uint32),
|
||||
indices=topk_ids,
|
||||
bound_m=bound_m,
|
||||
do_send=True,
|
||||
do_recv=False,
|
||||
)
|
||||
|
||||
return lambda: self._receiver(
|
||||
expert_num_tokens,
|
||||
expert_x,
|
||||
expert_x_scale,
|
||||
a1q,
|
||||
a1q_scale,
|
||||
topk_ids,
|
||||
bound_m,
|
||||
orig_a_scale_block_shape,
|
||||
)
|
||||
|
||||
def _receiver(
|
||||
self,
|
||||
expert_num_tokens: torch.Tensor,
|
||||
expert_x: torch.Tensor,
|
||||
expert_x_scale: Optional[torch.Tensor],
|
||||
a1q: torch.Tensor,
|
||||
a1q_scale: Optional[torch.Tensor],
|
||||
topk_ids: torch.Tensor,
|
||||
bound_m: Optional[torch.Tensor],
|
||||
orig_a_scale_block_shape: Optional[int],
|
||||
) -> mk.PrepareResultType:
|
||||
|
||||
self.a2a.dispatch(
|
||||
out_expert_num_tokens=expert_num_tokens,
|
||||
out_expert_x=expert_x,
|
||||
out_expert_x_scale=expert_x_scale,
|
||||
dp_x=a1q,
|
||||
dp_x_scale=a1q_scale,
|
||||
indices=topk_ids,
|
||||
bound_m=bound_m,
|
||||
do_send=False,
|
||||
do_recv=True,
|
||||
)
|
||||
|
||||
if expert_x_scale is not None:
|
||||
@@ -218,6 +258,31 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
||||
|
||||
return expert_x, expert_x_scale, expert_tokens_meta, None, None
|
||||
|
||||
def prepare(
|
||||
self,
|
||||
a1: torch.Tensor,
|
||||
a1_scale: Optional[torch.Tensor],
|
||||
a2_scale: Optional[torch.Tensor],
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
num_experts: int,
|
||||
expert_map: Optional[torch.Tensor],
|
||||
apply_router_weight_on_input: bool,
|
||||
quant_config: FusedMoEQuantConfig,
|
||||
) -> mk.PrepareResultType:
|
||||
receiver = self.prepare_async(
|
||||
a1,
|
||||
a1_scale,
|
||||
a2_scale,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
num_experts,
|
||||
expert_map,
|
||||
apply_router_weight_on_input,
|
||||
quant_config,
|
||||
)
|
||||
return receiver()
|
||||
|
||||
def finalize(
|
||||
self,
|
||||
output: torch.Tensor,
|
||||
|
||||
@@ -38,9 +38,7 @@ class MoEPrepareAndFinalizeNoEP(mk.FusedMoEPrepareAndFinalize):
|
||||
expert_map: Optional[torch.Tensor],
|
||||
apply_router_weight_on_input: bool,
|
||||
quant_config: FusedMoEQuantConfig,
|
||||
) -> tuple[torch.Tensor, Optional[torch.Tensor],
|
||||
Optional[mk.ExpertTokensMetadata], Optional[torch.Tensor],
|
||||
Optional[torch.Tensor]]:
|
||||
) -> mk.PrepareResultType:
|
||||
|
||||
if apply_router_weight_on_input:
|
||||
topk = topk_ids.size(1)
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from typing import Any, Callable, Optional
|
||||
from typing import Any, Callable, Optional, Union
|
||||
|
||||
import torch
|
||||
from torch.nn import Parameter
|
||||
@@ -505,7 +505,7 @@ class AWQMoEMethod(FusedMoEMethodBase):
|
||||
expert_load_view: Optional[torch.Tensor] = None,
|
||||
logical_to_physical_map: Optional[torch.Tensor] = None,
|
||||
logical_replica_count: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
|
||||
assert self.fused_experts is None
|
||||
|
||||
if enable_eplb:
|
||||
|
||||
@@ -474,7 +474,7 @@ class BitsAndBytesMoEMethod(FusedMoEMethodBase):
|
||||
expert_load_view: Optional[torch.Tensor] = None,
|
||||
logical_to_physical_map: Optional[torch.Tensor] = None,
|
||||
logical_replica_count: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
|
||||
from vllm.model_executor.layers.fused_moe import fused_experts
|
||||
assert self.fused_experts is None
|
||||
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
|
||||
import enum
|
||||
from enum import Enum
|
||||
from typing import Callable, Optional
|
||||
from typing import Callable, Optional, Union
|
||||
|
||||
import torch
|
||||
from compressed_tensors import CompressionFormat
|
||||
@@ -358,7 +358,7 @@ class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod):
|
||||
expert_load_view: Optional[torch.Tensor] = None,
|
||||
logical_to_physical_map: Optional[torch.Tensor] = None,
|
||||
logical_replica_count: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
|
||||
assert self.fused_experts is None
|
||||
|
||||
if enable_eplb:
|
||||
@@ -819,7 +819,7 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
|
||||
expert_load_view: Optional[torch.Tensor] = None,
|
||||
logical_to_physical_map: Optional[torch.Tensor] = None,
|
||||
logical_replica_count: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
|
||||
if enable_eplb:
|
||||
raise NotImplementedError(
|
||||
"EPLB not supported for "
|
||||
@@ -1069,7 +1069,7 @@ class CompressedTensorsW8A8Int8MoEMethod(CompressedTensorsMoEMethod):
|
||||
expert_load_view: Optional[torch.Tensor] = None,
|
||||
logical_to_physical_map: Optional[torch.Tensor] = None,
|
||||
logical_replica_count: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
|
||||
assert self.fused_experts is None
|
||||
|
||||
if enable_eplb:
|
||||
@@ -1375,7 +1375,7 @@ class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod):
|
||||
expert_load_view: Optional[torch.Tensor] = None,
|
||||
logical_to_physical_map: Optional[torch.Tensor] = None,
|
||||
logical_replica_count: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
|
||||
assert self.fused_experts is None
|
||||
|
||||
if enable_eplb:
|
||||
@@ -1608,7 +1608,7 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
|
||||
expert_load_view: Optional[torch.Tensor] = None,
|
||||
logical_to_physical_map: Optional[torch.Tensor] = None,
|
||||
logical_replica_count: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
|
||||
assert self.fused_experts is None
|
||||
|
||||
if enable_eplb:
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from typing import Any, Callable, Optional
|
||||
from typing import Any, Callable, Optional, Union
|
||||
|
||||
import torch
|
||||
|
||||
@@ -128,7 +128,7 @@ class ExpertsInt8MoEMethod(FusedMoEMethodBase):
|
||||
expert_load_view: Optional[torch.Tensor] = None,
|
||||
logical_to_physical_map: Optional[torch.Tensor] = None,
|
||||
logical_replica_count: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
|
||||
assert self.fused_experts is None
|
||||
|
||||
if enable_eplb:
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from typing import TYPE_CHECKING, Any, Callable, Optional
|
||||
from typing import TYPE_CHECKING, Any, Callable, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
@@ -988,7 +988,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
expert_load_view: Optional[torch.Tensor] = None,
|
||||
logical_to_physical_map: Optional[torch.Tensor] = None,
|
||||
logical_replica_count: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
|
||||
if enable_eplb:
|
||||
assert expert_load_view is not None
|
||||
assert logical_to_physical_map is not None
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from typing import Any, Callable, Optional
|
||||
from typing import Any, Callable, Optional, Union
|
||||
|
||||
import gguf
|
||||
import torch
|
||||
@@ -540,7 +540,7 @@ class GGUFMoEMethod(FusedMoEMethodBase):
|
||||
expert_load_view: Optional[torch.Tensor] = None,
|
||||
logical_to_physical_map: Optional[torch.Tensor] = None,
|
||||
logical_replica_count: Optional[torch.Tensor] = None,
|
||||
):
|
||||
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
|
||||
assert self.fused_experts is None
|
||||
|
||||
if enable_eplb:
|
||||
|
||||
@@ -654,7 +654,7 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
|
||||
expert_load_view: Optional[torch.Tensor] = None,
|
||||
logical_to_physical_map: Optional[torch.Tensor] = None,
|
||||
logical_replica_count: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
|
||||
assert self.fused_experts is None
|
||||
|
||||
if enable_eplb:
|
||||
|
||||
@@ -491,7 +491,7 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
|
||||
expert_load_view: Optional[torch.Tensor] = None,
|
||||
logical_to_physical_map: Optional[torch.Tensor] = None,
|
||||
logical_replica_count: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
|
||||
if enable_eplb:
|
||||
raise NotImplementedError(
|
||||
"EPLB not supported for `ModelOptFp8MoEMethod` yet.")
|
||||
@@ -1366,7 +1366,7 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
|
||||
expert_load_view: Optional[torch.Tensor] = None,
|
||||
logical_to_physical_map: Optional[torch.Tensor] = None,
|
||||
logical_replica_count: Optional[torch.Tensor] = None,
|
||||
):
|
||||
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
|
||||
if enable_eplb:
|
||||
raise NotImplementedError(
|
||||
"EPLB not supported for `ModelOptNvFp4FusedMoE` yet.")
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from typing import Any, Callable, Optional
|
||||
from typing import Any, Callable, Optional, Union
|
||||
|
||||
import torch
|
||||
|
||||
@@ -305,7 +305,7 @@ class MoeWNA16Method(FusedMoEMethodBase):
|
||||
expert_load_view: Optional[torch.Tensor] = None,
|
||||
logical_to_physical_map: Optional[torch.Tensor] = None,
|
||||
logical_replica_count: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
|
||||
assert self.fused_experts is None
|
||||
if enable_eplb:
|
||||
raise NotImplementedError(
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from typing import Callable, Optional
|
||||
from typing import Callable, Optional, Union
|
||||
|
||||
import torch
|
||||
from torch.nn.parameter import Parameter
|
||||
@@ -554,7 +554,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
||||
expert_load_view: Optional[torch.Tensor] = None,
|
||||
logical_to_physical_map: Optional[torch.Tensor] = None,
|
||||
logical_replica_count: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
|
||||
|
||||
if enable_eplb:
|
||||
raise NotImplementedError("EPLB is not supported for mxfp4")
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from typing import Any, Callable, Optional
|
||||
from typing import Any, Callable, Optional, Union
|
||||
|
||||
import torch
|
||||
|
||||
@@ -226,7 +226,7 @@ class QuarkW8A8Fp8MoEMethod(QuarkMoEMethod):
|
||||
expert_load_view: Optional[torch.Tensor] = None,
|
||||
logical_to_physical_map: Optional[torch.Tensor] = None,
|
||||
logical_replica_count: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
|
||||
assert self.fused_experts is None
|
||||
|
||||
if enable_eplb:
|
||||
@@ -390,7 +390,7 @@ class QuarkW4A4MXFp4MoEMethod(QuarkMoEMethod):
|
||||
expert_load_view: Optional[torch.Tensor] = None,
|
||||
logical_to_physical_map: Optional[torch.Tensor] = None,
|
||||
logical_replica_count: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
|
||||
assert self.fused_experts is None
|
||||
|
||||
if enable_eplb:
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
# Copyright © 2025, Oracle and/or its affiliates.
|
||||
|
||||
import os
|
||||
from typing import Any, Callable, Optional
|
||||
from typing import Any, Callable, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
@@ -291,7 +291,7 @@ class RTNMoEMethod(FusedMoEMethodBase):
|
||||
expert_load_view: Optional[torch.Tensor] = None,
|
||||
logical_to_physical_map: Optional[torch.Tensor] = None,
|
||||
logical_replica_count: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
|
||||
assert self.fused_experts is None
|
||||
|
||||
if enable_eplb:
|
||||
|
||||
6
vllm/model_executor/layers/shared_fused_moe/__init__.py
Normal file
6
vllm/model_executor/layers/shared_fused_moe/__init__.py
Normal file
@@ -0,0 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from vllm.model_executor.layers.shared_fused_moe.shared_fused_moe import (
|
||||
SharedFusedMoE)
|
||||
|
||||
__all__ = ["SharedFusedMoE"]
|
||||
@@ -0,0 +1,56 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.distributed import tensor_model_parallel_all_reduce
|
||||
from vllm.model_executor.layers.fused_moe.layer import FusedMoE
|
||||
|
||||
|
||||
# TODO(bnell): Add shared + fused combo function? e.g. +
|
||||
class SharedFusedMoE(FusedMoE):
|
||||
"""
|
||||
A FusedMoE operation that also computes the results of shared experts.
|
||||
If an all2all communicator is being used the shared expert computation
|
||||
can be interleaved with the fused all2all dispatch communication step.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
shared_experts: torch.nn.Module,
|
||||
use_overlapped: bool = True,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(**kwargs)
|
||||
self._shared_experts = shared_experts
|
||||
self.use_overlapped = use_overlapped
|
||||
|
||||
@property
|
||||
def shared_experts(self) -> Optional[torch.nn.Module]:
|
||||
return self._shared_experts if self.use_overlapped else None
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
if not self.use_overlapped:
|
||||
shared_out = self._shared_experts(hidden_states)
|
||||
|
||||
# Reduce outputs if necessary, since the MLP should
|
||||
# have been created with reduce_results=False.
|
||||
if (self.reduce_results and self.tp_size > 1
|
||||
and self.must_reduce_shared_expert_outputs()):
|
||||
shared_out = tensor_model_parallel_all_reduce(shared_out)
|
||||
|
||||
fused_out = super().forward(
|
||||
hidden_states=hidden_states,
|
||||
router_logits=router_logits,
|
||||
)
|
||||
else:
|
||||
shared_out, fused_out = super().forward(
|
||||
hidden_states=hidden_states,
|
||||
router_logits=router_logits,
|
||||
)
|
||||
return shared_out, fused_out
|
||||
@@ -49,6 +49,7 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||
from vllm.model_executor.layers.shared_fused_moe import SharedFusedMoE
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
ParallelLMHead, VocabParallelEmbedding)
|
||||
from vllm.model_executor.model_loader.weight_utils import (
|
||||
@@ -147,63 +148,85 @@ class DeepseekV2MoE(nn.Module):
|
||||
self.physical_expert_end = (self.physical_expert_start +
|
||||
self.n_local_physical_experts)
|
||||
|
||||
self.experts = FusedMoE(
|
||||
num_experts=config.n_routed_experts,
|
||||
top_k=config.num_experts_per_tok,
|
||||
hidden_size=config.hidden_size,
|
||||
intermediate_size=config.moe_intermediate_size,
|
||||
reduce_results=False,
|
||||
renormalize=config.norm_topk_prob,
|
||||
quant_config=quant_config,
|
||||
use_grouped_topk=True,
|
||||
num_expert_group=config.n_group,
|
||||
topk_group=config.topk_group,
|
||||
prefix=f"{prefix}.experts",
|
||||
scoring_func=config.scoring_func,
|
||||
# we do scaling outside, set factor to 1.0 to avoid double mul
|
||||
routed_scaling_factor=1.0,
|
||||
e_score_correction_bias=self.gate.e_score_correction_bias,
|
||||
enable_eplb=self.enable_eplb,
|
||||
num_redundant_experts=self.n_redundant_experts)
|
||||
|
||||
if config.n_shared_experts is not None:
|
||||
if config.n_shared_experts is None:
|
||||
self.experts = FusedMoE(
|
||||
num_experts=config.n_routed_experts,
|
||||
top_k=config.num_experts_per_tok,
|
||||
hidden_size=config.hidden_size,
|
||||
intermediate_size=config.moe_intermediate_size,
|
||||
reduce_results=False,
|
||||
renormalize=config.norm_topk_prob,
|
||||
quant_config=quant_config,
|
||||
use_grouped_topk=True,
|
||||
num_expert_group=config.n_group,
|
||||
topk_group=config.topk_group,
|
||||
prefix=f"{prefix}.experts",
|
||||
scoring_func=config.scoring_func,
|
||||
# we do scaling outside, set factor to 1.0 to avoid double mul
|
||||
routed_scaling_factor=1.0,
|
||||
e_score_correction_bias=self.gate.e_score_correction_bias,
|
||||
enable_eplb=self.enable_eplb,
|
||||
num_redundant_experts=self.n_redundant_experts)
|
||||
self.shared_experts = None
|
||||
else:
|
||||
intermediate_size = (config.moe_intermediate_size *
|
||||
config.n_shared_experts)
|
||||
|
||||
self.shared_experts = DeepseekV2MLP(
|
||||
hidden_size=config.hidden_size,
|
||||
intermediate_size=intermediate_size,
|
||||
hidden_act=config.hidden_act,
|
||||
quant_config=quant_config,
|
||||
reduce_results=self.experts.must_reduce_shared_expert_outputs(
|
||||
),
|
||||
reduce_results=False,
|
||||
prefix=f"{prefix}.shared_experts",
|
||||
)
|
||||
|
||||
self.experts = SharedFusedMoE(
|
||||
shared_experts=self.shared_experts,
|
||||
num_experts=config.n_routed_experts,
|
||||
top_k=config.num_experts_per_tok,
|
||||
hidden_size=config.hidden_size,
|
||||
intermediate_size=config.moe_intermediate_size,
|
||||
reduce_results=False,
|
||||
renormalize=config.norm_topk_prob,
|
||||
quant_config=quant_config,
|
||||
use_grouped_topk=True,
|
||||
num_expert_group=config.n_group,
|
||||
topk_group=config.topk_group,
|
||||
prefix=f"{prefix}.experts",
|
||||
scoring_func=config.scoring_func,
|
||||
# we do scaling outside, set factor to 1.0 to avoid double mul
|
||||
routed_scaling_factor=1.0,
|
||||
e_score_correction_bias=self.gate.e_score_correction_bias,
|
||||
enable_eplb=self.enable_eplb,
|
||||
num_redundant_experts=self.n_redundant_experts)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
num_tokens, hidden_dim = hidden_states.shape
|
||||
hidden_states = hidden_states.view(-1, hidden_dim)
|
||||
if self.n_shared_experts is not None:
|
||||
shared_output = self.shared_experts(hidden_states)
|
||||
# router_logits: (num_tokens, n_experts)
|
||||
router_logits, _ = self.gate(hidden_states)
|
||||
|
||||
if hidden_states.dtype != torch.float16:
|
||||
final_hidden_states = self.experts(
|
||||
hidden_states=hidden_states,
|
||||
router_logits=router_logits) * self.routed_scaling_factor
|
||||
fused_moe_out = self.experts(hidden_states=hidden_states,
|
||||
router_logits=router_logits)
|
||||
|
||||
if self.shared_experts is not None:
|
||||
shared_output, final_hidden_states = fused_moe_out
|
||||
else:
|
||||
# Fix FP16 overflow
|
||||
# See DeepseekV2DecoderLayer for more details.
|
||||
final_hidden_states = self.experts(hidden_states=hidden_states,
|
||||
router_logits=router_logits)
|
||||
if shared_output is not None:
|
||||
if hidden_states.dtype != torch.float16:
|
||||
final_hidden_states = final_hidden_states + shared_output
|
||||
else:
|
||||
# Fix FP16 overflow
|
||||
# See DeepseekV2DecoderLayer for more details.
|
||||
final_hidden_states = final_hidden_states + shared_output \
|
||||
* (1. / self.routed_scaling_factor)
|
||||
shared_output = None
|
||||
final_hidden_states = fused_moe_out
|
||||
|
||||
# Fix FP16 overflow
|
||||
# See DeepseekV2DecoderLayer for more details.
|
||||
if hidden_states.dtype != torch.float16:
|
||||
final_hidden_states *= self.routed_scaling_factor
|
||||
elif self.shared_experts is not None:
|
||||
assert shared_output is not None
|
||||
shared_output *= (1. / self.routed_scaling_factor)
|
||||
|
||||
if self.shared_experts is not None:
|
||||
assert shared_output is not None
|
||||
final_hidden_states += shared_output
|
||||
|
||||
if self.tp_size > 1:
|
||||
final_hidden_states = (
|
||||
|
||||
@@ -184,6 +184,8 @@ class Glm4MoE(nn.Module):
|
||||
|
||||
if self.n_shared_experts is not None:
|
||||
shared_output = self.shared_experts(hidden_states)
|
||||
else:
|
||||
shared_output = None
|
||||
router_logits = self.gate(hidden_states.to(dtype=torch.float32))
|
||||
final_hidden_states = self.experts(
|
||||
hidden_states=hidden_states,
|
||||
|
||||
@@ -36,6 +36,7 @@ from vllm.model_executor.layers.linear import (QKVParallelLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||
from vllm.model_executor.layers.shared_fused_moe import SharedFusedMoE
|
||||
from vllm.model_executor.model_loader.weight_utils import (
|
||||
default_weight_loader, maybe_remap_kv_scale_name)
|
||||
|
||||
@@ -73,7 +74,18 @@ class Llama4MoE(nn.Module):
|
||||
quant_config=None,
|
||||
prefix=f"{prefix}.router")
|
||||
|
||||
self.experts = FusedMoE(
|
||||
self.shared_expert = LlamaMLP(
|
||||
hidden_size=config.hidden_size,
|
||||
intermediate_size=intermediate_size_moe,
|
||||
hidden_act="silu",
|
||||
quant_config=quant_config,
|
||||
bias=False,
|
||||
prefix=f"{prefix}.shared_expert",
|
||||
reduce_results=False,
|
||||
)
|
||||
|
||||
self.experts = SharedFusedMoE(
|
||||
shared_experts=self.shared_expert,
|
||||
num_experts=config.num_local_experts,
|
||||
top_k=config.num_experts_per_tok,
|
||||
hidden_size=config.hidden_size,
|
||||
@@ -83,22 +95,13 @@ class Llama4MoE(nn.Module):
|
||||
reduce_results=False,
|
||||
renormalize=False,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.experts")
|
||||
|
||||
self.shared_expert = LlamaMLP(
|
||||
hidden_size=config.hidden_size,
|
||||
intermediate_size=intermediate_size_moe,
|
||||
hidden_act="silu",
|
||||
quant_config=quant_config,
|
||||
bias=False,
|
||||
prefix=f"{prefix}.shared_expert",
|
||||
reduce_results=self.experts.must_reduce_shared_expert_outputs(),
|
||||
prefix=f"{prefix}.experts",
|
||||
)
|
||||
|
||||
def forward(self, hidden_states):
|
||||
router_logits, _ = self.router(hidden_states)
|
||||
shared_out = self.shared_expert(hidden_states)
|
||||
routed_out = self.experts(
|
||||
|
||||
shared_out, routed_out = self.experts(
|
||||
hidden_states=hidden_states,
|
||||
router_logits=router_logits,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user