diff --git a/csrc/cpu/generate_cpu_attn_dispatch.py b/csrc/cpu/generate_cpu_attn_dispatch.py index f1d08017f..bbcd6d85b 100644 --- a/csrc/cpu/generate_cpu_attn_dispatch.py +++ b/csrc/cpu/generate_cpu_attn_dispatch.py @@ -8,7 +8,7 @@ Generate CPU attention dispatch switch cases and kernel instantiations. import os # Head dimensions divisible by 32 (support all ISAs) -HEAD_DIMS_32 = [32, 64, 96, 128, 160, 192, 224, 256] +HEAD_DIMS_32 = [32, 64, 96, 128, 160, 192, 224, 256, 512] # Head dimensions divisible by 16 but not 32 (VEC16 only) HEAD_DIMS_16 = [80, 112] diff --git a/docs/design/attention_backends.md b/docs/design/attention_backends.md index 47ac91464..6c0ed4672 100644 --- a/docs/design/attention_backends.md +++ b/docs/design/attention_backends.md @@ -165,7 +165,7 @@ Priority is **1 = highest** (tried first). | Backend | Version | Dtypes | KV Dtypes | Block Sizes | Head Sizes | Sink | MM Prefix | DCP | Attention Types | Compute Cap. | | ------- | ------- | ------ | --------- | ----------- | ---------- | ---- | --------- | --- | --------------- | ------------ | -| `CPU_ATTN` | | fp16, bf16, fp32 | `auto` | Any | 32, 64, 80, 96, 112, 128, 160, 192, 224, 256 | ❌ | ❌ | ❌ | All | N/A | +| `CPU_ATTN` | | fp16, bf16, fp32 | `auto` | Any | 32, 64, 80, 96, 112, 128, 160, 192, 224, 256, 512 | ❌ | ❌ | ❌ | All | N/A | | `FLASHINFER` | Native† | fp16, bf16 | `auto`, `float16`, `bfloat16`, `fp8`, `fp8_e4m3`, `fp8_e5m2` | 16, 32, 64 | 64, 128, 256 | ❌ | ❌ | ✅ | Decoder | 7.x-9.x | | `FLASHINFER` | TRTLLM† | fp16, bf16 | `auto`, `float16`, `bfloat16`, `fp8`, `fp8_e4m3`, `fp8_e5m2` | 16, 32, 64 | 64, 128, 256 | ✅ | ❌ | ✅ | Decoder | 10.x | | `FLASH_ATTN` | FA2* | fp16, bf16 | `auto`, `float16`, `bfloat16` | %16 | Any | ❌ | ❌ | ✅ | All | ≥8.0 | diff --git a/tests/kernels/attention/test_cpu_attn.py b/tests/kernels/attention/test_cpu_attn.py index 7e3d77134..f7691a90e 100644 --- a/tests/kernels/attention/test_cpu_attn.py +++ b/tests/kernels/attention/test_cpu_attn.py @@ -25,7 +25,7 @@ NUM_HEADS = [ (8, 2), (9, 3), ] -HEAD_SIZES = [96, 128] +HEAD_SIZES = [96, 128, 512] HEAD_SIZES_VEC16 = [96, 80, 112, 128] QTYPES = [torch.bfloat16, torch.half, torch.float32] SLIDING_WINDOWS = [None, 256] diff --git a/vllm/v1/attention/backends/cpu_attn.py b/vllm/v1/attention/backends/cpu_attn.py index 1df0fe654..90151a251 100644 --- a/vllm/v1/attention/backends/cpu_attn.py +++ b/vllm/v1/attention/backends/cpu_attn.py @@ -38,7 +38,7 @@ class CPUAttentionBackend(AttentionBackend): @classmethod def get_supported_head_sizes(cls) -> list[int]: - return [32, 64, 80, 96, 112, 128, 160, 192, 224, 256] + return [32, 64, 80, 96, 112, 128, 160, 192, 224, 256, 512] @staticmethod def get_name() -> str: