[Performance][B200] Fix deepgemm prologue (#27897)

Signed-off-by: Varun Sundar Rabindranath <vsundarr@redhat.com>
Co-authored-by: Varun Sundar Rabindranath <vsundarr@redhat.com>
This commit is contained in:
Varun Sundar Rabindranath
2025-11-12 16:13:03 -05:00
committed by GitHub
parent 478ee511de
commit 74a9a9faad
6 changed files with 163 additions and 48 deletions

View File

@@ -149,6 +149,15 @@ class FusedMoEPrepareAndFinalize(ABC):
described above.
"""
def post_init_setup(self, fused_experts: "FusedMoEPermuteExpertsUnpermute"):
"""
Initialize FusedMoEPrepareAndFinalize settings that depend on
FusedMoEPermuteExpertsUnpermute experts object.
The FusedMoEPrepareAndFinalize implementations that have such
dependencies may choose to override this function.
"""
return
@abstractmethod
def prepare(
self,
@@ -503,6 +512,13 @@ class FusedMoEPermuteExpertsUnpermute(ABC):
"""
raise NotImplementedError
def supports_packed_ue8m0_act_scales(self) -> bool:
"""
A flag indicating whether or not this class can process packed ue8m0
activation scales.
"""
return False
def workspace_dtype(self, act_dtype: torch.dtype) -> torch.dtype:
"""
Workspace type: The dtype to use for the workspace tensors.
@@ -698,6 +714,8 @@ class FusedMoEModularKernel(torch.nn.Module):
self.prepare_finalize = prepare_finalize
self.fused_experts = fused_experts
self.shared_experts = shared_experts
self._post_init_setup()
assert (
prepare_finalize.activation_format == fused_experts.activation_formats[0]
), (
@@ -707,6 +725,13 @@ class FusedMoEModularKernel(torch.nn.Module):
f"{fused_experts.activation_formats[0]}"
)
def _post_init_setup(self):
"""
Resolve any leftover setup dependencies between self.prepare_finalize
and self.fused_experts here.
"""
self.prepare_finalize.post_init_setup(self.fused_experts)
def supports_expert_map(self) -> bool:
"""
A flag indicating whether or not this class supports expert maps.