[Hardware][Intel] Add CPU inference backend (#3634)
Co-authored-by: Kunshang Ji <kunshang.ji@intel.com> Co-authored-by: Yuan Zhou <yuan.zhou@intel.com>
This commit is contained in:
@@ -5,7 +5,7 @@ import torch
|
||||
|
||||
from vllm.attention.backends.abstract import AttentionBackend
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils import is_hip
|
||||
from vllm.utils import is_cpu, is_hip
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@@ -17,6 +17,10 @@ def get_attn_backend(dtype: torch.dtype) -> Type[AttentionBackend]:
|
||||
from vllm.attention.backends.flash_attn import ( # noqa: F401
|
||||
FlashAttentionBackend)
|
||||
return FlashAttentionBackend
|
||||
elif is_cpu():
|
||||
logger.info("Using Torch SDPA backend.")
|
||||
from vllm.attention.backends.torch_sdpa import TorchSDPABackend
|
||||
return TorchSDPABackend
|
||||
else:
|
||||
logger.info("Using XFormers backend.")
|
||||
from vllm.attention.backends.xformers import ( # noqa: F401
|
||||
@@ -29,6 +33,8 @@ def _can_use_flash_attn(dtype: torch.dtype) -> bool:
|
||||
# AMD GPUs.
|
||||
logger.info("Cannot use FlashAttention backend for AMD GPUs.")
|
||||
return False
|
||||
if is_cpu():
|
||||
return False
|
||||
if torch.cuda.get_device_capability()[0] < 8:
|
||||
# Volta and Turing NVIDIA GPUs.
|
||||
logger.info("Cannot use FlashAttention backend for Volta and Turing "
|
||||
|
||||
Reference in New Issue
Block a user