[CPU] Refactor CPU attention backend (#27954)
Signed-off-by: jiang1.li <jiang1.li@intel.com>
This commit is contained in:
@@ -8,7 +8,6 @@ import platform
|
||||
import subprocess
|
||||
import sys
|
||||
from dataclasses import dataclass
|
||||
from importlib.util import find_spec
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import regex as re
|
||||
@@ -139,16 +138,15 @@ class CpuPlatform(Platform):
|
||||
) -> str:
|
||||
from vllm.attention.backends.registry import AttentionBackendEnum
|
||||
|
||||
if selected_backend and selected_backend != AttentionBackendEnum.TORCH_SDPA:
|
||||
if selected_backend and selected_backend != AttentionBackendEnum.CPU_ATTN:
|
||||
logger.info("Cannot use %s backend on CPU.", selected_backend)
|
||||
if use_mla:
|
||||
raise NotImplementedError("MLA is not supported on CPU.")
|
||||
if use_sparse:
|
||||
raise NotImplementedError("Sparse Attention is not supported on CPU.")
|
||||
logger.info("Using Torch SDPA backend.")
|
||||
if not use_v1:
|
||||
raise ValueError("CPU backend only supports V1.")
|
||||
return AttentionBackendEnum.TORCH_SDPA.get_path()
|
||||
return AttentionBackendEnum.CPU_ATTN.get_path()
|
||||
|
||||
@classmethod
|
||||
def get_device_total_memory(cls, device_id: int = 0) -> int:
|
||||
@@ -186,15 +184,13 @@ class CpuPlatform(Platform):
|
||||
|
||||
cache_config = vllm_config.cache_config
|
||||
|
||||
ipex_available = find_spec("intel_extension_for_pytorch") is not None
|
||||
if cache_config.block_size is None:
|
||||
cache_config.block_size = 128
|
||||
|
||||
if cache_config and cache_config.block_size is None:
|
||||
cache_config.block_size = 128 if ipex_available else 16
|
||||
|
||||
if not ipex_available and cache_config.block_size != 16:
|
||||
raise RuntimeError(
|
||||
f"--block-size={cache_config.block_size} requires"
|
||||
" intel_extension_for_pytorch"
|
||||
if cache_config.block_size % 32 != 0:
|
||||
logger.warning(
|
||||
"CPU backend prefers block_size is multiples of 32, "
|
||||
"otherwise the performance is not optimized."
|
||||
)
|
||||
|
||||
scheduler_config = vllm_config.scheduler_config
|
||||
@@ -207,22 +203,11 @@ class CpuPlatform(Platform):
|
||||
"backend is not compatible with FP8 KV cache."
|
||||
)
|
||||
|
||||
if cache_config.cache_dtype == "fp8_e4m3":
|
||||
cache_config.cache_dtype = "fp8_e5m2"
|
||||
if cache_config.cache_dtype != "auto":
|
||||
logger.warning(
|
||||
"CPU backend doesn't support fp8_e4m3 KV cache type, cast to fp8_e5m2."
|
||||
"CPU backend doesn't support KV cache quantization fallback to auto."
|
||||
)
|
||||
|
||||
if (
|
||||
cache_config.cache_dtype != "auto"
|
||||
and model_config is not None
|
||||
and model_config.dtype == torch.half
|
||||
):
|
||||
logger.warning(
|
||||
"FP8 KV cache on the CPU backend only does not"
|
||||
" support fp16 for now, cast to bf16."
|
||||
)
|
||||
model_config.dtype = torch.bfloat16
|
||||
cache_config.cache_dtype = "auto"
|
||||
|
||||
cache_config.cpu_kvcache_space_bytes = CpuPlatform.get_device_total_memory()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user