[Kernels] Clean up FusedMoeMethodBase and modular kernel setup. Remove extra arguments from modular kernel methods. (#22035)
Signed-off-by: Bill Nell <bnell@redhat.com> Co-authored-by: Michael Goin <mgoin64@gmail.com>
This commit is contained in:
@@ -105,7 +105,8 @@ class DeviceCommunicatorBase:
|
||||
# we initialize the all2all manager used in expert parallel.
|
||||
use_ep = config.parallel_config.data_parallel_size > 1
|
||||
|
||||
self.use_all2all = "ep" in unique_name and use_ep
|
||||
self.is_ep_communicator = "ep" in unique_name
|
||||
self.use_all2all = self.is_ep_communicator and use_ep
|
||||
self.all2all_manager: Optional[All2AllManagerBase] = None
|
||||
|
||||
def all_reduce(self, input_: torch.Tensor) -> torch.Tensor:
|
||||
@@ -246,7 +247,7 @@ class DeviceCommunicatorBase:
|
||||
"""
|
||||
Prepare the communication buffer for the model.
|
||||
"""
|
||||
if not self.use_all2all:
|
||||
if not self.is_ep_communicator:
|
||||
return
|
||||
|
||||
moe_modules = [
|
||||
@@ -254,7 +255,7 @@ class DeviceCommunicatorBase:
|
||||
if module.__class__.__name__ == "FusedMoE"
|
||||
]
|
||||
for module in moe_modules:
|
||||
module.quant_method.init_prepare_finalize(module.moe_config)
|
||||
module.quant_method.init_prepare_finalize()
|
||||
|
||||
def dispatch(
|
||||
self, hidden_states: torch.Tensor,
|
||||
|
||||
Reference in New Issue
Block a user