[Moe Refactor] Make Inplace Flag for FusedMoEModularKernel part of the constructor (#33375)
Signed-off-by: Bill Nell <bnell@redhat.com> Co-authored-by: Robert Shaw <114415538+robertgshaw2-redhat@users.noreply.github.com>
This commit is contained in:
@@ -811,11 +811,13 @@ class FusedMoEModularKernel(torch.nn.Module):
|
||||
fused_experts: FusedMoEPermuteExpertsUnpermute,
|
||||
shared_experts: torch.nn.Module | None = None,
|
||||
moe_parallel_config: FusedMoEParallelConfig | None = None,
|
||||
inplace: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
self.prepare_finalize = prepare_finalize
|
||||
self.fused_experts = fused_experts
|
||||
self.shared_experts = shared_experts
|
||||
self.inplace = inplace
|
||||
|
||||
# prefer an explicit FusedMoEParallelConfig when available (from
|
||||
# FusedMoE layers / tests).
|
||||
@@ -1292,7 +1294,6 @@ class FusedMoEModularKernel(torch.nn.Module):
|
||||
w2: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
inplace: bool = False,
|
||||
activation: str = "silu",
|
||||
global_num_experts: int = -1,
|
||||
expert_map: torch.Tensor | None = None,
|
||||
@@ -1309,8 +1310,6 @@ class FusedMoEModularKernel(torch.nn.Module):
|
||||
- topk_weights (torch.Tensor): The topk weights applied at the end of
|
||||
the layer.
|
||||
- topk_ids (torch.Tensor): A map of row to expert id.
|
||||
- inplace (bool): If True, perform the operation in-place.
|
||||
Defaults to False.
|
||||
- activation (str): The activation function to apply after the first
|
||||
MoE layer.
|
||||
- global_num_experts (int): The total number of experts in the global
|
||||
@@ -1326,7 +1325,9 @@ class FusedMoEModularKernel(torch.nn.Module):
|
||||
- torch.Tensor: The output tensor after applying the MoE layer.
|
||||
"""
|
||||
|
||||
if inplace and self.shared_experts is None and not disable_inplace():
|
||||
if self.inplace:
|
||||
assert self.shared_experts is None
|
||||
assert not disable_inplace()
|
||||
output = hidden_states
|
||||
else:
|
||||
output = torch.zeros_like(hidden_states)
|
||||
|
||||
Reference in New Issue
Block a user