[Kernel] Delegate construction of FusedMoEQuantConfig to FusedMoEMethodBase subclasses (#22537)
Signed-off-by: Bill Nell <bnell@redhat.com>
This commit is contained in:
@@ -15,6 +15,8 @@ from torch.distributed import ProcessGroup
|
||||
from typing_extensions import ParamSpec
|
||||
|
||||
from vllm.config import VllmConfig, set_current_vllm_config
|
||||
from vllm.model_executor.layers.fused_moe.config import (
|
||||
FusedMoEQuantConfig, fp8_w8a8_moe_quant_config)
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe import fused_experts
|
||||
from vllm.model_executor.layers.fused_moe.modular_kernel import (
|
||||
FusedMoEModularKernel)
|
||||
@@ -71,9 +73,12 @@ def make_block_quant_fp8_weights(
|
||||
Return weights w1q, w2q, w1_scale, w2_scale
|
||||
"""
|
||||
(_, w1q, w1_scale, _), (_, w2q, w2_scale,
|
||||
_) = make_test_weights(e, n, k, torch.bfloat16,
|
||||
_) = make_test_weights(e,
|
||||
n,
|
||||
k,
|
||||
torch.bfloat16,
|
||||
torch.float8_e4m3fn,
|
||||
block_size)
|
||||
block_shape=block_size)
|
||||
return w1q, w2q, w1_scale, w2_scale
|
||||
|
||||
|
||||
@@ -130,10 +135,11 @@ class TestTensors:
|
||||
config=config)
|
||||
|
||||
|
||||
def make_ll_modular_kernel(pg: ProcessGroup, pgi: ProcessGroupInfo,
|
||||
max_tokens_per_rank: int, dp_size: int,
|
||||
hidden_size: int, q_dtype: Optional[torch.dtype],
|
||||
test_config: TestConfig) -> FusedMoEModularKernel:
|
||||
def make_ll_modular_kernel(
|
||||
pg: ProcessGroup, pgi: ProcessGroupInfo, max_tokens_per_rank: int,
|
||||
dp_size: int, hidden_size: int, q_dtype: Optional[torch.dtype],
|
||||
test_config: TestConfig,
|
||||
quant_config: FusedMoEQuantConfig) -> FusedMoEModularKernel:
|
||||
|
||||
assert test_config.low_latency
|
||||
assert test_config.use_fp8_dispatch is not None
|
||||
@@ -154,17 +160,18 @@ def make_ll_modular_kernel(pg: ProcessGroup, pgi: ProcessGroupInfo,
|
||||
fused_experts = BatchedDeepGemmExperts(
|
||||
max_num_tokens=max_tokens_per_rank,
|
||||
num_dispatchers=pgi.world_size // dp_size,
|
||||
block_shape=test_config.block_size,
|
||||
per_act_token_quant=test_config.per_act_token_quant)
|
||||
quant_config=quant_config,
|
||||
)
|
||||
mk = FusedMoEModularKernel(prepare_finalize=a2a,
|
||||
fused_experts=fused_experts)
|
||||
return mk
|
||||
|
||||
|
||||
def make_ht_modular_kernel(pg: ProcessGroup, pgi: ProcessGroupInfo,
|
||||
dp_size: int, num_local_experts: int,
|
||||
q_dtype: Optional[torch.dtype],
|
||||
test_config: TestConfig) -> FusedMoEModularKernel:
|
||||
def make_ht_modular_kernel(
|
||||
pg: ProcessGroup, pgi: ProcessGroupInfo, dp_size: int,
|
||||
num_local_experts: int, q_dtype: Optional[torch.dtype],
|
||||
test_config: TestConfig,
|
||||
quant_config: FusedMoEQuantConfig) -> FusedMoEModularKernel:
|
||||
|
||||
assert not test_config.low_latency
|
||||
assert test_config.use_fp8_dispatch is None
|
||||
@@ -178,15 +185,16 @@ def make_ht_modular_kernel(pg: ProcessGroup, pgi: ProcessGroupInfo,
|
||||
q_dtype=q_dtype,
|
||||
block_shape=test_config.block_size)
|
||||
|
||||
fused_experts = DeepGemmExperts()
|
||||
fused_experts = DeepGemmExperts(quant_config)
|
||||
mk = FusedMoEModularKernel(prepare_finalize=a2a,
|
||||
fused_experts=fused_experts)
|
||||
return mk
|
||||
|
||||
|
||||
def make_modular_kernel(pg: ProcessGroup, pgi: ProcessGroupInfo, dp_size: int,
|
||||
num_local_experts: int,
|
||||
test_tensors: TestTensors) -> FusedMoEModularKernel:
|
||||
def make_modular_kernel(
|
||||
pg: ProcessGroup, pgi: ProcessGroupInfo, dp_size: int,
|
||||
num_local_experts: int, test_tensors: TestTensors,
|
||||
quant_config: FusedMoEQuantConfig) -> FusedMoEModularKernel:
|
||||
|
||||
q_dtype = torch.float8_e4m3fn
|
||||
test_config = test_tensors.config
|
||||
@@ -204,10 +212,16 @@ def make_modular_kernel(pg: ProcessGroup, pgi: ProcessGroupInfo, dp_size: int,
|
||||
dp_size=dp_size,
|
||||
hidden_size=hidden_size,
|
||||
q_dtype=q_dtype,
|
||||
test_config=test_config)
|
||||
test_config=test_config,
|
||||
quant_config=quant_config)
|
||||
else:
|
||||
mk = make_ht_modular_kernel(pg, pgi, dp_size, num_local_experts,
|
||||
q_dtype, test_config)
|
||||
mk = make_ht_modular_kernel(pg,
|
||||
pgi,
|
||||
dp_size,
|
||||
num_local_experts,
|
||||
q_dtype,
|
||||
test_config,
|
||||
quant_config=quant_config)
|
||||
|
||||
return mk
|
||||
|
||||
@@ -233,17 +247,23 @@ def deepep_deepgemm_moe_impl(pg: ProcessGroup, pgi: ProcessGroupInfo,
|
||||
return expert_map.to(device=torch.cuda.current_device(),
|
||||
dtype=torch.int32)
|
||||
|
||||
quant_config = fp8_w8a8_moe_quant_config(
|
||||
w1_scale=w1_scale,
|
||||
w2_scale=w2_scale,
|
||||
# Low-Latency kernels can't dispatch scales.
|
||||
a1_scale=(None if test_config.low_latency else
|
||||
test_tensors.rank_token_scales),
|
||||
block_shape=test_config.block_size,
|
||||
)
|
||||
|
||||
# Make modular kernel
|
||||
mk: FusedMoEModularKernel = make_modular_kernel(
|
||||
pg=pg,
|
||||
pgi=pgi,
|
||||
dp_size=dp_size,
|
||||
num_local_experts=num_local_experts,
|
||||
test_tensors=test_tensors)
|
||||
|
||||
# Low-Latency kernels can't dispatch scales.
|
||||
a1_scale = (None
|
||||
if test_config.low_latency else test_tensors.rank_token_scales)
|
||||
test_tensors=test_tensors,
|
||||
quant_config=quant_config)
|
||||
|
||||
out = mk.forward(hidden_states=test_tensors.rank_tokens,
|
||||
w1=w1,
|
||||
@@ -254,12 +274,6 @@ def deepep_deepgemm_moe_impl(pg: ProcessGroup, pgi: ProcessGroupInfo,
|
||||
activation="silu",
|
||||
global_num_experts=num_experts,
|
||||
expert_map=build_expert_map(),
|
||||
w1_scale=w1_scale,
|
||||
w2_scale=w2_scale,
|
||||
w1_zp=None,
|
||||
w2_zp=None,
|
||||
a1_scale=a1_scale,
|
||||
a2_scale=None,
|
||||
apply_router_weight_on_input=False)
|
||||
return out
|
||||
|
||||
@@ -269,6 +283,13 @@ def triton_impl(a: torch.Tensor, topk_ids: torch.Tensor,
|
||||
w1_scale: torch.Tensor, w2_scale: torch.Tensor,
|
||||
a1_scale: torch.Tensor, block_shape: list[int]):
|
||||
|
||||
quant_config = fp8_w8a8_moe_quant_config(
|
||||
w1_scale=w1_scale,
|
||||
w2_scale=w2_scale,
|
||||
a1_scale=a1_scale,
|
||||
block_shape=block_shape,
|
||||
)
|
||||
|
||||
return fused_experts(
|
||||
hidden_states=a,
|
||||
w1=w1,
|
||||
@@ -276,11 +297,7 @@ def triton_impl(a: torch.Tensor, topk_ids: torch.Tensor,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
inplace=False,
|
||||
use_fp8_w8a8=True,
|
||||
w1_scale=w1_scale,
|
||||
w2_scale=w2_scale,
|
||||
a1_scale=a1_scale,
|
||||
block_shape=block_shape,
|
||||
quant_config=quant_config,
|
||||
# Make sure this is set to False so we
|
||||
# don't end up comparing the same implementation.
|
||||
allow_deep_gemm=False)
|
||||
|
||||
Reference in New Issue
Block a user