[Kernels] MoE refactor (#19636)
Signed-off-by: Bill Nell <bnell@redhat.com> Signed-off-by: ElizaWszola <ewszola@redhat.com> Co-authored-by: ElizaWszola <ewszola@redhat.com>
This commit is contained in:
@@ -23,7 +23,7 @@ from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import has_deep_ep
|
||||
|
||||
from .utils import ProcessGroupInfo, parallel_launch
|
||||
from .parallel_utils import ProcessGroupInfo, parallel_launch
|
||||
|
||||
if has_deep_ep():
|
||||
from vllm.model_executor.layers.fused_moe.deepep_ht_prepare_finalize import ( # noqa: E501
|
||||
@@ -31,7 +31,7 @@ if has_deep_ep():
|
||||
from vllm.model_executor.layers.fused_moe.deepep_ll_prepare_finalize import ( # noqa: E501
|
||||
DeepEPLLPrepareAndFinalize)
|
||||
|
||||
from .utils import DeepEPHTArgs, DeepEPLLArgs, make_deepep_a2a
|
||||
from .parallel_utils import DeepEPHTArgs, DeepEPLLArgs, make_deepep_a2a
|
||||
|
||||
requires_deep_ep = pytest.mark.skipif(
|
||||
not has_deep_ep(),
|
||||
@@ -102,10 +102,6 @@ class TestTensors:
|
||||
rank_tokens = torch.randn(
|
||||
(config.m, config.k), device="cuda", dtype=token_dtype) / 10
|
||||
rank_token_scales = None
|
||||
if config.dtype == torch.float8_e4m3fn:
|
||||
# low_latency_mode kernels dont support per-token quant.
|
||||
_, rank_token_scales = ops.scaled_fp8_quant(
|
||||
rank_tokens, use_per_token_if_dynamic=not low_latency_mode)
|
||||
|
||||
topk = torch.randint(low=0,
|
||||
high=config.num_experts,
|
||||
@@ -121,11 +117,18 @@ class TestTensors:
|
||||
config=config)
|
||||
|
||||
|
||||
def make_modular_kernel(pg: ProcessGroup, pgi: ProcessGroupInfo,
|
||||
low_latency_mode: bool, hidden_size: int, dp_size: int,
|
||||
num_experts: int, num_local_experts: int,
|
||||
q_dtype: Optional[torch.dtype],
|
||||
use_fp8_dispatch: bool) -> FusedMoEModularKernel:
|
||||
def make_modular_kernel(
|
||||
pg: ProcessGroup,
|
||||
pgi: ProcessGroupInfo,
|
||||
low_latency_mode: bool,
|
||||
hidden_size: int,
|
||||
dp_size: int,
|
||||
num_experts: int,
|
||||
num_local_experts: int,
|
||||
q_dtype: Optional[torch.dtype],
|
||||
use_fp8_dispatch: bool,
|
||||
per_act_token_quant: bool,
|
||||
) -> FusedMoEModularKernel:
|
||||
|
||||
is_quantized = q_dtype is not None
|
||||
|
||||
@@ -152,6 +155,7 @@ def make_modular_kernel(pg: ProcessGroup, pgi: ProcessGroupInfo,
|
||||
deepep_ll_args = ll_args)
|
||||
|
||||
if low_latency_mode:
|
||||
assert not per_act_token_quant, "not supported in ll mode"
|
||||
fused_experts = BatchedTritonExperts(
|
||||
max_num_tokens=MAX_TOKENS_PER_RANK,
|
||||
world_size=pgi.world_size,
|
||||
@@ -159,25 +163,37 @@ def make_modular_kernel(pg: ProcessGroup, pgi: ProcessGroupInfo,
|
||||
use_fp8_w8a8=is_quantized,
|
||||
use_int8_w8a8=False,
|
||||
use_int8_w8a16=False,
|
||||
use_int4_w4a16=False)
|
||||
use_int4_w4a16=False,
|
||||
per_act_token_quant=False,
|
||||
)
|
||||
else:
|
||||
fused_experts = TritonExperts(use_fp8_w8a8=is_quantized,
|
||||
use_int8_w8a8=False,
|
||||
use_int8_w8a16=False,
|
||||
use_int4_w4a16=False,
|
||||
per_channel_quant=False)
|
||||
fused_experts = TritonExperts(
|
||||
use_fp8_w8a8=is_quantized,
|
||||
use_int8_w8a8=False,
|
||||
use_int8_w8a16=False,
|
||||
use_int4_w4a16=False,
|
||||
per_act_token_quant=per_act_token_quant,
|
||||
)
|
||||
|
||||
mk = FusedMoEModularKernel(prepare_finalize=a2a,
|
||||
fused_experts=fused_experts)
|
||||
return mk
|
||||
|
||||
|
||||
def deep_ep_moe_impl(pg: ProcessGroup, pgi: ProcessGroupInfo,
|
||||
low_latency_mode: bool, dp_size: int,
|
||||
test_tensors: TestTensors, w1: torch.Tensor,
|
||||
w2: torch.Tensor, w1_scale: Optional[torch.Tensor],
|
||||
w2_scale: Optional[torch.Tensor], num_experts: int,
|
||||
use_fp8_dispatch: bool) -> torch.Tensor:
|
||||
def deep_ep_moe_impl(
|
||||
pg: ProcessGroup,
|
||||
pgi: ProcessGroupInfo,
|
||||
low_latency_mode: bool,
|
||||
dp_size: int,
|
||||
test_tensors: TestTensors,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
w1_scale: Optional[torch.Tensor],
|
||||
w2_scale: Optional[torch.Tensor],
|
||||
num_experts: int,
|
||||
use_fp8_dispatch: bool,
|
||||
per_act_token_quant: bool,
|
||||
) -> torch.Tensor:
|
||||
|
||||
num_local_experts = w1.size(0)
|
||||
|
||||
@@ -199,11 +215,9 @@ def deep_ep_moe_impl(pg: ProcessGroup, pgi: ProcessGroupInfo,
|
||||
q_dtype = torch.float8_e4m3fn
|
||||
|
||||
# Make modular kernel
|
||||
mk: FusedMoEModularKernel = make_modular_kernel(pg, pgi, low_latency_mode,
|
||||
hidden_size, dp_size,
|
||||
num_experts,
|
||||
num_local_experts, q_dtype,
|
||||
use_fp8_dispatch)
|
||||
mk: FusedMoEModularKernel = make_modular_kernel(
|
||||
pg, pgi, low_latency_mode, hidden_size, dp_size, num_experts,
|
||||
num_local_experts, q_dtype, use_fp8_dispatch, per_act_token_quant)
|
||||
|
||||
out_hidden_states = torch.empty_like(test_tensors.rank_tokens)
|
||||
total_num_tokens = test_tensors.rank_tokens.size(0)
|
||||
@@ -257,9 +271,15 @@ def deep_ep_moe_impl(pg: ProcessGroup, pgi: ProcessGroupInfo,
|
||||
return out_hidden_states
|
||||
|
||||
|
||||
def torch_moe_impl(test_tensors: TestTensors, w1: torch.Tensor,
|
||||
w2: torch.Tensor, w1_scale: Optional[torch.Tensor],
|
||||
w2_scale: Optional[torch.Tensor], using_fp8_dispatch: bool):
|
||||
def torch_moe_impl(
|
||||
test_tensors: TestTensors,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
w1_scale: Optional[torch.Tensor],
|
||||
w2_scale: Optional[torch.Tensor],
|
||||
using_fp8_dispatch: bool,
|
||||
per_act_token_quant: bool,
|
||||
):
|
||||
|
||||
a, topk_ids, topk_weights = (test_tensors.rank_tokens, test_tensors.topk,
|
||||
test_tensors.topk_weights)
|
||||
@@ -267,6 +287,7 @@ def torch_moe_impl(test_tensors: TestTensors, w1: torch.Tensor,
|
||||
# The DeepEP implementation is requested to dispatch using FP8.
|
||||
# For numerical stability for testing, emulate the fp8 dispatch by
|
||||
# blockwise quant and de-quant.
|
||||
assert not per_act_token_quant
|
||||
a = test_tensors.rank_tokens
|
||||
aq, aq_scale = per_token_group_quant_fp8(a, 128)
|
||||
a = (aq.view(-1, 128).to(torch.float32) * aq_scale.view(-1, 1)).view(
|
||||
@@ -310,6 +331,7 @@ def _deep_ep_moe(
|
||||
w1_scale: Optional[torch.Tensor],
|
||||
w2_scale: Optional[torch.Tensor],
|
||||
use_fp8_dispatch: bool,
|
||||
per_act_token_quant: bool,
|
||||
):
|
||||
|
||||
if not low_latency_mode:
|
||||
@@ -331,7 +353,8 @@ def _deep_ep_moe(
|
||||
with set_current_vllm_config(VllmConfig()):
|
||||
# Reference
|
||||
torch_combined = torch_moe_impl(test_tensors, w1, w2, w1_scale,
|
||||
w2_scale, use_fp8_dispatch)
|
||||
w2_scale, use_fp8_dispatch,
|
||||
per_act_token_quant)
|
||||
|
||||
# Splice experts for this rank.
|
||||
num_local_experts = config.num_experts // pgi.world_size
|
||||
@@ -356,6 +379,7 @@ def _deep_ep_moe(
|
||||
w2_scale_ep,
|
||||
config.num_experts,
|
||||
use_fp8_dispatch,
|
||||
per_act_token_quant,
|
||||
)
|
||||
|
||||
torch.testing.assert_close(
|
||||
@@ -384,10 +408,16 @@ DTYPES = [torch.bfloat16, torch.float8_e4m3fn]
|
||||
@pytest.mark.parametrize("num_experts", [32])
|
||||
@pytest.mark.parametrize("topk", [6])
|
||||
@pytest.mark.parametrize("world_dp_size", [(2, 1)])
|
||||
@pytest.mark.parametrize("per_act_token_quant", [False, True])
|
||||
@requires_deep_ep
|
||||
def test_deep_ep_moe(dtype: torch.dtype, mnk: tuple[int, int, int],
|
||||
num_experts: int, topk: int, world_dp_size: tuple[int,
|
||||
int]):
|
||||
def test_deep_ep_moe(
|
||||
dtype: torch.dtype,
|
||||
mnk: tuple[int, int, int],
|
||||
num_experts: int,
|
||||
topk: int,
|
||||
world_dp_size: tuple[int, int],
|
||||
per_act_token_quant: bool,
|
||||
):
|
||||
low_latency_mode = False
|
||||
use_fp8_dispatch = False
|
||||
m, n, k = mnk
|
||||
@@ -404,7 +434,8 @@ def test_deep_ep_moe(dtype: torch.dtype, mnk: tuple[int, int, int],
|
||||
w1, w2, w1_scale, w2_scale = make_weights(num_experts, n, k, dtype)
|
||||
|
||||
parallel_launch(world_size, _deep_ep_moe, low_latency_mode, dp_size,
|
||||
config, w1, w2, w1_scale, w2_scale, use_fp8_dispatch)
|
||||
config, w1, w2, w1_scale, w2_scale, use_fp8_dispatch,
|
||||
per_act_token_quant)
|
||||
|
||||
|
||||
MNKs = [
|
||||
@@ -454,4 +485,5 @@ def test_low_latency_deep_ep_moe(dtype: torch.dtype, mnk: tuple[int, int, int],
|
||||
w1, w2, w1_scale, w2_scale = make_weights(num_experts, n, k, dtype)
|
||||
|
||||
parallel_launch(world_size, _deep_ep_moe, low_latency_mode, dp_size,
|
||||
config, w1, w2, w1_scale, w2_scale, use_fp8_dispatch)
|
||||
config, w1, w2, w1_scale, w2_scale, use_fp8_dispatch,
|
||||
False)
|
||||
|
||||
Reference in New Issue
Block a user