[MoE Refactor][14/N] Clean Up FI Quant Config Smuggling (#31593)
Signed-off-by: Robert Shaw <robshaw@redhat.com> Co-authored-by: Robert Shaw <robshaw@redhat.com>
This commit is contained in:
@@ -50,7 +50,8 @@ from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
|
||||
apply_flashinfer_per_tensor_scale_fp8,
|
||||
build_flashinfer_fp8_cutlass_moe_prepare_finalize,
|
||||
get_flashinfer_moe_backend,
|
||||
register_moe_scaling_factors,
|
||||
make_fp8_moe_alpha_scales_for_fi,
|
||||
register_scales_for_trtllm_fp8_per_tensor_moe,
|
||||
rotate_flashinfer_fp8_moe_weights,
|
||||
select_cutlass_fp8_gemm_impl,
|
||||
swap_w13_to_w31,
|
||||
@@ -774,6 +775,14 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
"FlashInfer CUTLASS FP8 MoE backend only supports SiLU "
|
||||
"activation function, but got {layer.activation}."
|
||||
)
|
||||
dynamic_per_token = (
|
||||
not self.block_quant and self.quant_config.activation_scheme != "static"
|
||||
)
|
||||
if self.flashinfer_moe_backend is not None and dynamic_per_token:
|
||||
raise NotImplementedError(
|
||||
"FlashInfer FP8 MoE backend does not support dynamic per token "
|
||||
"activation quantization."
|
||||
)
|
||||
|
||||
def create_weights(
|
||||
self,
|
||||
@@ -905,6 +914,8 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
w2_weight: torch.Tensor,
|
||||
w13_weight_scale: torch.Tensor,
|
||||
w2_weight_scale: torch.Tensor,
|
||||
w13_input_scale: torch.Tensor | None,
|
||||
w2_input_scale: torch.Tensor | None,
|
||||
) -> None:
|
||||
if self.fp8_backend == Fp8MoeBackend.DEEPGEMM:
|
||||
assert self.block_quant
|
||||
@@ -949,11 +960,16 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
if self.block_quant:
|
||||
w13_weight_scale = swap_w13_to_w31(w13_weight_scale)
|
||||
else:
|
||||
# TODO(rob): this function is a hack that renames the scaling
|
||||
# factors in the Module. This is a hack we should clean up.
|
||||
register_moe_scaling_factors(layer)
|
||||
if self.fp8_backend == Fp8MoeBackend.FLASHINFER_TRTLLM:
|
||||
rotate_flashinfer_fp8_moe_weights(w13_weight, w2_weight)
|
||||
register_scales_for_trtllm_fp8_per_tensor_moe(
|
||||
layer=layer,
|
||||
w13_weight_scale=w13_weight,
|
||||
w13_input_scale=w13_input_scale,
|
||||
w2_weight_scale=w2_weight,
|
||||
w2_input_scale=w2_input_scale,
|
||||
)
|
||||
|
||||
elif self.fp8_backend == Fp8MoeBackend.AITER:
|
||||
w13_weight, w2_weight = rocm_aiter_ops.shuffle_weights(
|
||||
w13_weight, w2_weight
|
||||
@@ -990,6 +1006,10 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
AiterExperts,
|
||||
)
|
||||
|
||||
# Flashinfer TRTLLM does not use the modular kernel abstraction.
|
||||
if self.fp8_backend == Fp8MoeBackend.FLASHINFER_TRTLLM:
|
||||
return
|
||||
|
||||
self.moe_quant_config = self.get_fused_moe_quant_config(layer)
|
||||
assert self.moe_quant_config is not None
|
||||
self.use_inplace = True
|
||||
@@ -1087,7 +1107,13 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
|
||||
# Shuffle weights into the runtime format.
|
||||
self._convert_weights_to_kernel_format(
|
||||
layer, w13_weight, w2_weight, w13_weight_scale, w2_weight_scale
|
||||
layer=layer,
|
||||
w13_weight=w13_weight,
|
||||
w2_weight=w2_weight,
|
||||
w13_weight_scale=w13_weight_scale,
|
||||
w2_weight_scale=w2_weight_scale,
|
||||
w13_input_scale=w13_input_scale,
|
||||
w2_input_scale=w2_input_scale,
|
||||
)
|
||||
|
||||
# Setup modular kernel for TP case.
|
||||
@@ -1182,6 +1208,11 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
def get_fused_moe_quant_config(
|
||||
self, layer: torch.nn.Module
|
||||
) -> FusedMoEQuantConfig | None:
|
||||
# TRTLLM does not use Modular Kernel.
|
||||
if self.fp8_backend == Fp8MoeBackend.FLASHINFER_TRTLLM:
|
||||
return None
|
||||
|
||||
# MARLIN uses mixed precision W8A16 config.
|
||||
if self.fp8_backend == Fp8MoeBackend.MARLIN:
|
||||
return fp8_w8a16_moe_quant_config(
|
||||
w1_scale=getattr(layer, f"w13_{self.weight_scale_name}"),
|
||||
@@ -1189,11 +1220,38 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
block_shape=self.weight_block_size,
|
||||
)
|
||||
|
||||
w1_scale = getattr(layer, f"w13_{self.weight_scale_name}")
|
||||
w2_scale = getattr(layer, f"w2_{self.weight_scale_name}")
|
||||
a1_scale = layer.w13_input_scale
|
||||
a2_scale = layer.w2_input_scale
|
||||
|
||||
# Flashinfer CUTLASS per-tensor uses single dq scale
|
||||
# (alpha = w_scale * a_scale) and inverse a2 scale.
|
||||
if (
|
||||
self.fp8_backend == Fp8MoeBackend.FLASHINFER_CUTLASS
|
||||
and not self.block_quant
|
||||
):
|
||||
g1_alphas, g2_alphas = make_fp8_moe_alpha_scales_for_fi(
|
||||
w1_scale,
|
||||
a1_scale,
|
||||
w2_scale,
|
||||
a2_scale,
|
||||
)
|
||||
return fp8_w8a8_moe_quant_config(
|
||||
w1_scale=w1_scale,
|
||||
w2_scale=w2_scale,
|
||||
a1_scale=a1_scale,
|
||||
a2_scale=(1.0 / a2_scale),
|
||||
g1_alphas=g1_alphas,
|
||||
g2_alphas=g2_alphas,
|
||||
)
|
||||
|
||||
# All other backends use normal config.
|
||||
return fp8_w8a8_moe_quant_config(
|
||||
w1_scale=getattr(layer, f"w13_{self.weight_scale_name}"),
|
||||
w2_scale=getattr(layer, f"w2_{self.weight_scale_name}"),
|
||||
a1_scale=layer.w13_input_scale,
|
||||
a2_scale=layer.w2_input_scale,
|
||||
w1_scale=w1_scale,
|
||||
w2_scale=w2_scale,
|
||||
a1_scale=a1_scale,
|
||||
a2_scale=a2_scale,
|
||||
block_shape=self.weight_block_size,
|
||||
)
|
||||
|
||||
@@ -1414,7 +1472,13 @@ class Fp8OnlineMoEMethod(Fp8MoEMethod):
|
||||
|
||||
# Shuffle weights into the runtime format.
|
||||
self._convert_weights_to_kernel_format(
|
||||
layer, w13_weight, w2_weight, layer.w13_weight_scale, layer.w2_weight_scale
|
||||
layer=layer,
|
||||
w13_weight=w13_weight,
|
||||
w2_weight=w2_weight,
|
||||
w13_weight_scale=layer.w13_weight_scale,
|
||||
w2_weight_scale=layer.w2_weight_scale,
|
||||
w13_input_scale=None,
|
||||
w2_input_scale=None,
|
||||
)
|
||||
|
||||
# Setup modular kernel for TP case.
|
||||
|
||||
Reference in New Issue
Block a user