[Kernels][Bugfix] Use torch op for all kernels in FusedMoE forward. Add additional testing for cudagraphs. (#19717)

Signed-off-by: Bill Nell <bnell@redhat.com>
This commit is contained in:
bnellnm
2025-06-25 02:22:58 -04:00
committed by GitHub
parent f59fc60fb3
commit 015fab8c2f
14 changed files with 379 additions and 238 deletions

View File

@@ -18,8 +18,8 @@ try:
except ImportError:
has_pplx = False
from tests.kernels.utils import torch_experts
from vllm.config import VllmConfig, set_current_vllm_config
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.fused_moe import override_config
from vllm.model_executor.layers.fused_moe.fused_batched_moe import (
BatchedExperts, BatchedPrepareAndFinalize, BatchedTritonExperts)
@@ -163,29 +163,6 @@ def batched_moe(
return fused_experts(a, w1, w2, topk_weight, topk_ids, num_experts)
# Note: same as torch_moe but with fused_topk factored out.
def torch_moe2(
a: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_weight: torch.Tensor,
topk_ids: torch.Tensor,
) -> torch.Tensor:
M, K = a.shape
topk = topk_ids.shape[1]
a = a.view(M, -1, K).repeat(1, topk, 1).reshape(-1, K)
out = torch.zeros(M * topk, w2.shape[1], dtype=a.dtype, device=a.device)
num_experts = w1.shape[0]
for i in range(num_experts):
mask = (topk_ids == i).view(-1)
if mask.sum():
out[mask] = SiluAndMul()(
a[mask] @ w1[i].transpose(0, 1)) @ w2[i].transpose(0, 1)
return (out.view(M, -1, w2.shape[1]) *
topk_weight.view(M, -1, 1).to(out.dtype)).sum(dim=1)
@pytest.mark.parametrize("m", [1, 33, 64, 222])
@pytest.mark.parametrize("n", [128, 1024, 2048])
@pytest.mark.parametrize("k", [128, 512, 1024])
@@ -209,7 +186,7 @@ def test_fused_moe_batched_experts(
with set_current_vllm_config(vllm_config):
topk_weight, topk_ids, _ = fused_topk(a, score, topk, False)
baseline_output = torch_moe2(a, w1, w2, topk_weight, topk_ids)
baseline_output = torch_experts(a, w1, w2, topk_weight, topk_ids)
torch_output = torch_batched_moe(a, w1, w2, topk_weight, topk_ids)
batched_output = batched_moe(a, w1, w2, topk_weight, topk_ids)
@@ -409,7 +386,7 @@ def pplx_moe(
w2: torch.Tensor,
topk_weight: torch.Tensor,
topk_ids: torch.Tensor,
use_compile: bool = True,
use_compile: bool = False,
use_cudagraphs: bool = True,
) -> torch.Tensor:
from vllm.model_executor.layers.fused_moe.pplx_prepare_finalize import (
@@ -470,10 +447,16 @@ def pplx_moe(
w1_chunk = chunk_by_rank(w1, rank, world_size).to(device)
w2_chunk = chunk_by_rank(w2, rank, world_size).to(device)
# Note: for now use_compile will error out if the problem size is
# large enough to trigger chunking. I'm leaving the flag and
# setup code in case we are able to revisit this later.
if use_compile:
_fused_experts = torch.compile(fused_experts,
backend='inductor',
fullgraph=True)
torch._dynamo.mark_dynamic(a_chunk, 0)
torch._dynamo.mark_dynamic(chunk_topk_weight, 0)
torch._dynamo.mark_dynamic(chunk_topk_ids, 0)
else:
_fused_experts = fused_experts
@@ -576,7 +559,7 @@ def _pplx_moe(
with set_current_vllm_config(vllm_config), override_config(moe_config):
topk_weight, topk_ids, _ = fused_topk(a, score, topk, False)
torch_output = torch_moe2(a, w1, w2, topk_weight, topk_ids)
torch_output = torch_experts(a, w1, w2, topk_weight, topk_ids)
pplx_output = pplx_moe(group_name, pgi.rank, pgi.world_size, dp_size,
a, w1, w2, topk_weight, topk_ids)
# TODO (bnell): fix + re-enable