511 lines
18 KiB
Python
511 lines
18 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
from enum import Enum
|
|
from typing import TYPE_CHECKING
|
|
|
|
import torch
|
|
|
|
from vllm import envs
|
|
from vllm.logger import init_logger
|
|
from vllm.model_executor.layers.fused_moe.activation import MoEActivation
|
|
from vllm.platforms import current_platform
|
|
from vllm.utils.math_utils import round_up
|
|
|
|
if TYPE_CHECKING:
|
|
from flashinfer.fused_moe.core import ActivationType
|
|
|
|
logger = init_logger(__name__)
|
|
|
|
|
|
class FlashinferMoeBackend(Enum):
|
|
TENSORRT_LLM = "TensorRT-LLM"
|
|
CUTLASS = "CUTLASS"
|
|
CUTEDSL = "CUTEDSL"
|
|
|
|
|
|
def activation_to_flashinfer_int(activation: MoEActivation) -> int:
|
|
return activation_to_flashinfer_type(activation).value
|
|
|
|
|
|
def activation_to_flashinfer_type(activation: MoEActivation) -> "ActivationType":
|
|
from flashinfer.fused_moe.core import ActivationType
|
|
|
|
# silu and gelu are mapped to their gated versions SwiGLU and GeGLU respectively
|
|
ACTIVATION_TO_FI_ACTIVATION = {
|
|
MoEActivation.SILU_NO_MUL: ActivationType.Silu,
|
|
MoEActivation.GELU_NO_MUL: ActivationType.Gelu,
|
|
MoEActivation.SILU: ActivationType.Swiglu,
|
|
MoEActivation.GELU: ActivationType.Geglu,
|
|
MoEActivation.RELU2_NO_MUL: ActivationType.Relu2,
|
|
}
|
|
return ACTIVATION_TO_FI_ACTIVATION[activation]
|
|
|
|
|
|
def swap_w13_to_w31(x: torch.Tensor) -> torch.Tensor:
|
|
return (
|
|
x.reshape(-1, 2, x.shape[-2] // 2, x.shape[-1]).flip(dims=[1]).reshape(x.shape)
|
|
)
|
|
|
|
|
|
def rotate_weights_for_fi_trtllm_fp8_per_tensor_moe(
|
|
gemm1_weights: torch.Tensor, gemm2_weights: torch.Tensor, is_gated_activation: bool
|
|
):
|
|
"""Shuffle weights for FI TRT-LLM Format"""
|
|
from flashinfer import reorder_rows_for_gated_act_gemm, shuffle_matrix_a
|
|
|
|
epilogue_tile_m = 128
|
|
num_experts = gemm1_weights.shape[0]
|
|
hidden_size = gemm1_weights.shape[-1]
|
|
intermediate_size = gemm1_weights.shape[1] // 2
|
|
|
|
# Reorder rows of W1 for fused gated activation
|
|
gemm1_weights_fp8_interleaved = []
|
|
for i in range(num_experts):
|
|
gemm1_weights_fp8_interleaved.append(
|
|
reorder_rows_for_gated_act_gemm(gemm1_weights[i])
|
|
if is_gated_activation
|
|
else gemm1_weights[i]
|
|
)
|
|
|
|
# Stack weights and scales for all experts
|
|
gemm1_weights_fp8_interleaved = torch.stack(gemm1_weights_fp8_interleaved).reshape(
|
|
num_experts, 2 * intermediate_size, hidden_size
|
|
)
|
|
|
|
# Shuffle weights and scaling factors for transposed mma output
|
|
gemm1_weights_fp8_shuffled = []
|
|
gemm2_weights_fp8_shuffled = []
|
|
for i in range(num_experts):
|
|
gemm1_weights_fp8_shuffled.append(
|
|
shuffle_matrix_a(
|
|
gemm1_weights_fp8_interleaved[i].view(torch.uint8), epilogue_tile_m
|
|
)
|
|
)
|
|
|
|
gemm2_weights_fp8_shuffled.append(
|
|
shuffle_matrix_a(gemm2_weights[i].view(torch.uint8), epilogue_tile_m)
|
|
)
|
|
|
|
# Stack weights for all experts
|
|
gemm1_weights.data = torch.stack(gemm1_weights_fp8_shuffled).view(
|
|
torch.float8_e4m3fn
|
|
)
|
|
gemm2_weights.data = torch.stack(gemm2_weights_fp8_shuffled).view(
|
|
torch.float8_e4m3fn
|
|
)
|
|
|
|
|
|
def get_flashinfer_moe_backend() -> FlashinferMoeBackend:
|
|
backend_map = {
|
|
"throughput": FlashinferMoeBackend.CUTLASS,
|
|
"latency": FlashinferMoeBackend.TENSORRT_LLM,
|
|
"masked_gemm": FlashinferMoeBackend.CUTEDSL,
|
|
}
|
|
|
|
flashinfer_moe_backend = envs.VLLM_FLASHINFER_MOE_BACKEND
|
|
if flashinfer_moe_backend in backend_map:
|
|
if (
|
|
flashinfer_moe_backend == "latency"
|
|
and not current_platform.is_device_capability_family(100)
|
|
):
|
|
logger.info_once(
|
|
"Flashinfer TRTLLM MOE backend is only supported on "
|
|
"SM100 and later, using CUTLASS backend instead",
|
|
scope="local",
|
|
)
|
|
return FlashinferMoeBackend.CUTLASS
|
|
return backend_map[flashinfer_moe_backend]
|
|
elif current_platform.is_device_capability(90):
|
|
return FlashinferMoeBackend.CUTLASS
|
|
|
|
raise ValueError(
|
|
f"Unknown flashinfer moe backend: {flashinfer_moe_backend!r}. "
|
|
f"Expected one of {list(backend_map.keys())}."
|
|
)
|
|
|
|
|
|
def is_flashinfer_supporting_global_sf(backend: FlashinferMoeBackend | None) -> bool:
|
|
# TODO(shuw@nvidia): Update when new backends are added.
|
|
backends_supporting_global_sf = (
|
|
FlashinferMoeBackend.CUTLASS,
|
|
FlashinferMoeBackend.TENSORRT_LLM,
|
|
FlashinferMoeBackend.CUTEDSL,
|
|
)
|
|
return backend in backends_supporting_global_sf
|
|
|
|
|
|
def convert_moe_weights_to_flashinfer_trtllm_block_layout(
|
|
cache_permute_indices: dict[torch.Size, torch.Tensor],
|
|
w13_weight: torch.Tensor,
|
|
w2_weight: torch.Tensor,
|
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
"""Convert expert weights to FlashInfer's block layout.
|
|
|
|
This reorders W13 and W2 into the expected epilogue-tiled block layout and
|
|
returns the shuffled weight tensors.
|
|
"""
|
|
if w13_weight.dtype != torch.bfloat16 or w2_weight.dtype != torch.bfloat16:
|
|
raise ValueError(
|
|
"Unquantized Moe Backend FlashInfer TRTLLM requires bfloat16 weights"
|
|
)
|
|
|
|
from flashinfer.fused_moe.core import (
|
|
_maybe_get_cached_w3_w1_permute_indices,
|
|
convert_to_block_layout,
|
|
get_w2_permute_indices_with_cache,
|
|
)
|
|
|
|
epilogue_tile_m = 128
|
|
block_k = 128
|
|
|
|
# Reorder rows of W13 and W2 for fused gated activation and convert to the
|
|
# block layout expected by the FlashInfer kernel.
|
|
num_experts = w13_weight.shape[0]
|
|
device_w13 = w13_weight.device
|
|
device_w2 = w2_weight.device
|
|
|
|
w13_weights_shuffled: list[torch.Tensor] = []
|
|
w2_weights_shuffled: list[torch.Tensor] = []
|
|
|
|
for i in range(num_experts):
|
|
permute_indices = _maybe_get_cached_w3_w1_permute_indices(
|
|
cache_permute_indices,
|
|
w13_weight[i].view(torch.uint8),
|
|
epilogue_tile_m,
|
|
)
|
|
tmp_weights1 = (
|
|
w13_weight[i]
|
|
.clone()
|
|
.view(torch.uint8)[permute_indices.to(device_w13)]
|
|
.contiguous()
|
|
)
|
|
|
|
permute_indices = get_w2_permute_indices_with_cache(
|
|
cache_permute_indices,
|
|
w2_weight[i].view(torch.uint8),
|
|
epilogue_tile_m,
|
|
)
|
|
tmp_weights2 = (
|
|
w2_weight[i]
|
|
.clone()
|
|
.view(torch.uint8)[permute_indices.to(device_w2)]
|
|
.contiguous()
|
|
)
|
|
|
|
tmp_weights1 = convert_to_block_layout(tmp_weights1.view(torch.uint8), block_k)
|
|
tmp_weights2 = convert_to_block_layout(tmp_weights2.view(torch.uint8), block_k)
|
|
|
|
w13_weights_shuffled.append(tmp_weights1.view(torch.bfloat16))
|
|
w2_weights_shuffled.append(tmp_weights2.view(torch.bfloat16))
|
|
|
|
# Stack weights for all experts and return as BF16 tensors.
|
|
w13_weights_shuffled_tensor = (
|
|
torch.stack(w13_weights_shuffled).view(torch.bfloat16).contiguous()
|
|
)
|
|
w2_weights_shuffled_tensor = (
|
|
torch.stack(w2_weights_shuffled).view(torch.bfloat16).contiguous()
|
|
)
|
|
|
|
return w13_weights_shuffled_tensor, w2_weights_shuffled_tensor
|
|
|
|
|
|
def align_fp4_moe_weights_for_fi(
|
|
w13: torch.Tensor,
|
|
w13_scale: torch.Tensor,
|
|
w2: torch.Tensor,
|
|
w2_scale: torch.Tensor,
|
|
is_act_and_mul: bool,
|
|
min_alignment: int = 16,
|
|
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, int]:
|
|
"""Pad intermediate size so FlashInfer kernels' alignment constraints hold.
|
|
|
|
Some FlashInfer FP4 MoE kernels require the intermediate size
|
|
used for GEMM to be divisible by a small alignment value. When this is
|
|
not satisfied (e.g. with certain tensor-parallel sizes), we pad the
|
|
gate/up and down projection weights along the intermediate dim.
|
|
"""
|
|
|
|
# Current local intermediate size (per partition) is the K dimension of
|
|
# the down projection.
|
|
num_experts, hidden_size, intermediate = w2.shape
|
|
intermediate *= 2 # because of packed FP4
|
|
|
|
padded_intermediate = round_up(intermediate, min_alignment)
|
|
|
|
if padded_intermediate == intermediate:
|
|
return w13, w13_scale, w2, w2_scale, intermediate
|
|
|
|
logger.info_once(
|
|
"Padding intermediate size from %d to %d for up/down projection weights.",
|
|
intermediate,
|
|
padded_intermediate,
|
|
scope="local",
|
|
)
|
|
|
|
up_mult = 2 if is_act_and_mul else 1
|
|
padded_gate_up_dim = up_mult * padded_intermediate
|
|
|
|
# Pad w13 and w2 along its intermediate dimension.
|
|
padded_w13 = w13.new_zeros((num_experts, padded_gate_up_dim, hidden_size // 2))
|
|
padded_w13[:, : w13.shape[1], :] = w13
|
|
|
|
padded_w2 = w2.new_zeros((num_experts, hidden_size, padded_intermediate // 2))
|
|
padded_w2[:, :, : w2.shape[2]] = w2
|
|
|
|
padded_w13_scale = w13_scale.new_zeros(
|
|
(num_experts, padded_gate_up_dim, hidden_size // 16)
|
|
)
|
|
padded_w13_scale[:, : w13_scale.shape[1], :] = w13_scale
|
|
|
|
padded_w2_scale = w2_scale.new_zeros(
|
|
(num_experts, hidden_size, padded_intermediate // 16)
|
|
)
|
|
padded_w2_scale[:, :, : w2_scale.shape[2]] = w2_scale
|
|
|
|
return padded_w13, padded_w13_scale, padded_w2, padded_w2_scale, padded_intermediate
|
|
|
|
|
|
def align_fp8_moe_weights_for_fi(
|
|
w13: torch.Tensor, w2: torch.Tensor, is_act_and_mul: bool, min_alignment: int = 16
|
|
) -> tuple[torch.Tensor, torch.Tensor, int]:
|
|
"""Pad intermediate size so FlashInfer kernels' alignment constraints hold.
|
|
|
|
Some FlashInfer FP8 MoE kernels require the (gated) intermediate size
|
|
used for GEMM to be divisible by a small alignment value. When this is
|
|
not satisfied (e.g. with certain tensor-parallel sizes), we pad the
|
|
gate/up and down projection weights along the intermediate dim.
|
|
"""
|
|
|
|
# Current local intermediate size (per partition) is the K dimension of
|
|
# the down projection.
|
|
num_experts, hidden_size, intermediate = w2.shape
|
|
|
|
padded_intermediate = round_up(intermediate, min_alignment)
|
|
|
|
if padded_intermediate == intermediate:
|
|
return w13, w2, intermediate
|
|
|
|
logger.info_once(
|
|
"Padding intermediate size from %d to %d for up/down projection weights.",
|
|
intermediate,
|
|
padded_intermediate,
|
|
scope="local",
|
|
)
|
|
|
|
up_mult = 2 if is_act_and_mul else 1
|
|
padded_gate_up_dim = up_mult * padded_intermediate
|
|
|
|
# Pad w13 and w2 along its intermediate dimension.
|
|
padded_w13 = w13.new_zeros((num_experts, padded_gate_up_dim, hidden_size))
|
|
padded_w13[:, : w13.shape[1], :] = w13
|
|
|
|
padded_w2 = w2.new_zeros((num_experts, hidden_size, padded_intermediate))
|
|
padded_w2[:, :, :intermediate] = w2
|
|
|
|
return padded_w13, padded_w2, padded_intermediate
|
|
|
|
|
|
def _shuffle_deepseek_fp8_moe_weights(
|
|
w13: torch.Tensor,
|
|
w2: torch.Tensor,
|
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
"""Preprocess DeepSeek FP8 block-scale weights for the FlashInfer TRT-LLM
|
|
kernel using the shuffle + BlockMajorK layout variant.
|
|
|
|
Returns 4D weight tensors in BlockMajorK layout
|
|
(E, K/block_k, Mn, block_k)
|
|
"""
|
|
from flashinfer import shuffle_matrix_a
|
|
from flashinfer.fused_moe import convert_to_block_layout
|
|
|
|
epilogue_tile_m = 64
|
|
block_k = 128
|
|
num_experts = w13.shape[0]
|
|
|
|
M13, K13 = w13.shape[1], w13.shape[2]
|
|
M2, K2 = w2.shape[1], w2.shape[2]
|
|
w13_out = torch.empty(
|
|
num_experts, K13 // block_k, M13, block_k, dtype=torch.uint8, device=w13.device
|
|
)
|
|
w2_out = torch.empty(
|
|
num_experts, K2 // block_k, M2, block_k, dtype=torch.uint8, device=w2.device
|
|
)
|
|
|
|
for i in range(num_experts):
|
|
t13 = shuffle_matrix_a(w13[i].view(torch.uint8), epilogue_tile_m)
|
|
w13_out[i] = convert_to_block_layout(t13, block_k)
|
|
|
|
t2 = shuffle_matrix_a(w2[i].view(torch.uint8), epilogue_tile_m)
|
|
w2_out[i] = convert_to_block_layout(t2, block_k)
|
|
|
|
return w13_out.view(torch.float8_e4m3fn), w2_out.view(torch.float8_e4m3fn)
|
|
|
|
|
|
def _shuffle_mxfp8_moe_weights(
|
|
w13: torch.Tensor,
|
|
w2: torch.Tensor,
|
|
w13_scale: torch.Tensor,
|
|
w2_scale: torch.Tensor,
|
|
is_gated: bool,
|
|
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
"""Preprocess MXFP8 weights and scales for the FlashInfer TRT-LLM kernel.
|
|
|
|
Following flashinfer/tests/moe/test_trtllm_gen_fused_moe.py:
|
|
1. reorder_rows_for_gated_act_gemm (interleave gate/up rows)
|
|
2. shuffle_matrix_a (weight data layout shuffle)
|
|
3. shuffle_matrix_sf_a (scale factor layout shuffle)
|
|
"""
|
|
from flashinfer import (
|
|
reorder_rows_for_gated_act_gemm,
|
|
shuffle_matrix_a,
|
|
shuffle_matrix_sf_a,
|
|
)
|
|
|
|
epilogue_tile_m = 128
|
|
num_experts = w13.shape[0]
|
|
intermediate_size = w13.shape[1] // 2
|
|
hidden_size = w13.shape[2]
|
|
|
|
w13_interleaved: list[torch.Tensor] = []
|
|
w13_scale_interleaved: list[torch.Tensor] = []
|
|
for i in range(num_experts):
|
|
if is_gated:
|
|
w13_interleaved.append(
|
|
reorder_rows_for_gated_act_gemm(
|
|
w13[i].reshape(2 * intermediate_size, -1)
|
|
)
|
|
)
|
|
w13_scale_interleaved.append(
|
|
reorder_rows_for_gated_act_gemm(
|
|
w13_scale[i].reshape(2 * intermediate_size, -1)
|
|
)
|
|
)
|
|
else:
|
|
w13_interleaved.append(w13[i])
|
|
w13_scale_interleaved.append(w13_scale[i])
|
|
|
|
w13_shuffled: list[torch.Tensor] = []
|
|
w2_shuffled: list[torch.Tensor] = []
|
|
w13_scale_shuffled: list[torch.Tensor] = []
|
|
w2_scale_shuffled: list[torch.Tensor] = []
|
|
for i in range(num_experts):
|
|
w13_shuffled.append(
|
|
shuffle_matrix_a(w13_interleaved[i].view(torch.uint8), epilogue_tile_m)
|
|
)
|
|
w2_shuffled.append(shuffle_matrix_a(w2[i].view(torch.uint8), epilogue_tile_m))
|
|
w13_scale_shuffled.append(
|
|
shuffle_matrix_sf_a(
|
|
w13_scale_interleaved[i]
|
|
.view(torch.uint8)
|
|
.reshape(2 * intermediate_size, -1),
|
|
epilogue_tile_m,
|
|
)
|
|
)
|
|
w2_scale_shuffled.append(
|
|
shuffle_matrix_sf_a(
|
|
w2_scale[i].view(torch.uint8).reshape(hidden_size, -1),
|
|
epilogue_tile_m,
|
|
)
|
|
)
|
|
|
|
w13_out = torch.stack(w13_shuffled).view(torch.float8_e4m3fn)
|
|
w2_out = torch.stack(w2_shuffled).view(torch.float8_e4m3fn)
|
|
w13_scale_out = torch.stack(w13_scale_shuffled).reshape(w13_scale.shape)
|
|
w2_scale_out = torch.stack(w2_scale_shuffled).reshape(w2_scale.shape)
|
|
|
|
return w13_out, w2_out, w13_scale_out, w2_scale_out
|
|
|
|
|
|
def prepare_fp8_moe_layer_for_fi(
|
|
layer: torch.nn.Module,
|
|
w13: torch.Tensor,
|
|
w2: torch.Tensor,
|
|
w13_scale: torch.Tensor,
|
|
w13_input_scale: torch.Tensor | None,
|
|
w2_scale: torch.Tensor,
|
|
w2_input_scale: torch.Tensor | None,
|
|
is_trtllm: bool = False,
|
|
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
"""
|
|
Convert Fp8 MoE weights to flashinfer kernel format
|
|
|
|
Note that for trtllm we update the model state dict
|
|
with the scale format needed for these kernels.
|
|
|
|
Note that for per-tensor, we update the layer's
|
|
intermediate size if the weights needed padding.
|
|
"""
|
|
|
|
assert hasattr(layer.moe_config, "is_act_and_mul")
|
|
block_quant = (
|
|
hasattr(layer, "weight_block_size") and layer.weight_block_size is not None
|
|
)
|
|
is_mxfp8 = block_quant and w13_scale.dtype == torch.uint8
|
|
is_deepseek_fp8 = block_quant and not is_mxfp8
|
|
is_gated = layer.activation.is_gated
|
|
|
|
# MXFP8 TRT-LLM requires W31 swap + reorder + shuffle.
|
|
if is_mxfp8 and is_trtllm:
|
|
# FlashInfer TRT-LLM SwiGLU expects [up; gate] but vLLM stores
|
|
# [gate; up]. Swap both weights and scales before interleaving.
|
|
if layer.moe_config.is_act_and_mul:
|
|
w13 = swap_w13_to_w31(w13)
|
|
# Scales may be 2D [E, flat] from _quantize_mxfp8_moe_weight;
|
|
# reshape to 3D so swap_w13_to_w31 can flip the two halves,
|
|
# then flatten back.
|
|
if w13_scale.ndim == 2:
|
|
num_rows = w13.shape[1] # 2 * intermediate_size
|
|
w13_scale = w13_scale.reshape(w13_scale.shape[0], num_rows, -1)
|
|
w13_scale = swap_w13_to_w31(w13_scale)
|
|
w13_scale = w13_scale.reshape(w13_scale.shape[0], -1)
|
|
else:
|
|
w13_scale = swap_w13_to_w31(w13_scale)
|
|
|
|
w13, w2, w13_scale, w2_scale = _shuffle_mxfp8_moe_weights(
|
|
w13, w2, w13_scale, w2_scale, is_gated
|
|
)
|
|
return w13, w2, w13_scale, w2_scale
|
|
|
|
# Some FI MoE kernels require internal alignment of 16
|
|
# for the gate-up proj. Pad the weights to respect this.
|
|
if not block_quant:
|
|
min_alignment = 16 if is_gated else 128
|
|
w13, w2, new_intermediate = align_fp8_moe_weights_for_fi(
|
|
w13,
|
|
w2,
|
|
layer.moe_config.is_act_and_mul,
|
|
min_alignment,
|
|
)
|
|
layer.moe_config.intermediate_size_per_partition = new_intermediate
|
|
|
|
# FI kernels require W31 layout rather than W13.
|
|
if layer.moe_config.is_act_and_mul:
|
|
w13 = swap_w13_to_w31(w13)
|
|
if block_quant:
|
|
w13_scale = swap_w13_to_w31(w13_scale)
|
|
|
|
# DeepSeekFp8 TRT-LLM: shuffle weights into BlockMajorK layout.
|
|
if is_deepseek_fp8 and is_trtllm:
|
|
w13, w2 = _shuffle_deepseek_fp8_moe_weights(w13, w2)
|
|
|
|
# FI TRT-LLM FP8 per-tensor MoE kernel requires weight shuffle
|
|
# and registration of alpha scales.
|
|
if is_trtllm and not block_quant:
|
|
assert w13_input_scale is not None
|
|
assert w2_input_scale is not None
|
|
|
|
rotate_weights_for_fi_trtllm_fp8_per_tensor_moe(w13, w2, is_gated)
|
|
|
|
# Clamp block scales to avoid NaN from the FlashInfer CUTLASS kernel.
|
|
# Some FP8 models have near-zero block scales (~1e-23) for dead/unused
|
|
# experts. The CUTLASS kernel doesn't handle these correctly on Hopper
|
|
# (SM 9.0), producing NaN instead of near-zero output. Clamping to a
|
|
# small minimum prevents this without affecting model accuracy since
|
|
# these experts' effective weights are already zero.
|
|
if block_quant:
|
|
_FI_CUTLASS_MIN_BLOCK_SCALE = 1e-10
|
|
w13_scale.clamp_(min=_FI_CUTLASS_MIN_BLOCK_SCALE)
|
|
w2_scale.clamp_(min=_FI_CUTLASS_MIN_BLOCK_SCALE)
|
|
|
|
return w13, w2, w13_scale, w2_scale
|