diff --git a/vllm/config/vllm.py b/vllm/config/vllm.py index f525ac871..77a908deb 100644 --- a/vllm/config/vllm.py +++ b/vllm/config/vllm.py @@ -120,7 +120,7 @@ def enable_allreduce_rms_fusion(cfg: "VllmConfig") -> bool: and current_platform.is_cuda() and has_flashinfer() and ( - current_platform.is_device_capability(100) + current_platform.is_device_capability_family(100) or current_platform.is_device_capability(90) ) # tp-dp combination broken: