Modularize fused experts and integrate PPLX kernels (#15956)
This commit is contained in:
@@ -1,5 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import functools
|
||||
import importlib.util
|
||||
from typing import Any, Callable, Optional
|
||||
|
||||
@@ -9,6 +10,7 @@ from torch.nn import Module
|
||||
from torch.nn.parameter import Parameter
|
||||
|
||||
import vllm.envs as envs
|
||||
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||
from vllm.logger import init_logger
|
||||
@@ -434,6 +436,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
"""
|
||||
|
||||
def __init__(self, quant_config: Fp8Config):
|
||||
from vllm.model_executor.layers.fused_moe import fused_experts
|
||||
self.quant_config = quant_config
|
||||
self.block_quant = self.quant_config.weight_block_size is not None
|
||||
|
||||
@@ -458,6 +461,11 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
logger.warning_once(
|
||||
"DeepGemm not supported on the current platform.")
|
||||
|
||||
self.fused_experts = functools.partial(
|
||||
fused_experts,
|
||||
block_shape=self.quant_config.weight_block_size,
|
||||
allow_deep_gemm=self.allow_deep_gemm)
|
||||
|
||||
def create_weights(self, layer: Module, num_experts: int, hidden_size: int,
|
||||
intermediate_size_per_partition: int,
|
||||
params_dtype: torch.dtype, **extra_weight_attrs):
|
||||
@@ -783,6 +791,31 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
del layer.w13_input_scale
|
||||
del layer.w2_input_scale
|
||||
|
||||
def set_prepare_finalize(
|
||||
self,
|
||||
dp_size: int,
|
||||
world_size: int,
|
||||
prepare_finalize: mk.FusedMoEPrepareAndFinalize,
|
||||
) -> bool:
|
||||
from vllm.model_executor.layers.fused_moe.triton_deep_gemm_moe import (
|
||||
TritonOrDeepGemmExperts)
|
||||
|
||||
if self.use_marlin or self.rocm_aiter_moe_enabled:
|
||||
return False
|
||||
|
||||
experts = TritonOrDeepGemmExperts(
|
||||
use_fp8_w8a8=True,
|
||||
block_shape=self.quant_config.weight_block_size,
|
||||
allow_deep_gemm=self.allow_deep_gemm,
|
||||
)
|
||||
|
||||
self.fused_experts = mk.FusedMoEModularKernel(
|
||||
prepare_finalize,
|
||||
experts,
|
||||
)
|
||||
|
||||
return True
|
||||
|
||||
def apply(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
@@ -801,10 +834,6 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
apply_router_weight_on_input: bool = False,
|
||||
activation: str = "silu",
|
||||
) -> torch.Tensor:
|
||||
from vllm.model_executor.layers.fused_moe import fused_experts
|
||||
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
|
||||
rocm_aiter_fused_experts)
|
||||
|
||||
topk_weights, topk_ids = FusedMoE.select_experts(
|
||||
hidden_states=x,
|
||||
router_logits=router_logits,
|
||||
@@ -819,6 +848,8 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
)
|
||||
|
||||
if self.rocm_aiter_moe_enabled:
|
||||
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( # noqa: E501
|
||||
rocm_aiter_fused_experts)
|
||||
return rocm_aiter_fused_experts(
|
||||
x,
|
||||
layer.w13_weight,
|
||||
@@ -835,8 +866,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
a1_scale=layer.w13_input_scale,
|
||||
a2_scale=layer.w2_input_scale,
|
||||
block_shape=self.quant_config.weight_block_size)
|
||||
|
||||
if self.use_marlin:
|
||||
elif self.use_marlin:
|
||||
assert activation == "silu", (
|
||||
f"{activation} not supported for Marlin MoE.")
|
||||
assert not apply_router_weight_on_input, (
|
||||
@@ -853,28 +883,26 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
quant_type_id=scalar_types.float8_e4m3fn.id,
|
||||
global_num_experts=global_num_experts,
|
||||
expert_map=expert_map)
|
||||
|
||||
return fused_experts(
|
||||
x,
|
||||
layer.w13_weight,
|
||||
layer.w2_weight,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
inplace=True,
|
||||
activation=activation,
|
||||
use_fp8_w8a8=True,
|
||||
global_num_experts=global_num_experts,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
expert_map=expert_map,
|
||||
w1_scale=(layer.w13_weight_scale_inv
|
||||
if self.block_quant else layer.w13_weight_scale),
|
||||
w2_scale=(layer.w2_weight_scale_inv
|
||||
if self.block_quant else layer.w2_weight_scale),
|
||||
a1_scale=layer.w13_input_scale,
|
||||
a2_scale=layer.w2_input_scale,
|
||||
block_shape=self.quant_config.weight_block_size,
|
||||
allow_deep_gemm=self.allow_deep_gemm,
|
||||
)
|
||||
else:
|
||||
return self.fused_experts(
|
||||
hidden_states=x,
|
||||
w1=layer.w13_weight,
|
||||
w2=layer.w2_weight,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
inplace=True,
|
||||
activation=activation,
|
||||
use_fp8_w8a8=True,
|
||||
global_num_experts=global_num_experts,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
expert_map=expert_map,
|
||||
w1_scale=(layer.w13_weight_scale_inv
|
||||
if self.block_quant else layer.w13_weight_scale),
|
||||
w2_scale=(layer.w2_weight_scale_inv
|
||||
if self.block_quant else layer.w2_weight_scale),
|
||||
a1_scale=layer.w13_input_scale,
|
||||
a2_scale=layer.w2_input_scale,
|
||||
)
|
||||
|
||||
|
||||
class Fp8KVCacheMethod(BaseKVCacheMethod):
|
||||
|
||||
Reference in New Issue
Block a user