[Bugfix] Fused MoE Modular Kernel chunking loop (#20392)

Signed-off-by: Varun Sundar Rabindranath <vsundarr@redhat.com>
Co-authored-by: Varun Sundar Rabindranath <vsundarr@redhat.com>
This commit is contained in:
Varun Sundar Rabindranath
2025-07-10 16:31:10 -04:00
committed by GitHub
parent 41060c6e08
commit fdadb6f43a
4 changed files with 404 additions and 107 deletions

View File

@@ -10,7 +10,8 @@ import torch
import vllm.envs as envs
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
from vllm.model_executor.layers.fused_moe.utils import _resize_cache
from vllm.model_executor.layers.fused_moe.utils import ( # yapf: disable
_resize_cache, count_expert_num_tokens)
from vllm.utils import cdiv
#
@@ -421,6 +422,177 @@ class FusedMoEModularKernel(torch.nn.Module):
f"{fused_experts.__class__.__name__}."
f"{fused_experts.activation_formats[0]}")
def _do_fused_experts(
self, fused_out: Optional[torch.Tensor], a1: torch.Tensor,
a1q: torch.Tensor, w1: torch.Tensor, w2: torch.Tensor,
topk_ids: torch.Tensor, activation: str, global_num_experts: int,
local_num_experts: int, expert_map: Optional[torch.Tensor],
w1_scale: Optional[torch.Tensor], w2_scale: Optional[torch.Tensor],
w1_zp: Optional[torch.Tensor], w2_zp: Optional[torch.Tensor],
a1q_scale: Optional[torch.Tensor],
a2_scale: Optional[torch.Tensor],
expert_tokens_meta: Optional[ExpertTokensMetadata]
) -> torch.Tensor:
_, M, N, K, top_k = _moe_problem_size(a1q, w1, w2, topk_ids)
(workspace13_shape, workspace2_shape, fused_out_shape,
workspace_dtype) = self.fused_experts.workspace_shapes(
a1, a1q, M, N, K, top_k, global_num_experts, local_num_experts)
# We can reuse the memory between cache1 and cache3 because by the
# time we need cache3, we're done with cache1.
workspace13 = torch.empty(prod(workspace13_shape),
device=a1.device,
dtype=workspace_dtype)
workspace2 = torch.empty(prod(workspace2_shape),
device=a1.device,
dtype=workspace_dtype)
assert fused_out is None or fused_out.shape == fused_out_shape, (
f"fused_out {fused_out.shape} but expected {fused_out_shape}")
if fused_out is None:
# reuse workspace13 for the output
fused_out = _resize_cache(workspace13, fused_out_shape)
self.fused_experts.apply(fused_out,
a1q,
w1,
w2,
topk_ids=topk_ids,
activation=activation,
global_num_experts=global_num_experts,
expert_map=expert_map,
w1_scale=w1_scale,
w2_scale=w2_scale,
w1_zp=w1_zp,
w2_zp=w2_zp,
a1q_scale=a1q_scale,
a2_scale=a2_scale,
workspace13=workspace13,
workspace2=workspace2,
expert_tokens_meta=expert_tokens_meta)
return fused_out
def _maybe_chunk_fused_experts(
self, a1: torch.Tensor, a1q: torch.Tensor, w1: torch.Tensor,
w2: torch.Tensor, topk_ids: torch.Tensor, activation: str,
global_num_experts: int, local_num_experts: int,
expert_map: Optional[torch.Tensor],
w1_scale: Optional[torch.Tensor], w2_scale: Optional[torch.Tensor],
w1_zp: Optional[torch.Tensor], w2_zp: Optional[torch.Tensor],
a1q_scale: Optional[torch.Tensor],
a2_scale: Optional[torch.Tensor],
expert_tokens_meta: Optional[ExpertTokensMetadata]
) -> torch.Tensor:
_, M, N, K, top_k = _moe_problem_size(a1q, w1, w2, topk_ids)
CHUNK_SIZE = envs.VLLM_FUSED_MOE_CHUNK_SIZE
num_chunks = cdiv(M, CHUNK_SIZE)
if not self.fused_experts.supports_chunking() or num_chunks == 1:
return self._do_fused_experts(
fused_out=None,
a1=a1,
a1q=a1q,
w1=w1,
w2=w2,
topk_ids=topk_ids,
activation=activation,
global_num_experts=global_num_experts,
local_num_experts=local_num_experts,
expert_map=expert_map,
w1_scale=w1_scale,
w2_scale=w2_scale,
w1_zp=w1_zp,
w2_zp=w2_zp,
a1q_scale=a1q_scale,
a2_scale=a2_scale,
expert_tokens_meta=expert_tokens_meta)
# Chunking required case
assert num_chunks > 1
# Construct the entire output that can then be processed in chunks.
(_, _, fused_out_shape,
_) = self.fused_experts.workspace_shapes(a1, a1q, M, N, K, top_k,
global_num_experts,
local_num_experts)
fused_out = torch.empty(fused_out_shape,
device=a1q.device,
dtype=a1.dtype)
def slice_input_tensors(
chunk_idx: int
) -> tuple[torch.Tensor, Optional[torch.Tensor],
Optional[torch.Tensor], torch.Tensor]:
s = chunk_idx * CHUNK_SIZE
e = min(s + CHUNK_SIZE, M)
return (a1q[s:e], _chunk_scales(a1q_scale, s, e),
_chunk_scales(a2_scale, s, e), topk_ids[s:e])
def slice_output_tensor(chunk_idx: int) -> torch.Tensor:
assert fused_out.size(0) % M == 0, (
f"fused_out shape {fused_out.shape} vs M {M}")
factor = fused_out.size(0) // M
out_chunk_size = CHUNK_SIZE * factor
s = chunk_idx * out_chunk_size
e = min(s + out_chunk_size, fused_out.size(0))
return fused_out[s:e]
def slice_expert_tokens_metadata(
full_expert_tokens_meta: ExpertTokensMetadata,
chunk_topk_ids: torch.Tensor, local_num_experts: int,
expert_map: Optional[torch.Tensor]) -> ExpertTokensMetadata:
# The existing expert_num_tokens is for the entire a1q
# input. Chunking forces recomputation of the number
# of tokens assigned to each expert.
c_expert_num_tokens = count_expert_num_tokens(
chunk_topk_ids, local_num_experts, expert_map)
c_expert_num_tokens_cpu = None
need_expert_num_tokens_cpu = (
full_expert_tokens_meta.expert_num_tokens_cpu is not None)
if need_expert_num_tokens_cpu:
c_expert_num_tokens_cpu = c_expert_num_tokens.to(
"cpu", non_blocking=True)
return ExpertTokensMetadata(
expert_num_tokens=c_expert_num_tokens,
expert_num_tokens_cpu=c_expert_num_tokens_cpu)
for chunk_idx in range(num_chunks):
c_a1q, c_a1q_scale, c_a2_scale, c_topk_ids = (
slice_input_tensors(chunk_idx))
c_expert_tokens_meta = None
if expert_tokens_meta is not None:
c_expert_tokens_meta = slice_expert_tokens_metadata(
expert_tokens_meta, c_topk_ids, local_num_experts,
expert_map)
self._do_fused_experts(fused_out=slice_output_tensor(chunk_idx),
a1=a1,
a1q=c_a1q,
w1=w1,
w2=w2,
topk_ids=c_topk_ids,
activation=activation,
global_num_experts=global_num_experts,
local_num_experts=local_num_experts,
expert_map=expert_map,
w1_scale=w1_scale,
w2_scale=w2_scale,
w1_zp=w1_zp,
w2_zp=w2_zp,
a1q_scale=c_a1q_scale,
a2_scale=c_a2_scale,
expert_tokens_meta=c_expert_tokens_meta)
return fused_out
def forward(
self,
hidden_states: torch.Tensor,
@@ -512,110 +684,23 @@ class FusedMoEModularKernel(torch.nn.Module):
# and can never run into the tensor.numel() == 0 case.
fused_out = torch.empty_like(a1q).to(dtype=a1.dtype)
else:
_, M, N, K, top_k = _moe_problem_size(a1q, w1, w2, topk_ids)
if self.fused_experts.enable_chunking():
CHUNK_SIZE = envs.VLLM_FUSED_MOE_CHUNK_SIZE
num_chunks = cdiv(M, CHUNK_SIZE)
else:
CHUNK_SIZE = M
num_chunks = 1
if num_chunks == 1:
(workspace13_shape, workspace2_shape, fused_out_shape,
workspace_dtype) = self.fused_experts.workspace_shapes(
a1, a1q, M, N, K, top_k, global_num_experts,
local_num_experts)
else:
# Use the full M to get the final output shape.
_, _, fused_out_shape, _ = (
self.fused_experts.workspace_shapes(
a1, a1q, M, N, K, top_k, global_num_experts,
local_num_experts))
# Use the CHUNK_SIZE to get the workspace shapes.
workspace13_shape, workspace2_shape, _, workspace_dtype = (
self.fused_experts.workspace_shapes(
a1, a1q, CHUNK_SIZE, N, K, top_k, global_num_experts,
local_num_experts))
# We can reuse the memory between cache1 and cache3 because by the
# time we need cache3, we're done with cache1.
workspace13 = torch.empty(prod(workspace13_shape),
device=a1.device,
dtype=workspace_dtype)
workspace2 = torch.empty(prod(workspace2_shape),
device=a1.device,
dtype=workspace_dtype)
if num_chunks == 1:
fused_out = _resize_cache(workspace13, fused_out_shape)
self.fused_experts.apply(
fused_out,
a1q,
w1,
w2,
topk_ids,
activation=activation,
global_num_experts=global_num_experts,
expert_map=expert_map,
w1_scale=w1_scale,
w2_scale=w2_scale,
w1_zp=w1_zp,
w2_zp=w2_zp,
a1q_scale=a1q_scale,
a2_scale=a2_scale,
workspace13=workspace13,
workspace2=workspace2,
expert_tokens_meta=expert_tokens_meta,
)
else:
# The leading output dimension may not be equal to M, so
# we compute output indices separately.
M_out = fused_out_shape[0]
assert M_out >= M
factor = M_out // M
assert factor > 0
OUT_CHUNK_SIZE = CHUNK_SIZE * factor
fused_out = torch.empty(fused_out_shape,
device=a1q.device,
dtype=workspace_dtype)
assert cdiv(M_out, OUT_CHUNK_SIZE) == num_chunks, (
f"{cdiv(M_out, OUT_CHUNK_SIZE)} == {num_chunks}")
for chunk in range(num_chunks):
begin_chunk_idx = chunk * CHUNK_SIZE
end_chunk_idx = min((chunk + 1) * CHUNK_SIZE, M)
begin_out_idx = chunk * OUT_CHUNK_SIZE
end_out_idx = min((chunk + 1) * OUT_CHUNK_SIZE, M_out)
curr_a1q = a1q[begin_chunk_idx:end_chunk_idx]
curr_a1q_scale = _chunk_scales(a1q_scale, begin_chunk_idx,
end_chunk_idx)
curr_a2_scale = _chunk_scales(a2_scale, begin_chunk_idx,
end_chunk_idx)
curr_topk_ids = topk_ids[begin_chunk_idx:end_chunk_idx]
self.fused_experts.apply(
fused_out[begin_out_idx:end_out_idx],
curr_a1q,
w1,
w2,
curr_topk_ids,
activation=activation,
global_num_experts=global_num_experts,
expert_map=expert_map,
w1_scale=w1_scale,
w2_scale=w2_scale,
w1_zp=w1_zp,
w2_zp=w2_zp,
a1q_scale=curr_a1q_scale,
a2_scale=curr_a2_scale,
workspace13=workspace13,
workspace2=workspace2,
expert_tokens_meta=expert_tokens_meta,
)
fused_out = self._maybe_chunk_fused_experts(
a1=a1,
a1q=a1q,
w1=w1,
w2=w2,
topk_ids=topk_ids,
activation=activation,
global_num_experts=global_num_experts,
local_num_experts=local_num_experts,
expert_map=expert_map,
w1_scale=w1_scale,
w2_scale=w2_scale,
w1_zp=w1_zp,
w2_zp=w2_zp,
a1q_scale=a1q_scale,
a2_scale=a2_scale,
expert_tokens_meta=expert_tokens_meta)
self.prepare_finalize.finalize(output, fused_out, topk_weights,
topk_ids, apply_router_weight_on_input)