[BugFix] : Fix Batched DeepGemm Experts (#19515)

Signed-off-by: Varun Sundar Rabindranath <vsundarr@redhat.com>
Co-authored-by: Varun Sundar Rabindranath <vsundarr@redhat.com>
This commit is contained in:
Varun Sundar Rabindranath
2025-06-12 22:43:02 -04:00
committed by GitHub
parent e6aab5de29
commit e3b12667d4
9 changed files with 52 additions and 32 deletions

View File

@@ -194,7 +194,8 @@ class FusedMoEPermuteExpertsUnpermute(ABC):
N: int,
K: int,
topk: int,
num_experts: int,
global_num_experts: int,
local_num_experts: int,
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]:
"""
Compute the shapes for the temporary and final outputs of the two gemms
@@ -372,8 +373,9 @@ class FusedMoEModularKernel(torch.nn.Module):
a1 = hidden_states
output = a1 if inplace else torch.zeros_like(a1)
local_num_experts = w1.size(0)
if global_num_experts == -1:
global_num_experts = w1.size(0)
global_num_experts = local_num_experts
(a1q, a1q_scale, expert_num_tokens, _expert_topk_ids,
_expert_topk_weights) = self.prepare_finalize.prepare(
@@ -408,16 +410,19 @@ class FusedMoEModularKernel(torch.nn.Module):
if num_chunks == 1:
(workspace13_shape, workspace2_shape, fused_out_shape,
workspace_dtype) = self.fused_experts.workspace_shapes(
a1, a1q, M, N, K, top_k, global_num_experts)
a1, a1q, M, N, K, top_k, global_num_experts,
local_num_experts)
else:
# Use the full M to get the final output shape.
_, _, fused_out_shape, _ = (
self.fused_experts.workspace_shapes(
a1, a1q, M, N, K, top_k, global_num_experts))
a1, a1q, M, N, K, top_k, global_num_experts,
local_num_experts))
# Use the CHUNK_SIZE to get the workspace shapes.
workspace13_shape, workspace2_shape, _, workspace_dtype = (
self.fused_experts.workspace_shapes(
a1, a1q, CHUNK_SIZE, N, K, top_k, global_num_experts))
a1, a1q, CHUNK_SIZE, N, K, top_k, global_num_experts,
local_num_experts))
# We can reuse the memory between cache1 and cache3 because by the
# time we need cache3, we're done with cache1.