[Torch 2.11] Guard torch._C._cpu attribute checks for forward compatibility (#35673)
Signed-off-by: atalman <atalman@fb.com>
This commit is contained in:
@@ -27,7 +27,7 @@ def get_attn_isa(
|
||||
else:
|
||||
if current_platform.get_cpu_architecture() == CpuArchEnum.ARM:
|
||||
return "neon"
|
||||
elif torch._C._cpu._is_amx_tile_supported():
|
||||
elif torch.cpu._is_amx_tile_supported():
|
||||
return "amx"
|
||||
else:
|
||||
return "vec"
|
||||
|
||||
@@ -24,7 +24,7 @@ except (ImportError, AttributeError) as e:
|
||||
sys.exit(1)
|
||||
|
||||
# ISA selection following test_cpu_fused_moe.py pattern
|
||||
ISA_CHOICES = ["amx", "vec"] if torch._C._cpu._is_amx_tile_supported() else ["vec"]
|
||||
ISA_CHOICES = ["amx", "vec"] if torch.cpu._is_amx_tile_supported() else ["vec"]
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
|
||||
@@ -48,7 +48,7 @@ def get_attn_isa(
|
||||
else:
|
||||
if current_platform.get_cpu_architecture() == CpuArchEnum.ARM:
|
||||
return "neon"
|
||||
elif torch._C._cpu._is_amx_tile_supported():
|
||||
elif torch.cpu._is_amx_tile_supported():
|
||||
return "amx"
|
||||
else:
|
||||
return "vec"
|
||||
@@ -400,9 +400,7 @@ def test_varlen_with_paged_kv_normal_vec(
|
||||
@pytest.mark.parametrize("use_alibi", [False])
|
||||
@pytest.mark.parametrize("use_sink", [False])
|
||||
@pytest.mark.parametrize("isa", ["amx"])
|
||||
@pytest.mark.skipif(
|
||||
not torch._C._cpu._is_amx_tile_supported(), reason="no AMX support."
|
||||
)
|
||||
@pytest.mark.skipif(not torch.cpu._is_amx_tile_supported(), reason="no AMX support.")
|
||||
def test_varlen_with_paged_kv_normal_amx(
|
||||
seq_lens: list[tuple[int, int]],
|
||||
num_heads: tuple[int, int],
|
||||
|
||||
@@ -22,7 +22,7 @@ INTERMEDIATE_DIM = [128, 2880]
|
||||
BATCH_SIZE = [1, 64, 256]
|
||||
ACT = [MoEActivation.SILU, MoEActivation.SWIGLUOAI]
|
||||
USE_BIAS = [True, False]
|
||||
ISA = ["amx", "vec"] if torch._C._cpu._is_amx_tile_supported() else ["vec"]
|
||||
ISA = ["amx", "vec"] if torch.cpu._is_amx_tile_supported() else ["vec"]
|
||||
DTYPE = [torch.bfloat16]
|
||||
|
||||
|
||||
|
||||
@@ -119,7 +119,7 @@ class CPUWNA16LinearKernel(MPLinearKernel):
|
||||
|
||||
|
||||
def _get_isa_hint(dtype: torch.dtype) -> str:
|
||||
supports_amx = torch._C._cpu._is_amx_tile_supported()
|
||||
supports_amx = torch.cpu._is_amx_tile_supported()
|
||||
if supports_amx and dtype in (torch.bfloat16,):
|
||||
return "amx"
|
||||
else:
|
||||
|
||||
@@ -280,7 +280,7 @@ class CPUFusedMOE:
|
||||
if not (w13_output_size % 32 == 0 and w2_output_size % 32 == 0):
|
||||
return False, "none"
|
||||
|
||||
supports_amx = torch._C._cpu._is_amx_tile_supported()
|
||||
supports_amx = torch.cpu._is_amx_tile_supported()
|
||||
|
||||
if (
|
||||
supports_amx
|
||||
|
||||
@@ -292,7 +292,7 @@ class CPUAWQLinearMethod(LinearMethodBase):
|
||||
|
||||
|
||||
def _get_isa_hint(dtype: torch.dtype) -> str:
|
||||
supports_amx = torch._C._cpu._is_amx_tile_supported()
|
||||
supports_amx = torch.cpu._is_amx_tile_supported()
|
||||
if supports_amx and dtype in (torch.bfloat16,):
|
||||
return "amx"
|
||||
else:
|
||||
|
||||
@@ -212,7 +212,7 @@ direct_register_custom_op(
|
||||
|
||||
def check_cpu_sgl_kernel(n: int, k: int, dtype: torch.dtype) -> bool:
|
||||
return (
|
||||
torch._C._cpu._is_amx_tile_supported()
|
||||
torch.cpu._is_amx_tile_supported()
|
||||
and (dtype in (torch.bfloat16, torch.int8))
|
||||
and k % 32 == 0
|
||||
and n % 16 == 0
|
||||
|
||||
@@ -482,7 +482,7 @@ def _get_attn_isa(
|
||||
) -> str:
|
||||
if head_size is not None and head_size % 32 != 0 and head_size % 16 == 0:
|
||||
return "vec16"
|
||||
supports_amx = torch._C._cpu._is_amx_tile_supported()
|
||||
supports_amx = torch.cpu._is_amx_tile_supported()
|
||||
supports_arm = current_platform.get_cpu_architecture() == CpuArchEnum.ARM
|
||||
supports_vxe = current_platform.get_cpu_architecture() == CpuArchEnum.S390X
|
||||
if supports_amx and dtype in (torch.bfloat16,) and block_size % 32 == 0:
|
||||
|
||||
Reference in New Issue
Block a user