[Perf] Change Trtllm fp8 MoE to use Shuffled Weights and BlockMajorK Layout (#38993)
Signed-off-by: wzhao18 <wzhao18.sz@gmail.com> Co-authored-by: Robert Shaw <114415538+robertgshaw2-redhat@users.noreply.github.com>
This commit is contained in:
@@ -305,6 +305,39 @@ def align_fp8_moe_weights_for_fi(
|
||||
return padded_w13, padded_w2, padded_intermediate
|
||||
|
||||
|
||||
def _shuffle_deepseek_fp8_moe_weights(
|
||||
w13: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Preprocess DeepSeek FP8 block-scale weights for the FlashInfer TRT-LLM
|
||||
kernel using the shuffle + BlockMajorK layout variant.
|
||||
|
||||
Returns 4D weight tensors in BlockMajorK layout
|
||||
(E, K/block_k, Mn, block_k)
|
||||
"""
|
||||
from flashinfer import shuffle_matrix_a
|
||||
from flashinfer.fused_moe import convert_to_block_layout
|
||||
|
||||
epilogue_tile_m = 64
|
||||
block_k = 128
|
||||
num_experts = w13.shape[0]
|
||||
|
||||
w13_shuffled: list[torch.Tensor] = []
|
||||
w2_shuffled: list[torch.Tensor] = []
|
||||
for i in range(num_experts):
|
||||
t13 = shuffle_matrix_a(w13[i].view(torch.uint8), epilogue_tile_m)
|
||||
t13 = convert_to_block_layout(t13, block_k)
|
||||
w13_shuffled.append(t13)
|
||||
|
||||
t2 = shuffle_matrix_a(w2[i].view(torch.uint8), epilogue_tile_m)
|
||||
t2 = convert_to_block_layout(t2, block_k)
|
||||
w2_shuffled.append(t2)
|
||||
|
||||
w13_out = torch.stack(w13_shuffled).view(torch.float8_e4m3fn)
|
||||
w2_out = torch.stack(w2_shuffled).view(torch.float8_e4m3fn)
|
||||
return w13_out, w2_out
|
||||
|
||||
|
||||
def _shuffle_mxfp8_moe_weights(
|
||||
w13: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
@@ -405,6 +438,7 @@ def prepare_fp8_moe_layer_for_fi(
|
||||
hasattr(layer, "weight_block_size") and layer.weight_block_size is not None
|
||||
)
|
||||
is_mxfp8 = block_quant and w13_scale.dtype == torch.uint8
|
||||
is_deepseek_fp8 = block_quant and not is_mxfp8
|
||||
is_gated = layer.activation.is_gated
|
||||
|
||||
# MXFP8 TRT-LLM requires W31 swap + reorder + shuffle.
|
||||
@@ -447,6 +481,10 @@ def prepare_fp8_moe_layer_for_fi(
|
||||
if block_quant:
|
||||
w13_scale = swap_w13_to_w31(w13_scale)
|
||||
|
||||
# DeepSeekFp8 TRT-LLM: shuffle weights into BlockMajorK layout.
|
||||
if is_deepseek_fp8 and is_trtllm:
|
||||
w13, w2 = _shuffle_deepseek_fp8_moe_weights(w13, w2)
|
||||
|
||||
# FI TRT-LLM FP8 per-tensor MoE kernel requires weight shuffle
|
||||
# and registration of alpha scales.
|
||||
if is_trtllm and not block_quant:
|
||||
|
||||
Reference in New Issue
Block a user