[CPU] Refactor CPU attention backend (#27954)

Signed-off-by: jiang1.li <jiang1.li@intel.com>
This commit is contained in:
Li, Jiang
2025-11-12 09:43:06 +08:00
committed by GitHub
parent e1710393c4
commit 7f829be7d3
34 changed files with 4354 additions and 1902 deletions

View File

@@ -2583,6 +2583,88 @@ def onednn_scaled_mm(
return output
def cpu_attn_get_scheduler_metadata(
num_reqs: int,
num_heads: int,
num_kv_heads: int,
head_dim: int,
seq_lens: torch.Tensor,
dtype: torch.dtype,
query_start_loc: torch.Tensor,
causal: bool,
sliding_window_size: int,
isa: str,
enable_kv_split: bool,
) -> torch.Tensor:
sheduler_metadata = torch.ops._C.get_scheduler_metadata(
num_reqs,
num_heads,
num_kv_heads,
head_dim,
seq_lens,
dtype,
query_start_loc,
causal,
sliding_window_size,
isa,
enable_kv_split,
)
return sheduler_metadata
def cpu_attn_reshape_and_cache(
key: torch.Tensor,
value: torch.Tensor,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
slot_mapping: torch.Tensor,
isa: str,
) -> None:
torch.ops._C.cpu_attn_reshape_and_cache(
key,
value,
key_cache,
value_cache,
slot_mapping,
isa,
)
def cpu_attention_with_kv_cache(
query: torch.Tensor,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
output: torch.Tensor,
query_start_loc: torch.Tensor,
seq_lens: torch.Tensor,
scale: float,
causal: bool,
alibi_slopes: torch.Tensor | None,
sliding_window: tuple[int, int],
block_table: torch.Tensor,
softcap: float,
scheduler_metadata: torch.Tensor,
s_aux: torch.Tensor | None,
) -> None:
torch.ops._C.cpu_attention_with_kv_cache(
query,
key_cache,
value_cache,
output,
query_start_loc,
seq_lens,
scale,
causal,
alibi_slopes,
sliding_window[0],
sliding_window[1],
block_table,
softcap,
scheduler_metadata,
s_aux,
)
if hasattr(torch.ops._qutlass_C, "matmul_mxf4_bf16_tn"):
@register_fake("_qutlass_C::matmul_mxf4_bf16_tn")

View File

@@ -49,7 +49,7 @@ class AttentionBackendEnum(enum.Enum, metaclass=_AttentionBackendEnumMeta):
ROCM_AITER_FA = (
"vllm.v1.attention.backends.rocm_aiter_fa.AiterFlashAttentionBackend"
)
TORCH_SDPA = "vllm.v1.attention.backends.cpu_attn.TorchSDPABackend"
TORCH_SDPA = "" # this tag is only used for ViT
FLASHINFER = "vllm.v1.attention.backends.flashinfer.FlashInferBackend"
FLASHINFER_MLA = (
"vllm.v1.attention.backends.mla.flashinfer_mla.FlashInferMLABackend"
@@ -70,6 +70,7 @@ class AttentionBackendEnum(enum.Enum, metaclass=_AttentionBackendEnumMeta):
"vllm.v1.attention.backends.rocm_aiter_unified_attn."
"RocmAiterUnifiedAttentionBackend"
)
CPU_ATTN = "vllm.v1.attention.backends.cpu_attn.CPUAttentionBackend"
# Placeholder for third-party/custom backends - must be registered before use
CUSTOM = ""

View File

@@ -1726,9 +1726,6 @@ class EngineArgs:
)
_raise_unsupported_error(feature_name=name)
if current_platform.is_cpu() and model_config.get_sliding_window() is not None:
_raise_unsupported_error(feature_name="sliding window (CPU backend)")
def _set_default_args(
self, usage_context: UsageContext, model_config: ModelConfig
) -> None:

View File

@@ -8,7 +8,6 @@ import platform
import subprocess
import sys
from dataclasses import dataclass
from importlib.util import find_spec
from typing import TYPE_CHECKING
import regex as re
@@ -139,16 +138,15 @@ class CpuPlatform(Platform):
) -> str:
from vllm.attention.backends.registry import AttentionBackendEnum
if selected_backend and selected_backend != AttentionBackendEnum.TORCH_SDPA:
if selected_backend and selected_backend != AttentionBackendEnum.CPU_ATTN:
logger.info("Cannot use %s backend on CPU.", selected_backend)
if use_mla:
raise NotImplementedError("MLA is not supported on CPU.")
if use_sparse:
raise NotImplementedError("Sparse Attention is not supported on CPU.")
logger.info("Using Torch SDPA backend.")
if not use_v1:
raise ValueError("CPU backend only supports V1.")
return AttentionBackendEnum.TORCH_SDPA.get_path()
return AttentionBackendEnum.CPU_ATTN.get_path()
@classmethod
def get_device_total_memory(cls, device_id: int = 0) -> int:
@@ -186,15 +184,13 @@ class CpuPlatform(Platform):
cache_config = vllm_config.cache_config
ipex_available = find_spec("intel_extension_for_pytorch") is not None
if cache_config.block_size is None:
cache_config.block_size = 128
if cache_config and cache_config.block_size is None:
cache_config.block_size = 128 if ipex_available else 16
if not ipex_available and cache_config.block_size != 16:
raise RuntimeError(
f"--block-size={cache_config.block_size} requires"
" intel_extension_for_pytorch"
if cache_config.block_size % 32 != 0:
logger.warning(
"CPU backend prefers block_size is multiples of 32, "
"otherwise the performance is not optimized."
)
scheduler_config = vllm_config.scheduler_config
@@ -207,22 +203,11 @@ class CpuPlatform(Platform):
"backend is not compatible with FP8 KV cache."
)
if cache_config.cache_dtype == "fp8_e4m3":
cache_config.cache_dtype = "fp8_e5m2"
if cache_config.cache_dtype != "auto":
logger.warning(
"CPU backend doesn't support fp8_e4m3 KV cache type, cast to fp8_e5m2."
"CPU backend doesn't support KV cache quantization fallback to auto."
)
if (
cache_config.cache_dtype != "auto"
and model_config is not None
and model_config.dtype == torch.half
):
logger.warning(
"FP8 KV cache on the CPU backend only does not"
" support fp16 for now, cast to bf16."
)
model_config.dtype = torch.bfloat16
cache_config.cache_dtype = "auto"
cache_config.cpu_kvcache_space_bytes = CpuPlatform.get_device_total_memory()

View File

@@ -57,7 +57,6 @@ STR_BACKEND_ENV_VAR: str = "VLLM_ATTENTION_BACKEND"
# Possible string values of STR_BACKEND_ENV_VAR
# register, corresponding to possible backends
STR_FLASHINFER_ATTN_VAL: str = "FLASHINFER"
STR_TORCH_SDPA_ATTN_VAL: str = "TORCH_SDPA"
STR_XFORMERS_ATTN_VAL: str = "XFORMERS"
STR_FLASH_ATTN_VAL: str = "FLASH_ATTN"
STR_INVALID_VAL: str = "INVALID"

File diff suppressed because it is too large Load Diff

View File

@@ -265,7 +265,7 @@ class AttentionMetadataBuilder(abc.ABC, Generic[M]):
def _init_reorder_batch_threshold(
self,
reorder_batch_threshold: int = 1,
reorder_batch_threshold: int | None = 1,
supports_spec_as_decode: bool = False,
supports_dcp_with_varlen: bool = False,
) -> None:

View File

@@ -1,7 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from contextlib import contextmanager
from typing import TYPE_CHECKING, Any
from typing import Any
import torch
import torch.nn as nn
@@ -12,9 +12,6 @@ from vllm.model_executor.model_loader import get_model
from vllm.v1.utils import CpuGpuBuffer
from vllm.v1.worker.gpu_model_runner import GPUModelRunner
if TYPE_CHECKING:
from vllm.v1.core.sched.output import SchedulerOutput
logger = init_logger(__name__)
@@ -31,15 +28,6 @@ class CPUModelRunner(GPUModelRunner):
self._postprocess_tensors()
# Note: Remove the override after new attention backend finished
def _may_reorder_batch(self, scheduler_output: "SchedulerOutput") -> None:
if len(self.kv_cache_config.kv_cache_groups) > 1:
raise ValueError(
"Multiple KVCacheGroups is not"
"currently supported with CPU model runner."
)
super()._may_reorder_batch(scheduler_output)
def _postprocess_tensors(self) -> None:
# Note: replace device tensors with cpu tensors
def replace_tensor(obj: Any, cpu_attr_name: str, device_attr_name) -> None: