[Torch 2.11] Guard torch._C._cpu attribute checks for forward compatibility (#35673)

Signed-off-by: atalman <atalman@fb.com>
This commit is contained in:
Andrey Talman
2026-03-17 14:47:59 -04:00
committed by GitHub
parent c5030c439d
commit 68f783a727
9 changed files with 10 additions and 12 deletions

View File

@@ -27,7 +27,7 @@ def get_attn_isa(
else:
if current_platform.get_cpu_architecture() == CpuArchEnum.ARM:
return "neon"
elif torch._C._cpu._is_amx_tile_supported():
elif torch.cpu._is_amx_tile_supported():
return "amx"
else:
return "vec"

View File

@@ -24,7 +24,7 @@ except (ImportError, AttributeError) as e:
sys.exit(1)
# ISA selection following test_cpu_fused_moe.py pattern
ISA_CHOICES = ["amx", "vec"] if torch._C._cpu._is_amx_tile_supported() else ["vec"]
ISA_CHOICES = ["amx", "vec"] if torch.cpu._is_amx_tile_supported() else ["vec"]
@torch.inference_mode()

View File

@@ -48,7 +48,7 @@ def get_attn_isa(
else:
if current_platform.get_cpu_architecture() == CpuArchEnum.ARM:
return "neon"
elif torch._C._cpu._is_amx_tile_supported():
elif torch.cpu._is_amx_tile_supported():
return "amx"
else:
return "vec"
@@ -400,9 +400,7 @@ def test_varlen_with_paged_kv_normal_vec(
@pytest.mark.parametrize("use_alibi", [False])
@pytest.mark.parametrize("use_sink", [False])
@pytest.mark.parametrize("isa", ["amx"])
@pytest.mark.skipif(
not torch._C._cpu._is_amx_tile_supported(), reason="no AMX support."
)
@pytest.mark.skipif(not torch.cpu._is_amx_tile_supported(), reason="no AMX support.")
def test_varlen_with_paged_kv_normal_amx(
seq_lens: list[tuple[int, int]],
num_heads: tuple[int, int],

View File

@@ -22,7 +22,7 @@ INTERMEDIATE_DIM = [128, 2880]
BATCH_SIZE = [1, 64, 256]
ACT = [MoEActivation.SILU, MoEActivation.SWIGLUOAI]
USE_BIAS = [True, False]
ISA = ["amx", "vec"] if torch._C._cpu._is_amx_tile_supported() else ["vec"]
ISA = ["amx", "vec"] if torch.cpu._is_amx_tile_supported() else ["vec"]
DTYPE = [torch.bfloat16]

View File

@@ -119,7 +119,7 @@ class CPUWNA16LinearKernel(MPLinearKernel):
def _get_isa_hint(dtype: torch.dtype) -> str:
supports_amx = torch._C._cpu._is_amx_tile_supported()
supports_amx = torch.cpu._is_amx_tile_supported()
if supports_amx and dtype in (torch.bfloat16,):
return "amx"
else:

View File

@@ -280,7 +280,7 @@ class CPUFusedMOE:
if not (w13_output_size % 32 == 0 and w2_output_size % 32 == 0):
return False, "none"
supports_amx = torch._C._cpu._is_amx_tile_supported()
supports_amx = torch.cpu._is_amx_tile_supported()
if (
supports_amx

View File

@@ -292,7 +292,7 @@ class CPUAWQLinearMethod(LinearMethodBase):
def _get_isa_hint(dtype: torch.dtype) -> str:
supports_amx = torch._C._cpu._is_amx_tile_supported()
supports_amx = torch.cpu._is_amx_tile_supported()
if supports_amx and dtype in (torch.bfloat16,):
return "amx"
else:

View File

@@ -212,7 +212,7 @@ direct_register_custom_op(
def check_cpu_sgl_kernel(n: int, k: int, dtype: torch.dtype) -> bool:
return (
torch._C._cpu._is_amx_tile_supported()
torch.cpu._is_amx_tile_supported()
and (dtype in (torch.bfloat16, torch.int8))
and k % 32 == 0
and n % 16 == 0

View File

@@ -482,7 +482,7 @@ def _get_attn_isa(
) -> str:
if head_size is not None and head_size % 32 != 0 and head_size % 16 == 0:
return "vec16"
supports_amx = torch._C._cpu._is_amx_tile_supported()
supports_amx = torch.cpu._is_amx_tile_supported()
supports_arm = current_platform.get_cpu_architecture() == CpuArchEnum.ARM
supports_vxe = current_platform.get_cpu_architecture() == CpuArchEnum.S390X
if supports_amx and dtype in (torch.bfloat16,) and block_size % 32 == 0: