diff --git a/vllm/model_executor/layers/batch_invariant.py b/vllm/model_executor/layers/batch_invariant.py index 2f9450244..54b36d103 100644 --- a/vllm/model_executor/layers/batch_invariant.py +++ b/vllm/model_executor/layers/batch_invariant.py @@ -935,11 +935,9 @@ def enable_batch_invariant_mode(): _batch_invariant_MODE = True _batch_invariant_LIB = torch.library.Library("aten", "IMPL") - if ( - current_platform.is_device_capability_family(100) - or current_platform.is_device_capability(80) - or current_platform.is_device_capability(89) - ): + if current_platform.is_device_capability_family( + 100 + ) or current_platform.is_device_capability_family(80): # For PyTorch 2.9, B200 uses GEMV for bs=1 # Requires https://github.com/pytorch/pytorch/pull/166735 _batch_invariant_LIB.impl("aten::mm", mm_batch_invariant, "CUDA")