[Kernel] Integrate batched/masked deepgemm kernel (#19111)

Signed-off-by: Varun <vsundarr@redhat.com>
Co-authored-by: Varun <vsundarr@redhat.com>
This commit is contained in:
Varun Sundar Rabindranath
2025-06-04 17:59:18 -04:00
committed by GitHub
parent ef3f98b59f
commit c3fd4d669a
6 changed files with 472 additions and 51 deletions

View File

@@ -1,6 +1,8 @@
# SPDX-License-Identifier: Apache-2.0
"""
Test DeepEP + DeepGEMM integration
DeepGEMM are gemm kernels specialized for the
fp8 block-quantized case.
"""
import dataclasses
@@ -33,10 +35,14 @@ except ImportError:
if has_deep_ep:
from vllm.model_executor.layers.fused_moe.deepep_ht_prepare_finalize import ( # noqa: E501
DeepEPHTPrepareAndFinalize)
from vllm.model_executor.layers.fused_moe.deepep_ll_prepare_finalize import ( # noqa: E501
DeepEPLLPrepareAndFinalize)
from .deepep_utils import DeepEPHTArgs, make_deepep_a2a
from .deepep_utils import DeepEPHTArgs, DeepEPLLArgs, make_deepep_a2a
if has_deep_gemm:
from vllm.model_executor.layers.fused_moe.batched_deep_gemm_moe import (
BatchedDeepGemmExperts)
from vllm.model_executor.layers.fused_moe.deep_gemm_moe import (
DeepGemmExperts)
@@ -53,6 +59,13 @@ requires_deep_gemm = pytest.mark.skipif(
P = ParamSpec("P")
def next_power_of_2(x):
import math
if x == 0:
return 1
return 2**math.ceil(math.log2(x))
def per_block_cast_to_fp8(
x: torch.Tensor,
block_size_n: int = 128) -> tuple[torch.Tensor, torch.Tensor]:
@@ -126,6 +139,9 @@ class TestConfig:
n: int
num_experts: int
block_size: list[int]
# configs for testing low-latency kernels
low_latency: bool
use_fp8_dispatch: Optional[bool] = False
@dataclasses.dataclass
@@ -170,9 +186,43 @@ class TestTensors:
config=config)
def make_modular_kernel(pg: ProcessGroup, pgi: ProcessGroupInfo, dp_size: int,
num_local_experts: int, q_dtype: Optional[torch.dtype],
block_shape: list[int]) -> FusedMoEModularKernel:
def make_ll_modular_kernel(pg: ProcessGroup, pgi: ProcessGroupInfo,
max_tokens_per_rank: int, dp_size: int,
hidden_size: int, q_dtype: Optional[torch.dtype],
test_config: TestConfig) -> FusedMoEModularKernel:
assert test_config.low_latency
assert test_config.use_fp8_dispatch is not None
a2a: DeepEPLLPrepareAndFinalize = make_deepep_a2a(
pg=pg,
pgi=pgi,
dp_size=dp_size,
deepep_ht_args=None,
deepep_ll_args=DeepEPLLArgs(
max_tokens_per_rank=max_tokens_per_rank,
hidden_size=hidden_size,
num_experts=test_config.num_experts,
use_fp8_dispatch=test_config.use_fp8_dispatch),
q_dtype=q_dtype,
block_shape=test_config.block_size)
fused_experts = BatchedDeepGemmExperts(max_num_tokens=max_tokens_per_rank,
world_size=pgi.world_size,
dp_size=dp_size,
block_shape=test_config.block_size)
mk = FusedMoEModularKernel(prepare_finalize=a2a,
fused_experts=fused_experts)
return mk
def make_ht_modular_kernel(pg: ProcessGroup, pgi: ProcessGroupInfo,
dp_size: int, num_local_experts: int,
q_dtype: Optional[torch.dtype],
test_config: TestConfig) -> FusedMoEModularKernel:
assert not test_config.low_latency
assert test_config.use_fp8_dispatch is None
a2a: DeepEPHTPrepareAndFinalize = make_deepep_a2a(
pg=pg,
@@ -181,7 +231,7 @@ def make_modular_kernel(pg: ProcessGroup, pgi: ProcessGroupInfo, dp_size: int,
deepep_ht_args=DeepEPHTArgs(num_local_experts=num_local_experts),
deepep_ll_args=None,
q_dtype=q_dtype,
block_shape=block_shape)
block_shape=test_config.block_size)
fused_experts = DeepGemmExperts()
mk = FusedMoEModularKernel(prepare_finalize=a2a,
@@ -189,12 +239,42 @@ def make_modular_kernel(pg: ProcessGroup, pgi: ProcessGroupInfo, dp_size: int,
return mk
def deep_ep_moe_impl(pg: ProcessGroup, pgi: ProcessGroupInfo, dp_size: int,
test_tensors: TestTensors, w1: torch.Tensor,
w2: torch.Tensor, w1_scale: Optional[torch.Tensor],
w2_scale: Optional[torch.Tensor],
num_experts: int) -> torch.Tensor:
def make_modular_kernel(pg: ProcessGroup, pgi: ProcessGroupInfo, dp_size: int,
num_local_experts: int,
test_tensors: TestTensors) -> FusedMoEModularKernel:
q_dtype = torch.float8_e4m3fn
test_config = test_tensors.config
mk: FusedMoEModularKernel
# Make modular kernel
if test_config.low_latency:
max_tokens_per_rank = max(
64, next_power_of_2(test_tensors.rank_tokens.size(0)))
hidden_size = test_tensors.rank_tokens.size(-1)
mk = make_ll_modular_kernel(pg=pg,
pgi=pgi,
max_tokens_per_rank=max_tokens_per_rank,
dp_size=dp_size,
hidden_size=hidden_size,
q_dtype=q_dtype,
test_config=test_config)
else:
mk = make_ht_modular_kernel(pg, pgi, dp_size, num_local_experts,
q_dtype, test_config)
return mk
def deepep_deepgemm_moe_impl(pg: ProcessGroup, pgi: ProcessGroupInfo,
dp_size: int, test_tensors: TestTensors,
w1: torch.Tensor, w2: torch.Tensor,
w1_scale: Optional[torch.Tensor],
w2_scale: Optional[torch.Tensor]) -> torch.Tensor:
test_config = test_tensors.config
num_experts = test_config.num_experts
num_local_experts = w1.size(0)
def build_expert_map():
@@ -208,14 +288,17 @@ def deep_ep_moe_impl(pg: ProcessGroup, pgi: ProcessGroupInfo, dp_size: int,
return expert_map.to(device=torch.cuda.current_device(),
dtype=torch.int32)
q_dtype = torch.float8_e4m3fn
# Make modular kernel
mk: FusedMoEModularKernel = make_modular_kernel(
pg, pgi, dp_size, num_local_experts, q_dtype,
test_tensors.config.block_size)
pg=pg,
pgi=pgi,
dp_size=dp_size,
num_local_experts=num_local_experts,
test_tensors=test_tensors)
a1_scale = test_tensors.rank_token_scales
# Low-Latency kernels can't dispatch scales.
a1_scale = (None
if test_config.low_latency else test_tensors.rank_token_scales)
out = mk.forward(hidden_states=test_tensors.rank_tokens,
w1=w1,
@@ -258,7 +341,7 @@ def triton_impl(a: torch.Tensor, topk_ids: torch.Tensor,
allow_deep_gemm=False)
def _deep_ep_moe(
def _test_deepep_deepgemm_moe(
pgi: ProcessGroupInfo,
dp_size: int,
config: TestConfig,
@@ -302,7 +385,7 @@ def _deep_ep_moe(
w1_scale_ep = w1_scale[e_start:e_end]
w2_scale_ep = w2_scale[e_start:e_end]
deepep_moe = deep_ep_moe_impl(
deepep_moe = deepep_deepgemm_moe_impl(
pg,
pgi,
dp_size,
@@ -311,7 +394,6 @@ def _deep_ep_moe(
w2_ep,
w1_scale_ep,
w2_scale_ep,
config.num_experts,
)
torch.testing.assert_close(
@@ -335,15 +417,21 @@ MNKs = [
(222, 1024, 2048),
]
TOPKS = [2, 6]
NUM_EXPERTS = [32]
@pytest.mark.parametrize("mnk", MNKs)
@pytest.mark.parametrize("num_experts", [32])
@pytest.mark.parametrize("topk", [2, 6])
@pytest.mark.parametrize("num_experts", NUM_EXPERTS)
@pytest.mark.parametrize("topk", TOPKS)
@pytest.mark.parametrize("world_dp_size", [(2, 1)])
@requires_deep_ep
@requires_deep_gemm
def test_deep_ep_moe(mnk: tuple[int, int, int], num_experts: int, topk: int,
world_dp_size: tuple[int, int]):
def test_ht_deepep_deepgemm_moe(mnk: tuple[int, int, int], num_experts: int,
topk: int, world_dp_size: tuple[int, int]):
"""
Tests for High-Throughput DeepEP + DeepGemm integration.
"""
m, n, k = mnk
current_platform.seed_everything(7)
@@ -354,6 +442,58 @@ def test_deep_ep_moe(mnk: tuple[int, int, int], num_experts: int, topk: int,
block_m = deep_gemm.get_m_alignment_for_contiguous_layout()
block_size = [block_m, block_m]
world_size, dp_size = world_dp_size
config = TestConfig(topk=topk,
m=m,
k=k,
n=n,
num_experts=num_experts,
block_size=block_size,
low_latency=False,
use_fp8_dispatch=None)
w1, w2, w1_scale, w2_scale = make_block_quant_fp8_weights(
num_experts, n, k, block_size)
parallel_launch(world_size, _test_deepep_deepgemm_moe, dp_size, config, w1,
w2, w1_scale, w2_scale)
MNKs = [
(1, 128, 2560),
(2, 128, 2560),
(3, 1024, 2560),
(32, 128, 2560),
(45, 512, 2560),
(64, 1024, 2560),
(222, 1024, 2560),
]
# Fix tests for USE_FP8_DISPATCH=True
USE_FP8_DISPATCH = [False]
@pytest.mark.parametrize("mnk", MNKs)
@pytest.mark.parametrize("num_experts", NUM_EXPERTS)
@pytest.mark.parametrize("topk", TOPKS)
@pytest.mark.parametrize("use_fp8_dispatch", USE_FP8_DISPATCH)
@pytest.mark.parametrize("block_size", [[128, 128]])
@pytest.mark.parametrize("world_dp_size", [(2, 1)])
@requires_deep_ep
@requires_deep_gemm
def test_ll_deepep_deepgemm_moe(mnk: tuple[int, int,
int], num_experts: int, topk: int,
use_fp8_dispatch: bool, block_size: list[int],
world_dp_size: tuple[int, int]):
"""
Tests for Low-Latency DeepEP + DeepGemm integration.
"""
m, n, k = mnk
current_platform.seed_everything(7)
if topk > num_experts:
pytest.skip(f"Skipping test: topk={topk} > E={num_experts}")
world_size, dp_size = world_dp_size
config = TestConfig(
topk=topk,
@@ -362,10 +502,12 @@ def test_deep_ep_moe(mnk: tuple[int, int, int], num_experts: int, topk: int,
n=n,
num_experts=num_experts,
block_size=block_size,
low_latency=True,
use_fp8_dispatch=use_fp8_dispatch,
)
w1, w2, w1_scale, w2_scale = make_block_quant_fp8_weights(
num_experts, n, k, block_size)
parallel_launch(world_size, _deep_ep_moe, dp_size, config, w1, w2,
w1_scale, w2_scale)
parallel_launch(world_size, _test_deepep_deepgemm_moe, dp_size, config, w1,
w2, w1_scale, w2_scale)