[Kernels] MoE refactor (#19636)
Signed-off-by: Bill Nell <bnell@redhat.com> Signed-off-by: ElizaWszola <ewszola@redhat.com> Co-authored-by: ElizaWszola <ewszola@redhat.com>
This commit is contained in:
@@ -18,18 +18,20 @@ try:
|
||||
except ImportError:
|
||||
has_pplx = False
|
||||
|
||||
from tests.kernels.moe.utils import make_test_weights, naive_batched_moe
|
||||
from tests.kernels.utils import torch_experts
|
||||
from vllm.config import VllmConfig, set_current_vllm_config
|
||||
from vllm.model_executor.layers.fused_moe import override_config
|
||||
from vllm.model_executor.layers.fused_moe import fused_topk, override_config
|
||||
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
|
||||
from vllm.model_executor.layers.fused_moe.fused_batched_moe import (
|
||||
BatchedExperts, BatchedPrepareAndFinalize, BatchedTritonExperts)
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe import (fused_topk,
|
||||
get_default_config)
|
||||
BatchedPrepareAndFinalize, BatchedTritonExperts, NaiveBatchedExperts)
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe import get_default_config
|
||||
from vllm.model_executor.layers.fused_moe.modular_kernel import (
|
||||
FusedMoEModularKernel)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import round_up
|
||||
|
||||
from .utils import ProcessGroupInfo, parallel_launch
|
||||
from .parallel_utils import ProcessGroupInfo, parallel_launch
|
||||
|
||||
requires_pplx = pytest.mark.skipif(
|
||||
not has_pplx,
|
||||
@@ -144,25 +146,6 @@ def torch_batched_moe(
|
||||
return torch_finalize(out, topk_weight, topk_ids)
|
||||
|
||||
|
||||
def batched_moe(
|
||||
a: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
topk_weight: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
num_experts = w1.shape[0]
|
||||
|
||||
fused_experts = FusedMoEModularKernel(
|
||||
BatchedPrepareAndFinalize(max_num_tokens=a.shape[0],
|
||||
world_size=1,
|
||||
dp_size=1,
|
||||
rank=0),
|
||||
BatchedExperts(max_num_tokens=a.shape[0], dp_size=1, world_size=1))
|
||||
|
||||
return fused_experts(a, w1, w2, topk_weight, topk_ids, num_experts)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("m", [1, 33, 64, 222])
|
||||
@pytest.mark.parametrize("n", [128, 1024, 2048])
|
||||
@pytest.mark.parametrize("k", [128, 512, 1024])
|
||||
@@ -188,7 +171,7 @@ def test_fused_moe_batched_experts(
|
||||
topk_weight, topk_ids, _ = fused_topk(a, score, topk, False)
|
||||
baseline_output = torch_experts(a, w1, w2, topk_weight, topk_ids)
|
||||
torch_output = torch_batched_moe(a, w1, w2, topk_weight, topk_ids)
|
||||
batched_output = batched_moe(a, w1, w2, topk_weight, topk_ids)
|
||||
batched_output = naive_batched_moe(a, w1, w2, topk_weight, topk_ids)
|
||||
|
||||
torch.testing.assert_close(baseline_output,
|
||||
torch_output,
|
||||
@@ -226,7 +209,6 @@ def pplx_prepare_finalize(
|
||||
|
||||
topk = topk_ids.shape[1]
|
||||
num_tokens, hidden_dim = a.shape
|
||||
block_size = 128
|
||||
device = pgi.device
|
||||
rank = pgi.rank
|
||||
world_size = pgi.world_size
|
||||
@@ -241,9 +223,7 @@ def pplx_prepare_finalize(
|
||||
dp_size=dp_size,
|
||||
hidden_dim=hidden_dim,
|
||||
hidden_dim_bytes=hidden_dim * a.dtype.itemsize,
|
||||
hidden_dim_scale_bytes=(0 if a.dtype.itemsize != 1 else
|
||||
((hidden_dim + block_size - 1) // block_size *
|
||||
torch.float32.itemsize)),
|
||||
hidden_dim_scale_bytes=0,
|
||||
)
|
||||
|
||||
if group_name is None:
|
||||
@@ -260,7 +240,6 @@ def pplx_prepare_finalize(
|
||||
world_size,
|
||||
rank,
|
||||
dp_size,
|
||||
a.dtype,
|
||||
)
|
||||
|
||||
a_chunk = chunk_by_rank(a, rank, world_size).to(device)
|
||||
@@ -276,6 +255,7 @@ def pplx_prepare_finalize(
|
||||
num_experts,
|
||||
None,
|
||||
False,
|
||||
FusedMoEQuantConfig(),
|
||||
)
|
||||
|
||||
b_a = b_a * 1.5
|
||||
@@ -350,6 +330,7 @@ def _pplx_prepare_finalize(
|
||||
# TODO (bnell): this test point does not work for odd M due to how the test is
|
||||
# written, not due to limitations of the pplx kernels. The pplx_moe
|
||||
# test below is able to deal with odd M.
|
||||
# TODO (bnell) add fp8 tests
|
||||
@pytest.mark.parametrize("mnk", PPLX_PREPARE_COMBOS)
|
||||
@pytest.mark.parametrize("e", NUM_EXPERTS)
|
||||
@pytest.mark.parametrize("topk", TOP_KS)
|
||||
@@ -386,18 +367,31 @@ def pplx_moe(
|
||||
w2: torch.Tensor,
|
||||
topk_weight: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
w1_scale: Optional[torch.Tensor] = None,
|
||||
w2_scale: Optional[torch.Tensor] = None,
|
||||
qtype: Optional[torch.dtype] = None,
|
||||
per_act_token_quant=False,
|
||||
block_shape: Optional[list[int]] = None,
|
||||
use_compile: bool = False,
|
||||
use_cudagraphs: bool = True,
|
||||
) -> torch.Tensor:
|
||||
from vllm.model_executor.layers.fused_moe.pplx_prepare_finalize import (
|
||||
PplxPrepareAndFinalize)
|
||||
PplxPrepareAndFinalize, pplx_hidden_dim_scale_bytes)
|
||||
|
||||
device = torch.device("cuda", rank)
|
||||
hidden_dim = a.shape[1]
|
||||
num_experts = w1.shape[0]
|
||||
block_size = 128
|
||||
topk = topk_ids.shape[1]
|
||||
max_num_tokens = rank_chunk(a.shape[0], 0, world_size)
|
||||
max_num_tokens = round_up(rank_chunk(a.shape[0], 0, world_size), 64)
|
||||
|
||||
hidden_dim_bytes, scale_bytes = pplx_hidden_dim_scale_bytes(
|
||||
max_num_tokens,
|
||||
hidden_dim,
|
||||
a.dtype,
|
||||
qtype,
|
||||
per_act_token_quant=per_act_token_quant,
|
||||
block_shape=block_shape,
|
||||
)
|
||||
|
||||
args = dict(
|
||||
max_num_tokens=max_num_tokens,
|
||||
@@ -407,10 +401,8 @@ def pplx_moe(
|
||||
world_size=world_size,
|
||||
dp_size=dp_size,
|
||||
hidden_dim=hidden_dim,
|
||||
hidden_dim_bytes=hidden_dim * a.dtype.itemsize,
|
||||
hidden_dim_scale_bytes=(0 if a.dtype.itemsize != 1 else
|
||||
((hidden_dim + block_size - 1) // block_size *
|
||||
torch.float32.itemsize)),
|
||||
hidden_dim_bytes=hidden_dim_bytes,
|
||||
hidden_dim_scale_bytes=scale_bytes,
|
||||
)
|
||||
|
||||
if group_name is None:
|
||||
@@ -429,9 +421,11 @@ def pplx_moe(
|
||||
dp_size,
|
||||
)
|
||||
|
||||
experts = BatchedTritonExperts(max_num_tokens=a.shape[0],
|
||||
experts = BatchedTritonExperts(max_num_tokens=max_num_tokens,
|
||||
world_size=world_size,
|
||||
dp_size=dp_size)
|
||||
dp_size=dp_size,
|
||||
use_fp8_w8a8=qtype == torch.float8_e4m3fn,
|
||||
block_shape=block_shape)
|
||||
|
||||
fused_experts = FusedMoEModularKernel(
|
||||
prepare_finalize,
|
||||
@@ -447,6 +441,13 @@ def pplx_moe(
|
||||
w1_chunk = chunk_by_rank(w1, rank, world_size).to(device)
|
||||
w2_chunk = chunk_by_rank(w2, rank, world_size).to(device)
|
||||
|
||||
if w1_scale is not None:
|
||||
w1_scale_chunk = chunk_by_rank(w1_scale, rank, world_size).to(device)
|
||||
w2_scale_chunk = chunk_by_rank(w2_scale, rank, world_size).to(device)
|
||||
else:
|
||||
w1_scale_chunk = None
|
||||
w2_scale_chunk = None
|
||||
|
||||
# Note: for now use_compile will error out if the problem size is
|
||||
# large enough to trigger chunking. I'm leaving the flag and
|
||||
# setup code in case we are able to revisit this later.
|
||||
@@ -465,6 +466,8 @@ def pplx_moe(
|
||||
w2_chunk,
|
||||
chunk_topk_weight,
|
||||
chunk_topk_ids,
|
||||
w1_scale=w1_scale_chunk,
|
||||
w2_scale=w2_scale_chunk,
|
||||
global_num_experts=num_experts)
|
||||
|
||||
if use_cudagraphs:
|
||||
@@ -477,6 +480,8 @@ def pplx_moe(
|
||||
w2_chunk,
|
||||
chunk_topk_weight,
|
||||
chunk_topk_ids,
|
||||
w1_scale=w1_scale_chunk,
|
||||
w2_scale=w2_scale_chunk,
|
||||
global_num_experts=num_experts)
|
||||
|
||||
torch.cuda.synchronize()
|
||||
@@ -505,9 +510,9 @@ def _batched_moe(pgi, dp_size, a, w1, w2, topk_weight, topk_ids):
|
||||
rank=rank,
|
||||
)
|
||||
|
||||
experts = BatchedExperts(max_num_tokens=a.shape[0],
|
||||
world_size=1,
|
||||
dp_size=1)
|
||||
experts = NaiveBatchedExperts(max_num_tokens=a.shape[0],
|
||||
world_size=1,
|
||||
dp_size=1)
|
||||
|
||||
fused_experts = FusedMoEModularKernel(
|
||||
prepare_finalize,
|
||||
@@ -539,7 +544,12 @@ def _pplx_moe(
|
||||
w2: torch.Tensor,
|
||||
score: torch.Tensor,
|
||||
topk: int,
|
||||
use_internode: bool,
|
||||
w1_s: Optional[torch.Tensor] = None,
|
||||
w2_s: Optional[torch.Tensor] = None,
|
||||
qtype: Optional[torch.dtype] = None,
|
||||
per_act_token_quant: bool = False,
|
||||
block_shape: Optional[list[int]] = None,
|
||||
use_internode: bool = False,
|
||||
):
|
||||
if use_internode:
|
||||
uid = nvshmem_get_unique_id(
|
||||
@@ -557,11 +567,28 @@ def _pplx_moe(
|
||||
|
||||
moe_config = get_default_config(m, e, n, k, topk, a.dtype, False)
|
||||
|
||||
device = torch.device("cuda", pgi.rank)
|
||||
a = a.to(device)
|
||||
w1 = w1.to(device)
|
||||
w2 = w2.to(device)
|
||||
w1_s = w1_s.to(device) if w1_s is not None else None
|
||||
w2_s = w2_s.to(device) if w2_s is not None else None
|
||||
|
||||
with set_current_vllm_config(vllm_config), override_config(moe_config):
|
||||
topk_weight, topk_ids, _ = fused_topk(a, score, topk, False)
|
||||
torch_output = torch_experts(a, w1, w2, topk_weight, topk_ids)
|
||||
torch_output = torch_experts(a,
|
||||
w1,
|
||||
w2,
|
||||
topk_weight,
|
||||
topk_ids,
|
||||
w1_scale=w1_s,
|
||||
w2_scale=w2_s,
|
||||
quant_dtype=qtype,
|
||||
per_act_token_quant=per_act_token_quant,
|
||||
block_shape=block_shape)
|
||||
pplx_output = pplx_moe(group_name, pgi.rank, pgi.world_size, dp_size,
|
||||
a, w1, w2, topk_weight, topk_ids)
|
||||
a, w1, w2, topk_weight, topk_ids, w1_s, w2_s,
|
||||
qtype, per_act_token_quant, block_shape)
|
||||
# TODO (bnell): fix + re-enable
|
||||
#batched_output = _batched_moe(pgi, dp_size, a, w1, w2, topk_weight,
|
||||
# topk_ids)
|
||||
@@ -581,6 +608,8 @@ def _pplx_moe(
|
||||
@pytest.mark.parametrize("topk", TOP_KS)
|
||||
@pytest.mark.parametrize("dtype", [torch.bfloat16])
|
||||
@pytest.mark.parametrize("world_dp_size", [[2, 1]])
|
||||
@pytest.mark.parametrize("per_act_token_quant", [False, True])
|
||||
@pytest.mark.parametrize("block_shape", [None, [128, 128]])
|
||||
@pytest.mark.parametrize("use_internode", [False])
|
||||
@requires_pplx
|
||||
def test_pplx_moe(
|
||||
@@ -589,15 +618,33 @@ def test_pplx_moe(
|
||||
topk: int,
|
||||
dtype: torch.dtype,
|
||||
world_dp_size: tuple[int, int],
|
||||
per_act_token_quant: bool,
|
||||
block_shape: Optional[list[int]],
|
||||
use_internode: bool,
|
||||
):
|
||||
current_platform.seed_everything(7)
|
||||
m, n, k = mnk
|
||||
world_size, dp_size = world_dp_size
|
||||
a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
|
||||
w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10
|
||||
w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10
|
||||
score = torch.randn((m, e), device="cuda", dtype=dtype)
|
||||
|
||||
if dtype == torch.float8_e4m3fn:
|
||||
use_fp8_w8a8 = True
|
||||
quant_dtype = dtype
|
||||
else:
|
||||
use_fp8_w8a8 = False
|
||||
quant_dtype = None
|
||||
|
||||
if not use_fp8_w8a8 and per_act_token_quant and block_shape is not None:
|
||||
pytest.skip("Skip quantization test for non-quantized type")
|
||||
|
||||
a = torch.randn((m, k), device="cuda", dtype=torch.bfloat16) / 10
|
||||
score = torch.randn((m, e), device="cuda", dtype=torch.bfloat16)
|
||||
|
||||
_, w1, w1_s, _, w2, w2_s = make_test_weights(e,
|
||||
n,
|
||||
k,
|
||||
quant_dtype=quant_dtype,
|
||||
block_shape=block_shape)
|
||||
|
||||
parallel_launch(world_size, _pplx_moe, dp_size, a, w1, w2, score, topk,
|
||||
w1_s, w2_s, quant_dtype, per_act_token_quant, block_shape,
|
||||
use_internode)
|
||||
|
||||
Reference in New Issue
Block a user