[Moe Refactor] Make Inplace Flag for FusedMoEModularKernel part of the constructor (#33375)

Signed-off-by: Bill Nell <bnell@redhat.com>
Co-authored-by: Robert Shaw <114415538+robertgshaw2-redhat@users.noreply.github.com>
This commit is contained in:
bnellnm
2026-02-05 13:07:18 -05:00
committed by GitHub
parent 1ee95841bd
commit a57c8228ff
37 changed files with 132 additions and 109 deletions

View File

@@ -811,11 +811,13 @@ class FusedMoEModularKernel(torch.nn.Module):
fused_experts: FusedMoEPermuteExpertsUnpermute,
shared_experts: torch.nn.Module | None = None,
moe_parallel_config: FusedMoEParallelConfig | None = None,
inplace: bool = False,
):
super().__init__()
self.prepare_finalize = prepare_finalize
self.fused_experts = fused_experts
self.shared_experts = shared_experts
self.inplace = inplace
# prefer an explicit FusedMoEParallelConfig when available (from
# FusedMoE layers / tests).
@@ -1292,7 +1294,6 @@ class FusedMoEModularKernel(torch.nn.Module):
w2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
inplace: bool = False,
activation: str = "silu",
global_num_experts: int = -1,
expert_map: torch.Tensor | None = None,
@@ -1309,8 +1310,6 @@ class FusedMoEModularKernel(torch.nn.Module):
- topk_weights (torch.Tensor): The topk weights applied at the end of
the layer.
- topk_ids (torch.Tensor): A map of row to expert id.
- inplace (bool): If True, perform the operation in-place.
Defaults to False.
- activation (str): The activation function to apply after the first
MoE layer.
- global_num_experts (int): The total number of experts in the global
@@ -1326,7 +1325,9 @@ class FusedMoEModularKernel(torch.nn.Module):
- torch.Tensor: The output tensor after applying the MoE layer.
"""
if inplace and self.shared_experts is None and not disable_inplace():
if self.inplace:
assert self.shared_experts is None
assert not disable_inplace()
output = hidden_states
else:
output = torch.zeros_like(hidden_states)