[Kernels] MoE refactor (#19636)

Signed-off-by: Bill Nell <bnell@redhat.com>
Signed-off-by: ElizaWszola <ewszola@redhat.com>
Co-authored-by: ElizaWszola <ewszola@redhat.com>
This commit is contained in:
bnellnm
2025-07-02 09:08:27 -04:00
committed by GitHub
parent b95877509b
commit c1909e7e8c
36 changed files with 2698 additions and 1584 deletions

View File

@@ -23,7 +23,7 @@ from vllm.model_executor.layers.quantization.utils.fp8_utils import (
from vllm.platforms import current_platform
from vllm.utils import has_deep_ep
from .utils import ProcessGroupInfo, parallel_launch
from .parallel_utils import ProcessGroupInfo, parallel_launch
if has_deep_ep():
from vllm.model_executor.layers.fused_moe.deepep_ht_prepare_finalize import ( # noqa: E501
@@ -31,7 +31,7 @@ if has_deep_ep():
from vllm.model_executor.layers.fused_moe.deepep_ll_prepare_finalize import ( # noqa: E501
DeepEPLLPrepareAndFinalize)
from .utils import DeepEPHTArgs, DeepEPLLArgs, make_deepep_a2a
from .parallel_utils import DeepEPHTArgs, DeepEPLLArgs, make_deepep_a2a
requires_deep_ep = pytest.mark.skipif(
not has_deep_ep(),
@@ -102,10 +102,6 @@ class TestTensors:
rank_tokens = torch.randn(
(config.m, config.k), device="cuda", dtype=token_dtype) / 10
rank_token_scales = None
if config.dtype == torch.float8_e4m3fn:
# low_latency_mode kernels dont support per-token quant.
_, rank_token_scales = ops.scaled_fp8_quant(
rank_tokens, use_per_token_if_dynamic=not low_latency_mode)
topk = torch.randint(low=0,
high=config.num_experts,
@@ -121,11 +117,18 @@ class TestTensors:
config=config)
def make_modular_kernel(pg: ProcessGroup, pgi: ProcessGroupInfo,
low_latency_mode: bool, hidden_size: int, dp_size: int,
num_experts: int, num_local_experts: int,
q_dtype: Optional[torch.dtype],
use_fp8_dispatch: bool) -> FusedMoEModularKernel:
def make_modular_kernel(
pg: ProcessGroup,
pgi: ProcessGroupInfo,
low_latency_mode: bool,
hidden_size: int,
dp_size: int,
num_experts: int,
num_local_experts: int,
q_dtype: Optional[torch.dtype],
use_fp8_dispatch: bool,
per_act_token_quant: bool,
) -> FusedMoEModularKernel:
is_quantized = q_dtype is not None
@@ -152,6 +155,7 @@ def make_modular_kernel(pg: ProcessGroup, pgi: ProcessGroupInfo,
deepep_ll_args = ll_args)
if low_latency_mode:
assert not per_act_token_quant, "not supported in ll mode"
fused_experts = BatchedTritonExperts(
max_num_tokens=MAX_TOKENS_PER_RANK,
world_size=pgi.world_size,
@@ -159,25 +163,37 @@ def make_modular_kernel(pg: ProcessGroup, pgi: ProcessGroupInfo,
use_fp8_w8a8=is_quantized,
use_int8_w8a8=False,
use_int8_w8a16=False,
use_int4_w4a16=False)
use_int4_w4a16=False,
per_act_token_quant=False,
)
else:
fused_experts = TritonExperts(use_fp8_w8a8=is_quantized,
use_int8_w8a8=False,
use_int8_w8a16=False,
use_int4_w4a16=False,
per_channel_quant=False)
fused_experts = TritonExperts(
use_fp8_w8a8=is_quantized,
use_int8_w8a8=False,
use_int8_w8a16=False,
use_int4_w4a16=False,
per_act_token_quant=per_act_token_quant,
)
mk = FusedMoEModularKernel(prepare_finalize=a2a,
fused_experts=fused_experts)
return mk
def deep_ep_moe_impl(pg: ProcessGroup, pgi: ProcessGroupInfo,
low_latency_mode: bool, dp_size: int,
test_tensors: TestTensors, w1: torch.Tensor,
w2: torch.Tensor, w1_scale: Optional[torch.Tensor],
w2_scale: Optional[torch.Tensor], num_experts: int,
use_fp8_dispatch: bool) -> torch.Tensor:
def deep_ep_moe_impl(
pg: ProcessGroup,
pgi: ProcessGroupInfo,
low_latency_mode: bool,
dp_size: int,
test_tensors: TestTensors,
w1: torch.Tensor,
w2: torch.Tensor,
w1_scale: Optional[torch.Tensor],
w2_scale: Optional[torch.Tensor],
num_experts: int,
use_fp8_dispatch: bool,
per_act_token_quant: bool,
) -> torch.Tensor:
num_local_experts = w1.size(0)
@@ -199,11 +215,9 @@ def deep_ep_moe_impl(pg: ProcessGroup, pgi: ProcessGroupInfo,
q_dtype = torch.float8_e4m3fn
# Make modular kernel
mk: FusedMoEModularKernel = make_modular_kernel(pg, pgi, low_latency_mode,
hidden_size, dp_size,
num_experts,
num_local_experts, q_dtype,
use_fp8_dispatch)
mk: FusedMoEModularKernel = make_modular_kernel(
pg, pgi, low_latency_mode, hidden_size, dp_size, num_experts,
num_local_experts, q_dtype, use_fp8_dispatch, per_act_token_quant)
out_hidden_states = torch.empty_like(test_tensors.rank_tokens)
total_num_tokens = test_tensors.rank_tokens.size(0)
@@ -257,9 +271,15 @@ def deep_ep_moe_impl(pg: ProcessGroup, pgi: ProcessGroupInfo,
return out_hidden_states
def torch_moe_impl(test_tensors: TestTensors, w1: torch.Tensor,
w2: torch.Tensor, w1_scale: Optional[torch.Tensor],
w2_scale: Optional[torch.Tensor], using_fp8_dispatch: bool):
def torch_moe_impl(
test_tensors: TestTensors,
w1: torch.Tensor,
w2: torch.Tensor,
w1_scale: Optional[torch.Tensor],
w2_scale: Optional[torch.Tensor],
using_fp8_dispatch: bool,
per_act_token_quant: bool,
):
a, topk_ids, topk_weights = (test_tensors.rank_tokens, test_tensors.topk,
test_tensors.topk_weights)
@@ -267,6 +287,7 @@ def torch_moe_impl(test_tensors: TestTensors, w1: torch.Tensor,
# The DeepEP implementation is requested to dispatch using FP8.
# For numerical stability for testing, emulate the fp8 dispatch by
# blockwise quant and de-quant.
assert not per_act_token_quant
a = test_tensors.rank_tokens
aq, aq_scale = per_token_group_quant_fp8(a, 128)
a = (aq.view(-1, 128).to(torch.float32) * aq_scale.view(-1, 1)).view(
@@ -310,6 +331,7 @@ def _deep_ep_moe(
w1_scale: Optional[torch.Tensor],
w2_scale: Optional[torch.Tensor],
use_fp8_dispatch: bool,
per_act_token_quant: bool,
):
if not low_latency_mode:
@@ -331,7 +353,8 @@ def _deep_ep_moe(
with set_current_vllm_config(VllmConfig()):
# Reference
torch_combined = torch_moe_impl(test_tensors, w1, w2, w1_scale,
w2_scale, use_fp8_dispatch)
w2_scale, use_fp8_dispatch,
per_act_token_quant)
# Splice experts for this rank.
num_local_experts = config.num_experts // pgi.world_size
@@ -356,6 +379,7 @@ def _deep_ep_moe(
w2_scale_ep,
config.num_experts,
use_fp8_dispatch,
per_act_token_quant,
)
torch.testing.assert_close(
@@ -384,10 +408,16 @@ DTYPES = [torch.bfloat16, torch.float8_e4m3fn]
@pytest.mark.parametrize("num_experts", [32])
@pytest.mark.parametrize("topk", [6])
@pytest.mark.parametrize("world_dp_size", [(2, 1)])
@pytest.mark.parametrize("per_act_token_quant", [False, True])
@requires_deep_ep
def test_deep_ep_moe(dtype: torch.dtype, mnk: tuple[int, int, int],
num_experts: int, topk: int, world_dp_size: tuple[int,
int]):
def test_deep_ep_moe(
dtype: torch.dtype,
mnk: tuple[int, int, int],
num_experts: int,
topk: int,
world_dp_size: tuple[int, int],
per_act_token_quant: bool,
):
low_latency_mode = False
use_fp8_dispatch = False
m, n, k = mnk
@@ -404,7 +434,8 @@ def test_deep_ep_moe(dtype: torch.dtype, mnk: tuple[int, int, int],
w1, w2, w1_scale, w2_scale = make_weights(num_experts, n, k, dtype)
parallel_launch(world_size, _deep_ep_moe, low_latency_mode, dp_size,
config, w1, w2, w1_scale, w2_scale, use_fp8_dispatch)
config, w1, w2, w1_scale, w2_scale, use_fp8_dispatch,
per_act_token_quant)
MNKs = [
@@ -454,4 +485,5 @@ def test_low_latency_deep_ep_moe(dtype: torch.dtype, mnk: tuple[int, int, int],
w1, w2, w1_scale, w2_scale = make_weights(num_experts, n, k, dtype)
parallel_launch(world_size, _deep_ep_moe, low_latency_mode, dp_size,
config, w1, w2, w1_scale, w2_scale, use_fp8_dispatch)
config, w1, w2, w1_scale, w2_scale, use_fp8_dispatch,
False)