Convert formatting to use ruff instead of yapf + isort (#26247)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
@@ -16,10 +16,11 @@ from typing_extensions import ParamSpec
|
||||
|
||||
from vllm.config import VllmConfig, set_current_vllm_config
|
||||
from vllm.model_executor.layers.fused_moe.config import (
|
||||
FusedMoEQuantConfig, fp8_w8a8_moe_quant_config)
|
||||
FusedMoEQuantConfig,
|
||||
fp8_w8a8_moe_quant_config,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe import fused_experts
|
||||
from vllm.model_executor.layers.fused_moe.modular_kernel import (
|
||||
FusedMoEModularKernel)
|
||||
from vllm.model_executor.layers.fused_moe.modular_kernel import FusedMoEModularKernel
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import has_deep_ep, has_deep_gemm
|
||||
from vllm.utils.deep_gemm import is_deep_gemm_e8m0_used, is_deep_gemm_supported
|
||||
@@ -30,18 +31,19 @@ from .utils import make_test_weights
|
||||
|
||||
if has_deep_ep():
|
||||
from vllm.model_executor.layers.fused_moe.deepep_ht_prepare_finalize import ( # noqa: E501
|
||||
DeepEPHTPrepareAndFinalize)
|
||||
DeepEPHTPrepareAndFinalize,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.deepep_ll_prepare_finalize import ( # noqa: E501
|
||||
DeepEPLLPrepareAndFinalize)
|
||||
DeepEPLLPrepareAndFinalize,
|
||||
)
|
||||
|
||||
from .parallel_utils import DeepEPHTArgs, DeepEPLLArgs, make_deepep_a2a
|
||||
|
||||
if has_deep_gemm():
|
||||
|
||||
from vllm.model_executor.layers.fused_moe.batched_deep_gemm_moe import (
|
||||
BatchedDeepGemmExperts)
|
||||
from vllm.model_executor.layers.fused_moe.deep_gemm_moe import (
|
||||
DeepGemmExperts)
|
||||
BatchedDeepGemmExperts,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.deep_gemm_moe import DeepGemmExperts
|
||||
|
||||
requires_deep_ep = pytest.mark.skipif(
|
||||
not has_deep_ep(),
|
||||
@@ -58,9 +60,10 @@ P = ParamSpec("P")
|
||||
|
||||
def next_power_of_2(x):
|
||||
import math
|
||||
|
||||
if x == 0:
|
||||
return 1
|
||||
return 2**math.ceil(math.log2(x))
|
||||
return 2 ** math.ceil(math.log2(x))
|
||||
|
||||
|
||||
def make_block_quant_fp8_weights(
|
||||
@@ -72,13 +75,9 @@ def make_block_quant_fp8_weights(
|
||||
"""
|
||||
Return weights w1q, w2q, w1_scale, w2_scale
|
||||
"""
|
||||
(_, w1q, w1_scale, _), (_, w2q, w2_scale,
|
||||
_) = make_test_weights(e,
|
||||
n,
|
||||
k,
|
||||
torch.bfloat16,
|
||||
torch.float8_e4m3fn,
|
||||
block_shape=block_size)
|
||||
(_, w1q, w1_scale, _), (_, w2q, w2_scale, _) = make_test_weights(
|
||||
e, n, k, torch.bfloat16, torch.float8_e4m3fn, block_shape=block_size
|
||||
)
|
||||
return w1q, w2q, w1_scale, w2_scale
|
||||
|
||||
|
||||
@@ -106,15 +105,15 @@ class TestTensors:
|
||||
|
||||
@staticmethod
|
||||
def make(config: TestConfig, rank) -> "TestTensors":
|
||||
|
||||
dtype = torch.bfloat16
|
||||
topk, m, k = (config.topk, config.m, config.k)
|
||||
|
||||
fp8_info = torch.finfo(torch.float8_e4m3fn)
|
||||
fp8_max, fp8_min = fp8_info.max, fp8_info.min
|
||||
|
||||
rank_tokens = torch.randn(
|
||||
(m, k), device=torch.cuda.current_device(), dtype=dtype) / 10.0
|
||||
rank_tokens = (
|
||||
torch.randn((m, k), device=torch.cuda.current_device(), dtype=dtype) / 10.0
|
||||
)
|
||||
rank_tokens = rank_tokens.clamp(min=fp8_min, max=fp8_max)
|
||||
rank_token_scales = None
|
||||
|
||||
@@ -122,25 +121,32 @@ class TestTensors:
|
||||
low=0,
|
||||
high=config.num_experts,
|
||||
size=(m, topk),
|
||||
device=torch.cuda.current_device()).to(dtype=torch.int64)
|
||||
device=torch.cuda.current_device(),
|
||||
).to(dtype=torch.int64)
|
||||
|
||||
topk_weights = torch.randn(topk_ids.shape,
|
||||
dtype=torch.float32,
|
||||
device=torch.cuda.current_device())
|
||||
topk_weights = torch.randn(
|
||||
topk_ids.shape, dtype=torch.float32, device=torch.cuda.current_device()
|
||||
)
|
||||
|
||||
return TestTensors(rank_tokens=rank_tokens,
|
||||
rank_token_scales=rank_token_scales,
|
||||
topk=topk_ids,
|
||||
topk_weights=topk_weights,
|
||||
config=config)
|
||||
return TestTensors(
|
||||
rank_tokens=rank_tokens,
|
||||
rank_token_scales=rank_token_scales,
|
||||
topk=topk_ids,
|
||||
topk_weights=topk_weights,
|
||||
config=config,
|
||||
)
|
||||
|
||||
|
||||
def make_ll_modular_kernel(
|
||||
pg: ProcessGroup, pgi: ProcessGroupInfo, max_tokens_per_rank: int,
|
||||
dp_size: int, hidden_size: int, q_dtype: Optional[torch.dtype],
|
||||
test_config: TestConfig,
|
||||
quant_config: FusedMoEQuantConfig) -> FusedMoEModularKernel:
|
||||
|
||||
pg: ProcessGroup,
|
||||
pgi: ProcessGroupInfo,
|
||||
max_tokens_per_rank: int,
|
||||
dp_size: int,
|
||||
hidden_size: int,
|
||||
q_dtype: Optional[torch.dtype],
|
||||
test_config: TestConfig,
|
||||
quant_config: FusedMoEQuantConfig,
|
||||
) -> FusedMoEModularKernel:
|
||||
assert test_config.low_latency
|
||||
assert test_config.use_fp8_dispatch is not None
|
||||
|
||||
@@ -153,26 +159,30 @@ def make_ll_modular_kernel(
|
||||
max_tokens_per_rank=max_tokens_per_rank,
|
||||
hidden_size=hidden_size,
|
||||
num_experts=test_config.num_experts,
|
||||
use_fp8_dispatch=test_config.use_fp8_dispatch),
|
||||
use_fp8_dispatch=test_config.use_fp8_dispatch,
|
||||
),
|
||||
q_dtype=q_dtype,
|
||||
block_shape=test_config.block_size)
|
||||
block_shape=test_config.block_size,
|
||||
)
|
||||
|
||||
fused_experts = BatchedDeepGemmExperts(
|
||||
max_num_tokens=max_tokens_per_rank,
|
||||
num_dispatchers=pgi.world_size // dp_size,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
mk = FusedMoEModularKernel(prepare_finalize=a2a,
|
||||
fused_experts=fused_experts)
|
||||
mk = FusedMoEModularKernel(prepare_finalize=a2a, fused_experts=fused_experts)
|
||||
return mk
|
||||
|
||||
|
||||
def make_ht_modular_kernel(
|
||||
pg: ProcessGroup, pgi: ProcessGroupInfo, dp_size: int,
|
||||
num_local_experts: int, q_dtype: Optional[torch.dtype],
|
||||
test_config: TestConfig,
|
||||
quant_config: FusedMoEQuantConfig) -> FusedMoEModularKernel:
|
||||
|
||||
pg: ProcessGroup,
|
||||
pgi: ProcessGroupInfo,
|
||||
dp_size: int,
|
||||
num_local_experts: int,
|
||||
q_dtype: Optional[torch.dtype],
|
||||
test_config: TestConfig,
|
||||
quant_config: FusedMoEQuantConfig,
|
||||
) -> FusedMoEModularKernel:
|
||||
assert not test_config.low_latency
|
||||
assert test_config.use_fp8_dispatch is None
|
||||
|
||||
@@ -183,76 +193,82 @@ def make_ht_modular_kernel(
|
||||
deepep_ht_args=DeepEPHTArgs(num_local_experts=num_local_experts),
|
||||
deepep_ll_args=None,
|
||||
q_dtype=q_dtype,
|
||||
block_shape=test_config.block_size)
|
||||
block_shape=test_config.block_size,
|
||||
)
|
||||
|
||||
fused_experts = DeepGemmExperts(quant_config)
|
||||
mk = FusedMoEModularKernel(prepare_finalize=a2a,
|
||||
fused_experts=fused_experts)
|
||||
mk = FusedMoEModularKernel(prepare_finalize=a2a, fused_experts=fused_experts)
|
||||
return mk
|
||||
|
||||
|
||||
def make_modular_kernel(
|
||||
pg: ProcessGroup, pgi: ProcessGroupInfo, dp_size: int,
|
||||
num_local_experts: int, test_tensors: TestTensors,
|
||||
quant_config: FusedMoEQuantConfig) -> FusedMoEModularKernel:
|
||||
|
||||
pg: ProcessGroup,
|
||||
pgi: ProcessGroupInfo,
|
||||
dp_size: int,
|
||||
num_local_experts: int,
|
||||
test_tensors: TestTensors,
|
||||
quant_config: FusedMoEQuantConfig,
|
||||
) -> FusedMoEModularKernel:
|
||||
q_dtype = torch.float8_e4m3fn
|
||||
test_config = test_tensors.config
|
||||
|
||||
mk: FusedMoEModularKernel
|
||||
# Make modular kernel
|
||||
if test_config.low_latency:
|
||||
max_tokens_per_rank = max(
|
||||
64, next_power_of_2(test_tensors.rank_tokens.size(0)))
|
||||
max_tokens_per_rank = max(64, next_power_of_2(test_tensors.rank_tokens.size(0)))
|
||||
hidden_size = test_tensors.rank_tokens.size(-1)
|
||||
|
||||
mk = make_ll_modular_kernel(pg=pg,
|
||||
pgi=pgi,
|
||||
max_tokens_per_rank=max_tokens_per_rank,
|
||||
dp_size=dp_size,
|
||||
hidden_size=hidden_size,
|
||||
q_dtype=q_dtype,
|
||||
test_config=test_config,
|
||||
quant_config=quant_config)
|
||||
mk = make_ll_modular_kernel(
|
||||
pg=pg,
|
||||
pgi=pgi,
|
||||
max_tokens_per_rank=max_tokens_per_rank,
|
||||
dp_size=dp_size,
|
||||
hidden_size=hidden_size,
|
||||
q_dtype=q_dtype,
|
||||
test_config=test_config,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
else:
|
||||
mk = make_ht_modular_kernel(pg,
|
||||
pgi,
|
||||
dp_size,
|
||||
num_local_experts,
|
||||
q_dtype,
|
||||
test_config,
|
||||
quant_config=quant_config)
|
||||
mk = make_ht_modular_kernel(
|
||||
pg,
|
||||
pgi,
|
||||
dp_size,
|
||||
num_local_experts,
|
||||
q_dtype,
|
||||
test_config,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
|
||||
return mk
|
||||
|
||||
|
||||
def deepep_deepgemm_moe_impl(pg: ProcessGroup, pgi: ProcessGroupInfo,
|
||||
dp_size: int, test_tensors: TestTensors,
|
||||
w1: torch.Tensor, w2: torch.Tensor,
|
||||
w1_scale: Optional[torch.Tensor],
|
||||
w2_scale: Optional[torch.Tensor]) -> torch.Tensor:
|
||||
|
||||
def deepep_deepgemm_moe_impl(
|
||||
pg: ProcessGroup,
|
||||
pgi: ProcessGroupInfo,
|
||||
dp_size: int,
|
||||
test_tensors: TestTensors,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
w1_scale: Optional[torch.Tensor],
|
||||
w2_scale: Optional[torch.Tensor],
|
||||
) -> torch.Tensor:
|
||||
test_config = test_tensors.config
|
||||
num_experts = test_config.num_experts
|
||||
num_local_experts = w1.size(0)
|
||||
|
||||
def build_expert_map():
|
||||
num_local_experts = w1.size(0)
|
||||
expert_map = torch.full((num_experts, ),
|
||||
fill_value=-1,
|
||||
dtype=torch.int32)
|
||||
expert_map = torch.full((num_experts,), fill_value=-1, dtype=torch.int32)
|
||||
s = pgi.rank * num_local_experts
|
||||
e = s + num_local_experts
|
||||
expert_map[s:e] = torch.tensor(list(range(num_local_experts)))
|
||||
return expert_map.to(device=torch.cuda.current_device(),
|
||||
dtype=torch.int32)
|
||||
return expert_map.to(device=torch.cuda.current_device(), dtype=torch.int32)
|
||||
|
||||
quant_config = fp8_w8a8_moe_quant_config(
|
||||
w1_scale=w1_scale,
|
||||
w2_scale=w2_scale,
|
||||
# Low-Latency kernels can't dispatch scales.
|
||||
a1_scale=(None if test_config.low_latency else
|
||||
test_tensors.rank_token_scales),
|
||||
a1_scale=(None if test_config.low_latency else test_tensors.rank_token_scales),
|
||||
block_shape=test_config.block_size,
|
||||
)
|
||||
|
||||
@@ -263,26 +279,35 @@ def deepep_deepgemm_moe_impl(pg: ProcessGroup, pgi: ProcessGroupInfo,
|
||||
dp_size=dp_size,
|
||||
num_local_experts=num_local_experts,
|
||||
test_tensors=test_tensors,
|
||||
quant_config=quant_config)
|
||||
quant_config=quant_config,
|
||||
)
|
||||
|
||||
out = mk.forward(hidden_states=test_tensors.rank_tokens,
|
||||
w1=w1,
|
||||
w2=w2,
|
||||
topk_weights=test_tensors.topk_weights,
|
||||
topk_ids=test_tensors.topk,
|
||||
inplace=False,
|
||||
activation="silu",
|
||||
global_num_experts=num_experts,
|
||||
expert_map=build_expert_map(),
|
||||
apply_router_weight_on_input=False)
|
||||
out = mk.forward(
|
||||
hidden_states=test_tensors.rank_tokens,
|
||||
w1=w1,
|
||||
w2=w2,
|
||||
topk_weights=test_tensors.topk_weights,
|
||||
topk_ids=test_tensors.topk,
|
||||
inplace=False,
|
||||
activation="silu",
|
||||
global_num_experts=num_experts,
|
||||
expert_map=build_expert_map(),
|
||||
apply_router_weight_on_input=False,
|
||||
)
|
||||
return out
|
||||
|
||||
|
||||
def triton_impl(a: torch.Tensor, topk_ids: torch.Tensor,
|
||||
topk_weights: torch.Tensor, w1: torch.Tensor, w2: torch.Tensor,
|
||||
w1_scale: torch.Tensor, w2_scale: torch.Tensor,
|
||||
a1_scale: torch.Tensor, block_shape: list[int]):
|
||||
|
||||
def triton_impl(
|
||||
a: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
w1_scale: torch.Tensor,
|
||||
w2_scale: torch.Tensor,
|
||||
a1_scale: torch.Tensor,
|
||||
block_shape: list[int],
|
||||
):
|
||||
quant_config = fp8_w8a8_moe_quant_config(
|
||||
w1_scale=w1_scale,
|
||||
w2_scale=w2_scale,
|
||||
@@ -300,7 +325,8 @@ def triton_impl(a: torch.Tensor, topk_ids: torch.Tensor,
|
||||
quant_config=quant_config,
|
||||
# Make sure this is set to False so we
|
||||
# don't end up comparing the same implementation.
|
||||
allow_deep_gemm=False)
|
||||
allow_deep_gemm=False,
|
||||
)
|
||||
|
||||
|
||||
def _test_deepep_deepgemm_moe(
|
||||
@@ -321,22 +347,21 @@ def _test_deepep_deepgemm_moe(
|
||||
|
||||
pg = torch.distributed.new_group(list(range(pgi.world_size)))
|
||||
test_tensors = TestTensors.make(config, pgi.rank)
|
||||
block_shape = [
|
||||
w1.size(1) // w1_scale.size(1),
|
||||
w1.size(2) // w1_scale.size(2)
|
||||
]
|
||||
block_shape = [w1.size(1) // w1_scale.size(1), w1.size(2) // w1_scale.size(2)]
|
||||
|
||||
with set_current_vllm_config(VllmConfig()):
|
||||
# Reference
|
||||
triton_moe = triton_impl(a=test_tensors.rank_tokens,
|
||||
topk_ids=test_tensors.topk,
|
||||
topk_weights=test_tensors.topk_weights,
|
||||
w1=w1,
|
||||
w2=w2,
|
||||
w1_scale=w1_scale,
|
||||
w2_scale=w2_scale,
|
||||
a1_scale=test_tensors.rank_token_scales,
|
||||
block_shape=block_shape)
|
||||
triton_moe = triton_impl(
|
||||
a=test_tensors.rank_tokens,
|
||||
topk_ids=test_tensors.topk,
|
||||
topk_weights=test_tensors.topk_weights,
|
||||
w1=w1,
|
||||
w2=w2,
|
||||
w1_scale=w1_scale,
|
||||
w2_scale=w2_scale,
|
||||
a1_scale=test_tensors.rank_token_scales,
|
||||
block_shape=block_shape,
|
||||
)
|
||||
|
||||
# Slice experts for this rank.
|
||||
num_local_experts = config.num_experts // pgi.world_size
|
||||
@@ -390,10 +415,15 @@ NUM_EXPERTS = [32]
|
||||
@multi_gpu_test(num_gpus=2)
|
||||
@requires_deep_ep
|
||||
@requires_deep_gemm
|
||||
@pytest.mark.skipif(is_deep_gemm_e8m0_used(),
|
||||
reason="Skipping test for Blackwell DeepGEMM")
|
||||
def test_ht_deepep_deepgemm_moe(mnk: tuple[int, int, int], num_experts: int,
|
||||
topk: int, world_dp_size: tuple[int, int]):
|
||||
@pytest.mark.skipif(
|
||||
is_deep_gemm_e8m0_used(), reason="Skipping test for Blackwell DeepGEMM"
|
||||
)
|
||||
def test_ht_deepep_deepgemm_moe(
|
||||
mnk: tuple[int, int, int],
|
||||
num_experts: int,
|
||||
topk: int,
|
||||
world_dp_size: tuple[int, int],
|
||||
):
|
||||
"""
|
||||
Tests for High-Throughput DeepEP + DeepGemm integration.
|
||||
"""
|
||||
@@ -409,21 +439,32 @@ def test_ht_deepep_deepgemm_moe(mnk: tuple[int, int, int], num_experts: int,
|
||||
block_size = [block_m, block_m]
|
||||
|
||||
world_size, dp_size = world_dp_size
|
||||
config = TestConfig(topk=topk,
|
||||
m=m,
|
||||
k=k,
|
||||
n=n,
|
||||
num_experts=num_experts,
|
||||
per_act_token_quant=False,
|
||||
block_size=block_size,
|
||||
low_latency=False,
|
||||
use_fp8_dispatch=None)
|
||||
config = TestConfig(
|
||||
topk=topk,
|
||||
m=m,
|
||||
k=k,
|
||||
n=n,
|
||||
num_experts=num_experts,
|
||||
per_act_token_quant=False,
|
||||
block_size=block_size,
|
||||
low_latency=False,
|
||||
use_fp8_dispatch=None,
|
||||
)
|
||||
|
||||
w1, w2, w1_scale, w2_scale = make_block_quant_fp8_weights(
|
||||
num_experts, n, k, block_size)
|
||||
num_experts, n, k, block_size
|
||||
)
|
||||
|
||||
parallel_launch(world_size, _test_deepep_deepgemm_moe, dp_size, config, w1,
|
||||
w2, w1_scale, w2_scale)
|
||||
parallel_launch(
|
||||
world_size,
|
||||
_test_deepep_deepgemm_moe,
|
||||
dp_size,
|
||||
config,
|
||||
w1,
|
||||
w2,
|
||||
w1_scale,
|
||||
w2_scale,
|
||||
)
|
||||
|
||||
|
||||
MNKs = [
|
||||
@@ -448,8 +489,9 @@ USE_FP8_DISPATCH = [False]
|
||||
@multi_gpu_test(num_gpus=2)
|
||||
@requires_deep_ep
|
||||
@requires_deep_gemm
|
||||
@pytest.mark.skipif(is_deep_gemm_e8m0_used(),
|
||||
reason="Skipping test for Blackwell DeepGEMM")
|
||||
@pytest.mark.skipif(
|
||||
is_deep_gemm_e8m0_used(), reason="Skipping test for Blackwell DeepGEMM"
|
||||
)
|
||||
def test_ll_deepep_deepgemm_moe(
|
||||
mnk: tuple[int, int, int],
|
||||
num_experts: int,
|
||||
@@ -482,7 +524,16 @@ def test_ll_deepep_deepgemm_moe(
|
||||
)
|
||||
|
||||
w1, w2, w1_scale, w2_scale = make_block_quant_fp8_weights(
|
||||
num_experts, n, k, block_size)
|
||||
num_experts, n, k, block_size
|
||||
)
|
||||
|
||||
parallel_launch(world_size, _test_deepep_deepgemm_moe, dp_size, config, w1,
|
||||
w2, w1_scale, w2_scale)
|
||||
parallel_launch(
|
||||
world_size,
|
||||
_test_deepep_deepgemm_moe,
|
||||
dp_size,
|
||||
config,
|
||||
w1,
|
||||
w2,
|
||||
w1_scale,
|
||||
w2_scale,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user