[CPU] Refactor CPU attention backend (#27954)
Signed-off-by: jiang1.li <jiang1.li@intel.com>
This commit is contained in:
@@ -2583,6 +2583,88 @@ def onednn_scaled_mm(
|
||||
return output
|
||||
|
||||
|
||||
def cpu_attn_get_scheduler_metadata(
|
||||
num_reqs: int,
|
||||
num_heads: int,
|
||||
num_kv_heads: int,
|
||||
head_dim: int,
|
||||
seq_lens: torch.Tensor,
|
||||
dtype: torch.dtype,
|
||||
query_start_loc: torch.Tensor,
|
||||
causal: bool,
|
||||
sliding_window_size: int,
|
||||
isa: str,
|
||||
enable_kv_split: bool,
|
||||
) -> torch.Tensor:
|
||||
sheduler_metadata = torch.ops._C.get_scheduler_metadata(
|
||||
num_reqs,
|
||||
num_heads,
|
||||
num_kv_heads,
|
||||
head_dim,
|
||||
seq_lens,
|
||||
dtype,
|
||||
query_start_loc,
|
||||
causal,
|
||||
sliding_window_size,
|
||||
isa,
|
||||
enable_kv_split,
|
||||
)
|
||||
return sheduler_metadata
|
||||
|
||||
|
||||
def cpu_attn_reshape_and_cache(
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
key_cache: torch.Tensor,
|
||||
value_cache: torch.Tensor,
|
||||
slot_mapping: torch.Tensor,
|
||||
isa: str,
|
||||
) -> None:
|
||||
torch.ops._C.cpu_attn_reshape_and_cache(
|
||||
key,
|
||||
value,
|
||||
key_cache,
|
||||
value_cache,
|
||||
slot_mapping,
|
||||
isa,
|
||||
)
|
||||
|
||||
|
||||
def cpu_attention_with_kv_cache(
|
||||
query: torch.Tensor,
|
||||
key_cache: torch.Tensor,
|
||||
value_cache: torch.Tensor,
|
||||
output: torch.Tensor,
|
||||
query_start_loc: torch.Tensor,
|
||||
seq_lens: torch.Tensor,
|
||||
scale: float,
|
||||
causal: bool,
|
||||
alibi_slopes: torch.Tensor | None,
|
||||
sliding_window: tuple[int, int],
|
||||
block_table: torch.Tensor,
|
||||
softcap: float,
|
||||
scheduler_metadata: torch.Tensor,
|
||||
s_aux: torch.Tensor | None,
|
||||
) -> None:
|
||||
torch.ops._C.cpu_attention_with_kv_cache(
|
||||
query,
|
||||
key_cache,
|
||||
value_cache,
|
||||
output,
|
||||
query_start_loc,
|
||||
seq_lens,
|
||||
scale,
|
||||
causal,
|
||||
alibi_slopes,
|
||||
sliding_window[0],
|
||||
sliding_window[1],
|
||||
block_table,
|
||||
softcap,
|
||||
scheduler_metadata,
|
||||
s_aux,
|
||||
)
|
||||
|
||||
|
||||
if hasattr(torch.ops._qutlass_C, "matmul_mxf4_bf16_tn"):
|
||||
|
||||
@register_fake("_qutlass_C::matmul_mxf4_bf16_tn")
|
||||
|
||||
@@ -49,7 +49,7 @@ class AttentionBackendEnum(enum.Enum, metaclass=_AttentionBackendEnumMeta):
|
||||
ROCM_AITER_FA = (
|
||||
"vllm.v1.attention.backends.rocm_aiter_fa.AiterFlashAttentionBackend"
|
||||
)
|
||||
TORCH_SDPA = "vllm.v1.attention.backends.cpu_attn.TorchSDPABackend"
|
||||
TORCH_SDPA = "" # this tag is only used for ViT
|
||||
FLASHINFER = "vllm.v1.attention.backends.flashinfer.FlashInferBackend"
|
||||
FLASHINFER_MLA = (
|
||||
"vllm.v1.attention.backends.mla.flashinfer_mla.FlashInferMLABackend"
|
||||
@@ -70,6 +70,7 @@ class AttentionBackendEnum(enum.Enum, metaclass=_AttentionBackendEnumMeta):
|
||||
"vllm.v1.attention.backends.rocm_aiter_unified_attn."
|
||||
"RocmAiterUnifiedAttentionBackend"
|
||||
)
|
||||
CPU_ATTN = "vllm.v1.attention.backends.cpu_attn.CPUAttentionBackend"
|
||||
# Placeholder for third-party/custom backends - must be registered before use
|
||||
CUSTOM = ""
|
||||
|
||||
|
||||
@@ -1726,9 +1726,6 @@ class EngineArgs:
|
||||
)
|
||||
_raise_unsupported_error(feature_name=name)
|
||||
|
||||
if current_platform.is_cpu() and model_config.get_sliding_window() is not None:
|
||||
_raise_unsupported_error(feature_name="sliding window (CPU backend)")
|
||||
|
||||
def _set_default_args(
|
||||
self, usage_context: UsageContext, model_config: ModelConfig
|
||||
) -> None:
|
||||
|
||||
@@ -8,7 +8,6 @@ import platform
|
||||
import subprocess
|
||||
import sys
|
||||
from dataclasses import dataclass
|
||||
from importlib.util import find_spec
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import regex as re
|
||||
@@ -139,16 +138,15 @@ class CpuPlatform(Platform):
|
||||
) -> str:
|
||||
from vllm.attention.backends.registry import AttentionBackendEnum
|
||||
|
||||
if selected_backend and selected_backend != AttentionBackendEnum.TORCH_SDPA:
|
||||
if selected_backend and selected_backend != AttentionBackendEnum.CPU_ATTN:
|
||||
logger.info("Cannot use %s backend on CPU.", selected_backend)
|
||||
if use_mla:
|
||||
raise NotImplementedError("MLA is not supported on CPU.")
|
||||
if use_sparse:
|
||||
raise NotImplementedError("Sparse Attention is not supported on CPU.")
|
||||
logger.info("Using Torch SDPA backend.")
|
||||
if not use_v1:
|
||||
raise ValueError("CPU backend only supports V1.")
|
||||
return AttentionBackendEnum.TORCH_SDPA.get_path()
|
||||
return AttentionBackendEnum.CPU_ATTN.get_path()
|
||||
|
||||
@classmethod
|
||||
def get_device_total_memory(cls, device_id: int = 0) -> int:
|
||||
@@ -186,15 +184,13 @@ class CpuPlatform(Platform):
|
||||
|
||||
cache_config = vllm_config.cache_config
|
||||
|
||||
ipex_available = find_spec("intel_extension_for_pytorch") is not None
|
||||
if cache_config.block_size is None:
|
||||
cache_config.block_size = 128
|
||||
|
||||
if cache_config and cache_config.block_size is None:
|
||||
cache_config.block_size = 128 if ipex_available else 16
|
||||
|
||||
if not ipex_available and cache_config.block_size != 16:
|
||||
raise RuntimeError(
|
||||
f"--block-size={cache_config.block_size} requires"
|
||||
" intel_extension_for_pytorch"
|
||||
if cache_config.block_size % 32 != 0:
|
||||
logger.warning(
|
||||
"CPU backend prefers block_size is multiples of 32, "
|
||||
"otherwise the performance is not optimized."
|
||||
)
|
||||
|
||||
scheduler_config = vllm_config.scheduler_config
|
||||
@@ -207,22 +203,11 @@ class CpuPlatform(Platform):
|
||||
"backend is not compatible with FP8 KV cache."
|
||||
)
|
||||
|
||||
if cache_config.cache_dtype == "fp8_e4m3":
|
||||
cache_config.cache_dtype = "fp8_e5m2"
|
||||
if cache_config.cache_dtype != "auto":
|
||||
logger.warning(
|
||||
"CPU backend doesn't support fp8_e4m3 KV cache type, cast to fp8_e5m2."
|
||||
"CPU backend doesn't support KV cache quantization fallback to auto."
|
||||
)
|
||||
|
||||
if (
|
||||
cache_config.cache_dtype != "auto"
|
||||
and model_config is not None
|
||||
and model_config.dtype == torch.half
|
||||
):
|
||||
logger.warning(
|
||||
"FP8 KV cache on the CPU backend only does not"
|
||||
" support fp16 for now, cast to bf16."
|
||||
)
|
||||
model_config.dtype = torch.bfloat16
|
||||
cache_config.cache_dtype = "auto"
|
||||
|
||||
cache_config.cpu_kvcache_space_bytes = CpuPlatform.get_device_total_memory()
|
||||
|
||||
|
||||
@@ -57,7 +57,6 @@ STR_BACKEND_ENV_VAR: str = "VLLM_ATTENTION_BACKEND"
|
||||
# Possible string values of STR_BACKEND_ENV_VAR
|
||||
# register, corresponding to possible backends
|
||||
STR_FLASHINFER_ATTN_VAL: str = "FLASHINFER"
|
||||
STR_TORCH_SDPA_ATTN_VAL: str = "TORCH_SDPA"
|
||||
STR_XFORMERS_ATTN_VAL: str = "XFORMERS"
|
||||
STR_FLASH_ATTN_VAL: str = "FLASH_ATTN"
|
||||
STR_INVALID_VAL: str = "INVALID"
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -265,7 +265,7 @@ class AttentionMetadataBuilder(abc.ABC, Generic[M]):
|
||||
|
||||
def _init_reorder_batch_threshold(
|
||||
self,
|
||||
reorder_batch_threshold: int = 1,
|
||||
reorder_batch_threshold: int | None = 1,
|
||||
supports_spec_as_decode: bool = False,
|
||||
supports_dcp_with_varlen: bool = False,
|
||||
) -> None:
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from contextlib import contextmanager
|
||||
from typing import TYPE_CHECKING, Any
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@@ -12,9 +12,6 @@ from vllm.model_executor.model_loader import get_model
|
||||
from vllm.v1.utils import CpuGpuBuffer
|
||||
from vllm.v1.worker.gpu_model_runner import GPUModelRunner
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.v1.core.sched.output import SchedulerOutput
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@@ -31,15 +28,6 @@ class CPUModelRunner(GPUModelRunner):
|
||||
|
||||
self._postprocess_tensors()
|
||||
|
||||
# Note: Remove the override after new attention backend finished
|
||||
def _may_reorder_batch(self, scheduler_output: "SchedulerOutput") -> None:
|
||||
if len(self.kv_cache_config.kv_cache_groups) > 1:
|
||||
raise ValueError(
|
||||
"Multiple KVCacheGroups is not"
|
||||
"currently supported with CPU model runner."
|
||||
)
|
||||
super()._may_reorder_batch(scheduler_output)
|
||||
|
||||
def _postprocess_tensors(self) -> None:
|
||||
# Note: replace device tensors with cpu tensors
|
||||
def replace_tensor(obj: Any, cpu_attr_name: str, device_attr_name) -> None:
|
||||
|
||||
Reference in New Issue
Block a user