[Kernels] Overlap shared experts with send/recv (#23273)
Signed-off-by: Bill Nell <bnell@redhat.com>
This commit is contained in:
@@ -4,7 +4,7 @@
|
||||
from abc import abstractmethod
|
||||
from collections.abc import Iterable
|
||||
from enum import Enum
|
||||
from typing import Callable, Literal, Optional, overload
|
||||
from typing import Callable, Literal, Optional, Union, overload
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
@@ -215,6 +215,7 @@ class FusedMoEMethodBase(QuantizeMethodBase):
|
||||
self.fused_experts = FusedMoEModularKernel(
|
||||
prepare_finalize,
|
||||
experts,
|
||||
layer.shared_experts,
|
||||
)
|
||||
|
||||
def select_gemm_impl(
|
||||
@@ -252,7 +253,7 @@ class FusedMoEMethodBase(QuantizeMethodBase):
|
||||
expert_load_view: Optional[torch.Tensor] = None,
|
||||
logical_to_physical_map: Optional[torch.Tensor] = None,
|
||||
logical_replica_count: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
@@ -409,7 +410,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
||||
expert_load_view: Optional[torch.Tensor] = None,
|
||||
logical_to_physical_map: Optional[torch.Tensor] = None,
|
||||
logical_replica_count: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
|
||||
if enable_eplb:
|
||||
assert expert_load_view is not None
|
||||
assert logical_to_physical_map is not None
|
||||
@@ -461,7 +462,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
||||
expert_load_view: Optional[torch.Tensor] = None,
|
||||
logical_to_physical_map: Optional[torch.Tensor] = None,
|
||||
logical_replica_count: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
|
||||
|
||||
topk_weights, topk_ids = FusedMoE.select_experts(
|
||||
hidden_states=x,
|
||||
@@ -547,7 +548,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
||||
expert_load_view: Optional[torch.Tensor] = None,
|
||||
logical_to_physical_map: Optional[torch.Tensor] = None,
|
||||
logical_replica_count: Optional[torch.Tensor] = None,
|
||||
):
|
||||
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
|
||||
if enable_eplb is not False or expert_load_view is not None or \
|
||||
logical_to_physical_map is not None or \
|
||||
logical_replica_count is not None:
|
||||
@@ -594,7 +595,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
||||
expert_load_view: Optional[torch.Tensor] = None,
|
||||
logical_to_physical_map: Optional[torch.Tensor] = None,
|
||||
logical_replica_count: Optional[torch.Tensor] = None,
|
||||
):
|
||||
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
|
||||
if enable_eplb is not False or expert_load_view is not None or \
|
||||
logical_to_physical_map is not None or \
|
||||
logical_replica_count is not None:
|
||||
@@ -633,7 +634,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
||||
expert_load_view: Optional[torch.Tensor] = None,
|
||||
logical_to_physical_map: Optional[torch.Tensor] = None,
|
||||
logical_replica_count: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
|
||||
assert not use_grouped_topk
|
||||
assert num_expert_group is None
|
||||
assert topk_group is None
|
||||
@@ -948,6 +949,10 @@ class FusedMoE(CustomOp):
|
||||
dtype=moe.in_dtype,
|
||||
device=torch.cuda.current_device())
|
||||
|
||||
@property
|
||||
def shared_experts(self) -> Optional[torch.nn.Module]:
|
||||
return None
|
||||
|
||||
@property
|
||||
def tp_size(self):
|
||||
return self.moe_parallel_config.tp_size
|
||||
@@ -1400,6 +1405,7 @@ class FusedMoE(CustomOp):
|
||||
return [
|
||||
weight.view(self.local_num_experts, -1) for name, weight in weights
|
||||
if name not in NON_EXPERT_WEIGHTS
|
||||
and not name.startswith("_shared_experts.")
|
||||
]
|
||||
|
||||
def set_eplb_state(
|
||||
@@ -1582,25 +1588,45 @@ class FusedMoE(CustomOp):
|
||||
else:
|
||||
return tensor_model_parallel_all_reduce(final_hidden_states)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor):
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
|
||||
og_hidden_states = hidden_states.shape[-1]
|
||||
if self.hidden_size != og_hidden_states:
|
||||
hidden_states = F.pad(hidden_states,
|
||||
(0, self.hidden_size - og_hidden_states),
|
||||
mode='constant',
|
||||
value=0.0)
|
||||
# TODO: Once the OOM issue for the TPU backend is resolved, we will
|
||||
# switch to using the moe_forward custom op.
|
||||
if current_platform.is_tpu():
|
||||
return self.forward_impl(hidden_states, router_logits)
|
||||
else:
|
||||
return torch.ops.vllm.moe_forward(
|
||||
hidden_states, router_logits,
|
||||
self.layer_name)[..., :og_hidden_states]
|
||||
|
||||
def forward_impl_chunked(self, full_hidden_states: torch.Tensor,
|
||||
full_router_logits: torch.Tensor):
|
||||
if self.shared_experts is None:
|
||||
if current_platform.is_tpu():
|
||||
# TODO: Once the OOM issue for the TPU backend is resolved, we
|
||||
# will switch to using the moe_forward custom op.
|
||||
fused_output = self.forward_impl(hidden_states, router_logits)
|
||||
assert not isinstance(fused_output, tuple)
|
||||
else:
|
||||
fused_output = torch.ops.vllm.moe_forward(
|
||||
hidden_states, router_logits, self.layer_name)
|
||||
return fused_output[..., :og_hidden_states]
|
||||
else:
|
||||
if current_platform.is_tpu():
|
||||
# TODO: Once the OOM issue for the TPU backend is resolved, we
|
||||
# will switch to using the moe_forward custom op.
|
||||
shared_output, fused_output = self.forward_impl(
|
||||
hidden_states, router_logits)
|
||||
else:
|
||||
shared_output, fused_output = torch.ops.vllm.moe_forward_shared(
|
||||
hidden_states, router_logits, self.layer_name)
|
||||
return (shared_output[..., :og_hidden_states],
|
||||
fused_output[..., :og_hidden_states])
|
||||
|
||||
def forward_impl_chunked(
|
||||
self,
|
||||
full_hidden_states: torch.Tensor,
|
||||
full_router_logits: torch.Tensor,
|
||||
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
|
||||
assert self.batched_hidden_states is not None
|
||||
assert self.batched_router_logits is not None
|
||||
assert self.batched_hidden_states.dtype == full_hidden_states.dtype
|
||||
@@ -1611,7 +1637,10 @@ class FusedMoE(CustomOp):
|
||||
assert (
|
||||
self.batched_router_logits.size(-1) == full_router_logits.size(-1))
|
||||
|
||||
full_final_hidden_states = torch.empty_like(full_hidden_states)
|
||||
full_fused_final_hidden_states = torch.empty_like(full_hidden_states)
|
||||
if self.shared_experts is not None:
|
||||
full_shared_final_hidden_states = torch.empty_like(
|
||||
full_hidden_states)
|
||||
|
||||
def process_chunk(chunk_start, chunk_end, skip_result_store=False):
|
||||
chunk_size = chunk_end - chunk_start
|
||||
@@ -1652,9 +1681,21 @@ class FusedMoE(CustomOp):
|
||||
logical_replica_count=self.logical_replica_count,
|
||||
)
|
||||
|
||||
assert self.shared_experts is None or isinstance(
|
||||
final_hidden_states, tuple)
|
||||
|
||||
if not skip_result_store:
|
||||
full_final_hidden_states[chunk_start:chunk_end, :].copy_(
|
||||
final_hidden_states, non_blocking=True)
|
||||
if self.shared_experts is None:
|
||||
full_fused_final_hidden_states[
|
||||
chunk_start:chunk_end, :].copy_(final_hidden_states,
|
||||
non_blocking=True)
|
||||
else:
|
||||
full_shared_final_hidden_states[
|
||||
chunk_start:chunk_end, :].copy_(final_hidden_states[0],
|
||||
non_blocking=True)
|
||||
full_fused_final_hidden_states[
|
||||
chunk_start:chunk_end, :].copy_(final_hidden_states[1],
|
||||
non_blocking=True)
|
||||
|
||||
ctx = get_forward_context()
|
||||
# flashinfer_cutlass_kernels can handle: optional DP + TP/EP
|
||||
@@ -1675,10 +1716,17 @@ class FusedMoE(CustomOp):
|
||||
chunk_end,
|
||||
skip_result_store=chunk_start_ >= num_tokens)
|
||||
|
||||
return full_final_hidden_states
|
||||
if self.shared_experts is None:
|
||||
return full_fused_final_hidden_states
|
||||
else:
|
||||
return (full_shared_final_hidden_states,
|
||||
full_fused_final_hidden_states)
|
||||
|
||||
def forward_impl(self, hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor):
|
||||
def forward_impl(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
|
||||
assert self.quant_method is not None
|
||||
# Route to the chunked forward path using the FlashInfer Cutlass kernel
|
||||
# only when data parallelism (DP) is enabled.
|
||||
@@ -1698,6 +1746,15 @@ class FusedMoE(CustomOp):
|
||||
hidden_states, router_logits = get_ep_group().dispatch(
|
||||
hidden_states, router_logits)
|
||||
|
||||
# If there are shared experts but we are not using a modular kernel, the
|
||||
# shared experts must be called here
|
||||
if (not isinstance(self.quant_method.fused_experts,
|
||||
FusedMoEModularKernel)
|
||||
and self.shared_experts is not None):
|
||||
shared_output = self.shared_experts(hidden_states)
|
||||
else:
|
||||
shared_output = None
|
||||
|
||||
# Matrix multiply.
|
||||
final_hidden_states = self.quant_method.apply(
|
||||
layer=self,
|
||||
@@ -1722,14 +1779,30 @@ class FusedMoE(CustomOp):
|
||||
logical_replica_count=self.logical_replica_count,
|
||||
)
|
||||
|
||||
if do_naive_dispatch_combine:
|
||||
final_hidden_states = get_ep_group().combine(final_hidden_states)
|
||||
if self.reduce_results and (self.tp_size > 1 or self.ep_size > 1):
|
||||
# Default set to False. (May have to add shared expert outputs.
|
||||
final_hidden_states = self.maybe_all_reduce_tensor_model_parallel(
|
||||
final_hidden_states)
|
||||
if shared_output is not None:
|
||||
assert not isinstance(final_hidden_states, tuple)
|
||||
assert self.shared_experts is not None
|
||||
final_hidden_states = (
|
||||
shared_output,
|
||||
final_hidden_states,
|
||||
)
|
||||
|
||||
return final_hidden_states
|
||||
def reduce_output(states: torch.Tensor) -> torch.Tensor:
|
||||
if do_naive_dispatch_combine:
|
||||
states = get_ep_group().combine(states)
|
||||
|
||||
if self.reduce_results and (self.tp_size > 1 or self.ep_size > 1):
|
||||
states = self.maybe_all_reduce_tensor_model_parallel(states)
|
||||
|
||||
return states
|
||||
|
||||
if self.shared_experts is None:
|
||||
return reduce_output(final_hidden_states)
|
||||
else:
|
||||
return (
|
||||
reduce_output(final_hidden_states[0]),
|
||||
reduce_output(final_hidden_states[1]),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def make_expert_params_mapping(
|
||||
@@ -1784,17 +1857,22 @@ class FusedMoE(CustomOp):
|
||||
return s
|
||||
|
||||
|
||||
def moe_forward(hidden_states: torch.Tensor, router_logits: torch.Tensor,
|
||||
layer_name: str) -> torch.Tensor:
|
||||
def moe_forward(
|
||||
hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
layer_name: str,
|
||||
) -> torch.Tensor:
|
||||
forward_context: ForwardContext = get_forward_context()
|
||||
self = forward_context.no_compile_layers[layer_name]
|
||||
assert self.quant_method is not None
|
||||
|
||||
assert self.shared_experts is None
|
||||
return self.forward_impl(hidden_states, router_logits)
|
||||
|
||||
|
||||
def moe_forward_fake(hidden_states: torch.Tensor, router_logits: torch.Tensor,
|
||||
layer_name: str) -> torch.Tensor:
|
||||
def moe_forward_fake(
|
||||
hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
layer_name: str,
|
||||
) -> torch.Tensor:
|
||||
return torch.empty_like(hidden_states)
|
||||
|
||||
|
||||
@@ -1807,6 +1885,37 @@ direct_register_custom_op(
|
||||
tags=(torch.Tag.needs_fixed_stride_order, ),
|
||||
)
|
||||
|
||||
|
||||
def moe_forward_shared(
|
||||
hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
layer_name: str,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
forward_context: ForwardContext = get_forward_context()
|
||||
self = forward_context.no_compile_layers[layer_name]
|
||||
assert self.shared_experts is not None
|
||||
return self.forward_impl(hidden_states, router_logits)
|
||||
|
||||
|
||||
def moe_forward_shared_fake(
|
||||
hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
layer_name: str,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
shared_out = torch.empty_like(hidden_states)
|
||||
fused_out = torch.empty_like(hidden_states)
|
||||
return shared_out, fused_out
|
||||
|
||||
|
||||
direct_register_custom_op(
|
||||
op_name="moe_forward_shared",
|
||||
op_func=moe_forward_shared,
|
||||
mutates_args=["hidden_states"],
|
||||
fake_impl=moe_forward_shared_fake,
|
||||
dispatch_key=current_platform.dispatch_key,
|
||||
tags=(torch.Tag.needs_fixed_stride_order, ),
|
||||
)
|
||||
|
||||
# Mark the FusedMoE weight_loader as supporting MoE-specific parameters
|
||||
# to avoid expensive runtime reflection in model loading code
|
||||
FusedMoE.weight_loader.supports_moe_loading = True # type: ignore[attr-defined]
|
||||
|
||||
Reference in New Issue
Block a user