From 16a65e41736c5e8d27e9e843668f6e3f99d68d9a Mon Sep 17 00:00:00 2001 From: Yusuf Mohammad <79484377+YM2132@users.noreply.github.com> Date: Thu, 2 Apr 2026 14:29:58 +0100 Subject: [PATCH] [Bugfix] Enable batch-invariant Triton matmul on all Ampere GPUs (SM 8x) (#38427) Signed-off-by: yusuf Signed-off-by: yusuf Signed-off-by: Yusuf Mohammad <79484377+YM2132@users.noreply.github.com> Signed-off-by: <> Co-authored-by: Claude Co-authored-by: Wentao Ye <44945378+yewentao256@users.noreply.github.com> Co-authored-by: yusuf --- vllm/model_executor/layers/batch_invariant.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/vllm/model_executor/layers/batch_invariant.py b/vllm/model_executor/layers/batch_invariant.py index 2f9450244..54b36d103 100644 --- a/vllm/model_executor/layers/batch_invariant.py +++ b/vllm/model_executor/layers/batch_invariant.py @@ -935,11 +935,9 @@ def enable_batch_invariant_mode(): _batch_invariant_MODE = True _batch_invariant_LIB = torch.library.Library("aten", "IMPL") - if ( - current_platform.is_device_capability_family(100) - or current_platform.is_device_capability(80) - or current_platform.is_device_capability(89) - ): + if current_platform.is_device_capability_family( + 100 + ) or current_platform.is_device_capability_family(80): # For PyTorch 2.9, B200 uses GEMV for bs=1 # Requires https://github.com/pytorch/pytorch/pull/166735 _batch_invariant_LIB.impl("aten::mm", mm_batch_invariant, "CUDA")