[Kernel] Delegate construction of FusedMoEQuantConfig to FusedMoEMethodBase subclasses (#22537)

Signed-off-by: Bill Nell <bnell@redhat.com>
This commit is contained in:
bnellnm
2025-09-17 19:43:31 -04:00
committed by GitHub
parent e6585ddb45
commit 5963b98b46
68 changed files with 2698 additions and 2526 deletions

View File

@@ -15,6 +15,8 @@ from torch.distributed import ProcessGroup
from typing_extensions import ParamSpec
from vllm.config import VllmConfig, set_current_vllm_config
from vllm.model_executor.layers.fused_moe.config import (
FusedMoEQuantConfig, fp8_w8a8_moe_quant_config)
from vllm.model_executor.layers.fused_moe.fused_moe import fused_experts
from vllm.model_executor.layers.fused_moe.modular_kernel import (
FusedMoEModularKernel)
@@ -71,9 +73,12 @@ def make_block_quant_fp8_weights(
Return weights w1q, w2q, w1_scale, w2_scale
"""
(_, w1q, w1_scale, _), (_, w2q, w2_scale,
_) = make_test_weights(e, n, k, torch.bfloat16,
_) = make_test_weights(e,
n,
k,
torch.bfloat16,
torch.float8_e4m3fn,
block_size)
block_shape=block_size)
return w1q, w2q, w1_scale, w2_scale
@@ -130,10 +135,11 @@ class TestTensors:
config=config)
def make_ll_modular_kernel(pg: ProcessGroup, pgi: ProcessGroupInfo,
max_tokens_per_rank: int, dp_size: int,
hidden_size: int, q_dtype: Optional[torch.dtype],
test_config: TestConfig) -> FusedMoEModularKernel:
def make_ll_modular_kernel(
pg: ProcessGroup, pgi: ProcessGroupInfo, max_tokens_per_rank: int,
dp_size: int, hidden_size: int, q_dtype: Optional[torch.dtype],
test_config: TestConfig,
quant_config: FusedMoEQuantConfig) -> FusedMoEModularKernel:
assert test_config.low_latency
assert test_config.use_fp8_dispatch is not None
@@ -154,17 +160,18 @@ def make_ll_modular_kernel(pg: ProcessGroup, pgi: ProcessGroupInfo,
fused_experts = BatchedDeepGemmExperts(
max_num_tokens=max_tokens_per_rank,
num_dispatchers=pgi.world_size // dp_size,
block_shape=test_config.block_size,
per_act_token_quant=test_config.per_act_token_quant)
quant_config=quant_config,
)
mk = FusedMoEModularKernel(prepare_finalize=a2a,
fused_experts=fused_experts)
return mk
def make_ht_modular_kernel(pg: ProcessGroup, pgi: ProcessGroupInfo,
dp_size: int, num_local_experts: int,
q_dtype: Optional[torch.dtype],
test_config: TestConfig) -> FusedMoEModularKernel:
def make_ht_modular_kernel(
pg: ProcessGroup, pgi: ProcessGroupInfo, dp_size: int,
num_local_experts: int, q_dtype: Optional[torch.dtype],
test_config: TestConfig,
quant_config: FusedMoEQuantConfig) -> FusedMoEModularKernel:
assert not test_config.low_latency
assert test_config.use_fp8_dispatch is None
@@ -178,15 +185,16 @@ def make_ht_modular_kernel(pg: ProcessGroup, pgi: ProcessGroupInfo,
q_dtype=q_dtype,
block_shape=test_config.block_size)
fused_experts = DeepGemmExperts()
fused_experts = DeepGemmExperts(quant_config)
mk = FusedMoEModularKernel(prepare_finalize=a2a,
fused_experts=fused_experts)
return mk
def make_modular_kernel(pg: ProcessGroup, pgi: ProcessGroupInfo, dp_size: int,
num_local_experts: int,
test_tensors: TestTensors) -> FusedMoEModularKernel:
def make_modular_kernel(
pg: ProcessGroup, pgi: ProcessGroupInfo, dp_size: int,
num_local_experts: int, test_tensors: TestTensors,
quant_config: FusedMoEQuantConfig) -> FusedMoEModularKernel:
q_dtype = torch.float8_e4m3fn
test_config = test_tensors.config
@@ -204,10 +212,16 @@ def make_modular_kernel(pg: ProcessGroup, pgi: ProcessGroupInfo, dp_size: int,
dp_size=dp_size,
hidden_size=hidden_size,
q_dtype=q_dtype,
test_config=test_config)
test_config=test_config,
quant_config=quant_config)
else:
mk = make_ht_modular_kernel(pg, pgi, dp_size, num_local_experts,
q_dtype, test_config)
mk = make_ht_modular_kernel(pg,
pgi,
dp_size,
num_local_experts,
q_dtype,
test_config,
quant_config=quant_config)
return mk
@@ -233,17 +247,23 @@ def deepep_deepgemm_moe_impl(pg: ProcessGroup, pgi: ProcessGroupInfo,
return expert_map.to(device=torch.cuda.current_device(),
dtype=torch.int32)
quant_config = fp8_w8a8_moe_quant_config(
w1_scale=w1_scale,
w2_scale=w2_scale,
# Low-Latency kernels can't dispatch scales.
a1_scale=(None if test_config.low_latency else
test_tensors.rank_token_scales),
block_shape=test_config.block_size,
)
# Make modular kernel
mk: FusedMoEModularKernel = make_modular_kernel(
pg=pg,
pgi=pgi,
dp_size=dp_size,
num_local_experts=num_local_experts,
test_tensors=test_tensors)
# Low-Latency kernels can't dispatch scales.
a1_scale = (None
if test_config.low_latency else test_tensors.rank_token_scales)
test_tensors=test_tensors,
quant_config=quant_config)
out = mk.forward(hidden_states=test_tensors.rank_tokens,
w1=w1,
@@ -254,12 +274,6 @@ def deepep_deepgemm_moe_impl(pg: ProcessGroup, pgi: ProcessGroupInfo,
activation="silu",
global_num_experts=num_experts,
expert_map=build_expert_map(),
w1_scale=w1_scale,
w2_scale=w2_scale,
w1_zp=None,
w2_zp=None,
a1_scale=a1_scale,
a2_scale=None,
apply_router_weight_on_input=False)
return out
@@ -269,6 +283,13 @@ def triton_impl(a: torch.Tensor, topk_ids: torch.Tensor,
w1_scale: torch.Tensor, w2_scale: torch.Tensor,
a1_scale: torch.Tensor, block_shape: list[int]):
quant_config = fp8_w8a8_moe_quant_config(
w1_scale=w1_scale,
w2_scale=w2_scale,
a1_scale=a1_scale,
block_shape=block_shape,
)
return fused_experts(
hidden_states=a,
w1=w1,
@@ -276,11 +297,7 @@ def triton_impl(a: torch.Tensor, topk_ids: torch.Tensor,
topk_weights=topk_weights,
topk_ids=topk_ids,
inplace=False,
use_fp8_w8a8=True,
w1_scale=w1_scale,
w2_scale=w2_scale,
a1_scale=a1_scale,
block_shape=block_shape,
quant_config=quant_config,
# Make sure this is set to False so we
# don't end up comparing the same implementation.
allow_deep_gemm=False)