[Bugfix] Enable batch-invariant Triton matmul on all Ampere GPUs (SM 8x) (#38427)
Signed-off-by: yusuf <yusufmohammad@live.com> Signed-off-by: yusuf <yusuf@deeplearningmachine.mynet> Signed-off-by: Yusuf Mohammad <79484377+YM2132@users.noreply.github.com> Signed-off-by: <> Co-authored-by: Claude <noreply@anthropic.com> Co-authored-by: Wentao Ye <44945378+yewentao256@users.noreply.github.com> Co-authored-by: yusuf <yusuf@deeplearningmachine.mynet>
This commit is contained in:
@@ -935,11 +935,9 @@ def enable_batch_invariant_mode():
|
||||
_batch_invariant_MODE = True
|
||||
_batch_invariant_LIB = torch.library.Library("aten", "IMPL")
|
||||
|
||||
if (
|
||||
current_platform.is_device_capability_family(100)
|
||||
or current_platform.is_device_capability(80)
|
||||
or current_platform.is_device_capability(89)
|
||||
):
|
||||
if current_platform.is_device_capability_family(
|
||||
100
|
||||
) or current_platform.is_device_capability_family(80):
|
||||
# For PyTorch 2.9, B200 uses GEMV for bs=1
|
||||
# Requires https://github.com/pytorch/pytorch/pull/166735
|
||||
_batch_invariant_LIB.impl("aten::mm", mm_batch_invariant, "CUDA")
|
||||
|
||||
Reference in New Issue
Block a user