[MoE Refactor][15/N] Apply Refactor to Fp8 (#31415)
This commit is contained in:
@@ -18,6 +18,7 @@ from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_prepare_finalize im
|
||||
create_flashinfer_prepare_finalize,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.math_utils import round_up
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@@ -58,9 +59,10 @@ def swap_w13_to_w31(x: torch.Tensor) -> torch.Tensor:
|
||||
)
|
||||
|
||||
|
||||
def rotate_flashinfer_fp8_moe_weights(
|
||||
def rotate_weights_for_fi_trtllm_fp8_per_tensor_moe(
|
||||
gemm1_weights: torch.Tensor, gemm2_weights: torch.Tensor
|
||||
):
|
||||
"""Shuffle weights for for FI TRT-LLM Format"""
|
||||
from flashinfer import reorder_rows_for_gated_act_gemm, shuffle_matrix_a
|
||||
|
||||
epilogue_tile_m = 128
|
||||
@@ -105,16 +107,16 @@ def rotate_flashinfer_fp8_moe_weights(
|
||||
|
||||
def register_scales_for_trtllm_fp8_per_tensor_moe(
|
||||
layer: torch.nn.Module,
|
||||
w13_weight_scale: torch.Tensor,
|
||||
w13_scale: torch.Tensor,
|
||||
w13_input_scale: torch.Tensor,
|
||||
w2_weight_scale: torch.Tensor,
|
||||
w2_scale: torch.Tensor,
|
||||
w2_input_scale: torch.Tensor,
|
||||
) -> None:
|
||||
"""Register necessary scales for FlashInfer TRTLLM FP8 MoE kernel"""
|
||||
g1_alphas, g2_alphas = make_fp8_moe_alpha_scales_for_fi(
|
||||
w13_scale=w13_weight_scale,
|
||||
w13_scale=w13_scale,
|
||||
w13_input_scale=w13_input_scale,
|
||||
w2_scale=w2_weight_scale,
|
||||
w2_scale=w2_scale,
|
||||
w2_input_scale=w2_input_scale,
|
||||
)
|
||||
layer.w2_input_scale_inv = 1.0 / w2_input_scale
|
||||
@@ -123,7 +125,7 @@ def register_scales_for_trtllm_fp8_per_tensor_moe(
|
||||
layer.output2_scales_scalar = g2_alphas
|
||||
|
||||
|
||||
def apply_flashinfer_per_tensor_scale_fp8(
|
||||
def apply_fi_trtllm_fp8_per_tensor_moe(
|
||||
layer: torch.nn.Module,
|
||||
hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
@@ -139,16 +141,23 @@ def apply_flashinfer_per_tensor_scale_fp8(
|
||||
import vllm.model_executor.layers.fused_moe.flashinfer_trtllm_moe # noqa: E501, F401
|
||||
from vllm.model_executor.models.llama4 import Llama4MoE
|
||||
|
||||
# Added to the layer by: register_scales_for_trtllm_fp8_per_tensor_moe
|
||||
assert (
|
||||
hasattr(layer, "output1_scales_scalar")
|
||||
and hasattr(layer, "output1_scales_gate_scalar")
|
||||
and hasattr(layer, "output2_scales_scalar")
|
||||
)
|
||||
|
||||
assert layer.custom_routing_function == Llama4MoE.custom_routing_function, (
|
||||
"FusedMoE flashinfer kernels are only supported for Llama4"
|
||||
# Added to the layer by: register_scales_for_trtllm_fp8_per_tensor_moe
|
||||
assert (
|
||||
hasattr(layer, "output1_scales_scalar")
|
||||
and hasattr(layer, "output1_scales_gate_scalar")
|
||||
and hasattr(layer, "output2_scales_scalar")
|
||||
)
|
||||
return torch.ops.vllm.flashinfer_fused_moe_per_tensor_scale_fp8(
|
||||
|
||||
is_llama4 = layer.custom_routing_function == Llama4MoE.custom_routing_function
|
||||
assert is_llama4, "FusedMoE flashinfer kernels are only supported for Llama4"
|
||||
return torch.ops.vllm.fi_trtllm_fp8_per_tensor_moe(
|
||||
routing_logits=router_logits,
|
||||
routing_bias=routing_bias,
|
||||
hidden_states=hidden_states,
|
||||
@@ -221,50 +230,6 @@ def select_cutlass_fp8_gemm_impl(
|
||||
)
|
||||
|
||||
|
||||
def flashinfer_cutlass_moe_fp8(
|
||||
hidden_states: torch.Tensor,
|
||||
layer: torch.nn.Module,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
inplace: bool = False,
|
||||
activation: str = "silu",
|
||||
global_num_experts: int = -1,
|
||||
expert_map: torch.Tensor | None = None,
|
||||
apply_router_weight_on_input: bool = False,
|
||||
use_deepseek_fp8_block_scale: bool = False,
|
||||
moe: FusedMoEConfig | None = None,
|
||||
) -> torch.Tensor:
|
||||
quant_config = layer.quant_method.get_fused_moe_quant_config(layer)
|
||||
assert quant_config is not None
|
||||
|
||||
# Construct modular kernel with block-scale support when requested.
|
||||
fused_experts = mk.FusedMoEModularKernel(
|
||||
build_flashinfer_fp8_cutlass_moe_prepare_finalize(
|
||||
moe=moe, use_deepseek_fp8_block_scale=use_deepseek_fp8_block_scale
|
||||
),
|
||||
select_cutlass_fp8_gemm_impl(
|
||||
moe=moe,
|
||||
quant_config=quant_config,
|
||||
out_dtype=hidden_states.dtype,
|
||||
use_deepseek_fp8_block_scale=use_deepseek_fp8_block_scale,
|
||||
),
|
||||
moe_parallel_config=layer.moe_parallel_config,
|
||||
)
|
||||
|
||||
return fused_experts(
|
||||
hidden_states,
|
||||
layer.w13_weight,
|
||||
layer.w2_weight,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
inplace=inplace,
|
||||
activation=activation,
|
||||
global_num_experts=global_num_experts,
|
||||
expert_map=expert_map,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
)
|
||||
|
||||
|
||||
def get_flashinfer_moe_backend() -> FlashinferMoeBackend:
|
||||
backend_map = {
|
||||
"throughput": FlashinferMoeBackend.CUTLASS,
|
||||
@@ -301,3 +266,104 @@ def is_flashinfer_supporting_global_sf(backend: FlashinferMoeBackend | None) ->
|
||||
FlashinferMoeBackend.TENSORRT_LLM,
|
||||
)
|
||||
return backend in backends_supporting_global_sf
|
||||
|
||||
|
||||
def align_fp8_moe_weights_for_fi(
|
||||
w13: torch.Tensor, w2: torch.Tensor, is_act_and_mul: bool
|
||||
) -> tuple[torch.Tensor, torch.Tensor, int]:
|
||||
"""Pad intermediate size so FlashInfer kernels' alignment constraints hold.
|
||||
|
||||
Some FlashInfer FP8 MoE kernels require the (gated) intermediate size
|
||||
used for GEMM to be divisible by a small alignment value. When this is
|
||||
not satisfied (e.g. with certain tensor-parallel sizes), we pad the
|
||||
gate/up and down projection weights along the intermediate dim.
|
||||
"""
|
||||
|
||||
# Current local intermediate size (per partition) is the K dimension of
|
||||
# the down projection.
|
||||
num_experts, hidden_size, intermediate = w2.shape
|
||||
|
||||
min_alignment = 16
|
||||
padded_intermediate = round_up(intermediate, min_alignment)
|
||||
|
||||
if padded_intermediate == intermediate:
|
||||
return w13, w2, intermediate
|
||||
|
||||
logger.info_once(
|
||||
"Padding intermediate size from %d to %d for up/down projection weights.",
|
||||
intermediate,
|
||||
padded_intermediate,
|
||||
scope="local",
|
||||
)
|
||||
|
||||
up_mult = 2 if is_act_and_mul else 1
|
||||
padded_gate_up_dim = up_mult * padded_intermediate
|
||||
|
||||
# Pad w13 and w2 along its intermediate dimension.
|
||||
padded_w13 = w13.new_zeros((num_experts, padded_gate_up_dim, hidden_size))
|
||||
padded_w13[:, : w13.shape[1], :] = w13
|
||||
|
||||
padded_w2 = w2.new_zeros((num_experts, hidden_size, padded_intermediate))
|
||||
padded_w2[:, :, :intermediate] = w2
|
||||
|
||||
return padded_w13, padded_w2, padded_intermediate
|
||||
|
||||
|
||||
def prepare_fp8_moe_layer_for_fi(
|
||||
layer: torch.nn.Module,
|
||||
w13: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
w13_scale: torch.Tensor,
|
||||
w13_input_scale: torch.Tensor | None,
|
||||
w2_scale: torch.Tensor,
|
||||
w2_input_scale: torch.Tensor | None,
|
||||
is_trtllm: bool = False,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Convert Fp8 MoE weights to flashinfer kernel format
|
||||
|
||||
Note that for trtllm we update the model state dict
|
||||
with the scale format needed for these kernels.
|
||||
|
||||
Note that for per-tensor, we update the layer's
|
||||
intermediate size if the weights needed padding.
|
||||
"""
|
||||
|
||||
assert hasattr(layer.moe_config, "is_act_and_mul")
|
||||
block_quant = (
|
||||
hasattr(layer, "weight_block_size") and layer.weight_block_size is not None
|
||||
)
|
||||
|
||||
# Some FI MoE kernels require internal alignment of 16
|
||||
# for the gate-up proj. Pad the weights to respect this.
|
||||
if not block_quant:
|
||||
w13, w2, new_intermediate = align_fp8_moe_weights_for_fi(
|
||||
w13,
|
||||
w2,
|
||||
layer.moe_config.is_act_and_mul,
|
||||
)
|
||||
layer.intermediate_size_per_partition = new_intermediate
|
||||
|
||||
# FI kernels require W31 layout rather than W13.
|
||||
if layer.moe_config.is_act_and_mul:
|
||||
w13 = swap_w13_to_w31(w13)
|
||||
if block_quant:
|
||||
w13_scale = swap_w13_to_w31(w13_scale)
|
||||
|
||||
# FI TRT-LLM FP8 per-tensor MoE kernel requires weight shuffle
|
||||
# and registration of alpha scales. Note that we do not register
|
||||
# as nn.Parameters since they are not needed for weight-reloading.
|
||||
if is_trtllm and not block_quant:
|
||||
assert w13_input_scale is not None
|
||||
assert w2_input_scale is not None
|
||||
|
||||
rotate_weights_for_fi_trtllm_fp8_per_tensor_moe(w13, w2)
|
||||
register_scales_for_trtllm_fp8_per_tensor_moe(
|
||||
layer,
|
||||
w13_scale=w13_scale,
|
||||
w13_input_scale=w13_input_scale,
|
||||
w2_scale=w2_scale,
|
||||
w2_input_scale=w2_input_scale,
|
||||
)
|
||||
|
||||
return w13, w2, w13_scale
|
||||
|
||||
Reference in New Issue
Block a user