[Kernels] Overlap shared experts with send/recv (#23273)

Signed-off-by: Bill Nell <bnell@redhat.com>
This commit is contained in:
bnellnm
2025-09-03 12:35:18 -04:00
committed by GitHub
parent fa4311d85f
commit e9b92dcd89
32 changed files with 885 additions and 227 deletions

View File

@@ -4,10 +4,11 @@
Run `pytest tests/kernels/test_pplx_moe.py`.
"""
import copy
import itertools
import textwrap
import traceback
from typing import Callable, Optional
from typing import Callable, Optional, Union
import pytest
import torch
@@ -21,7 +22,10 @@ try:
except ImportError:
has_pplx = False
from tests.kernels.moe.utils import make_test_weights, naive_batched_moe
from tests.kernels.moe.modular_kernel_tools.parallel_utils import (
_set_vllm_config)
from tests.kernels.moe.utils import (make_shared_experts, make_test_weights,
naive_batched_moe)
from tests.kernels.quant_utils import dequant
from tests.kernels.utils import torch_experts
from vllm.config import VllmConfig, set_current_vllm_config
@@ -511,7 +515,8 @@ def pplx_moe(
block_shape: Optional[list[int]] = None,
use_compile: bool = False,
use_cudagraphs: bool = True,
) -> torch.Tensor:
shared_experts: Optional[torch.nn.Module] = None,
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
num_tokens, hidden_dim = a.shape
num_experts = w1.shape[0]
@@ -546,6 +551,7 @@ def pplx_moe(
fused_experts = FusedMoEModularKernel(
prepare_finalize,
experts,
shared_experts,
)
# Note: workers with the same dp_rank must use the exact same inputs.
@@ -586,7 +592,11 @@ def pplx_moe(
global_num_experts=num_experts)
if use_cudagraphs:
out.fill_(0)
if isinstance(out, tuple):
out[0].fill_(0)
out[1].fill_(0)
else:
out.fill_(0)
stream = torch.cuda.Stream()
graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(graph, stream=stream):
@@ -626,6 +636,7 @@ def _pplx_moe(
per_act_token_quant: bool = False,
block_shape: Optional[list[int]] = None,
use_internode: bool = False,
shared_experts: Optional[torch.nn.Module] = None,
):
try:
if use_internode:
@@ -666,6 +677,11 @@ def _pplx_moe(
with set_current_vllm_config(vllm_config), override_config(moe_config):
topk_weight, topk_ids, _ = fused_topk(a, score, topk, False)
if shared_experts is not None:
shared_output = shared_experts(a)
else:
shared_output = None
torch_output = torch_experts(
a,
w1,
@@ -696,7 +712,7 @@ def _pplx_moe(
block_shape=block_shape,
)
pplx_output = pplx_moe(
pplx_outputs = pplx_moe(
group_name,
rank,
world_size,
@@ -713,8 +729,24 @@ def _pplx_moe(
quant_dtype=quant_dtype,
per_act_token_quant=per_act_token_quant,
block_shape=block_shape,
shared_experts=shared_experts,
)
if shared_experts is None:
pplx_shared_output = None
pplx_output = pplx_outputs
assert isinstance(pplx_output, torch.Tensor)
else:
pplx_shared_output, pplx_output = pplx_outputs
if shared_output is not None:
assert pplx_shared_output is not None
chunked_shared_output = chunk_by_rank(
shared_output, pgi.rank,
pgi.world_size).to(pplx_shared_output.device)
else:
chunked_shared_output = None
chunked_batch_output = chunk_by_rank(
batched_output, pgi.rank, pgi.world_size).to(pplx_output.device)
@@ -727,6 +759,15 @@ def _pplx_moe(
chunked_batch_output,
atol=3e-2,
rtol=3e-2)
if shared_experts is not None:
assert chunked_shared_output is not None
assert pplx_shared_output is not None
torch.testing.assert_close(pplx_shared_output,
chunked_shared_output,
atol=3e-2,
rtol=3e-2)
finally:
if use_internode:
nvshmem_finalize()
@@ -788,7 +829,8 @@ def test_pplx_moe_slow(
def _pplx_test_loop(pgi: ProcessGroupInfo, dp_size: int, use_internode: bool,
make_weights: bool, test_fn: Callable):
use_shared_experts: bool, make_weights: bool,
test_fn: Callable):
def format_result(msg, ex=None):
if ex is not None:
@@ -803,6 +845,14 @@ def _pplx_test_loop(pgi: ProcessGroupInfo, dp_size: int, use_internode: bool,
else:
print(f"PASSED {msg}")
if use_shared_experts:
# Note: this config is only needed for the non-naive shared experts.
new_vllm_config = copy.deepcopy(vllm_config)
new_vllm_config.parallel_config.data_parallel_size = pgi.world_size
new_vllm_config.parallel_config.enable_expert_parallel = True
_set_vllm_config(new_vllm_config, pgi.world_size, pgi.rank,
pgi.local_rank)
current_platform.seed_everything(7)
combos = itertools.product(PPLX_COMBOS, NUM_EXPERTS, TOP_KS, DTYPES,
[False, True], [None, [128, 128]])
@@ -819,9 +869,11 @@ def _pplx_test_loop(pgi: ProcessGroupInfo, dp_size: int, use_internode: bool,
use_fp8_w8a8 = False
quant_dtype = None
test_desc = (f"test_pplx_moe[mnk={mnk}, e={e}, topk={topk}, "
f"dtype={dtype}, per_act_token={per_act_token_quant}, "
f"block_shape={block_shape}")
test_desc = (
f"test_pplx_moe[mnk={mnk}, e={e}, topk={topk}, "
f"dtype={dtype}, per_act_token={per_act_token_quant}, "
f"block_shape={block_shape}, use_internode={use_internode}, "
f"use_shared_experts={use_shared_experts}")
if not use_fp8_w8a8 and (per_act_token_quant
or block_shape is not None):
@@ -852,6 +904,14 @@ def _pplx_test_loop(pgi: ProcessGroupInfo, dp_size: int, use_internode: bool,
args["w1_s"] = w1_s
args["w2_s"] = w2_s
if use_shared_experts:
args["shared_experts"] = make_shared_experts(
n,
k,
in_dtype=a.dtype,
quant_dtype=quant_dtype,
)
try:
test_fn(
pgi=pgi,
@@ -891,18 +951,20 @@ def test_pplx_prepare_finalize(
current_platform.seed_everything(7)
world_size, dp_size = world_dp_size
parallel_launch(world_size * dp_size, _pplx_test_loop, dp_size,
use_internode, False, _pplx_prepare_finalize)
use_internode, False, False, _pplx_prepare_finalize)
@pytest.mark.parametrize("world_dp_size", [[2, 1]])
@pytest.mark.parametrize("use_internode", [False])
@pytest.mark.parametrize("use_shared_experts", [False, True])
@requires_pplx
@multi_gpu_test(num_gpus=2)
def test_pplx_moe(
world_dp_size: tuple[int, int],
use_internode: bool,
use_shared_experts: bool,
):
current_platform.seed_everything(7)
world_size, dp_size = world_dp_size
parallel_launch(world_size, _pplx_test_loop, dp_size, use_internode, True,
_pplx_moe)
parallel_launch(world_size, _pplx_test_loop, dp_size, use_internode,
use_shared_experts, True, _pplx_moe)