[Bugfix] Use latency MOE backend as default for Flashinfer and other misc fixes (#27439)

Signed-off-by: Pavani Majety <pmajety@nvidia.com>
This commit is contained in:
Pavani Majety
2025-11-07 04:18:39 -08:00
committed by GitHub
parent e0919f331d
commit 72b1c2ae2c
7 changed files with 47 additions and 12 deletions

View File

@@ -50,6 +50,9 @@ class CompressedTensorsW4A4Fp4(CompressedTensorsScheme):
elif envs.VLLM_NVFP4_GEMM_BACKEND.startswith("flashinfer-"):
self.backend = envs.VLLM_NVFP4_GEMM_BACKEND
assert has_flashinfer(), f"FlashInfer is required for {self.backend}"
elif envs.VLLM_NVFP4_GEMM_BACKEND == "cutlass":
self.backend = "cutlass"
assert cutlass_fp4_supported(), f"Cutlass is required for {self.backend}"
if self.backend == "none":
raise ValueError(

View File

@@ -138,6 +138,13 @@ def get_fp8_moe_backend(block_quant: bool) -> Fp8MoeBackend:
logger.info_once("Using FlashInfer FP8 MoE TRTLLM backend for SM100")
return Fp8MoeBackend.FLASHINFER_TRTLLM
else:
if block_quant:
raise ValueError(
"FlashInfer FP8 MoE throughput backend does not "
"support block quantization. Please use "
"VLLM_FLASHINFER_MOE_BACKEND=latency "
"instead."
)
logger.info_once("Using FlashInfer FP8 MoE CUTLASS backend for SM100")
return Fp8MoeBackend.FLASHINFER_CUTLASS

View File

@@ -221,7 +221,10 @@ class ModelOptFp8Config(QuantizationConfig):
def get_quant_method(
self, layer: torch.nn.Module, prefix: str
) -> Optional["QuantizeMethodBase"]:
from vllm.attention.layer import Attention # Avoid circular import
from vllm.attention.layer import ( # Avoid circular import
Attention,
MLAAttention,
)
if isinstance(layer, LinearBase):
if self.is_layer_excluded(prefix):
@@ -230,7 +233,7 @@ class ModelOptFp8Config(QuantizationConfig):
if "vision_tower" in prefix or "vision_model" in prefix:
return UnquantizedLinearMethod()
return ModelOptFp8LinearMethod(self)
elif isinstance(layer, Attention):
elif isinstance(layer, (Attention, MLAAttention)):
return ModelOptFp8KVCacheMethod(self)
elif isinstance(layer, FusedMoE):
return ModelOptFp8MoEMethod(self, layer)
@@ -888,7 +891,10 @@ class ModelOptNvFp4Config(QuantizationConfig):
def get_quant_method(
self, layer: torch.nn.Module, prefix: str
) -> Optional["QuantizeMethodBase"]:
from vllm.attention.layer import Attention # Avoid circular import
from vllm.attention.layer import ( # Avoid circular import
Attention,
MLAAttention,
)
skip_layer = self.is_layer_excluded(prefix)
if isinstance(layer, LinearBase):
@@ -898,7 +904,7 @@ class ModelOptNvFp4Config(QuantizationConfig):
if "vision_tower" in prefix or "vision_model" in prefix:
return UnquantizedLinearMethod()
return ModelOptNvFp4LinearMethod(self)
elif isinstance(layer, Attention):
elif isinstance(layer, (Attention, MLAAttention)):
return ModelOptFp8KVCacheMethod(self)
elif isinstance(layer, FusedMoE):
if skip_layer:
@@ -941,6 +947,9 @@ class ModelOptNvFp4LinearMethod(LinearMethodBase):
elif envs.VLLM_NVFP4_GEMM_BACKEND.startswith("flashinfer-"):
self.backend = envs.VLLM_NVFP4_GEMM_BACKEND
assert has_flashinfer(), f"FlashInfer is required for {self.backend}"
elif envs.VLLM_NVFP4_GEMM_BACKEND == "cutlass":
self.backend = "cutlass"
assert cutlass_fp4_supported(), f"Cutlass is required for {self.backend}"
if self.backend == "none":
raise ValueError(