[NVIDIA] Add SM100 Flashinfer Cutlass MoE fp8 backend (#22357)

Signed-off-by: Amir Klein <203507526+amirkl94@users.noreply.github.com>
This commit is contained in:
amirkl94
2025-08-20 01:01:53 +03:00
committed by GitHub
parent 21dce80ea9
commit a38b8af4c3
6 changed files with 613 additions and 139 deletions

View File

@@ -9,6 +9,7 @@ from torch.nn import Module
from torch.nn.parameter import Parameter
import vllm.envs as envs
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm import _custom_ops as ops
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.logger import init_logger
@@ -23,8 +24,11 @@ from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig, QuantizeMethodBase)
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
apply_flashinfer_per_tensor_scale_fp8, register_moe_scaling_factors,
rotate_flashinfer_fp8_moe_weights, swap_w13_to_w31)
FlashinferMoeBackend, apply_flashinfer_per_tensor_scale_fp8,
build_flashinfer_fp8_cutlass_moe_prepare_finalize,
flashinfer_cutlass_moe_fp8, get_flashinfer_moe_backend,
register_moe_scaling_factors, rotate_flashinfer_fp8_moe_weights,
select_cutlass_fp8_gemm_impl, swap_w13_to_w31)
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
get_col_major_tma_aligned_tensor, requant_weight_ue8m0_inplace)
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
@@ -145,7 +149,7 @@ class Fp8Config(QuantizationConfig):
return UnquantizedLinearMethod()
return Fp8LinearMethod(self)
elif isinstance(layer, FusedMoE):
return Fp8MoEMethod(self, layer.moe_config)
return Fp8MoEMethod(self, layer)
elif isinstance(layer, Attention):
return Fp8KVCacheMethod(self)
return None
@@ -482,16 +486,20 @@ class Fp8MoEMethod(FusedMoEMethodBase):
quant_config: The quantization config.
"""
def __init__(self, quant_config: Fp8Config, moe: FusedMoEConfig):
super().__init__(moe)
def __init__(self, quant_config: Fp8Config, layer: torch.nn.Module):
super().__init__(layer.moe_config)
self.layer = layer
self.quant_config = quant_config
self.block_quant = self.quant_config.weight_block_size is not None
self.flashinfer_moe_enabled = False
self.flashinfer_moe_backend: Optional[FlashinferMoeBackend] = None
self.fused_experts: Optional[
mk.FusedMoEModularKernel] = None # type: ignore
if envs.VLLM_USE_FLASHINFER_MOE_FP8 and has_flashinfer_moe():
self.flashinfer_moe_backend = get_flashinfer_moe_backend()
logger.info_once(
"Using FlashInfer MoE FP8 kernels for Fp8MoEMethod.")
self.flashinfer_moe_enabled = True
f"Using FlashInfer {self.flashinfer_moe_backend.value} kernels"
)
# For GPUs that lack FP8 hardware support, we can leverage the Marlin
# kernel for fast weight-only FP8 quantization
self.use_marlin = (not current_platform.has_device_capability(89)
@@ -531,6 +539,20 @@ class Fp8MoEMethod(FusedMoEMethodBase):
"CutlassBlockScaledGroupedGemm not supported on the current "
"platform.")
def maybe_make_prepare_finalize(
self,
moe: FusedMoEConfig,
) -> Optional[mk.FusedMoEPrepareAndFinalize]:
if self.flashinfer_moe_backend != FlashinferMoeBackend.CUTLASS:
return super().maybe_make_prepare_finalize(moe)
prepare_finalize = build_flashinfer_fp8_cutlass_moe_prepare_finalize(
moe,
layer=self.layer,
)
logger.debug_once("%s", prepare_finalize.__class__.__name__)
return prepare_finalize
def create_weights(self, layer: Module, num_experts: int, hidden_size: int,
intermediate_size_per_partition: int,
params_dtype: torch.dtype, **extra_weight_attrs):
@@ -678,7 +700,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
normalize_e4m3fn_to_e4m3fnuz(
layer.w2_weight, layer.w2_weight_scale_inv,
layer.w2_input_scale)
elif self.flashinfer_moe_enabled:
elif self.flashinfer_moe_backend is not None:
# NOTE: weights have to be swapped since the activation is
# applied on different half for flashinfer vs vllm
w13_weight = swap_w13_to_w31(layer.w13_weight.data)
@@ -686,9 +708,6 @@ class Fp8MoEMethod(FusedMoEMethodBase):
layer.w13_weight_scale_inv.data)
w2_weight = layer.w2_weight.data
w2_weight_scale_inv = layer.w2_weight_scale_inv.data
if not self.block_quant:
register_moe_scaling_factors(layer)
rotate_flashinfer_fp8_moe_weights(w13_weight, w2_weight)
else:
w13_weight = layer.w13_weight.data
w13_weight_scale_inv = layer.w13_weight_scale_inv.data
@@ -834,6 +853,17 @@ class Fp8MoEMethod(FusedMoEMethodBase):
layer.w13_weight_scale = torch.nn.Parameter(max_w13_scales,
requires_grad=False)
if self.flashinfer_moe_backend is not None:
# NOTE: weights have to be swapped since the activation is
# applied on different half for flashinfer vs vllm
assert not self.block_quant
register_moe_scaling_factors(layer)
w13_weight = swap_w13_to_w31(layer.w13_weight.data)
if self.flashinfer_moe_backend == \
FlashinferMoeBackend.TENSORRT_LLM:
rotate_flashinfer_fp8_moe_weights(w13_weight, w2_weight)
layer.w13_weight.data = w13_weight.data
if self.use_marlin:
prepare_moe_fp8_layer_for_marlin(layer, False)
# Activations not quantized for marlin.
@@ -892,6 +922,13 @@ class Fp8MoEMethod(FusedMoEMethodBase):
per_act_token_quant=False,
allow_deep_gemm=self.allow_deep_gemm,
)
elif self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS:
experts = select_cutlass_fp8_gemm_impl(
moe,
self.layer,
)
logger.debug_once("Using %s", experts.__class__.__name__)
return experts
else:
logger.debug(
"TritonOrDeepGemmExperts(%s): block_size=%s, per_act_token=%s",
@@ -930,25 +967,66 @@ class Fp8MoEMethod(FusedMoEMethodBase):
assert logical_to_physical_map is not None
assert logical_replica_count is not None
assert isinstance(layer, FusedMoE)
if not self.flashinfer_moe_enabled:
topk_weights, topk_ids = FusedMoE.select_experts(
hidden_states=x,
router_logits=router_logits,
use_grouped_topk=use_grouped_topk,
top_k=top_k,
renormalize=renormalize,
topk_group=topk_group,
num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function,
scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias,
indices_type=self.topk_indices_dtype,
enable_eplb=enable_eplb,
expert_map=expert_map,
expert_load_view=expert_load_view,
logical_to_physical_map=logical_to_physical_map,
logical_replica_count=logical_replica_count,
)
if self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM:
assert activation == 'silu', (
f"Expected 'silu' activation but got {activation}")
assert scoring_func == 'sigmoid', (
f"Expected 'sigmoid' scoring func but got {scoring_func}")
if self.block_quant:
assert (renormalize and use_grouped_topk
and custom_routing_function is None)
return torch.ops.vllm.flashinfer_fused_moe_blockscale_fp8(
routing_logits=router_logits.to(torch.float32),
routing_bias=e_score_correction_bias,
x=x,
w13_weight=layer.w13_weight,
w13_weight_scale_inv=layer.w13_weight_scale_inv,
w2_weight=layer.w2_weight,
w2_weight_scale_inv=layer.w2_weight_scale_inv,
global_num_experts=global_num_experts,
top_k=top_k,
num_expert_group=num_expert_group,
topk_group=topk_group,
intermediate_size=layer.intermediate_size_per_partition,
expert_offset=layer.ep_rank * layer.local_num_experts,
local_num_experts=layer.local_num_experts,
block_shape=self.quant_config.weight_block_size,
routed_scaling=1.0,
)
else:
assert (not renormalize
and custom_routing_function is not None)
return apply_flashinfer_per_tensor_scale_fp8(
layer=layer,
hidden_states=x,
router_logits=router_logits,
routing_bias=e_score_correction_bias,
global_num_experts=global_num_experts,
top_k=top_k,
num_expert_group=num_expert_group,
topk_group=topk_group,
apply_router_weight_on_input=apply_router_weight_on_input)
topk_weights, topk_ids = FusedMoE.select_experts(
hidden_states=x,
router_logits=router_logits,
use_grouped_topk=use_grouped_topk,
top_k=top_k,
renormalize=renormalize,
topk_group=topk_group,
num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function,
scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias,
indices_type=self.topk_indices_dtype,
enable_eplb=enable_eplb,
expert_map=expert_map,
expert_load_view=expert_load_view,
logical_to_physical_map=logical_to_physical_map,
logical_replica_count=logical_replica_count,
)
if self.rocm_aiter_moe_enabled:
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( # noqa: E501
@@ -988,63 +1066,38 @@ class Fp8MoEMethod(FusedMoEMethodBase):
apply_router_weight_on_input=apply_router_weight_on_input,
global_num_experts=global_num_experts,
expert_map=expert_map)
elif self.flashinfer_moe_enabled:
assert activation == 'silu'
assert scoring_func == 'sigmoid'
if self.block_quant:
assert (renormalize and use_grouped_topk
and custom_routing_function is None)
return torch.ops.vllm.flashinfer_fused_moe_blockscale_fp8(
routing_logits=router_logits.to(torch.float32),
routing_bias=e_score_correction_bias,
x=x,
w13_weight=layer.w13_weight,
w13_weight_scale_inv=layer.w13_weight_scale_inv,
w2_weight=layer.w2_weight,
w2_weight_scale_inv=layer.w2_weight_scale_inv,
elif self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS:
assert self.block_quant is None
assert (not renormalize and custom_routing_function is not None)
assert activation == 'silu', (
f"Expected 'silu' activation but got {activation}")
assert scoring_func == 'sigmoid', (
f"Expected 'sigmoid' scoring func but got {scoring_func}")
if self.fused_experts is not None:
return self.fused_experts(
x,
layer.w13_weight,
layer.w2_weight,
topk_weights,
topk_ids,
inplace=False,
activation=activation,
global_num_experts=global_num_experts,
top_k=top_k,
num_expert_group=num_expert_group,
topk_group=topk_group,
intermediate_size=layer.intermediate_size_per_partition,
expert_offset=layer.ep_rank * layer.local_num_experts,
local_num_experts=layer.local_num_experts,
block_shape=self.quant_config.weight_block_size,
routed_scaling=1.0,
expert_map=expert_map,
apply_router_weight_on_input=apply_router_weight_on_input,
)
else:
assert (not renormalize
and custom_routing_function is not None)
return apply_flashinfer_per_tensor_scale_fp8(
layer=layer,
hidden_states=x,
router_logits=router_logits,
routing_bias=e_score_correction_bias,
return flashinfer_cutlass_moe_fp8(
x,
layer,
topk_weights,
topk_ids,
inplace=False,
activation=activation,
global_num_experts=global_num_experts,
top_k=top_k,
num_expert_group=num_expert_group,
topk_group=topk_group,
apply_router_weight_on_input=apply_router_weight_on_input)
elif self.fused_experts is not None:
return self.fused_experts(
hidden_states=x,
w1=layer.w13_weight,
w2=layer.w2_weight,
topk_weights=topk_weights,
topk_ids=topk_ids,
inplace=True,
activation=activation,
global_num_experts=global_num_experts,
apply_router_weight_on_input=apply_router_weight_on_input,
expert_map=expert_map,
w1_scale=(layer.w13_weight_scale_inv
if self.block_quant else layer.w13_weight_scale),
w2_scale=(layer.w2_weight_scale_inv
if self.block_quant else layer.w2_weight_scale),
a1_scale=layer.w13_input_scale,
a2_scale=layer.w2_input_scale,
)
expert_map=expert_map,
apply_router_weight_on_input=apply_router_weight_on_input,
)
else:
from vllm.model_executor.layers.fused_moe import fused_experts
return fused_experts(

View File

@@ -1,7 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from enum import Enum
from typing import Any, Callable, Optional, Union
import torch
@@ -27,8 +26,11 @@ from vllm.model_executor.layers.quantization.utils.flashinfer_fp4_moe import (
build_flashinfer_fp4_cutlass_moe_prepare_finalize, reorder_w1w3_to_w3w1,
select_nvfp4_gemm_impl)
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
apply_flashinfer_per_tensor_scale_fp8, register_moe_scaling_factors,
rotate_flashinfer_fp8_moe_weights, swap_w13_to_w31)
FlashinferMoeBackend, apply_flashinfer_per_tensor_scale_fp8,
build_flashinfer_fp8_cutlass_moe_prepare_finalize,
flashinfer_cutlass_moe_fp8, get_flashinfer_moe_backend,
register_moe_scaling_factors, rotate_flashinfer_fp8_moe_weights,
select_cutlass_fp8_gemm_impl, swap_w13_to_w31)
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import (
apply_fp4_marlin_linear, is_fp4_marlin_supported,
prepare_fp4_layer_for_marlin, prepare_moe_fp4_layer_for_marlin)
@@ -49,11 +51,6 @@ QUANT_ALGOS = ["FP8", "NVFP4"]
KV_CACHE_QUANT_ALGOS = ["FP8"]
class FlashinferMoeBackend(Enum):
TENSORRT_LLM = "TensorRT-LLM"
CUTLASS = "CUTLASS"
class ModelOptFp8Config(QuantizationConfig):
"""Config class for ModelOpt FP8."""
@@ -179,7 +176,7 @@ class ModelOptFp8Config(QuantizationConfig):
elif isinstance(layer, Attention):
return ModelOptFp8KVCacheMethod(self)
elif isinstance(layer, FusedMoE):
return ModelOptFp8MoEMethod(self, layer.moe_config)
return ModelOptFp8MoEMethod(self, layer)
return None
@@ -278,18 +275,49 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
def __init__(
self,
quant_config: ModelOptFp8Config,
moe: FusedMoEConfig,
layer: torch.nn.Module,
) -> None:
super().__init__(moe)
super().__init__(layer.moe_config)
self.layer = layer
self.quant_config = quant_config
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
cutlass_fp8_supported)
self.cutlass_fp8_supported = cutlass_fp8_supported()
self.flashinfer_moe_enabled = False
self.flashinfer_moe_backend: Optional[FlashinferMoeBackend] = None
self.fused_experts: Optional[
mk.FusedMoEModularKernel] = None # type: ignore
if envs.VLLM_USE_FLASHINFER_MOE_FP8 and has_flashinfer_moe():
self.flashinfer_moe_backend = get_flashinfer_moe_backend()
logger.info_once(
"Using FlashInfer MoE FP8 kernels for ModelOptFp8MoEMethod.")
self.flashinfer_moe_enabled = True
f"Using FlashInfer {self.flashinfer_moe_backend.value} kernels"
)
def maybe_make_prepare_finalize(
self,
moe: FusedMoEConfig,
) -> Optional[mk.FusedMoEPrepareAndFinalize]:
if self.fused_experts is not None or \
self.flashinfer_moe_backend != FlashinferMoeBackend.CUTLASS:
return super().maybe_make_prepare_finalize(moe)
prepare_finalize = build_flashinfer_fp8_cutlass_moe_prepare_finalize(
moe,
layer=self.layer,
)
logger.debug_once("%s", prepare_finalize.__class__.__name__)
return prepare_finalize
def select_gemm_impl(
self,
prepare_finalize: mk.FusedMoEPrepareAndFinalize,
moe: FusedMoEConfig,
) -> mk.FusedMoEPermuteExpertsUnpermute:
experts = select_cutlass_fp8_gemm_impl(
moe,
self.layer,
)
logger.debug_once("Using %s", experts.__class__.__name__)
return experts
def create_weights(
self,
@@ -433,11 +461,12 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
layer.w2_input_scale = Parameter(layer.w2_input_scale.max(),
requires_grad=False)
if self.flashinfer_moe_enabled:
if self.flashinfer_moe_backend is not None:
layer.w13_weight.data = swap_w13_to_w31(layer.w13_weight.data)
rotate_flashinfer_fp8_moe_weights(layer.w13_weight,
layer.w2_weight)
register_moe_scaling_factors(layer)
if self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM:
rotate_flashinfer_fp8_moe_weights(layer.w13_weight,
layer.w2_weight)
def apply(
self,
@@ -461,14 +490,13 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
logical_to_physical_map: Optional[torch.Tensor] = None,
logical_replica_count: Optional[torch.Tensor] = None,
) -> torch.Tensor:
assert self.fused_experts is None
if enable_eplb:
raise NotImplementedError(
"EPLB not supported for `ModelOptFp8MoEMethod` yet.")
if self.flashinfer_moe_enabled:
assert activation == 'silu'
if self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM:
assert activation == 'silu', (
f"Expected 'silu' activation but got {activation}")
assert not renormalize
return apply_flashinfer_per_tensor_scale_fp8(
layer=layer,
@@ -495,6 +523,36 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
e_score_correction_bias=e_score_correction_bias,
indices_type=self.topk_indices_dtype,
)
if self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS:
assert not renormalize
assert activation == 'silu', (
f"Expected 'silu' activation but got {activation}")
if self.fused_experts is not None:
return self.fused_experts(
x,
layer.w13_weight,
layer.w2_weight,
topk_weights,
topk_ids,
inplace=False,
activation=activation,
global_num_experts=global_num_experts,
expert_map=expert_map,
apply_router_weight_on_input=apply_router_weight_on_input,
)
else:
return flashinfer_cutlass_moe_fp8(
x,
layer,
topk_weights,
topk_ids,
inplace=False,
activation=activation,
global_num_experts=global_num_experts,
expert_map=expert_map,
apply_router_weight_on_input=apply_router_weight_on_input,
)
from vllm.model_executor.layers.fused_moe.fused_moe import (
fused_experts)
return fused_experts(
@@ -951,20 +1009,10 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
self.flashinfer_moe_backend = None
if self.allow_flashinfer:
flashinfer_moe_backend = envs.VLLM_FLASHINFER_MOE_BACKEND
if flashinfer_moe_backend == "throughput":
self.flashinfer_moe_backend = FlashinferMoeBackend.CUTLASS
logger.info_once("Using FlashInfer CUTLASS kernels for "
"ModelOptNvFp4FusedMoE.")
elif flashinfer_moe_backend == "latency":
self.flashinfer_moe_backend = FlashinferMoeBackend.TENSORRT_LLM
logger.info_once("Using FlashInfer TensorRT-LLM kernels for "
"ModelOptNvFp4FusedMoE.")
else:
allowed_backends = ["throughput", "latency"]
raise ValueError(
f"Unknown flashinfer moe backend: {flashinfer_moe_backend}"
f" expected one of {allowed_backends}")
self.flashinfer_moe_backend = get_flashinfer_moe_backend()
logger.info_once(
f"Using FlashInfer {self.flashinfer_moe_backend.value} kernels"
" for ModelOptNvFp4FusedMoE.")
def maybe_make_prepare_finalize(
self,

View File

@@ -1,9 +1,26 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from enum import Enum
from typing import Optional
import torch
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm import envs
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe.config import FusedMoEConfig
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import (
FlashInferExperts)
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_prepare_finalize import ( # noqa: E501
FlashInferCutlassMoEPrepareAndFinalize)
logger = init_logger(__name__)
class FlashinferMoeBackend(Enum):
TENSORRT_LLM = "TensorRT-LLM"
CUTLASS = "CUTLASS"
def calculate_tile_tokens_dim(num_tokens, top_k, num_experts):
@@ -144,3 +161,98 @@ def register_moe_scaling_factors(layer: torch.nn.Module) -> None:
layer.register_parameter(
'output2_scales_scalar',
torch.nn.Parameter(output2_scales, requires_grad=False))
layer.register_parameter(
'w2_input_scale_inv',
torch.nn.Parameter(1.0 / layer.w2_input_scale, requires_grad=False))
def build_flashinfer_fp8_cutlass_moe_prepare_finalize(
moe: Optional[FusedMoEConfig],
layer: torch.nn.Module,
) -> mk.FusedMoEPrepareAndFinalize:
"""Create a FlashInfer CUTLASS fused-MoE prepare finalize kernel"""
use_dp = moe.moe_parallel_config.dp_size > 1 if moe is not None else False
return FlashInferCutlassMoEPrepareAndFinalize(
use_dp, a1_gscale=layer.w13_input_scale)
def select_cutlass_fp8_gemm_impl(
moe: Optional[FusedMoEConfig],
layer: torch.nn.Module,
out_dtype: Optional[torch.dtype] = None,
) -> mk.FusedMoEPermuteExpertsUnpermute:
"""Return a GEMM *experts* implementation for fused-MoE layers"""
from vllm.model_executor.models.llama4 import Llama4MoE
assert layer.custom_routing_function == Llama4MoE.custom_routing_function, \
"FusedMoE flashinfer kernels are only supported for Llama4"
if moe is not None:
return FlashInferExperts(
g1_alphas=layer.output1_scales_gate_scalar,
g2_alphas=layer.output2_scales_scalar,
a1_gscale=layer.w13_input_scale,
a2_gscale=layer.w2_input_scale_inv,
out_dtype=moe.in_dtype,
quant_dtype=torch.float8_e4m3fn,
ep_rank=moe.moe_parallel_config.ep_rank,
ep_size=moe.moe_parallel_config.ep_size,
tp_rank=moe.moe_parallel_config.tp_rank,
tp_size=moe.moe_parallel_config.tp_size,
)
assert out_dtype is not None, (
"If moe config is None, out_dtype must be passed")
return FlashInferExperts(
g1_alphas=layer.output1_scales_gate_scalar,
g2_alphas=layer.output2_scales_scalar,
a1_gscale=layer.w13_input_scale,
a2_gscale=layer.w2_input_scale_inv,
out_dtype=out_dtype,
quant_dtype=torch.float8_e4m3fn,
)
def flashinfer_cutlass_moe_fp8(
hidden_states: torch.Tensor,
layer: torch.nn.Module,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
inplace: bool = False,
activation: str = "silu",
global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False,
) -> torch.Tensor:
fused_experts = mk.FusedMoEModularKernel(
build_flashinfer_fp8_cutlass_moe_prepare_finalize(moe=None,
layer=layer),
select_cutlass_fp8_gemm_impl(moe=None,
layer=layer,
out_dtype=hidden_states.dtype))
return fused_experts(
hidden_states,
layer.w13_weight,
layer.w2_weight,
topk_weights,
topk_ids,
inplace=inplace,
activation=activation,
global_num_experts=global_num_experts,
expert_map=expert_map,
apply_router_weight_on_input=apply_router_weight_on_input,
)
def get_flashinfer_moe_backend() -> FlashinferMoeBackend:
flashinfer_moe_backend = envs.VLLM_FLASHINFER_MOE_BACKEND
if flashinfer_moe_backend == "throughput":
return FlashinferMoeBackend.CUTLASS
elif flashinfer_moe_backend == "latency":
return FlashinferMoeBackend.TENSORRT_LLM
allowed_backends = ["throughput", "latency"]
raise ValueError(
f"Unknown flashinfer moe backend: {flashinfer_moe_backend}"
f" expected one of {allowed_backends}")