[Kernel] Delegate construction of FusedMoEQuantConfig to FusedMoEMethodBase subclasses (#22537)
Signed-off-by: Bill Nell <bnell@redhat.com>
This commit is contained in:
@@ -8,7 +8,8 @@ import torch
|
||||
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
||||
from vllm import envs
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.fused_moe.config import FusedMoEConfig
|
||||
from vllm.model_executor.layers.fused_moe.config import (FusedMoEConfig,
|
||||
FusedMoEQuantConfig)
|
||||
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import (
|
||||
FlashInferExperts)
|
||||
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_prepare_finalize import ( # noqa: E501
|
||||
@@ -99,6 +100,8 @@ def apply_flashinfer_per_tensor_scale_fp8(
|
||||
apply_router_weight_on_input: bool,
|
||||
) -> torch.Tensor:
|
||||
from flashinfer.fused_moe import RoutingMethodType
|
||||
|
||||
import vllm.model_executor.layers.fused_moe.flashinfer_trtllm_moe # noqa: E501, F401
|
||||
assert layer.output1_scales_scalar is not None, (
|
||||
"Expected output1_scales_scalar to be initialized")
|
||||
assert layer.output1_scales_scalar is not None, (
|
||||
@@ -167,34 +170,23 @@ def register_moe_scaling_factors(layer: torch.nn.Module) -> None:
|
||||
|
||||
|
||||
def build_flashinfer_fp8_cutlass_moe_prepare_finalize(
|
||||
moe: Optional[FusedMoEConfig],
|
||||
layer: torch.nn.Module,
|
||||
) -> mk.FusedMoEPrepareAndFinalize:
|
||||
moe: Optional[FusedMoEConfig], ) -> mk.FusedMoEPrepareAndFinalize:
|
||||
"""Create a FlashInfer CUTLASS fused-MoE prepare finalize kernel"""
|
||||
use_dp = moe.moe_parallel_config.dp_size > 1 if moe is not None else False
|
||||
return FlashInferCutlassMoEPrepareAndFinalize(
|
||||
use_dp, a1_gscale=layer.w13_input_scale)
|
||||
return FlashInferCutlassMoEPrepareAndFinalize(use_dp)
|
||||
|
||||
|
||||
def select_cutlass_fp8_gemm_impl(
|
||||
moe: Optional[FusedMoEConfig],
|
||||
layer: torch.nn.Module,
|
||||
quant_config: FusedMoEQuantConfig,
|
||||
out_dtype: Optional[torch.dtype] = None,
|
||||
) -> mk.FusedMoEPermuteExpertsUnpermute:
|
||||
"""Return a GEMM *experts* implementation for fused-MoE layers"""
|
||||
|
||||
from vllm.model_executor.models.llama4 import Llama4MoE
|
||||
assert layer.custom_routing_function == Llama4MoE.custom_routing_function, \
|
||||
"FusedMoE flashinfer kernels are only supported for Llama4"
|
||||
|
||||
if moe is not None:
|
||||
return FlashInferExperts(
|
||||
g1_alphas=layer.output1_scales_gate_scalar,
|
||||
g2_alphas=layer.output2_scales_scalar,
|
||||
a1_gscale=layer.w13_input_scale,
|
||||
a2_gscale=layer.w2_input_scale_inv,
|
||||
out_dtype=moe.in_dtype,
|
||||
quant_dtype=torch.float8_e4m3fn,
|
||||
quant_config=quant_config,
|
||||
ep_rank=moe.moe_parallel_config.ep_rank,
|
||||
ep_size=moe.moe_parallel_config.ep_size,
|
||||
tp_rank=moe.moe_parallel_config.tp_rank,
|
||||
@@ -204,12 +196,8 @@ def select_cutlass_fp8_gemm_impl(
|
||||
assert out_dtype is not None, (
|
||||
"If moe config is None, out_dtype must be passed")
|
||||
return FlashInferExperts(
|
||||
g1_alphas=layer.output1_scales_gate_scalar,
|
||||
g2_alphas=layer.output2_scales_scalar,
|
||||
a1_gscale=layer.w13_input_scale,
|
||||
a2_gscale=layer.w2_input_scale_inv,
|
||||
out_dtype=out_dtype,
|
||||
quant_dtype=torch.float8_e4m3fn,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
|
||||
|
||||
@@ -224,11 +212,13 @@ def flashinfer_cutlass_moe_fp8(
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
apply_router_weight_on_input: bool = False,
|
||||
) -> torch.Tensor:
|
||||
quant_config = layer.quant_method.get_fused_moe_quant_config(layer)
|
||||
assert quant_config is not None
|
||||
|
||||
fused_experts = mk.FusedMoEModularKernel(
|
||||
build_flashinfer_fp8_cutlass_moe_prepare_finalize(moe=None,
|
||||
layer=layer),
|
||||
build_flashinfer_fp8_cutlass_moe_prepare_finalize(moe=None),
|
||||
select_cutlass_fp8_gemm_impl(moe=None,
|
||||
layer=layer,
|
||||
quant_config=quant_config,
|
||||
out_dtype=hidden_states.dtype))
|
||||
|
||||
return fused_experts(
|
||||
|
||||
Reference in New Issue
Block a user