[NVIDIA] Add SM100 Flashinfer Cutlass MoE fp8 backend (#22357)

Signed-off-by: Amir Klein <203507526+amirkl94@users.noreply.github.com>
This commit is contained in:
amirkl94
2025-08-20 01:01:53 +03:00
committed by GitHub
parent 21dce80ea9
commit a38b8af4c3
6 changed files with 613 additions and 139 deletions

View File

@@ -1,9 +1,26 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from enum import Enum
from typing import Optional
import torch
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm import envs
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe.config import FusedMoEConfig
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import (
FlashInferExperts)
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_prepare_finalize import ( # noqa: E501
FlashInferCutlassMoEPrepareAndFinalize)
logger = init_logger(__name__)
class FlashinferMoeBackend(Enum):
TENSORRT_LLM = "TensorRT-LLM"
CUTLASS = "CUTLASS"
def calculate_tile_tokens_dim(num_tokens, top_k, num_experts):
@@ -144,3 +161,98 @@ def register_moe_scaling_factors(layer: torch.nn.Module) -> None:
layer.register_parameter(
'output2_scales_scalar',
torch.nn.Parameter(output2_scales, requires_grad=False))
layer.register_parameter(
'w2_input_scale_inv',
torch.nn.Parameter(1.0 / layer.w2_input_scale, requires_grad=False))
def build_flashinfer_fp8_cutlass_moe_prepare_finalize(
moe: Optional[FusedMoEConfig],
layer: torch.nn.Module,
) -> mk.FusedMoEPrepareAndFinalize:
"""Create a FlashInfer CUTLASS fused-MoE prepare finalize kernel"""
use_dp = moe.moe_parallel_config.dp_size > 1 if moe is not None else False
return FlashInferCutlassMoEPrepareAndFinalize(
use_dp, a1_gscale=layer.w13_input_scale)
def select_cutlass_fp8_gemm_impl(
moe: Optional[FusedMoEConfig],
layer: torch.nn.Module,
out_dtype: Optional[torch.dtype] = None,
) -> mk.FusedMoEPermuteExpertsUnpermute:
"""Return a GEMM *experts* implementation for fused-MoE layers"""
from vllm.model_executor.models.llama4 import Llama4MoE
assert layer.custom_routing_function == Llama4MoE.custom_routing_function, \
"FusedMoE flashinfer kernels are only supported for Llama4"
if moe is not None:
return FlashInferExperts(
g1_alphas=layer.output1_scales_gate_scalar,
g2_alphas=layer.output2_scales_scalar,
a1_gscale=layer.w13_input_scale,
a2_gscale=layer.w2_input_scale_inv,
out_dtype=moe.in_dtype,
quant_dtype=torch.float8_e4m3fn,
ep_rank=moe.moe_parallel_config.ep_rank,
ep_size=moe.moe_parallel_config.ep_size,
tp_rank=moe.moe_parallel_config.tp_rank,
tp_size=moe.moe_parallel_config.tp_size,
)
assert out_dtype is not None, (
"If moe config is None, out_dtype must be passed")
return FlashInferExperts(
g1_alphas=layer.output1_scales_gate_scalar,
g2_alphas=layer.output2_scales_scalar,
a1_gscale=layer.w13_input_scale,
a2_gscale=layer.w2_input_scale_inv,
out_dtype=out_dtype,
quant_dtype=torch.float8_e4m3fn,
)
def flashinfer_cutlass_moe_fp8(
hidden_states: torch.Tensor,
layer: torch.nn.Module,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
inplace: bool = False,
activation: str = "silu",
global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False,
) -> torch.Tensor:
fused_experts = mk.FusedMoEModularKernel(
build_flashinfer_fp8_cutlass_moe_prepare_finalize(moe=None,
layer=layer),
select_cutlass_fp8_gemm_impl(moe=None,
layer=layer,
out_dtype=hidden_states.dtype))
return fused_experts(
hidden_states,
layer.w13_weight,
layer.w2_weight,
topk_weights,
topk_ids,
inplace=inplace,
activation=activation,
global_num_experts=global_num_experts,
expert_map=expert_map,
apply_router_weight_on_input=apply_router_weight_on_input,
)
def get_flashinfer_moe_backend() -> FlashinferMoeBackend:
flashinfer_moe_backend = envs.VLLM_FLASHINFER_MOE_BACKEND
if flashinfer_moe_backend == "throughput":
return FlashinferMoeBackend.CUTLASS
elif flashinfer_moe_backend == "latency":
return FlashinferMoeBackend.TENSORRT_LLM
allowed_backends = ["throughput", "latency"]
raise ValueError(
f"Unknown flashinfer moe backend: {flashinfer_moe_backend}"
f" expected one of {allowed_backends}")