[Kernel] Support Flashinfer trtllm fused MoE non gated FP8 & NVFP4 (#33506)
Signed-off-by: amitz-nv <203509407+amitz-nv@users.noreply.github.com>
This commit is contained in:
@@ -6,6 +6,7 @@ import torch
|
||||
|
||||
from vllm import envs
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.fused_moe.activation import MoEActivation
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.math_utils import round_up
|
||||
|
||||
@@ -18,6 +19,20 @@ class FlashinferMoeBackend(Enum):
|
||||
CUTEDSL = "CUTEDSL"
|
||||
|
||||
|
||||
def activation_to_flashinfer_int(activation: MoEActivation) -> int:
|
||||
from flashinfer.fused_moe.core import ActivationType
|
||||
|
||||
# silu and gelu are mapped to their gated versions SwiGLU and GeGLU respectively
|
||||
ACTIVATION_TO_FI_ACTIVATION = {
|
||||
MoEActivation.SILU_NO_MUL: ActivationType.Silu,
|
||||
MoEActivation.GELU_NO_MUL: ActivationType.Gelu,
|
||||
MoEActivation.SILU: ActivationType.Swiglu,
|
||||
MoEActivation.GELU: ActivationType.Geglu,
|
||||
MoEActivation.RELU2_NO_MUL: ActivationType.Relu2,
|
||||
}
|
||||
return ACTIVATION_TO_FI_ACTIVATION[activation].value
|
||||
|
||||
|
||||
def swap_w13_to_w31(x: torch.Tensor) -> torch.Tensor:
|
||||
return (
|
||||
x.reshape(-1, 2, x.shape[-2] // 2, x.shape[-1]).flip(dims=[1]).reshape(x.shape)
|
||||
@@ -25,7 +40,7 @@ def swap_w13_to_w31(x: torch.Tensor) -> torch.Tensor:
|
||||
|
||||
|
||||
def rotate_weights_for_fi_trtllm_fp8_per_tensor_moe(
|
||||
gemm1_weights: torch.Tensor, gemm2_weights: torch.Tensor
|
||||
gemm1_weights: torch.Tensor, gemm2_weights: torch.Tensor, is_gated_activation: bool
|
||||
):
|
||||
"""Shuffle weights for for FI TRT-LLM Format"""
|
||||
from flashinfer import reorder_rows_for_gated_act_gemm, shuffle_matrix_a
|
||||
@@ -40,6 +55,8 @@ def rotate_weights_for_fi_trtllm_fp8_per_tensor_moe(
|
||||
for i in range(num_experts):
|
||||
gemm1_weights_fp8_interleaved.append(
|
||||
reorder_rows_for_gated_act_gemm(gemm1_weights[i])
|
||||
if is_gated_activation
|
||||
else gemm1_weights[i]
|
||||
)
|
||||
|
||||
# Stack weights and scales for all experts
|
||||
@@ -86,7 +103,13 @@ def register_scales_for_trtllm_fp8_per_tensor_moe(
|
||||
)
|
||||
layer.w2_input_scale_inv = 1.0 / w2_input_scale
|
||||
layer.output1_scales_gate_scalar = g1_alphas
|
||||
layer.output1_scales_scalar = g1_alphas * layer.w2_input_scale_inv
|
||||
|
||||
if layer.activation.is_gated:
|
||||
layer.output1_scales_scalar = g1_alphas * layer.w2_input_scale_inv
|
||||
else:
|
||||
layer.output1_scales_scalar = (
|
||||
torch.ones_like(g1_alphas) * layer.w2_input_scale_inv
|
||||
)
|
||||
layer.output2_scales_scalar = g2_alphas
|
||||
|
||||
|
||||
@@ -125,6 +148,7 @@ def apply_fi_trtllm_fp8_per_tensor_moe(
|
||||
assert layer.custom_routing_function is None, (
|
||||
"Custom routing function is only supported for Llama4"
|
||||
)
|
||||
activation_type = activation_to_flashinfer_int(layer.activation)
|
||||
|
||||
return torch.ops.vllm.fi_trtllm_fp8_per_tensor_moe(
|
||||
routing_logits=router_logits,
|
||||
@@ -145,6 +169,7 @@ def apply_fi_trtllm_fp8_per_tensor_moe(
|
||||
local_num_experts=layer.local_num_experts,
|
||||
use_routing_scales_on_input=apply_router_weight_on_input,
|
||||
routing_method_type=layer.routing_method_type,
|
||||
activation_type=activation_type,
|
||||
)
|
||||
|
||||
|
||||
@@ -274,8 +299,64 @@ def convert_moe_weights_to_flashinfer_trtllm_block_layout(
|
||||
return w13_weights_shuffled_tensor, w2_weights_shuffled_tensor
|
||||
|
||||
|
||||
def align_fp4_moe_weights_for_fi(
|
||||
w13: torch.Tensor,
|
||||
w13_scale: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
w2_scale: torch.Tensor,
|
||||
is_act_and_mul: bool,
|
||||
min_alignment: int = 16,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, int]:
|
||||
"""Pad intermediate size so FlashInfer kernels' alignment constraints hold.
|
||||
|
||||
Some FlashInfer FP4 MoE kernels require the intermediate size
|
||||
used for GEMM to be divisible by a small alignment value. When this is
|
||||
not satisfied (e.g. with certain tensor-parallel sizes), we pad the
|
||||
gate/up and down projection weights along the intermediate dim.
|
||||
"""
|
||||
|
||||
# Current local intermediate size (per partition) is the K dimension of
|
||||
# the down projection.
|
||||
num_experts, hidden_size, intermediate = w2.shape
|
||||
intermediate *= 2 # because of packed FP4
|
||||
|
||||
padded_intermediate = round_up(intermediate, min_alignment)
|
||||
|
||||
if padded_intermediate == intermediate:
|
||||
return w13, w13_scale, w2, w2_scale, intermediate
|
||||
|
||||
logger.info_once(
|
||||
"Padding intermediate size from %d to %d for up/down projection weights.",
|
||||
intermediate,
|
||||
padded_intermediate,
|
||||
scope="local",
|
||||
)
|
||||
|
||||
up_mult = 2 if is_act_and_mul else 1
|
||||
padded_gate_up_dim = up_mult * padded_intermediate
|
||||
|
||||
# Pad w13 and w2 along its intermediate dimension.
|
||||
padded_w13 = w13.new_zeros((num_experts, padded_gate_up_dim, hidden_size // 2))
|
||||
padded_w13[:, : w13.shape[1], :] = w13
|
||||
|
||||
padded_w2 = w2.new_zeros((num_experts, hidden_size, padded_intermediate // 2))
|
||||
padded_w2[:, :, : w2.shape[2]] = w2
|
||||
|
||||
padded_w13_scale = w13_scale.new_zeros(
|
||||
(num_experts, padded_gate_up_dim, hidden_size // 16)
|
||||
)
|
||||
padded_w13_scale[:, : w13_scale.shape[1], :] = w13_scale
|
||||
|
||||
padded_w2_scale = w2_scale.new_zeros(
|
||||
(num_experts, hidden_size, padded_intermediate // 16)
|
||||
)
|
||||
padded_w2_scale[:, :, : w2_scale.shape[2]] = w2_scale
|
||||
|
||||
return padded_w13, padded_w13_scale, padded_w2, padded_w2_scale, padded_intermediate
|
||||
|
||||
|
||||
def align_fp8_moe_weights_for_fi(
|
||||
w13: torch.Tensor, w2: torch.Tensor, is_act_and_mul: bool
|
||||
w13: torch.Tensor, w2: torch.Tensor, is_act_and_mul: bool, min_alignment: int = 16
|
||||
) -> tuple[torch.Tensor, torch.Tensor, int]:
|
||||
"""Pad intermediate size so FlashInfer kernels' alignment constraints hold.
|
||||
|
||||
@@ -289,7 +370,6 @@ def align_fp8_moe_weights_for_fi(
|
||||
# the down projection.
|
||||
num_experts, hidden_size, intermediate = w2.shape
|
||||
|
||||
min_alignment = 16
|
||||
padded_intermediate = round_up(intermediate, min_alignment)
|
||||
|
||||
if padded_intermediate == intermediate:
|
||||
@@ -342,11 +422,14 @@ def prepare_fp8_moe_layer_for_fi(
|
||||
|
||||
# Some FI MoE kernels require internal alignment of 16
|
||||
# for the gate-up proj. Pad the weights to respect this.
|
||||
is_gated = layer.activation.is_gated
|
||||
if not block_quant:
|
||||
min_alignment = 16 if is_gated else 128
|
||||
w13, w2, new_intermediate = align_fp8_moe_weights_for_fi(
|
||||
w13,
|
||||
w2,
|
||||
layer.moe_config.is_act_and_mul,
|
||||
min_alignment,
|
||||
)
|
||||
layer.intermediate_size_per_partition = new_intermediate
|
||||
|
||||
@@ -363,7 +446,7 @@ def prepare_fp8_moe_layer_for_fi(
|
||||
assert w13_input_scale is not None
|
||||
assert w2_input_scale is not None
|
||||
|
||||
rotate_weights_for_fi_trtllm_fp8_per_tensor_moe(w13, w2)
|
||||
rotate_weights_for_fi_trtllm_fp8_per_tensor_moe(w13, w2, is_gated)
|
||||
register_scales_for_trtllm_fp8_per_tensor_moe(
|
||||
layer,
|
||||
w13_scale=w13_scale,
|
||||
|
||||
Reference in New Issue
Block a user