[CPU] Add head sizes 80 and 112 with vec16 fallback (#31968)

Signed-off-by: Rehan Khan <Rehan.Khan7@ibm.com>
This commit is contained in:
R3hankhan
2026-01-09 19:44:46 +05:30
committed by GitHub
parent 7cdf7e2fe0
commit 8e27663b6a
4 changed files with 12 additions and 5 deletions

View File

@@ -42,7 +42,7 @@ class CPUAttentionBackend(AttentionBackend):
@classmethod
def get_supported_head_sizes(cls) -> list[int]:
return [32, 64, 96, 128, 160, 192, 224, 256]
return [32, 64, 80, 96, 112, 128, 160, 192, 224, 256]
@staticmethod
def get_name() -> str:
@@ -137,7 +137,7 @@ class CPUAttentionMetadataBuilder(AttentionMetadataBuilder[CPUAttentionMetadata]
if self.window_size is None:
self.window_size = -1
self.block_size = vllm_config.cache_config.block_size
self.isa = _get_attn_isa(self.dtype, self.block_size)
self.isa = _get_attn_isa(self.dtype, self.block_size, self.head_dim)
self.is_cross_attention = isinstance(kv_cache_spec, CrossAttentionSpec)
def build(
@@ -484,7 +484,11 @@ def _make_sliding_window_bias(
return attn_biases
def _get_attn_isa(dtype: torch.dtype, block_size: int) -> str:
def _get_attn_isa(
dtype: torch.dtype, block_size: int, head_size: int | None = None
) -> str:
if head_size is not None and head_size % 32 != 0 and head_size % 16 == 0:
return "vec16"
supports_amx = torch._C._cpu._is_amx_tile_supported()
if supports_amx and dtype in (torch.bfloat16,) and block_size % 32 == 0:
return "amx"