[MoE Refactor] Create MK for TRTLLM Kernels (#32564)
Signed-off-by: Robert Shaw <robshaw@redhat.com> Signed-off-by: Robert Shaw <rshaw@neuralmagic.com> Signed-off-by: Robert Shaw <robertgshaw2@gmail.com> Co-authored-by: Robert Shaw <robshaw@redhat.com> Co-authored-by: Robert Shaw <rshaw@neuralmagic.com>
This commit is contained in:
@@ -1,6 +1,7 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from enum import Enum
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import torch
|
||||
|
||||
@@ -10,6 +11,9 @@ from vllm.model_executor.layers.fused_moe.activation import MoEActivation
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.math_utils import round_up
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from flashinfer.fused_moe.core import ActivationType
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@@ -20,6 +24,10 @@ class FlashinferMoeBackend(Enum):
|
||||
|
||||
|
||||
def activation_to_flashinfer_int(activation: MoEActivation) -> int:
|
||||
return activation_to_flashinfer_type(activation).value
|
||||
|
||||
|
||||
def activation_to_flashinfer_type(activation: MoEActivation) -> "ActivationType":
|
||||
from flashinfer.fused_moe.core import ActivationType
|
||||
|
||||
# silu and gelu are mapped to their gated versions SwiGLU and GeGLU respectively
|
||||
@@ -30,7 +38,7 @@ def activation_to_flashinfer_int(activation: MoEActivation) -> int:
|
||||
MoEActivation.GELU: ActivationType.Geglu,
|
||||
MoEActivation.RELU2_NO_MUL: ActivationType.Relu2,
|
||||
}
|
||||
return ACTIVATION_TO_FI_ACTIVATION[activation].value
|
||||
return ACTIVATION_TO_FI_ACTIVATION[activation]
|
||||
|
||||
|
||||
def swap_w13_to_w31(x: torch.Tensor) -> torch.Tensor:
|
||||
@@ -87,104 +95,6 @@ def rotate_weights_for_fi_trtllm_fp8_per_tensor_moe(
|
||||
)
|
||||
|
||||
|
||||
def register_scales_for_trtllm_fp8_per_tensor_moe(
|
||||
layer: torch.nn.Module,
|
||||
w13_scale: torch.Tensor,
|
||||
w13_input_scale: torch.Tensor,
|
||||
w2_scale: torch.Tensor,
|
||||
w2_input_scale: torch.Tensor,
|
||||
) -> None:
|
||||
"""Register necessary scales for FlashInfer TRTLLM FP8 MoE kernel"""
|
||||
g1_alphas, g2_alphas = make_fp8_moe_alpha_scales_for_fi(
|
||||
w13_scale=w13_scale,
|
||||
w13_input_scale=w13_input_scale,
|
||||
w2_scale=w2_scale,
|
||||
w2_input_scale=w2_input_scale,
|
||||
)
|
||||
layer.w2_input_scale_inv = 1.0 / w2_input_scale
|
||||
layer.output1_scales_gate_scalar = g1_alphas
|
||||
|
||||
if layer.activation.is_gated:
|
||||
layer.output1_scales_scalar = g1_alphas * layer.w2_input_scale_inv
|
||||
else:
|
||||
layer.output1_scales_scalar = (
|
||||
torch.ones_like(g1_alphas) * layer.w2_input_scale_inv
|
||||
)
|
||||
layer.output2_scales_scalar = g2_alphas
|
||||
|
||||
|
||||
def apply_fi_trtllm_fp8_per_tensor_moe(
|
||||
layer: torch.nn.Module,
|
||||
hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
routing_bias: torch.Tensor | None,
|
||||
top_k: int,
|
||||
num_expert_group: int | None,
|
||||
topk_group: int | None,
|
||||
global_num_experts: int,
|
||||
apply_router_weight_on_input: bool,
|
||||
) -> torch.Tensor:
|
||||
from flashinfer.fused_moe import RoutingMethodType
|
||||
|
||||
import vllm.model_executor.layers.fused_moe.flashinfer_trtllm_moe # noqa: E501, F401
|
||||
from vllm.model_executor.models.llama4 import Llama4MoE
|
||||
|
||||
# Added to the layer by: register_scales_for_trtllm_fp8_per_tensor_moe
|
||||
assert (
|
||||
hasattr(layer, "output1_scales_scalar")
|
||||
and hasattr(layer, "output1_scales_gate_scalar")
|
||||
and hasattr(layer, "output2_scales_scalar")
|
||||
)
|
||||
|
||||
if layer.routing_method_type == RoutingMethodType.Llama4:
|
||||
assert (
|
||||
not layer.renormalize
|
||||
and layer.custom_routing_function == Llama4MoE.custom_routing_function
|
||||
), (
|
||||
"FusedMoE flashinfer kernels with Llama4 routing method are only "
|
||||
"supported for Llama4"
|
||||
)
|
||||
else:
|
||||
assert layer.custom_routing_function is None, (
|
||||
"Custom routing function is only supported for Llama4"
|
||||
)
|
||||
activation_type = activation_to_flashinfer_int(layer.activation)
|
||||
|
||||
return torch.ops.vllm.fi_trtllm_fp8_per_tensor_moe(
|
||||
routing_logits=router_logits,
|
||||
routing_bias=routing_bias,
|
||||
hidden_states=hidden_states,
|
||||
input_scale=layer.w13_input_scale,
|
||||
gemm1_weights=layer.w13_weight,
|
||||
gemm2_weights=layer.w2_weight,
|
||||
output1_scales_scalar=layer.output1_scales_scalar,
|
||||
output1_scales_gate_scalar=layer.output1_scales_gate_scalar,
|
||||
output2_scales_scalar=layer.output2_scales_scalar,
|
||||
num_experts=global_num_experts,
|
||||
top_k=top_k,
|
||||
num_expert_group=num_expert_group,
|
||||
topk_group=topk_group,
|
||||
intermediate_size=layer.intermediate_size_per_partition,
|
||||
local_expert_offset=layer.ep_rank * layer.local_num_experts,
|
||||
local_num_experts=layer.local_num_experts,
|
||||
use_routing_scales_on_input=apply_router_weight_on_input,
|
||||
routing_method_type=layer.routing_method_type,
|
||||
activation_type=activation_type,
|
||||
)
|
||||
|
||||
|
||||
def make_fp8_moe_alpha_scales_for_fi(
|
||||
w13_scale: torch.Tensor,
|
||||
w13_input_scale: torch.Tensor,
|
||||
w2_scale: torch.Tensor,
|
||||
w2_input_scale: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
g1_alphas = (w13_scale * w13_input_scale).squeeze()
|
||||
g2_alphas = (w2_scale * w2_input_scale).squeeze()
|
||||
|
||||
return g1_alphas, g2_alphas
|
||||
|
||||
|
||||
def get_flashinfer_moe_backend() -> FlashinferMoeBackend:
|
||||
backend_map = {
|
||||
"throughput": FlashinferMoeBackend.CUTLASS,
|
||||
@@ -432,6 +342,7 @@ def prepare_fp8_moe_layer_for_fi(
|
||||
min_alignment,
|
||||
)
|
||||
layer.intermediate_size_per_partition = new_intermediate
|
||||
layer.moe_config.intermediate_size_per_partition = new_intermediate
|
||||
|
||||
# FI kernels require W31 layout rather than W13.
|
||||
if layer.moe_config.is_act_and_mul:
|
||||
@@ -440,20 +351,12 @@ def prepare_fp8_moe_layer_for_fi(
|
||||
w13_scale = swap_w13_to_w31(w13_scale)
|
||||
|
||||
# FI TRT-LLM FP8 per-tensor MoE kernel requires weight shuffle
|
||||
# and registration of alpha scales. Note that we do not register
|
||||
# as nn.Parameters since they are not needed for weight-reloading.
|
||||
# and registration of alpha scales.
|
||||
if is_trtllm and not block_quant:
|
||||
assert w13_input_scale is not None
|
||||
assert w2_input_scale is not None
|
||||
|
||||
rotate_weights_for_fi_trtllm_fp8_per_tensor_moe(w13, w2, is_gated)
|
||||
register_scales_for_trtllm_fp8_per_tensor_moe(
|
||||
layer,
|
||||
w13_scale=w13_scale,
|
||||
w13_input_scale=w13_input_scale,
|
||||
w2_scale=w2_scale,
|
||||
w2_input_scale=w2_input_scale,
|
||||
)
|
||||
|
||||
# Clamp block scales to avoid NaN from the FlashInfer CUTLASS kernel.
|
||||
# Some FP8 models have near-zero block scales (~1e-23) for dead/unused
|
||||
|
||||
Reference in New Issue
Block a user