[Kernels] Overlap shared experts with send/recv (#23273)
Signed-off-by: Bill Nell <bnell@redhat.com>
This commit is contained in:
@@ -4,10 +4,11 @@
|
||||
|
||||
Run `pytest tests/kernels/test_pplx_moe.py`.
|
||||
"""
|
||||
import copy
|
||||
import itertools
|
||||
import textwrap
|
||||
import traceback
|
||||
from typing import Callable, Optional
|
||||
from typing import Callable, Optional, Union
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
@@ -21,7 +22,10 @@ try:
|
||||
except ImportError:
|
||||
has_pplx = False
|
||||
|
||||
from tests.kernels.moe.utils import make_test_weights, naive_batched_moe
|
||||
from tests.kernels.moe.modular_kernel_tools.parallel_utils import (
|
||||
_set_vllm_config)
|
||||
from tests.kernels.moe.utils import (make_shared_experts, make_test_weights,
|
||||
naive_batched_moe)
|
||||
from tests.kernels.quant_utils import dequant
|
||||
from tests.kernels.utils import torch_experts
|
||||
from vllm.config import VllmConfig, set_current_vllm_config
|
||||
@@ -511,7 +515,8 @@ def pplx_moe(
|
||||
block_shape: Optional[list[int]] = None,
|
||||
use_compile: bool = False,
|
||||
use_cudagraphs: bool = True,
|
||||
) -> torch.Tensor:
|
||||
shared_experts: Optional[torch.nn.Module] = None,
|
||||
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
|
||||
|
||||
num_tokens, hidden_dim = a.shape
|
||||
num_experts = w1.shape[0]
|
||||
@@ -546,6 +551,7 @@ def pplx_moe(
|
||||
fused_experts = FusedMoEModularKernel(
|
||||
prepare_finalize,
|
||||
experts,
|
||||
shared_experts,
|
||||
)
|
||||
|
||||
# Note: workers with the same dp_rank must use the exact same inputs.
|
||||
@@ -586,7 +592,11 @@ def pplx_moe(
|
||||
global_num_experts=num_experts)
|
||||
|
||||
if use_cudagraphs:
|
||||
out.fill_(0)
|
||||
if isinstance(out, tuple):
|
||||
out[0].fill_(0)
|
||||
out[1].fill_(0)
|
||||
else:
|
||||
out.fill_(0)
|
||||
stream = torch.cuda.Stream()
|
||||
graph = torch.cuda.CUDAGraph()
|
||||
with torch.cuda.graph(graph, stream=stream):
|
||||
@@ -626,6 +636,7 @@ def _pplx_moe(
|
||||
per_act_token_quant: bool = False,
|
||||
block_shape: Optional[list[int]] = None,
|
||||
use_internode: bool = False,
|
||||
shared_experts: Optional[torch.nn.Module] = None,
|
||||
):
|
||||
try:
|
||||
if use_internode:
|
||||
@@ -666,6 +677,11 @@ def _pplx_moe(
|
||||
with set_current_vllm_config(vllm_config), override_config(moe_config):
|
||||
topk_weight, topk_ids, _ = fused_topk(a, score, topk, False)
|
||||
|
||||
if shared_experts is not None:
|
||||
shared_output = shared_experts(a)
|
||||
else:
|
||||
shared_output = None
|
||||
|
||||
torch_output = torch_experts(
|
||||
a,
|
||||
w1,
|
||||
@@ -696,7 +712,7 @@ def _pplx_moe(
|
||||
block_shape=block_shape,
|
||||
)
|
||||
|
||||
pplx_output = pplx_moe(
|
||||
pplx_outputs = pplx_moe(
|
||||
group_name,
|
||||
rank,
|
||||
world_size,
|
||||
@@ -713,8 +729,24 @@ def _pplx_moe(
|
||||
quant_dtype=quant_dtype,
|
||||
per_act_token_quant=per_act_token_quant,
|
||||
block_shape=block_shape,
|
||||
shared_experts=shared_experts,
|
||||
)
|
||||
|
||||
if shared_experts is None:
|
||||
pplx_shared_output = None
|
||||
pplx_output = pplx_outputs
|
||||
assert isinstance(pplx_output, torch.Tensor)
|
||||
else:
|
||||
pplx_shared_output, pplx_output = pplx_outputs
|
||||
|
||||
if shared_output is not None:
|
||||
assert pplx_shared_output is not None
|
||||
chunked_shared_output = chunk_by_rank(
|
||||
shared_output, pgi.rank,
|
||||
pgi.world_size).to(pplx_shared_output.device)
|
||||
else:
|
||||
chunked_shared_output = None
|
||||
|
||||
chunked_batch_output = chunk_by_rank(
|
||||
batched_output, pgi.rank, pgi.world_size).to(pplx_output.device)
|
||||
|
||||
@@ -727,6 +759,15 @@ def _pplx_moe(
|
||||
chunked_batch_output,
|
||||
atol=3e-2,
|
||||
rtol=3e-2)
|
||||
|
||||
if shared_experts is not None:
|
||||
assert chunked_shared_output is not None
|
||||
assert pplx_shared_output is not None
|
||||
torch.testing.assert_close(pplx_shared_output,
|
||||
chunked_shared_output,
|
||||
atol=3e-2,
|
||||
rtol=3e-2)
|
||||
|
||||
finally:
|
||||
if use_internode:
|
||||
nvshmem_finalize()
|
||||
@@ -788,7 +829,8 @@ def test_pplx_moe_slow(
|
||||
|
||||
|
||||
def _pplx_test_loop(pgi: ProcessGroupInfo, dp_size: int, use_internode: bool,
|
||||
make_weights: bool, test_fn: Callable):
|
||||
use_shared_experts: bool, make_weights: bool,
|
||||
test_fn: Callable):
|
||||
|
||||
def format_result(msg, ex=None):
|
||||
if ex is not None:
|
||||
@@ -803,6 +845,14 @@ def _pplx_test_loop(pgi: ProcessGroupInfo, dp_size: int, use_internode: bool,
|
||||
else:
|
||||
print(f"PASSED {msg}")
|
||||
|
||||
if use_shared_experts:
|
||||
# Note: this config is only needed for the non-naive shared experts.
|
||||
new_vllm_config = copy.deepcopy(vllm_config)
|
||||
new_vllm_config.parallel_config.data_parallel_size = pgi.world_size
|
||||
new_vllm_config.parallel_config.enable_expert_parallel = True
|
||||
_set_vllm_config(new_vllm_config, pgi.world_size, pgi.rank,
|
||||
pgi.local_rank)
|
||||
|
||||
current_platform.seed_everything(7)
|
||||
combos = itertools.product(PPLX_COMBOS, NUM_EXPERTS, TOP_KS, DTYPES,
|
||||
[False, True], [None, [128, 128]])
|
||||
@@ -819,9 +869,11 @@ def _pplx_test_loop(pgi: ProcessGroupInfo, dp_size: int, use_internode: bool,
|
||||
use_fp8_w8a8 = False
|
||||
quant_dtype = None
|
||||
|
||||
test_desc = (f"test_pplx_moe[mnk={mnk}, e={e}, topk={topk}, "
|
||||
f"dtype={dtype}, per_act_token={per_act_token_quant}, "
|
||||
f"block_shape={block_shape}")
|
||||
test_desc = (
|
||||
f"test_pplx_moe[mnk={mnk}, e={e}, topk={topk}, "
|
||||
f"dtype={dtype}, per_act_token={per_act_token_quant}, "
|
||||
f"block_shape={block_shape}, use_internode={use_internode}, "
|
||||
f"use_shared_experts={use_shared_experts}")
|
||||
|
||||
if not use_fp8_w8a8 and (per_act_token_quant
|
||||
or block_shape is not None):
|
||||
@@ -852,6 +904,14 @@ def _pplx_test_loop(pgi: ProcessGroupInfo, dp_size: int, use_internode: bool,
|
||||
args["w1_s"] = w1_s
|
||||
args["w2_s"] = w2_s
|
||||
|
||||
if use_shared_experts:
|
||||
args["shared_experts"] = make_shared_experts(
|
||||
n,
|
||||
k,
|
||||
in_dtype=a.dtype,
|
||||
quant_dtype=quant_dtype,
|
||||
)
|
||||
|
||||
try:
|
||||
test_fn(
|
||||
pgi=pgi,
|
||||
@@ -891,18 +951,20 @@ def test_pplx_prepare_finalize(
|
||||
current_platform.seed_everything(7)
|
||||
world_size, dp_size = world_dp_size
|
||||
parallel_launch(world_size * dp_size, _pplx_test_loop, dp_size,
|
||||
use_internode, False, _pplx_prepare_finalize)
|
||||
use_internode, False, False, _pplx_prepare_finalize)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("world_dp_size", [[2, 1]])
|
||||
@pytest.mark.parametrize("use_internode", [False])
|
||||
@pytest.mark.parametrize("use_shared_experts", [False, True])
|
||||
@requires_pplx
|
||||
@multi_gpu_test(num_gpus=2)
|
||||
def test_pplx_moe(
|
||||
world_dp_size: tuple[int, int],
|
||||
use_internode: bool,
|
||||
use_shared_experts: bool,
|
||||
):
|
||||
current_platform.seed_everything(7)
|
||||
world_size, dp_size = world_dp_size
|
||||
parallel_launch(world_size, _pplx_test_loop, dp_size, use_internode, True,
|
||||
_pplx_moe)
|
||||
parallel_launch(world_size, _pplx_test_loop, dp_size, use_internode,
|
||||
use_shared_experts, True, _pplx_moe)
|
||||
|
||||
Reference in New Issue
Block a user