[Perf] Change Trtllm fp8 MoE to use Shuffled Weights and BlockMajorK Layout (#38993)
Signed-off-by: wzhao18 <wzhao18.sz@gmail.com> Co-authored-by: Robert Shaw <114415538+robertgshaw2-redhat@users.noreply.github.com>
This commit is contained in:
@@ -112,6 +112,24 @@ class TrtLlmFp8ExpertsModular(TrtLlmFp8ExpertsBase, mk.FusedMoEExpertsModular):
|
||||
]
|
||||
return (weight_key, activation_key) in SUPPORTED_W_A
|
||||
|
||||
def moe_problem_size(
|
||||
self,
|
||||
a1: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
) -> tuple[int, int, int, int, int]:
|
||||
"""Override to handle 4D BlockMajorK weights (E, K/bk, Mn, bk)."""
|
||||
if w1.dim() == 4:
|
||||
# BlockMajorK: (E, K/bk, Mn, bk)
|
||||
E = w1.shape[0]
|
||||
N = w1.shape[2]
|
||||
K = a1.size(-1)
|
||||
M = a1.size(0) if a1.dim() == 2 else a1.size(1)
|
||||
topk = topk_ids.size(1)
|
||||
return E, M, N, K, topk
|
||||
return super().moe_problem_size(a1, w1, w2, topk_ids)
|
||||
|
||||
def workspace_shapes(
|
||||
self,
|
||||
M: int,
|
||||
@@ -152,7 +170,7 @@ class TrtLlmFp8ExpertsModular(TrtLlmFp8ExpertsBase, mk.FusedMoEExpertsModular):
|
||||
apply_router_weight_on_input: bool,
|
||||
):
|
||||
import flashinfer
|
||||
from flashinfer.fused_moe import Fp8QuantizationType
|
||||
from flashinfer.fused_moe import Fp8QuantizationType, WeightLayout
|
||||
|
||||
# Pack topk ids and weights into format expected by the kernel.
|
||||
packed_topk_ids = trtllm_moe_pack_topk_ids_weights(topk_ids, topk_weights)
|
||||
@@ -170,10 +188,12 @@ class TrtLlmFp8ExpertsModular(TrtLlmFp8ExpertsBase, mk.FusedMoEExpertsModular):
|
||||
if is_mxfp8:
|
||||
fp8_quant_type = Fp8QuantizationType.MxFp8
|
||||
use_shuffled_weight = True
|
||||
weight_layout = WeightLayout.MajorK
|
||||
hidden_states_scale = a1q_scale
|
||||
else:
|
||||
fp8_quant_type = Fp8QuantizationType.DeepSeekFp8
|
||||
use_shuffled_weight = False
|
||||
use_shuffled_weight = True
|
||||
weight_layout = WeightLayout.BlockMajorK
|
||||
hidden_states_scale = a1q_scale.t().contiguous()
|
||||
|
||||
# `trtllm_fp8_block_scale_routed_moe` has a bug and does not write to the
|
||||
@@ -199,7 +219,7 @@ class TrtLlmFp8ExpertsModular(TrtLlmFp8ExpertsBase, mk.FusedMoEExpertsModular):
|
||||
routed_scaling_factor=None,
|
||||
routing_method_type=1,
|
||||
use_shuffled_weight=use_shuffled_weight,
|
||||
weight_layout=0,
|
||||
weight_layout=weight_layout,
|
||||
fp8_quantization_type=fp8_quant_type,
|
||||
# output=output,
|
||||
)
|
||||
@@ -322,7 +342,7 @@ class TrtLlmFp8ExpertsMonolithic(TrtLlmFp8ExpertsBase, mk.FusedMoEExpertsMonolit
|
||||
topk_group: int | None = None,
|
||||
) -> torch.Tensor:
|
||||
import flashinfer
|
||||
from flashinfer.fused_moe import Fp8QuantizationType
|
||||
from flashinfer.fused_moe import Fp8QuantizationType, WeightLayout
|
||||
|
||||
assert not apply_router_weight_on_input
|
||||
assert activation == MoEActivation.SILU
|
||||
@@ -342,10 +362,12 @@ class TrtLlmFp8ExpertsMonolithic(TrtLlmFp8ExpertsBase, mk.FusedMoEExpertsMonolit
|
||||
if is_mxfp8:
|
||||
fp8_quant_type = Fp8QuantizationType.MxFp8
|
||||
use_shuffled_weight = True
|
||||
weight_layout = WeightLayout.MajorK
|
||||
hidden_states_scale = a1q_scale
|
||||
else:
|
||||
fp8_quant_type = Fp8QuantizationType.DeepSeekFp8
|
||||
use_shuffled_weight = False
|
||||
use_shuffled_weight = True
|
||||
weight_layout = WeightLayout.BlockMajorK
|
||||
hidden_states_scale = a1q_scale.t().contiguous()
|
||||
|
||||
return flashinfer.fused_moe.trtllm_fp8_block_scale_moe(
|
||||
@@ -367,6 +389,7 @@ class TrtLlmFp8ExpertsMonolithic(TrtLlmFp8ExpertsBase, mk.FusedMoEExpertsMonolit
|
||||
routed_scaling_factor=routed_scaling_factor,
|
||||
routing_method_type=self.routing_method_type,
|
||||
use_shuffled_weight=use_shuffled_weight,
|
||||
weight_layout=weight_layout,
|
||||
fp8_quantization_type=fp8_quant_type,
|
||||
)
|
||||
|
||||
|
||||
@@ -305,6 +305,39 @@ def align_fp8_moe_weights_for_fi(
|
||||
return padded_w13, padded_w2, padded_intermediate
|
||||
|
||||
|
||||
def _shuffle_deepseek_fp8_moe_weights(
|
||||
w13: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Preprocess DeepSeek FP8 block-scale weights for the FlashInfer TRT-LLM
|
||||
kernel using the shuffle + BlockMajorK layout variant.
|
||||
|
||||
Returns 4D weight tensors in BlockMajorK layout
|
||||
(E, K/block_k, Mn, block_k)
|
||||
"""
|
||||
from flashinfer import shuffle_matrix_a
|
||||
from flashinfer.fused_moe import convert_to_block_layout
|
||||
|
||||
epilogue_tile_m = 64
|
||||
block_k = 128
|
||||
num_experts = w13.shape[0]
|
||||
|
||||
w13_shuffled: list[torch.Tensor] = []
|
||||
w2_shuffled: list[torch.Tensor] = []
|
||||
for i in range(num_experts):
|
||||
t13 = shuffle_matrix_a(w13[i].view(torch.uint8), epilogue_tile_m)
|
||||
t13 = convert_to_block_layout(t13, block_k)
|
||||
w13_shuffled.append(t13)
|
||||
|
||||
t2 = shuffle_matrix_a(w2[i].view(torch.uint8), epilogue_tile_m)
|
||||
t2 = convert_to_block_layout(t2, block_k)
|
||||
w2_shuffled.append(t2)
|
||||
|
||||
w13_out = torch.stack(w13_shuffled).view(torch.float8_e4m3fn)
|
||||
w2_out = torch.stack(w2_shuffled).view(torch.float8_e4m3fn)
|
||||
return w13_out, w2_out
|
||||
|
||||
|
||||
def _shuffle_mxfp8_moe_weights(
|
||||
w13: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
@@ -405,6 +438,7 @@ def prepare_fp8_moe_layer_for_fi(
|
||||
hasattr(layer, "weight_block_size") and layer.weight_block_size is not None
|
||||
)
|
||||
is_mxfp8 = block_quant and w13_scale.dtype == torch.uint8
|
||||
is_deepseek_fp8 = block_quant and not is_mxfp8
|
||||
is_gated = layer.activation.is_gated
|
||||
|
||||
# MXFP8 TRT-LLM requires W31 swap + reorder + shuffle.
|
||||
@@ -447,6 +481,10 @@ def prepare_fp8_moe_layer_for_fi(
|
||||
if block_quant:
|
||||
w13_scale = swap_w13_to_w31(w13_scale)
|
||||
|
||||
# DeepSeekFp8 TRT-LLM: shuffle weights into BlockMajorK layout.
|
||||
if is_deepseek_fp8 and is_trtllm:
|
||||
w13, w2 = _shuffle_deepseek_fp8_moe_weights(w13, w2)
|
||||
|
||||
# FI TRT-LLM FP8 per-tensor MoE kernel requires weight shuffle
|
||||
# and registration of alpha scales.
|
||||
if is_trtllm and not block_quant:
|
||||
|
||||
@@ -13,7 +13,7 @@ import vllm.envs as envs
|
||||
from vllm.distributed.parallel_state import get_dp_group, is_global_first_rank
|
||||
from vllm.model_executor.layers.fused_moe.deep_gemm_moe import DeepGemmExperts
|
||||
from vllm.model_executor.layers.fused_moe.deep_gemm_utils import compute_aligned_M
|
||||
from vllm.model_executor.layers.fused_moe.layer import FusedMoE, FusedMoEModularMethod
|
||||
from vllm.model_executor.layers.fused_moe.layer import FusedMoE
|
||||
from vllm.model_executor.layers.fused_moe.triton_deep_gemm_moe import (
|
||||
TritonOrDeepGemmExperts,
|
||||
)
|
||||
@@ -168,14 +168,12 @@ def _fused_moe_grouped_gemm_may_use_deep_gemm(module: torch.nn.Module) -> bool:
|
||||
):
|
||||
return False
|
||||
|
||||
if not isinstance(module.quant_method, FusedMoEModularMethod):
|
||||
# modular kernels could invoke deep_gemm_moe_fp8
|
||||
return True
|
||||
moe_kernel = getattr(module.quant_method, "moe_kernel", None)
|
||||
if moe_kernel is None:
|
||||
return False
|
||||
|
||||
# Further check if the ModularKernel implementation uses the DeepGemmExperts
|
||||
return isinstance(
|
||||
module.quant_method.moe_kernel, (DeepGemmExperts, TritonOrDeepGemmExperts)
|
||||
)
|
||||
fused_experts = moe_kernel.impl.fused_experts
|
||||
return isinstance(fused_experts, (DeepGemmExperts, TritonOrDeepGemmExperts))
|
||||
|
||||
|
||||
FP8_GEMM_NT_WARMUP_CACHE: set[torch.Size] = set()
|
||||
|
||||
Reference in New Issue
Block a user