[Bugfix] Enable batch-invariant Triton matmul on all Ampere GPUs (SM 8x) (#38427)

Signed-off-by: yusuf <yusufmohammad@live.com>
Signed-off-by: yusuf <yusuf@deeplearningmachine.mynet>
Signed-off-by: Yusuf Mohammad <79484377+YM2132@users.noreply.github.com>
Signed-off-by: <>
Co-authored-by: Claude <noreply@anthropic.com>
Co-authored-by: Wentao Ye <44945378+yewentao256@users.noreply.github.com>
Co-authored-by: yusuf <yusuf@deeplearningmachine.mynet>
This commit is contained in:
Yusuf Mohammad
2026-04-02 14:29:58 +01:00
committed by GitHub
parent c0817e4d39
commit 16a65e4173

View File

@@ -935,11 +935,9 @@ def enable_batch_invariant_mode():
_batch_invariant_MODE = True
_batch_invariant_LIB = torch.library.Library("aten", "IMPL")
if (
current_platform.is_device_capability_family(100)
or current_platform.is_device_capability(80)
or current_platform.is_device_capability(89)
):
if current_platform.is_device_capability_family(
100
) or current_platform.is_device_capability_family(80):
# For PyTorch 2.9, B200 uses GEMV for bs=1
# Requires https://github.com/pytorch/pytorch/pull/166735
_batch_invariant_LIB.impl("aten::mm", mm_batch_invariant, "CUDA")