diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index 7c447c2a5..a4de4d709 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -1099,8 +1099,7 @@ class Fp8MoEMethod(FusedMoEMethodBase): apply_router_weight_on_input=apply_router_weight_on_input, ) else: - from vllm.model_executor.layers.fused_moe import fused_experts - return fused_experts( + common_kwargs = dict( hidden_states=x, w1=layer.w13_weight, w2=layer.w2_weight, @@ -1117,11 +1116,20 @@ class Fp8MoEMethod(FusedMoEMethodBase): if self.block_quant else layer.w2_weight_scale), a1_scale=layer.w13_input_scale, a2_scale=layer.w2_input_scale, - use_fp8_w8a8=True, - block_shape=self.quant_config.weight_block_size, - allow_deep_gemm=self.allow_deep_gemm, - allow_cutlass_block_scaled_grouped_gemm=( - self.allow_cutlass_block_scaled_grouped_gemm)) + ) + + if self.fused_experts is not None: + return self.fused_experts(**common_kwargs) + else: + from vllm.model_executor.layers.fused_moe import fused_experts + return fused_experts( + **common_kwargs, + use_fp8_w8a8=True, + block_shape=self.quant_config.weight_block_size, + allow_deep_gemm=self.allow_deep_gemm, + allow_cutlass_block_scaled_grouped_gemm=( + self.allow_cutlass_block_scaled_grouped_gemm), + ) class Fp8KVCacheMethod(BaseKVCacheMethod):