Modularize fused experts and integrate PPLX kernels (#15956)

This commit is contained in:
bnellnm
2025-05-14 16:11:54 -04:00
committed by GitHub
parent 418d2f8bfb
commit f9c069c85e
42 changed files with 3830 additions and 660 deletions

View File

@@ -1,5 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
import functools
import importlib.util
from typing import Any, Callable, Optional
@@ -9,6 +10,7 @@ from torch.nn import Module
from torch.nn.parameter import Parameter
import vllm.envs as envs
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm import _custom_ops as ops
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.logger import init_logger
@@ -434,6 +436,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
"""
def __init__(self, quant_config: Fp8Config):
from vllm.model_executor.layers.fused_moe import fused_experts
self.quant_config = quant_config
self.block_quant = self.quant_config.weight_block_size is not None
@@ -458,6 +461,11 @@ class Fp8MoEMethod(FusedMoEMethodBase):
logger.warning_once(
"DeepGemm not supported on the current platform.")
self.fused_experts = functools.partial(
fused_experts,
block_shape=self.quant_config.weight_block_size,
allow_deep_gemm=self.allow_deep_gemm)
def create_weights(self, layer: Module, num_experts: int, hidden_size: int,
intermediate_size_per_partition: int,
params_dtype: torch.dtype, **extra_weight_attrs):
@@ -783,6 +791,31 @@ class Fp8MoEMethod(FusedMoEMethodBase):
del layer.w13_input_scale
del layer.w2_input_scale
def set_prepare_finalize(
self,
dp_size: int,
world_size: int,
prepare_finalize: mk.FusedMoEPrepareAndFinalize,
) -> bool:
from vllm.model_executor.layers.fused_moe.triton_deep_gemm_moe import (
TritonOrDeepGemmExperts)
if self.use_marlin or self.rocm_aiter_moe_enabled:
return False
experts = TritonOrDeepGemmExperts(
use_fp8_w8a8=True,
block_shape=self.quant_config.weight_block_size,
allow_deep_gemm=self.allow_deep_gemm,
)
self.fused_experts = mk.FusedMoEModularKernel(
prepare_finalize,
experts,
)
return True
def apply(
self,
layer: torch.nn.Module,
@@ -801,10 +834,6 @@ class Fp8MoEMethod(FusedMoEMethodBase):
apply_router_weight_on_input: bool = False,
activation: str = "silu",
) -> torch.Tensor:
from vllm.model_executor.layers.fused_moe import fused_experts
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
rocm_aiter_fused_experts)
topk_weights, topk_ids = FusedMoE.select_experts(
hidden_states=x,
router_logits=router_logits,
@@ -819,6 +848,8 @@ class Fp8MoEMethod(FusedMoEMethodBase):
)
if self.rocm_aiter_moe_enabled:
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( # noqa: E501
rocm_aiter_fused_experts)
return rocm_aiter_fused_experts(
x,
layer.w13_weight,
@@ -835,8 +866,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
a1_scale=layer.w13_input_scale,
a2_scale=layer.w2_input_scale,
block_shape=self.quant_config.weight_block_size)
if self.use_marlin:
elif self.use_marlin:
assert activation == "silu", (
f"{activation} not supported for Marlin MoE.")
assert not apply_router_weight_on_input, (
@@ -853,28 +883,26 @@ class Fp8MoEMethod(FusedMoEMethodBase):
quant_type_id=scalar_types.float8_e4m3fn.id,
global_num_experts=global_num_experts,
expert_map=expert_map)
return fused_experts(
x,
layer.w13_weight,
layer.w2_weight,
topk_weights=topk_weights,
topk_ids=topk_ids,
inplace=True,
activation=activation,
use_fp8_w8a8=True,
global_num_experts=global_num_experts,
apply_router_weight_on_input=apply_router_weight_on_input,
expert_map=expert_map,
w1_scale=(layer.w13_weight_scale_inv
if self.block_quant else layer.w13_weight_scale),
w2_scale=(layer.w2_weight_scale_inv
if self.block_quant else layer.w2_weight_scale),
a1_scale=layer.w13_input_scale,
a2_scale=layer.w2_input_scale,
block_shape=self.quant_config.weight_block_size,
allow_deep_gemm=self.allow_deep_gemm,
)
else:
return self.fused_experts(
hidden_states=x,
w1=layer.w13_weight,
w2=layer.w2_weight,
topk_weights=topk_weights,
topk_ids=topk_ids,
inplace=True,
activation=activation,
use_fp8_w8a8=True,
global_num_experts=global_num_experts,
apply_router_weight_on_input=apply_router_weight_on_input,
expert_map=expert_map,
w1_scale=(layer.w13_weight_scale_inv
if self.block_quant else layer.w13_weight_scale),
w2_scale=(layer.w2_weight_scale_inv
if self.block_quant else layer.w2_weight_scale),
a1_scale=layer.w13_input_scale,
a2_scale=layer.w2_input_scale,
)
class Fp8KVCacheMethod(BaseKVCacheMethod):