[Perf] Change Trtllm fp8 MoE to use Shuffled Weights and BlockMajorK Layout (#38993)

Signed-off-by: wzhao18 <wzhao18.sz@gmail.com>
Co-authored-by: Robert Shaw <114415538+robertgshaw2-redhat@users.noreply.github.com>
This commit is contained in:
Wei Zhao
2026-04-05 10:54:31 -04:00
committed by GitHub
parent 228023b3a5
commit 1af6f78ae5
3 changed files with 72 additions and 13 deletions

View File

@@ -305,6 +305,39 @@ def align_fp8_moe_weights_for_fi(
return padded_w13, padded_w2, padded_intermediate
def _shuffle_deepseek_fp8_moe_weights(
w13: torch.Tensor,
w2: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
"""Preprocess DeepSeek FP8 block-scale weights for the FlashInfer TRT-LLM
kernel using the shuffle + BlockMajorK layout variant.
Returns 4D weight tensors in BlockMajorK layout
(E, K/block_k, Mn, block_k)
"""
from flashinfer import shuffle_matrix_a
from flashinfer.fused_moe import convert_to_block_layout
epilogue_tile_m = 64
block_k = 128
num_experts = w13.shape[0]
w13_shuffled: list[torch.Tensor] = []
w2_shuffled: list[torch.Tensor] = []
for i in range(num_experts):
t13 = shuffle_matrix_a(w13[i].view(torch.uint8), epilogue_tile_m)
t13 = convert_to_block_layout(t13, block_k)
w13_shuffled.append(t13)
t2 = shuffle_matrix_a(w2[i].view(torch.uint8), epilogue_tile_m)
t2 = convert_to_block_layout(t2, block_k)
w2_shuffled.append(t2)
w13_out = torch.stack(w13_shuffled).view(torch.float8_e4m3fn)
w2_out = torch.stack(w2_shuffled).view(torch.float8_e4m3fn)
return w13_out, w2_out
def _shuffle_mxfp8_moe_weights(
w13: torch.Tensor,
w2: torch.Tensor,
@@ -405,6 +438,7 @@ def prepare_fp8_moe_layer_for_fi(
hasattr(layer, "weight_block_size") and layer.weight_block_size is not None
)
is_mxfp8 = block_quant and w13_scale.dtype == torch.uint8
is_deepseek_fp8 = block_quant and not is_mxfp8
is_gated = layer.activation.is_gated
# MXFP8 TRT-LLM requires W31 swap + reorder + shuffle.
@@ -447,6 +481,10 @@ def prepare_fp8_moe_layer_for_fi(
if block_quant:
w13_scale = swap_w13_to_w31(w13_scale)
# DeepSeekFp8 TRT-LLM: shuffle weights into BlockMajorK layout.
if is_deepseek_fp8 and is_trtllm:
w13, w2 = _shuffle_deepseek_fp8_moe_weights(w13, w2)
# FI TRT-LLM FP8 per-tensor MoE kernel requires weight shuffle
# and registration of alpha scales.
if is_trtllm and not block_quant: