[Kernel] Integrate CUTLASS MoE kernel with PPLX (#18762)
Signed-off-by: ElizaWszola <ewszola@redhat.com> Signed-off-by: Tyler Michael Smith <tyler@neuralmagic.com> Co-authored-by: Tyler Michael Smith <tyler@neuralmagic.com>
This commit is contained in:
@@ -175,6 +175,7 @@ class FusedMoEPermuteExpertsUnpermute(ABC):
|
||||
def workspace_shapes(
|
||||
self,
|
||||
a: torch.Tensor,
|
||||
aq: torch.Tensor,
|
||||
M: int,
|
||||
N: int,
|
||||
K: int,
|
||||
@@ -309,7 +310,7 @@ class FusedMoEModularKernel(torch.nn.Module):
|
||||
|
||||
# Use a1 here to decipher the correct workspace datatype
|
||||
workspace13_shape, workspace2_shape, workspace_dtype = (
|
||||
self.fused_experts.workspace_shapes(a1, M, N, K, top_k,
|
||||
self.fused_experts.workspace_shapes(a1, a1q, M, N, K, top_k,
|
||||
global_num_experts))
|
||||
|
||||
# We can reuse the memory between cache1 and cache3 because by the time
|
||||
|
||||
Reference in New Issue
Block a user