[Kernels] Clean up FusedMoeMethodBase and modular kernel setup. Remove extra arguments from modular kernel methods. (#22035)

Signed-off-by: Bill Nell <bnell@redhat.com>
Co-authored-by: Michael Goin <mgoin64@gmail.com>
This commit is contained in:
bnellnm
2025-08-15 14:46:00 -04:00
committed by GitHub
parent 48b01fd4d4
commit 8ad7285ea2
54 changed files with 2010 additions and 1293 deletions

View File

@@ -105,7 +105,8 @@ class DeviceCommunicatorBase:
# we initialize the all2all manager used in expert parallel.
use_ep = config.parallel_config.data_parallel_size > 1
self.use_all2all = "ep" in unique_name and use_ep
self.is_ep_communicator = "ep" in unique_name
self.use_all2all = self.is_ep_communicator and use_ep
self.all2all_manager: Optional[All2AllManagerBase] = None
def all_reduce(self, input_: torch.Tensor) -> torch.Tensor:
@@ -246,7 +247,7 @@ class DeviceCommunicatorBase:
"""
Prepare the communication buffer for the model.
"""
if not self.use_all2all:
if not self.is_ep_communicator:
return
moe_modules = [
@@ -254,7 +255,7 @@ class DeviceCommunicatorBase:
if module.__class__.__name__ == "FusedMoE"
]
for module in moe_modules:
module.quant_method.init_prepare_finalize(module.moe_config)
module.quant_method.init_prepare_finalize()
def dispatch(
self, hidden_states: torch.Tensor,