[Kernels] MoE refactor (#19636)
Signed-off-by: Bill Nell <bnell@redhat.com> Signed-off-by: ElizaWszola <ewszola@redhat.com> Co-authored-by: ElizaWszola <ewszola@redhat.com>
This commit is contained in:
@@ -1,6 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
"""
|
||||
Test DeepEP + DeepGEMM integration
|
||||
Test DeepEP + DeepGEMM integration
|
||||
DeepGEMM are gemm kernels specialized for the
|
||||
fp8 block-quantized case.
|
||||
"""
|
||||
@@ -17,12 +17,11 @@ from vllm.config import VllmConfig, set_current_vllm_config
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe import fused_experts
|
||||
from vllm.model_executor.layers.fused_moe.modular_kernel import (
|
||||
FusedMoEModularKernel)
|
||||
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
||||
per_token_group_quant_fp8)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import has_deep_ep, has_deep_gemm
|
||||
|
||||
from .utils import ProcessGroupInfo, parallel_launch
|
||||
from .parallel_utils import ProcessGroupInfo, parallel_launch
|
||||
from .utils import make_test_weights
|
||||
|
||||
if has_deep_ep():
|
||||
from vllm.model_executor.layers.fused_moe.deepep_ht_prepare_finalize import ( # noqa: E501
|
||||
@@ -30,10 +29,9 @@ if has_deep_ep():
|
||||
from vllm.model_executor.layers.fused_moe.deepep_ll_prepare_finalize import ( # noqa: E501
|
||||
DeepEPLLPrepareAndFinalize)
|
||||
|
||||
from .utils import DeepEPHTArgs, DeepEPLLArgs, make_deepep_a2a
|
||||
from .parallel_utils import DeepEPHTArgs, DeepEPLLArgs, make_deepep_a2a
|
||||
|
||||
if has_deep_gemm():
|
||||
import deep_gemm
|
||||
|
||||
from vllm.model_executor.layers.fused_moe.batched_deep_gemm_moe import (
|
||||
BatchedDeepGemmExperts)
|
||||
@@ -60,25 +58,6 @@ def next_power_of_2(x):
|
||||
return 2**math.ceil(math.log2(x))
|
||||
|
||||
|
||||
def per_block_cast_to_fp8(
|
||||
x: torch.Tensor,
|
||||
block_size_n: int = 128) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
assert x.dim() == 2
|
||||
m, n = x.shape
|
||||
x_padded = torch.zeros(
|
||||
(deep_gemm.ceil_div(m, 128) * 128,
|
||||
deep_gemm.ceil_div(n, block_size_n) * block_size_n),
|
||||
dtype=x.dtype,
|
||||
device=x.device)
|
||||
x_padded[:m, :n] = x
|
||||
x_view = x_padded.view(-1, 128, x_padded.size(1) // 128, block_size_n)
|
||||
x_amax = x_view.abs().float().amax(dim=(1, 3), keepdim=True).clamp(1e-4)
|
||||
x_scaled = (x_view * (448.0 / x_amax)).to(torch.float8_e4m3fn)
|
||||
x_scaled_sub = x_scaled.view_as(x_padded)[:m, :n].contiguous()
|
||||
scales = (x_amax / 448.0).view(x_view.size(0), x_view.size(2))
|
||||
return x_scaled_sub, scales
|
||||
|
||||
|
||||
def make_block_quant_fp8_weights(
|
||||
e: int,
|
||||
n: int,
|
||||
@@ -86,43 +65,11 @@ def make_block_quant_fp8_weights(
|
||||
block_size: list[int],
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Return weights w1, w2, w1q, w2q, w1_scale, w2_scale
|
||||
Return weights w1q, w2q, w1_scale, w2_scale
|
||||
"""
|
||||
dtype = torch.bfloat16
|
||||
|
||||
fp8_info = torch.finfo(torch.float8_e4m3fn)
|
||||
fp8_max, fp8_min = fp8_info.max, fp8_info.min
|
||||
|
||||
w1_bf16 = torch.randn((e, 2 * n, k), dtype=dtype) / 10
|
||||
w1_bf16 = w1_bf16.clamp(min=fp8_min, max=fp8_max).to(dtype=dtype)
|
||||
|
||||
w2_bf16 = torch.randn((e, k, n), dtype=dtype) / 10
|
||||
w2_bf16 = w2_bf16.clamp(min=fp8_min, max=fp8_max).to(dtype=dtype)
|
||||
|
||||
block_n, block_k = block_size[0], block_size[1]
|
||||
n_tiles_w1 = ((2 * n) + block_n - 1) // block_n
|
||||
k_tiles_w1 = (k + block_k - 1) // block_k
|
||||
n_tiles_w2 = (k + block_n - 1) // block_n
|
||||
k_tiles_w2 = (n + block_k - 1) // block_k
|
||||
|
||||
w1 = torch.empty_like(w1_bf16, dtype=torch.float8_e4m3fn)
|
||||
w2 = torch.empty_like(w2_bf16, dtype=torch.float8_e4m3fn)
|
||||
|
||||
w1_s = torch.empty((e, n_tiles_w1, k_tiles_w1),
|
||||
device="cuda",
|
||||
dtype=torch.float32)
|
||||
w2_s = torch.empty((e, n_tiles_w2, k_tiles_w2),
|
||||
device="cuda",
|
||||
dtype=torch.float32)
|
||||
|
||||
assert w1_s.shape == (e, (2 * n + 127) // 128, (k + 127) // 128)
|
||||
assert (w2.shape[-2] + block_n - 1) // block_n == w2_s.shape[-2]
|
||||
|
||||
for i in range(e):
|
||||
w1[i], w1_s[i] = per_block_cast_to_fp8(w1_bf16[i])
|
||||
w2[i], w2_s[i] = per_block_cast_to_fp8(w2_bf16[i])
|
||||
|
||||
return w1, w2, w1_s, w2_s
|
||||
w1, w1q, w1_scale, w2, w2q, w2_scale = make_test_weights(
|
||||
e, n, k, torch.bfloat16, torch.float8_e4m3fn, block_size)
|
||||
return w1q, w2q, w1_scale, w2_scale
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
@@ -132,6 +79,7 @@ class TestConfig:
|
||||
k: int
|
||||
n: int
|
||||
num_experts: int
|
||||
per_act_token_quant: bool
|
||||
block_size: list[int]
|
||||
# configs for testing low-latency kernels
|
||||
low_latency: bool
|
||||
@@ -150,8 +98,7 @@ class TestTensors:
|
||||
def make(config: TestConfig, rank) -> "TestTensors":
|
||||
|
||||
dtype = torch.bfloat16
|
||||
topk, m, k, block_size = (config.topk, config.m, config.k,
|
||||
config.block_size)
|
||||
topk, m, k = (config.topk, config.m, config.k)
|
||||
|
||||
fp8_info = torch.finfo(torch.float8_e4m3fn)
|
||||
fp8_max, fp8_min = fp8_info.max, fp8_info.min
|
||||
@@ -159,9 +106,7 @@ class TestTensors:
|
||||
rank_tokens = torch.randn(
|
||||
(m, k), device=torch.cuda.current_device(), dtype=dtype) / 10.0
|
||||
rank_tokens = rank_tokens.clamp(min=fp8_min, max=fp8_max)
|
||||
|
||||
block_k = block_size[1]
|
||||
_, rank_token_scales = per_token_group_quant_fp8(rank_tokens, block_k)
|
||||
rank_token_scales = None
|
||||
|
||||
topk_ids = torch.randint(
|
||||
low=0,
|
||||
@@ -201,10 +146,12 @@ def make_ll_modular_kernel(pg: ProcessGroup, pgi: ProcessGroupInfo,
|
||||
q_dtype=q_dtype,
|
||||
block_shape=test_config.block_size)
|
||||
|
||||
fused_experts = BatchedDeepGemmExperts(max_num_tokens=max_tokens_per_rank,
|
||||
world_size=pgi.world_size,
|
||||
dp_size=dp_size,
|
||||
block_shape=test_config.block_size)
|
||||
fused_experts = BatchedDeepGemmExperts(
|
||||
max_num_tokens=max_tokens_per_rank,
|
||||
world_size=pgi.world_size,
|
||||
dp_size=dp_size,
|
||||
block_shape=test_config.block_size,
|
||||
per_act_token_quant=test_config.per_act_token_quant)
|
||||
mk = FusedMoEModularKernel(prepare_finalize=a2a,
|
||||
fused_experts=fused_experts)
|
||||
return mk
|
||||
@@ -426,6 +373,7 @@ def test_ht_deepep_deepgemm_moe(mnk: tuple[int, int, int], num_experts: int,
|
||||
"""
|
||||
Tests for High-Throughput DeepEP + DeepGemm integration.
|
||||
"""
|
||||
import deep_gemm
|
||||
|
||||
m, n, k = mnk
|
||||
current_platform.seed_everything(7)
|
||||
@@ -442,6 +390,7 @@ def test_ht_deepep_deepgemm_moe(mnk: tuple[int, int, int], num_experts: int,
|
||||
k=k,
|
||||
n=n,
|
||||
num_experts=num_experts,
|
||||
per_act_token_quant=False,
|
||||
block_size=block_size,
|
||||
low_latency=False,
|
||||
use_fp8_dispatch=None)
|
||||
@@ -474,10 +423,14 @@ USE_FP8_DISPATCH = [False]
|
||||
@pytest.mark.parametrize("world_dp_size", [(2, 1)])
|
||||
@requires_deep_ep
|
||||
@requires_deep_gemm
|
||||
def test_ll_deepep_deepgemm_moe(mnk: tuple[int, int,
|
||||
int], num_experts: int, topk: int,
|
||||
use_fp8_dispatch: bool, block_size: list[int],
|
||||
world_dp_size: tuple[int, int]):
|
||||
def test_ll_deepep_deepgemm_moe(
|
||||
mnk: tuple[int, int, int],
|
||||
num_experts: int,
|
||||
topk: int,
|
||||
use_fp8_dispatch: bool,
|
||||
block_size: list[int],
|
||||
world_dp_size: tuple[int, int],
|
||||
):
|
||||
"""
|
||||
Tests for Low-Latency DeepEP + DeepGemm integration.
|
||||
"""
|
||||
@@ -495,6 +448,7 @@ def test_ll_deepep_deepgemm_moe(mnk: tuple[int, int,
|
||||
k=k,
|
||||
n=n,
|
||||
num_experts=num_experts,
|
||||
per_act_token_quant=False,
|
||||
block_size=block_size,
|
||||
low_latency=True,
|
||||
use_fp8_dispatch=use_fp8_dispatch,
|
||||
|
||||
Reference in New Issue
Block a user