[Quant][Feature] Support online MXFP8 quantization for MoE and dense models (#35448)
Signed-off-by: EdalatiAli <aliedalati@cohere.com>
This commit is contained in:
@@ -305,6 +305,81 @@ def align_fp8_moe_weights_for_fi(
|
||||
return padded_w13, padded_w2, padded_intermediate
|
||||
|
||||
|
||||
def _shuffle_mxfp8_moe_weights(
|
||||
w13: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
w13_scale: torch.Tensor,
|
||||
w2_scale: torch.Tensor,
|
||||
is_gated: bool,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""Preprocess MXFP8 weights and scales for the FlashInfer TRT-LLM kernel.
|
||||
|
||||
Following flashinfer/tests/moe/test_trtllm_gen_fused_moe.py:
|
||||
1. reorder_rows_for_gated_act_gemm (interleave gate/up rows)
|
||||
2. shuffle_matrix_a (weight data layout shuffle)
|
||||
3. shuffle_matrix_sf_a (scale factor layout shuffle)
|
||||
"""
|
||||
from flashinfer import (
|
||||
reorder_rows_for_gated_act_gemm,
|
||||
shuffle_matrix_a,
|
||||
shuffle_matrix_sf_a,
|
||||
)
|
||||
|
||||
epilogue_tile_m = 128
|
||||
num_experts = w13.shape[0]
|
||||
intermediate_size = w13.shape[1] // 2
|
||||
hidden_size = w13.shape[2]
|
||||
|
||||
w13_interleaved: list[torch.Tensor] = []
|
||||
w13_scale_interleaved: list[torch.Tensor] = []
|
||||
for i in range(num_experts):
|
||||
if is_gated:
|
||||
w13_interleaved.append(
|
||||
reorder_rows_for_gated_act_gemm(
|
||||
w13[i].reshape(2 * intermediate_size, -1)
|
||||
)
|
||||
)
|
||||
w13_scale_interleaved.append(
|
||||
reorder_rows_for_gated_act_gemm(
|
||||
w13_scale[i].reshape(2 * intermediate_size, -1)
|
||||
)
|
||||
)
|
||||
else:
|
||||
w13_interleaved.append(w13[i])
|
||||
w13_scale_interleaved.append(w13_scale[i])
|
||||
|
||||
w13_shuffled: list[torch.Tensor] = []
|
||||
w2_shuffled: list[torch.Tensor] = []
|
||||
w13_scale_shuffled: list[torch.Tensor] = []
|
||||
w2_scale_shuffled: list[torch.Tensor] = []
|
||||
for i in range(num_experts):
|
||||
w13_shuffled.append(
|
||||
shuffle_matrix_a(w13_interleaved[i].view(torch.uint8), epilogue_tile_m)
|
||||
)
|
||||
w2_shuffled.append(shuffle_matrix_a(w2[i].view(torch.uint8), epilogue_tile_m))
|
||||
w13_scale_shuffled.append(
|
||||
shuffle_matrix_sf_a(
|
||||
w13_scale_interleaved[i]
|
||||
.view(torch.uint8)
|
||||
.reshape(2 * intermediate_size, -1),
|
||||
epilogue_tile_m,
|
||||
)
|
||||
)
|
||||
w2_scale_shuffled.append(
|
||||
shuffle_matrix_sf_a(
|
||||
w2_scale[i].view(torch.uint8).reshape(hidden_size, -1),
|
||||
epilogue_tile_m,
|
||||
)
|
||||
)
|
||||
|
||||
w13_out = torch.stack(w13_shuffled).view(torch.float8_e4m3fn)
|
||||
w2_out = torch.stack(w2_shuffled).view(torch.float8_e4m3fn)
|
||||
w13_scale_out = torch.stack(w13_scale_shuffled).reshape(w13_scale.shape)
|
||||
w2_scale_out = torch.stack(w2_scale_shuffled).reshape(w2_scale.shape)
|
||||
|
||||
return w13_out, w2_out, w13_scale_out, w2_scale_out
|
||||
|
||||
|
||||
def prepare_fp8_moe_layer_for_fi(
|
||||
layer: torch.nn.Module,
|
||||
w13: torch.Tensor,
|
||||
@@ -314,7 +389,7 @@ def prepare_fp8_moe_layer_for_fi(
|
||||
w2_scale: torch.Tensor,
|
||||
w2_input_scale: torch.Tensor | None,
|
||||
is_trtllm: bool = False,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Convert Fp8 MoE weights to flashinfer kernel format
|
||||
|
||||
@@ -329,10 +404,33 @@ def prepare_fp8_moe_layer_for_fi(
|
||||
block_quant = (
|
||||
hasattr(layer, "weight_block_size") and layer.weight_block_size is not None
|
||||
)
|
||||
is_mxfp8 = block_quant and w13_scale.dtype == torch.uint8
|
||||
is_gated = layer.activation.is_gated
|
||||
|
||||
# MXFP8 TRT-LLM requires W31 swap + reorder + shuffle.
|
||||
if is_mxfp8 and is_trtllm:
|
||||
# FlashInfer TRT-LLM SwiGLU expects [up; gate] but vLLM stores
|
||||
# [gate; up]. Swap both weights and scales before interleaving.
|
||||
if layer.moe_config.is_act_and_mul:
|
||||
w13 = swap_w13_to_w31(w13)
|
||||
# Scales may be 2D [E, flat] from _quantize_mxfp8_moe_weight;
|
||||
# reshape to 3D so swap_w13_to_w31 can flip the two halves,
|
||||
# then flatten back.
|
||||
if w13_scale.ndim == 2:
|
||||
num_rows = w13.shape[1] # 2 * intermediate_size
|
||||
w13_scale = w13_scale.reshape(w13_scale.shape[0], num_rows, -1)
|
||||
w13_scale = swap_w13_to_w31(w13_scale)
|
||||
w13_scale = w13_scale.reshape(w13_scale.shape[0], -1)
|
||||
else:
|
||||
w13_scale = swap_w13_to_w31(w13_scale)
|
||||
|
||||
w13, w2, w13_scale, w2_scale = _shuffle_mxfp8_moe_weights(
|
||||
w13, w2, w13_scale, w2_scale, is_gated
|
||||
)
|
||||
return w13, w2, w13_scale, w2_scale
|
||||
|
||||
# Some FI MoE kernels require internal alignment of 16
|
||||
# for the gate-up proj. Pad the weights to respect this.
|
||||
is_gated = layer.activation.is_gated
|
||||
if not block_quant:
|
||||
min_alignment = 16 if is_gated else 128
|
||||
w13, w2, new_intermediate = align_fp8_moe_weights_for_fi(
|
||||
@@ -369,4 +467,4 @@ def prepare_fp8_moe_layer_for_fi(
|
||||
w13_scale.clamp_(min=_FI_CUTLASS_MIN_BLOCK_SCALE)
|
||||
w2_scale.clamp_(min=_FI_CUTLASS_MIN_BLOCK_SCALE)
|
||||
|
||||
return w13, w2, w13_scale
|
||||
return w13, w2, w13_scale, w2_scale
|
||||
|
||||
Reference in New Issue
Block a user