[Bugfix] Fix FP8 MoE LoRA (#29890)

Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
This commit is contained in:
Jee Jee Li
2025-12-05 02:17:49 +08:00
committed by GitHub
parent 6dcb07f676
commit 652ba93da3

View File

@@ -124,12 +124,16 @@ class Fp8MoeBackend(Enum):
def get_fp8_moe_backend(
block_quant: bool, moe_parallel_config: FusedMoEParallelConfig
block_quant: bool,
moe_parallel_config: FusedMoEParallelConfig,
with_lora_support: bool,
) -> Fp8MoeBackend:
"""
Select the primary FP8 MoE backend
Note: Shape-specific fallbacks may still occur at runtime.
"""
if with_lora_support:
return Fp8MoeBackend.TRITON
# Prefer FlashInfer backends on supported GPUs; allow SM90 and SM100.
if (
current_platform.is_cuda()
@@ -665,7 +669,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
self.weight_block_size = self.quant_config.weight_block_size
self.block_quant: bool = self.weight_block_size is not None
self.fp8_backend = get_fp8_moe_backend(
self.block_quant, layer.moe_parallel_config
self.block_quant, layer.moe_parallel_config, self.moe.is_lora_enabled
)
self.marlin_input_dtype = None
@@ -1084,6 +1088,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
from vllm.model_executor.layers.fused_moe import (
BatchedDeepGemmExperts,
BatchedTritonExperts,
TritonExperts,
TritonOrDeepGemmExperts,
)
@@ -1116,7 +1121,8 @@ class Fp8MoEMethod(FusedMoEMethodBase):
num_dispatchers=prepare_finalize.num_dispatchers(),
quant_config=self.moe_quant_config,
)
elif self.moe.is_lora_enabled:
return TritonExperts(quant_config=self.moe_quant_config)
elif self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS:
# Select GEMM experts with block-scale when weights are block-quantized
experts = select_cutlass_fp8_gemm_impl(