[Kernels] Overlap shared experts with send/recv (#23273)

Signed-off-by: Bill Nell <bnell@redhat.com>
This commit is contained in:
bnellnm
2025-09-03 12:35:18 -04:00
committed by GitHub
parent fa4311d85f
commit e9b92dcd89
32 changed files with 885 additions and 227 deletions

View File

@@ -4,7 +4,7 @@
from abc import abstractmethod
from collections.abc import Iterable
from enum import Enum
from typing import Callable, Literal, Optional, overload
from typing import Callable, Literal, Optional, Union, overload
import torch
import torch.nn.functional as F
@@ -215,6 +215,7 @@ class FusedMoEMethodBase(QuantizeMethodBase):
self.fused_experts = FusedMoEModularKernel(
prepare_finalize,
experts,
layer.shared_experts,
)
def select_gemm_impl(
@@ -252,7 +253,7 @@ class FusedMoEMethodBase(QuantizeMethodBase):
expert_load_view: Optional[torch.Tensor] = None,
logical_to_physical_map: Optional[torch.Tensor] = None,
logical_replica_count: Optional[torch.Tensor] = None,
) -> torch.Tensor:
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
raise NotImplementedError
@@ -409,7 +410,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
expert_load_view: Optional[torch.Tensor] = None,
logical_to_physical_map: Optional[torch.Tensor] = None,
logical_replica_count: Optional[torch.Tensor] = None,
) -> torch.Tensor:
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
if enable_eplb:
assert expert_load_view is not None
assert logical_to_physical_map is not None
@@ -461,7 +462,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
expert_load_view: Optional[torch.Tensor] = None,
logical_to_physical_map: Optional[torch.Tensor] = None,
logical_replica_count: Optional[torch.Tensor] = None,
) -> torch.Tensor:
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
topk_weights, topk_ids = FusedMoE.select_experts(
hidden_states=x,
@@ -547,7 +548,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
expert_load_view: Optional[torch.Tensor] = None,
logical_to_physical_map: Optional[torch.Tensor] = None,
logical_replica_count: Optional[torch.Tensor] = None,
):
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
if enable_eplb is not False or expert_load_view is not None or \
logical_to_physical_map is not None or \
logical_replica_count is not None:
@@ -594,7 +595,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
expert_load_view: Optional[torch.Tensor] = None,
logical_to_physical_map: Optional[torch.Tensor] = None,
logical_replica_count: Optional[torch.Tensor] = None,
):
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
if enable_eplb is not False or expert_load_view is not None or \
logical_to_physical_map is not None or \
logical_replica_count is not None:
@@ -633,7 +634,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
expert_load_view: Optional[torch.Tensor] = None,
logical_to_physical_map: Optional[torch.Tensor] = None,
logical_replica_count: Optional[torch.Tensor] = None,
) -> torch.Tensor:
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
assert not use_grouped_topk
assert num_expert_group is None
assert topk_group is None
@@ -948,6 +949,10 @@ class FusedMoE(CustomOp):
dtype=moe.in_dtype,
device=torch.cuda.current_device())
@property
def shared_experts(self) -> Optional[torch.nn.Module]:
return None
@property
def tp_size(self):
return self.moe_parallel_config.tp_size
@@ -1400,6 +1405,7 @@ class FusedMoE(CustomOp):
return [
weight.view(self.local_num_experts, -1) for name, weight in weights
if name not in NON_EXPERT_WEIGHTS
and not name.startswith("_shared_experts.")
]
def set_eplb_state(
@@ -1582,25 +1588,45 @@ class FusedMoE(CustomOp):
else:
return tensor_model_parallel_all_reduce(final_hidden_states)
def forward(self, hidden_states: torch.Tensor,
router_logits: torch.Tensor):
def forward(
self,
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
og_hidden_states = hidden_states.shape[-1]
if self.hidden_size != og_hidden_states:
hidden_states = F.pad(hidden_states,
(0, self.hidden_size - og_hidden_states),
mode='constant',
value=0.0)
# TODO: Once the OOM issue for the TPU backend is resolved, we will
# switch to using the moe_forward custom op.
if current_platform.is_tpu():
return self.forward_impl(hidden_states, router_logits)
else:
return torch.ops.vllm.moe_forward(
hidden_states, router_logits,
self.layer_name)[..., :og_hidden_states]
def forward_impl_chunked(self, full_hidden_states: torch.Tensor,
full_router_logits: torch.Tensor):
if self.shared_experts is None:
if current_platform.is_tpu():
# TODO: Once the OOM issue for the TPU backend is resolved, we
# will switch to using the moe_forward custom op.
fused_output = self.forward_impl(hidden_states, router_logits)
assert not isinstance(fused_output, tuple)
else:
fused_output = torch.ops.vllm.moe_forward(
hidden_states, router_logits, self.layer_name)
return fused_output[..., :og_hidden_states]
else:
if current_platform.is_tpu():
# TODO: Once the OOM issue for the TPU backend is resolved, we
# will switch to using the moe_forward custom op.
shared_output, fused_output = self.forward_impl(
hidden_states, router_logits)
else:
shared_output, fused_output = torch.ops.vllm.moe_forward_shared(
hidden_states, router_logits, self.layer_name)
return (shared_output[..., :og_hidden_states],
fused_output[..., :og_hidden_states])
def forward_impl_chunked(
self,
full_hidden_states: torch.Tensor,
full_router_logits: torch.Tensor,
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
assert self.batched_hidden_states is not None
assert self.batched_router_logits is not None
assert self.batched_hidden_states.dtype == full_hidden_states.dtype
@@ -1611,7 +1637,10 @@ class FusedMoE(CustomOp):
assert (
self.batched_router_logits.size(-1) == full_router_logits.size(-1))
full_final_hidden_states = torch.empty_like(full_hidden_states)
full_fused_final_hidden_states = torch.empty_like(full_hidden_states)
if self.shared_experts is not None:
full_shared_final_hidden_states = torch.empty_like(
full_hidden_states)
def process_chunk(chunk_start, chunk_end, skip_result_store=False):
chunk_size = chunk_end - chunk_start
@@ -1652,9 +1681,21 @@ class FusedMoE(CustomOp):
logical_replica_count=self.logical_replica_count,
)
assert self.shared_experts is None or isinstance(
final_hidden_states, tuple)
if not skip_result_store:
full_final_hidden_states[chunk_start:chunk_end, :].copy_(
final_hidden_states, non_blocking=True)
if self.shared_experts is None:
full_fused_final_hidden_states[
chunk_start:chunk_end, :].copy_(final_hidden_states,
non_blocking=True)
else:
full_shared_final_hidden_states[
chunk_start:chunk_end, :].copy_(final_hidden_states[0],
non_blocking=True)
full_fused_final_hidden_states[
chunk_start:chunk_end, :].copy_(final_hidden_states[1],
non_blocking=True)
ctx = get_forward_context()
# flashinfer_cutlass_kernels can handle: optional DP + TP/EP
@@ -1675,10 +1716,17 @@ class FusedMoE(CustomOp):
chunk_end,
skip_result_store=chunk_start_ >= num_tokens)
return full_final_hidden_states
if self.shared_experts is None:
return full_fused_final_hidden_states
else:
return (full_shared_final_hidden_states,
full_fused_final_hidden_states)
def forward_impl(self, hidden_states: torch.Tensor,
router_logits: torch.Tensor):
def forward_impl(
self,
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
assert self.quant_method is not None
# Route to the chunked forward path using the FlashInfer Cutlass kernel
# only when data parallelism (DP) is enabled.
@@ -1698,6 +1746,15 @@ class FusedMoE(CustomOp):
hidden_states, router_logits = get_ep_group().dispatch(
hidden_states, router_logits)
# If there are shared experts but we are not using a modular kernel, the
# shared experts must be called here
if (not isinstance(self.quant_method.fused_experts,
FusedMoEModularKernel)
and self.shared_experts is not None):
shared_output = self.shared_experts(hidden_states)
else:
shared_output = None
# Matrix multiply.
final_hidden_states = self.quant_method.apply(
layer=self,
@@ -1722,14 +1779,30 @@ class FusedMoE(CustomOp):
logical_replica_count=self.logical_replica_count,
)
if do_naive_dispatch_combine:
final_hidden_states = get_ep_group().combine(final_hidden_states)
if self.reduce_results and (self.tp_size > 1 or self.ep_size > 1):
# Default set to False. (May have to add shared expert outputs.
final_hidden_states = self.maybe_all_reduce_tensor_model_parallel(
final_hidden_states)
if shared_output is not None:
assert not isinstance(final_hidden_states, tuple)
assert self.shared_experts is not None
final_hidden_states = (
shared_output,
final_hidden_states,
)
return final_hidden_states
def reduce_output(states: torch.Tensor) -> torch.Tensor:
if do_naive_dispatch_combine:
states = get_ep_group().combine(states)
if self.reduce_results and (self.tp_size > 1 or self.ep_size > 1):
states = self.maybe_all_reduce_tensor_model_parallel(states)
return states
if self.shared_experts is None:
return reduce_output(final_hidden_states)
else:
return (
reduce_output(final_hidden_states[0]),
reduce_output(final_hidden_states[1]),
)
@classmethod
def make_expert_params_mapping(
@@ -1784,17 +1857,22 @@ class FusedMoE(CustomOp):
return s
def moe_forward(hidden_states: torch.Tensor, router_logits: torch.Tensor,
layer_name: str) -> torch.Tensor:
def moe_forward(
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
layer_name: str,
) -> torch.Tensor:
forward_context: ForwardContext = get_forward_context()
self = forward_context.no_compile_layers[layer_name]
assert self.quant_method is not None
assert self.shared_experts is None
return self.forward_impl(hidden_states, router_logits)
def moe_forward_fake(hidden_states: torch.Tensor, router_logits: torch.Tensor,
layer_name: str) -> torch.Tensor:
def moe_forward_fake(
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
layer_name: str,
) -> torch.Tensor:
return torch.empty_like(hidden_states)
@@ -1807,6 +1885,37 @@ direct_register_custom_op(
tags=(torch.Tag.needs_fixed_stride_order, ),
)
def moe_forward_shared(
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
layer_name: str,
) -> tuple[torch.Tensor, torch.Tensor]:
forward_context: ForwardContext = get_forward_context()
self = forward_context.no_compile_layers[layer_name]
assert self.shared_experts is not None
return self.forward_impl(hidden_states, router_logits)
def moe_forward_shared_fake(
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
layer_name: str,
) -> tuple[torch.Tensor, torch.Tensor]:
shared_out = torch.empty_like(hidden_states)
fused_out = torch.empty_like(hidden_states)
return shared_out, fused_out
direct_register_custom_op(
op_name="moe_forward_shared",
op_func=moe_forward_shared,
mutates_args=["hidden_states"],
fake_impl=moe_forward_shared_fake,
dispatch_key=current_platform.dispatch_key,
tags=(torch.Tag.needs_fixed_stride_order, ),
)
# Mark the FusedMoE weight_loader as supporting MoE-specific parameters
# to avoid expensive runtime reflection in model loading code
FusedMoE.weight_loader.supports_moe_loading = True # type: ignore[attr-defined]