[Core] Deprecate xformers (#29262)

Signed-off-by: Roger Wang <hey@rogerw.io>
This commit is contained in:
Roger Wang
2025-11-23 20:18:55 -08:00
committed by GitHub
parent 5253f4276f
commit 0ff70821c9
31 changed files with 77 additions and 963 deletions

View File

@@ -43,7 +43,6 @@ class AttentionBackendEnum(Enum, metaclass=_AttentionBackendEnumMeta):
FLASH_ATTN = "vllm.v1.attention.backends.flash_attn.FlashAttentionBackend"
TRITON_ATTN = "vllm.v1.attention.backends.triton_attn.TritonAttentionBackend"
XFORMERS = "vllm.v1.attention.backends.xformers.XFormersAttentionBackend"
ROCM_ATTN = "vllm.v1.attention.backends.rocm_attn.RocmAttentionBackend"
ROCM_AITER_MLA = "vllm.v1.attention.backends.mla.rocm_aiter_mla.AiterMLABackend"
ROCM_AITER_TRITON_MLA = (

View File

@@ -51,31 +51,6 @@ else:
FP8_DTYPE = current_platform.fp8_dtype()
logger = init_logger(__name__)
USE_XFORMERS_OPS = None
def check_xformers_availability():
global USE_XFORMERS_OPS
if USE_XFORMERS_OPS is not None:
return USE_XFORMERS_OPS
if current_platform.is_cuda() and current_platform.has_device_capability(100):
# Xformers FA is not compatible with B200
USE_XFORMERS_OPS = False
else:
try:
from importlib.util import find_spec
find_spec("xformers.ops")
USE_XFORMERS_OPS = True
except ImportError:
USE_XFORMERS_OPS = False
# the warning only needs to be shown once
if not USE_XFORMERS_OPS:
logger.warning("Xformers is not available, falling back.")
return USE_XFORMERS_OPS
def check_upstream_fa_availability(dtype: torch.dtype):
@@ -533,7 +508,6 @@ class MultiHeadAttention(nn.Module):
if backend
in {
AttentionBackendEnum.TORCH_SDPA,
AttentionBackendEnum.XFORMERS,
AttentionBackendEnum.PALLAS,
AttentionBackendEnum.ROCM_AITER_FA,
AttentionBackendEnum.FLASH_ATTN,
@@ -549,12 +523,6 @@ class MultiHeadAttention(nn.Module):
)
)
if (
self.attn_backend == AttentionBackendEnum.XFORMERS
and not check_xformers_availability()
):
self.attn_backend = AttentionBackendEnum.TORCH_SDPA
self.is_flash_attn_backend = self.attn_backend in {
AttentionBackendEnum.FLASH_ATTN,
AttentionBackendEnum.ROCM_AITER_FA,
@@ -614,12 +582,6 @@ class MultiHeadAttention(nn.Module):
max_seqlen_k=kv_len,
softmax_scale=self.scale,
)
elif self.attn_backend == AttentionBackendEnum.XFORMERS:
from xformers import ops as xops
out = xops.memory_efficient_attention_forward(
query, key, value, scale=self.scale
)
elif self.attn_backend == AttentionBackendEnum.TORCH_SDPA:
query, key, value = (x.transpose(1, 2) for x in (query, key, value))
out = F.scaled_dot_product_attention(query, key, value, scale=self.scale)

View File

@@ -3,7 +3,7 @@
"""
This file contains ops for ViT attention to be compatible with torch.compile
as there are operations here not supported by torch.compile (for instance,
`to_list` in xformers attn, or `.item()` in flash attention)
`.item()` in flash attention)
Using these ops and wrapping vision blocks with `torch.compile` can speed up
throughput in vision models by ~5% relative on H100, and improve token
@@ -19,42 +19,6 @@ import torch.nn.functional as F
from vllm.utils.torch_utils import direct_register_custom_op
def xformers_attn_seqlens_wrapper(
q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, seqlens: torch.Tensor
) -> torch.Tensor:
from xformers import ops as xops
from xformers.ops.fmha.attn_bias import BlockDiagonalMask
attn_bias = BlockDiagonalMask.from_seqlens(
q_seqlen=seqlens.tolist(), kv_seqlen=None, device=q.device
)
context_layer = xops.memory_efficient_attention_forward(
q, k, v, attn_bias=attn_bias, p=0, scale=None
)
context_layer = einops.rearrange(context_layer, "b s h d -> s b (h d)").contiguous()
return context_layer
def xformers_attn_seqlens_wrapper_fake(
q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, seqlens: torch.Tensor
) -> torch.Tensor:
b, s, h, d = q.shape
return torch.empty((s, b, h * d), dtype=q.dtype, device=q.device)
direct_register_custom_op(
op_name="xformers_attn_seqlens_wrapper",
op_func=xformers_attn_seqlens_wrapper,
fake_impl=xformers_attn_seqlens_wrapper_fake,
)
def vit_xformers_attn_wrapper(
q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, seqlens: torch.Tensor
) -> torch.Tensor:
return torch.ops.vllm.xformers_attn_seqlens_wrapper(q, k, v, seqlens)
def flash_attn_maxseqlen_wrapper(
q: torch.Tensor,
k: torch.Tensor,

View File

@@ -36,7 +36,14 @@ def get_env_variable_attn_backend() -> AttentionBackendEnum | None:
* None otherwise
"""
backend_name = os.environ.get(STR_BACKEND_ENV_VAR)
return None if backend_name is None else AttentionBackendEnum[backend_name]
if backend_name is None:
return None
if backend_name == "XFORMERS":
raise ValueError(
"Attention backend 'XFORMERS' has been removed (See PR #29262 for "
"details). Please select a supported attention backend."
)
return AttentionBackendEnum[backend_name]
# Global state allows a particular choice of backend

View File

@@ -173,6 +173,12 @@ class MultiModalConfig:
# We need to import the real type here (deferred to avoid circular import).
from vllm.attention.backends.registry import AttentionBackendEnum
if isinstance(value, str) and value.upper() == "XFORMERS":
raise ValueError(
"Attention backend 'XFORMERS' has been removed (See PR #29262 for "
"details). Please select a supported attention backend."
)
if value is None or isinstance(value, AttentionBackendEnum):
return value

View File

@@ -640,7 +640,6 @@ environment_variables: dict[str, Callable[[], Any]] = {
# Example options:
# - "TORCH_SDPA": use torch.nn.MultiheadAttention
# - "FLASH_ATTN": use FlashAttention
# - "XFORMERS": use XFormers
# - "FLASHINFER": use flashinfer
# - "FLASHMLA": use FlashMLA
# - "FLASH_ATTN_MLA": use FlashAttention for MLA

View File

@@ -306,7 +306,6 @@ class DotsVisionAttention(nn.Module):
if self.attn_backend not in {
AttentionBackendEnum.FLASH_ATTN,
AttentionBackendEnum.TORCH_SDPA,
AttentionBackendEnum.XFORMERS,
AttentionBackendEnum.ROCM_AITER_FA,
}:
raise RuntimeError(
@@ -324,7 +323,6 @@ class DotsVisionAttention(nn.Module):
rotary_pos_emb: torch.Tensor | None = None,
*,
max_seqlen: int | None = None,
seqlens: list[int] | None = None,
) -> torch.Tensor:
# [S, C] -> [S, B=1, C]
x = hidden_states.unsqueeze(1)
@@ -374,16 +372,6 @@ class DotsVisionAttention(nn.Module):
out_i = out_i.permute(0, 2, 1, 3)
outputs.append(out_i)
context_layer = torch.cat(outputs, dim=1) if outputs else q[:, :0]
elif self.attn_backend == AttentionBackendEnum.XFORMERS:
from xformers import ops as xops
from xformers.ops.fmha.attn_bias import BlockDiagonalMask
attn_bias = BlockDiagonalMask.from_seqlens(
q_seqlen=seqlens, kv_seqlen=None, device=q.device
)
context_layer = xops.memory_efficient_attention_forward(
q, k, v, attn_bias=attn_bias, p=0, scale=None
)
else:
raise RuntimeError("Unsupported attention backend")
@@ -545,14 +533,12 @@ class DotsVisionBlock(nn.Module):
cu_seqlens: torch.Tensor,
rotary_pos_emb: torch.Tensor,
max_seqlen: int | None = None,
seqlens: list[int] | None = None,
) -> torch.Tensor:
hidden_states = hidden_states + self.attn(
self.norm1(hidden_states),
cu_seqlens=cu_seqlens,
rotary_pos_emb=rotary_pos_emb,
max_seqlen=max_seqlen,
seqlens=seqlens,
)
hidden_states = hidden_states + self.mlp(self.norm2(hidden_states))
return hidden_states
@@ -663,18 +649,14 @@ class DotsVisionTransformer(nn.Module):
rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1)
return rotary_pos_emb
def compute_attn_mask_seqlen(
self, cu_seqlens: torch.Tensor
) -> tuple[int | None, list[int] | None]:
max_seqlen, seqlens = None, None
def compute_attn_mask_seqlen(self, cu_seqlens: torch.Tensor) -> int | None:
max_seqlen = None
if (
self.attn_backend == AttentionBackendEnum.FLASH_ATTN
or self.attn_backend == AttentionBackendEnum.ROCM_AITER_FA
):
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
elif self.attn_backend == AttentionBackendEnum.XFORMERS:
seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()
return max_seqlen, seqlens
return max_seqlen
def forward(
self, hidden_states: torch.Tensor, grid_thw: list[list[int]]
@@ -694,14 +676,13 @@ class DotsVisionTransformer(nn.Module):
)
cu_seqlens = torch.cat([cu_seqlens.new_zeros(1), cu_seqlens])
max_seqlen, seqlens = self.compute_attn_mask_seqlen(cu_seqlens)
max_seqlen = self.compute_attn_mask_seqlen(cu_seqlens)
for blk in self.blocks:
hidden_states = blk(
hidden_states,
cu_seqlens=cu_seqlens,
rotary_pos_emb=rotary_pos_emb,
max_seqlen=max_seqlen,
seqlens=seqlens,
)
if self.post_trunk_norm is not None:

View File

@@ -214,7 +214,6 @@ class Ernie4_5_VisionAttention(nn.Module):
if self.attn_backend not in {
AttentionBackendEnum.FLASH_ATTN,
AttentionBackendEnum.TORCH_SDPA,
AttentionBackendEnum.XFORMERS,
AttentionBackendEnum.ROCM_AITER_FA,
}:
raise RuntimeError(
@@ -259,7 +258,6 @@ class Ernie4_5_VisionAttention(nn.Module):
cu_seqlens: torch.Tensor,
rotary_pos_emb: torch.Tensor,
max_seqlen: int | None = None, # Only used for Flash Attention
seqlens: list[int] | None = None, # Only used for xFormers
) -> torch.Tensor:
# [s, b, c] --> [s, b, head * 3 * head_dim]
x, _ = self.qkv(x)
@@ -311,20 +309,6 @@ class Ernie4_5_VisionAttention(nn.Module):
context_layer = rearrange(
context_layer, "b s h d -> s b (h d)"
).contiguous()
elif self.attn_backend == AttentionBackendEnum.XFORMERS:
from xformers import ops as xops
from xformers.ops.fmha.attn_bias import BlockDiagonalMask
attn_bias = BlockDiagonalMask.from_seqlens(
q_seqlen=seqlens, kv_seqlen=None, device=q.device
)
context_layer = xops.memory_efficient_attention_forward(
q, k, v, attn_bias=attn_bias, p=0, scale=None
)
context_layer = rearrange(
context_layer, "b s h d -> s b (h d)"
).contiguous()
output, _ = self.proj(context_layer)
return output
@@ -404,14 +388,12 @@ class Ernie4_5_VisionBlock(nn.Module):
cu_seqlens: torch.Tensor,
rotary_pos_emb: torch.Tensor,
max_seqlen: int | None = None, # Only used for Flash Attention
seqlens: list[int] | None = None, # Only used for xFormers
) -> torch.Tensor:
hidden_states = hidden_states + self.attn(
self.norm1(hidden_states),
cu_seqlens=cu_seqlens,
rotary_pos_emb=rotary_pos_emb,
max_seqlen=max_seqlen,
seqlens=seqlens,
)
hidden_states = hidden_states + self.mlp(self.norm2(hidden_states))
return hidden_states
@@ -562,18 +544,14 @@ class Ernie4_5_VisionTransformer(nn.Module):
rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1)
return rotary_pos_emb
def compute_attn_mask_seqlen(
self, cu_seqlens: torch.Tensor
) -> tuple[int | None, list[int] | None]:
max_seqlen, seqlens = None, None
def compute_attn_mask_seqlen(self, cu_seqlens: torch.Tensor) -> int | None:
max_seqlen = None
if (
self.attn_backend == AttentionBackendEnum.FLASH_ATTN
or self.attn_backend == AttentionBackendEnum.ROCM_AITER_FA
):
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
elif self.attn_backend == AttentionBackendEnum.XFORMERS:
seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()
return max_seqlen, seqlens
return max_seqlen
def forward(
self, hidden_states: torch.Tensor, grid_thw: torch.Tensor, num_pad=0
@@ -598,8 +576,8 @@ class Ernie4_5_VisionTransformer(nn.Module):
if hidden_states.ndim == 2:
hidden_states = hidden_states.unsqueeze(dim=1)
# pre-compute seqlens for attn mask to reduce cuMemcpy operations
max_seqlen, seqlens = self.compute_attn_mask_seqlen(cu_seqlens)
# pre-compute max_seqlen for attn mask to reduce cuMemcpy operations
max_seqlen = self.compute_attn_mask_seqlen(cu_seqlens)
for i, blk in enumerate(self.blocks):
hidden_states = blk(
@@ -607,7 +585,6 @@ class Ernie4_5_VisionTransformer(nn.Module):
cu_seqlens=cu_seqlens,
rotary_pos_emb=rotary_pos_emb,
max_seqlen=max_seqlen,
seqlens=seqlens,
)
final_output = self.ln(hidden_states)

View File

@@ -309,7 +309,6 @@ class Glm4vVisionAttention(nn.Module):
if self.attn_backend not in {
AttentionBackendEnum.FLASH_ATTN,
AttentionBackendEnum.TORCH_SDPA,
AttentionBackendEnum.XFORMERS,
AttentionBackendEnum.ROCM_AITER_FA,
}:
raise RuntimeError(
@@ -345,7 +344,6 @@ class Glm4vVisionAttention(nn.Module):
rotary_pos_emb_cos: torch.Tensor,
rotary_pos_emb_sin: torch.Tensor,
max_seqlen: int | None = None, # Only used for Flash Attention
seqlens: list[int] | None = None, # Only used for xFormers
) -> torch.Tensor:
# [s, b, c] --> [s, b, head * 3 * head_dim]
x, _ = self.qkv(x)
@@ -400,20 +398,6 @@ class Glm4vVisionAttention(nn.Module):
context_layer = rearrange(
context_layer, "b s h d -> s b (h d)"
).contiguous()
elif self.attn_backend == AttentionBackendEnum.XFORMERS:
from xformers import ops as xops
from xformers.ops.fmha.attn_bias import BlockDiagonalMask
attn_bias = BlockDiagonalMask.from_seqlens(
q_seqlen=seqlens, kv_seqlen=None, device=q.device
)
context_layer = xops.memory_efficient_attention_forward(
q, k, v, attn_bias=attn_bias, p=0, scale=None
)
context_layer = rearrange(
context_layer, "b s h d -> s b (h d)"
).contiguous()
output, _ = self.proj(context_layer)
return output
@@ -461,7 +445,6 @@ class Glm4vVisionBlock(nn.Module):
rotary_pos_emb_cos: torch.Tensor,
rotary_pos_emb_sin: torch.Tensor,
max_seqlen: int | None = None, # Only used for Flash Attention
seqlens: list[int] | None = None, # Only used for xFormers
) -> torch.Tensor:
x_attn = self.attn(
self.norm1(x),
@@ -469,7 +452,6 @@ class Glm4vVisionBlock(nn.Module):
rotary_pos_emb_cos=rotary_pos_emb_cos,
rotary_pos_emb_sin=rotary_pos_emb_sin,
max_seqlen=max_seqlen,
seqlens=seqlens,
)
x_fused_norm, residual = self.norm2(x, residual=x_attn)
x = residual + self.mlp(x_fused_norm)
@@ -803,15 +785,14 @@ class Glm4vVisionTransformer(nn.Module):
def compute_attn_mask_seqlen(
self,
cu_seqlens: torch.Tensor,
) -> tuple[int | None, list[int] | None]:
max_seqlen, seqlens = None, None
seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()
) -> int | None:
max_seqlen = None
if (
self.attn_backend == AttentionBackendEnum.FLASH_ATTN
or self.attn_backend == AttentionBackendEnum.ROCM_AITER_FA
):
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
return max_seqlen, seqlens
return max_seqlen
def forward(
self,
@@ -836,8 +817,9 @@ class Glm4vVisionTransformer(nn.Module):
).cumsum(dim=0, dtype=torch.int32)
cu_seqlens = F.pad(cu_seqlens, (1, 0), "constant", 0)
# pre-compute seqlens for attn mask to reduce cuMemcpy operations
max_seqlen, seqlens = self.compute_attn_mask_seqlen(cu_seqlens)
# pre-compute max_seqlen for attn mask to reduce cuMemcpy operations
max_seqlen = self.compute_attn_mask_seqlen(cu_seqlens)
seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()
x = self.embeddings(
x, seqlens, grid_thw, image_type_ids[:, 0], image_type_ids[:, 1]
)
@@ -851,7 +833,6 @@ class Glm4vVisionTransformer(nn.Module):
rotary_pos_emb_cos=rotary_pos_emb_cos,
rotary_pos_emb_sin=rotary_pos_emb_sin,
max_seqlen=max_seqlen,
seqlens=seqlens,
)
# adapter

View File

@@ -9,6 +9,7 @@ from typing import Annotated, Any, Literal, TypeAlias, TypeVar
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
from transformers import PretrainedConfig
from transformers.activations import GELUActivation
@@ -424,7 +425,7 @@ class KeyeSiglipAttention(nn.Module):
if self.attn_backend not in {
AttentionBackendEnum.FLASH_ATTN,
AttentionBackendEnum.XFORMERS,
AttentionBackendEnum.TORCH_SDPA,
AttentionBackendEnum.ROCM_AITER_FA,
}:
raise RuntimeError(
@@ -451,7 +452,6 @@ class KeyeSiglipAttention(nn.Module):
)
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()
batch_size = q.shape[0]
if rope_emb is None:
@@ -498,17 +498,21 @@ class KeyeSiglipAttention(nn.Module):
softmax_scale=self.scale,
)
context_layer = rearrange(output, "(b s) ... -> b s ...", b=batch_size)
elif self.attn_backend == AttentionBackendEnum.XFORMERS:
from xformers import ops as xops
from xformers.ops.fmha.attn_bias import BlockDiagonalMask
attn_bias = BlockDiagonalMask.from_seqlens(
q_seqlen=seqlens, kv_seqlen=None, device=q.device
)
context_layer = xops.memory_efficient_attention_forward(
q, k, v, attn_bias=attn_bias, p=0, scale=None
)
elif self.attn_backend == AttentionBackendEnum.TORCH_SDPA:
outputs = []
for i in range(1, len(cu_seqlens)):
start_idx = cu_seqlens[i - 1]
end_idx = cu_seqlens[i]
q_i = q[:, start_idx:end_idx]
k_i = k[:, start_idx:end_idx]
v_i = v[:, start_idx:end_idx]
q_i, k_i, v_i = (
rearrange(x, "b s h d -> b h s d") for x in (q_i, k_i, v_i)
)
output_i = F.scaled_dot_product_attention(q_i, k_i, v_i, dropout_p=0.0)
output_i = rearrange(output_i, "b h s d -> b s h d ")
outputs.append(output_i)
context_layer = torch.cat(outputs, dim=1) if outputs else q[:, :0]
context_layer = rearrange(context_layer, "b s h d -> b s (h d)").contiguous()

View File

@@ -38,7 +38,6 @@ from vllm.attention.layer import (
)
from vllm.attention.ops.vit_attn_wrappers import (
vit_flash_attn_wrapper,
vit_xformers_attn_wrapper,
)
from vllm.config import VllmConfig
from vllm.config.multimodal import BaseDummyOptions
@@ -657,7 +656,6 @@ class SiglipAttention(nn.Module):
cu_seqlens: torch.Tensor,
rotary_pos_emb: torch.Tensor | None,
max_seqlen: torch.Tensor | None,
seqlens: torch.Tensor | None,
) -> torch.Tensor:
batch_size, _, _ = hidden_states.shape
@@ -703,10 +701,6 @@ class SiglipAttention(nn.Module):
context_layer = rearrange(
context_layer, "b s h d -> s b (h d)"
).contiguous()
elif self.attn_backend == AttentionBackendEnum.XFORMERS:
if seqlens is None:
raise ValueError("xFormers attention backend requires seqlens tensor.")
context_layer = vit_xformers_attn_wrapper(q, k, v, seqlens)
else:
raise RuntimeError(
f"PaddleOCR-VL does not support {self.attn_backend} backend now."
@@ -818,7 +812,6 @@ class SiglipEncoderLayer(nn.Module):
cu_seqlens: torch.Tensor,
rotary_pos_emb: torch.Tensor | None,
max_seqlen: torch.Tensor | None,
seqlens: torch.Tensor | None,
) -> torch.Tensor:
residual = hidden_states
@@ -828,7 +821,6 @@ class SiglipEncoderLayer(nn.Module):
cu_seqlens=cu_seqlens,
rotary_pos_emb=rotary_pos_emb,
max_seqlen=max_seqlen,
seqlens=seqlens,
)
hidden_states = residual + hidden_states
@@ -870,7 +862,6 @@ class SiglipEncoder(nn.Module):
if self.attn_backend not in {
AttentionBackendEnum.FLASH_ATTN,
AttentionBackendEnum.TORCH_SDPA,
AttentionBackendEnum.XFORMERS,
AttentionBackendEnum.ROCM_AITER_FA,
}:
raise RuntimeError(
@@ -943,14 +934,11 @@ class SiglipEncoder(nn.Module):
cu_seqlens = cu_seqlens.to(device=device)
max_seqlen = None
seqlens = None
if self.attn_backend in {
AttentionBackendEnum.FLASH_ATTN,
AttentionBackendEnum.ROCM_AITER_FA,
}:
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max()
elif self.attn_backend == AttentionBackendEnum.XFORMERS:
seqlens = cu_seqlens[1:] - cu_seqlens[:-1]
hidden_states = inputs_embeds
for encoder_layer in self.layers:
@@ -959,7 +947,6 @@ class SiglipEncoder(nn.Module):
cu_seqlens=cu_seqlens,
rotary_pos_emb=rotary_pos_emb,
max_seqlen=max_seqlen,
seqlens=seqlens,
)
return hidden_states

View File

@@ -74,6 +74,7 @@ from .vision import (
)
try:
# Note: vLLM does not install xformers by default.
from xformers import ops as xops
if current_platform.is_cuda() and current_platform.has_device_capability(100):

View File

@@ -46,7 +46,6 @@ from vllm.attention.layer import maybe_get_vit_flash_attn_backend
from vllm.attention.ops.vit_attn_wrappers import (
vit_flash_attn_wrapper,
vit_torch_sdpa_wrapper,
vit_xformers_attn_wrapper,
)
from vllm.compilation.decorators import support_torch_compile
from vllm.config import VllmConfig
@@ -375,7 +374,6 @@ class Qwen2_5_VisionAttention(nn.Module):
rotary_pos_emb_cos: torch.Tensor,
rotary_pos_emb_sin: torch.Tensor,
max_seqlen: torch.Tensor, # Only used for Flash Attention
seqlens: torch.Tensor, # Only used for xFormers
) -> torch.Tensor:
# [s, b, c] --> [s, b, head * 3 * head_dim]
x, _ = self.qkv(x)
@@ -435,8 +433,6 @@ class Qwen2_5_VisionAttention(nn.Module):
v,
cu_seqlens,
)
elif self.attn_backend == AttentionBackendEnum.XFORMERS:
context_layer = vit_xformers_attn_wrapper(q, k, v, seqlens)
output, _ = self.proj(context_layer)
return output
@@ -448,9 +444,7 @@ class Qwen2_5_VisionAttention(nn.Module):
"cu_seqlens": 0,
"rotary_pos_emb_cos": 0,
"rotary_pos_emb_sin": 0,
"seqlens": 0,
},
mark_unbacked_dims={"seqlens": 0},
enable_if=should_torch_compile_mm_vit,
)
class Qwen2_5_VisionBlock(nn.Module):
@@ -501,7 +495,6 @@ class Qwen2_5_VisionBlock(nn.Module):
rotary_pos_emb_cos: torch.Tensor,
rotary_pos_emb_sin: torch.Tensor,
max_seqlen: torch.Tensor, # Only used for Flash Attention
seqlens: torch.Tensor, # Only used for xFormers
) -> torch.Tensor:
x_attn = self.attn(
self.norm1(x),
@@ -509,7 +502,6 @@ class Qwen2_5_VisionBlock(nn.Module):
rotary_pos_emb_cos=rotary_pos_emb_cos,
rotary_pos_emb_sin=rotary_pos_emb_sin,
max_seqlen=max_seqlen,
seqlens=seqlens,
)
x_fused_norm, residual = self.norm2(x, residual=x_attn)
x = residual + self.mlp(x_fused_norm)
@@ -670,7 +662,6 @@ class Qwen2_5_VisionTransformer(nn.Module):
if self.attn_backend not in {
AttentionBackendEnum.FLASH_ATTN,
AttentionBackendEnum.TORCH_SDPA,
AttentionBackendEnum.XFORMERS,
AttentionBackendEnum.ROCM_AITER_FA,
}:
raise RuntimeError(
@@ -822,17 +813,14 @@ class Qwen2_5_VisionTransformer(nn.Module):
def compute_attn_mask_seqlen(
self,
cu_seqlens: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
) -> torch.Tensor:
max_seqlen = torch.zeros([], device=cu_seqlens.device)
seqlens = torch.zeros(1, device=cu_seqlens.device)
if self.attn_backend in {
AttentionBackendEnum.FLASH_ATTN,
AttentionBackendEnum.ROCM_AITER_FA,
}:
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max()
elif self.attn_backend == AttentionBackendEnum.XFORMERS:
seqlens = cu_seqlens[1:] - cu_seqlens[:-1]
return max_seqlen, seqlens
return max_seqlen
@staticmethod
def invert_permutation(perm: torch.Tensor) -> torch.Tensor:
@@ -897,10 +885,8 @@ class Qwen2_5_VisionTransformer(nn.Module):
# transformers
# pre-compute seqlens for window/full attn to reduce cuMemcpy operations
max_seqlen_full, seqlens_full = self.compute_attn_mask_seqlen(cu_seqlens)
max_seqlen_window, seqlens_window = self.compute_attn_mask_seqlen(
cu_window_seqlens
)
max_seqlen_full = self.compute_attn_mask_seqlen(cu_seqlens)
max_seqlen_window = self.compute_attn_mask_seqlen(cu_window_seqlens)
cu_seqlens = cu_seqlens.to(device=self.device, non_blocking=True)
cu_window_seqlens = cu_window_seqlens.to(device=self.device, non_blocking=True)
@@ -927,11 +913,9 @@ class Qwen2_5_VisionTransformer(nn.Module):
if layer_num in self.fullatt_block_indexes:
cu_seqlens_now = cu_seqlens
max_seqlen_now = max_seqlen_full
seqlens_now = seqlens_full
else:
cu_seqlens_now = cu_window_seqlens
max_seqlen_now = max_seqlen_window
seqlens_now = seqlens_window
hidden_states = blk(
hidden_states,
@@ -939,7 +923,6 @@ class Qwen2_5_VisionTransformer(nn.Module):
rotary_pos_emb_cos=rotary_pos_emb_cos,
rotary_pos_emb_sin=rotary_pos_emb_sin,
max_seqlen=max_seqlen_now,
seqlens=seqlens_now,
)
# For Qwen2.5-VL-3B, float16 will overflow at last block

View File

@@ -348,7 +348,6 @@ class Qwen2VisionAttention(nn.Module):
if self.attn_backend not in {
AttentionBackendEnum.FLASH_ATTN,
AttentionBackendEnum.TORCH_SDPA,
AttentionBackendEnum.XFORMERS,
AttentionBackendEnum.ROCM_AITER_FA,
}:
raise RuntimeError(
@@ -384,7 +383,6 @@ class Qwen2VisionAttention(nn.Module):
rotary_pos_emb_cos: torch.Tensor,
rotary_pos_emb_sin: torch.Tensor,
max_seqlen: int | None = None, # Only used for Flash Attention
seqlens: list[int] | None = None, # Only used for xFormers
) -> torch.Tensor:
# [s, b, c] --> [s, b, 3 * head * head_dim]
x, _ = self.qkv(x)
@@ -445,20 +443,6 @@ class Qwen2VisionAttention(nn.Module):
context_layer = rearrange(
context_layer, "b s h d -> s b (h d)"
).contiguous()
elif self.attn_backend == AttentionBackendEnum.XFORMERS:
from xformers import ops as xops
from xformers.ops.fmha.attn_bias import BlockDiagonalMask
attn_bias = BlockDiagonalMask.from_seqlens(
q_seqlen=seqlens, kv_seqlen=None, device=q.device
)
context_layer = xops.memory_efficient_attention_forward(
q, k, v, attn_bias=attn_bias, p=0, scale=None
)
context_layer = rearrange(
context_layer, "b s h d -> s b (h d)"
).contiguous()
output, _ = self.proj(context_layer)
return output
@@ -509,7 +493,6 @@ class Qwen2VisionBlock(nn.Module):
rotary_pos_emb_cos: torch.Tensor,
rotary_pos_emb_sin: torch.Tensor,
max_seqlen: int | None = None, # Only used for Flash Attention
seqlens: list[int] | None = None, # Only used for xFormers
) -> torch.Tensor:
x = x + self.attn(
self.norm1(x),
@@ -517,7 +500,6 @@ class Qwen2VisionBlock(nn.Module):
rotary_pos_emb_cos=rotary_pos_emb_cos,
rotary_pos_emb_sin=rotary_pos_emb_sin,
max_seqlen=max_seqlen,
seqlens=seqlens,
)
x = x + self.mlp(self.norm2(x))
@@ -728,18 +710,14 @@ class Qwen2VisionTransformer(nn.Module):
sin_combined = sin[pos_ids].flatten(1)
return cos_combined, sin_combined
def compute_attn_mask_seqlen(
self, cu_seqlens: torch.Tensor
) -> tuple[int | None, list[int] | None]:
max_seqlen, seqlens = None, None
def compute_attn_mask_seqlen(self, cu_seqlens: torch.Tensor) -> int | None:
max_seqlen = None
if self.attn_backend in {
AttentionBackendEnum.FLASH_ATTN,
AttentionBackendEnum.ROCM_AITER_FA,
}:
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
elif self.attn_backend == AttentionBackendEnum.XFORMERS:
seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()
return max_seqlen, seqlens
return max_seqlen
def forward(
self,
@@ -771,7 +749,7 @@ class Qwen2VisionTransformer(nn.Module):
x = x.unsqueeze(1)
# pre-compute seqlens for attn mask to reduce cuMemcpy operations
max_seqlen, seqlens = self.compute_attn_mask_seqlen(cu_seqlens)
max_seqlen = self.compute_attn_mask_seqlen(cu_seqlens)
cu_seqlens = cu_seqlens.to(self.device, non_blocking=True)
for blk in self.blocks:
x = blk(
@@ -780,7 +758,6 @@ class Qwen2VisionTransformer(nn.Module):
rotary_pos_emb_cos=rotary_pos_emb_cos,
rotary_pos_emb_sin=rotary_pos_emb_sin,
max_seqlen=max_seqlen,
seqlens=seqlens,
)
# adapter

View File

@@ -224,7 +224,6 @@ class Qwen3_VisionBlock(nn.Module):
rotary_pos_emb_cos: torch.Tensor,
rotary_pos_emb_sin: torch.Tensor,
max_seqlen: torch.Tensor, # Only used for Flash Attention
seqlens: torch.Tensor, # Only used for xFormers
) -> torch.Tensor:
x = x + self.attn(
self.norm1(x),
@@ -232,7 +231,6 @@ class Qwen3_VisionBlock(nn.Module):
rotary_pos_emb_cos=rotary_pos_emb_cos,
rotary_pos_emb_sin=rotary_pos_emb_sin,
max_seqlen=max_seqlen,
seqlens=seqlens,
)
x = x + self.mlp(self.norm2(x))
@@ -500,14 +498,11 @@ class Qwen3Omni_VisionTransformer(nn.Module):
def compute_attn_mask_seqlen(
self,
cu_seqlens: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
) -> torch.Tensor:
max_seqlen = torch.zeros([], device=cu_seqlens.device)
seqlens = torch.zeros(1, device=cu_seqlens.device)
if self.attn_backend == AttentionBackendEnum.FLASH_ATTN:
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max()
elif self.attn_backend == AttentionBackendEnum.XFORMERS:
seqlens = cu_seqlens[1:] - cu_seqlens[:-1]
return max_seqlen, seqlens
return max_seqlen
def forward(
self,
@@ -533,7 +528,7 @@ class Qwen3Omni_VisionTransformer(nn.Module):
hidden_states = hidden_states.unsqueeze(1)
rotary_pos_emb_cos = rotary_pos_emb_cos.to(hidden_states.device)
rotary_pos_emb_sin = rotary_pos_emb_sin.to(hidden_states.device)
max_seqlen, seqlens = self.compute_attn_mask_seqlen(cu_seqlens)
max_seqlen = self.compute_attn_mask_seqlen(cu_seqlens)
hidden_states_list = []
deepstack_visual_indexes = self.deepstack_visual_indexes
@@ -545,7 +540,6 @@ class Qwen3Omni_VisionTransformer(nn.Module):
rotary_pos_emb_cos=rotary_pos_emb_cos,
rotary_pos_emb_sin=rotary_pos_emb_sin,
max_seqlen=max_seqlen,
seqlens=seqlens,
)
if (
deepstack_visual_indexes is not None

View File

@@ -235,7 +235,6 @@ class Qwen3_VisionBlock(nn.Module):
rotary_pos_emb_cos: torch.Tensor,
rotary_pos_emb_sin: torch.Tensor,
max_seqlen: torch.Tensor, # Only used for Flash Attention
seqlens: torch.Tensor, # Only used for xFormers
) -> torch.Tensor:
x = x + self.attn(
self.norm1(x),
@@ -243,7 +242,6 @@ class Qwen3_VisionBlock(nn.Module):
rotary_pos_emb_cos=rotary_pos_emb_cos,
rotary_pos_emb_sin=rotary_pos_emb_sin,
max_seqlen=max_seqlen,
seqlens=seqlens,
)
x = x + self.mlp(self.norm2(x))
@@ -391,7 +389,6 @@ class Qwen3_VisionTransformer(nn.Module):
if self.attn_backend not in {
AttentionBackendEnum.FLASH_ATTN,
AttentionBackendEnum.TORCH_SDPA,
AttentionBackendEnum.XFORMERS,
AttentionBackendEnum.ROCM_AITER_FA,
}:
raise RuntimeError(
@@ -531,17 +528,14 @@ class Qwen3_VisionTransformer(nn.Module):
def compute_attn_mask_seqlen(
self,
cu_seqlens: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
) -> torch.Tensor:
max_seqlen = torch.zeros([], device=cu_seqlens.device)
seqlens = torch.zeros(1, device=cu_seqlens.device)
if (
self.attn_backend == AttentionBackendEnum.FLASH_ATTN
or self.attn_backend == AttentionBackendEnum.ROCM_AITER_FA
):
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max()
elif self.attn_backend == AttentionBackendEnum.XFORMERS:
seqlens = cu_seqlens[1:] - cu_seqlens[:-1]
return max_seqlen, seqlens
return max_seqlen
def forward(
self,
@@ -569,7 +563,7 @@ class Qwen3_VisionTransformer(nn.Module):
cu_seqlens = torch.from_numpy(cu_seqlens)
hidden_states = hidden_states.unsqueeze(1)
max_seqlen, seqlens = self.compute_attn_mask_seqlen(cu_seqlens)
max_seqlen = self.compute_attn_mask_seqlen(cu_seqlens)
cu_seqlens = cu_seqlens.to(self.device, non_blocking=True)
deepstack_feature_lists = []
@@ -580,7 +574,6 @@ class Qwen3_VisionTransformer(nn.Module):
rotary_pos_emb_cos=rotary_pos_emb_cos,
rotary_pos_emb_sin=rotary_pos_emb_sin,
max_seqlen=max_seqlen,
seqlens=seqlens,
)
if layer_num in self.deepstack_visual_indexes:
deepstack_merger_idx = self.deepstack_visual_indexes.index(layer_num)

View File

@@ -277,12 +277,7 @@ class CudaPlatformBase(Platform):
except ImportError:
pass
if cls.has_device_capability(100):
# xFormers doesn't support Blackwell, fall back to SDPA
# See https://github.com/facebookresearch/xformers/issues/1317#issuecomment-3199392579 # noqa: E501
return AttentionBackendEnum.TORCH_SDPA
else:
return AttentionBackendEnum.XFORMERS
return AttentionBackendEnum.TORCH_SDPA
@classmethod
def get_valid_backends(

View File

@@ -49,7 +49,6 @@ STR_BACKEND_ENV_VAR: str = "VLLM_ATTENTION_BACKEND"
# Possible string values of STR_BACKEND_ENV_VAR
# register, corresponding to possible backends
STR_FLASHINFER_ATTN_VAL: str = "FLASHINFER"
STR_XFORMERS_ATTN_VAL: str = "XFORMERS"
STR_FLASH_ATTN_VAL: str = "FLASH_ATTN"
STR_INVALID_VAL: str = "INVALID"

View File

@@ -1,420 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Attention layer with XFormersAttention."""
from dataclasses import dataclass
from typing import ClassVar, Optional
import torch
from vllm.attention.backends.abstract import (
AttentionBackend,
AttentionImpl,
AttentionType,
MultipleOf,
)
from vllm.attention.ops.triton_unified_attention import unified_attention
from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.v1.attention.backends.utils import (
AttentionMetadataBuilder,
CommonAttentionMetadata,
split_decodes_and_prefills,
)
from vllm.v1.kv_cache_interface import AttentionSpec
try:
from xformers import ops as xops
from xformers.ops.fmha.attn_bias import (
AttentionBias,
PagedBlockDiagonalCausalWithOffsetPaddedKeysMask,
)
XFORMERS_AVAILABLE = True
except ImportError:
XFORMERS_AVAILABLE = False
from vllm import _custom_ops as ops
logger = init_logger(__name__)
class XFormersAttentionBackend(AttentionBackend):
accept_output_buffer: bool = True
supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16]
@staticmethod
def get_supported_kernel_block_sizes() -> list[int | MultipleOf]:
return [MultipleOf(16)]
@classmethod
def get_supported_head_sizes(cls) -> list[int]:
return [
32,
40,
48,
56,
64,
72,
80,
88,
96,
104,
112,
120,
128,
136,
144,
152,
160,
168,
176,
184,
192,
200,
208,
216,
224,
232,
240,
248,
256,
]
@staticmethod
def get_name() -> str:
return "XFORMERS"
@staticmethod
def get_impl_cls() -> type["XFormersAttentionImpl"]:
return XFormersAttentionImpl
@staticmethod
def get_kv_cache_shape(
num_blocks: int,
block_size: int,
num_kv_heads: int,
head_size: int,
cache_dtype_str: str = "auto",
) -> tuple[int, ...]:
if block_size % 16 != 0:
raise ValueError("Block size must be a multiple of 16.")
return (2, num_blocks, block_size, num_kv_heads, head_size)
@staticmethod
def get_builder_cls() -> type["XFormersAttentionMetadataBuilder"]:
return XFormersAttentionMetadataBuilder
@staticmethod
def use_cascade_attention(*args, **kwargs) -> bool:
return False
@dataclass
class XFormersAttentionMetadata:
num_actual_tokens: int # Number of tokens excluding padding.
max_query_len: int
query_start_loc: torch.Tensor
max_seq_len: int
seq_lens: torch.Tensor
block_table: torch.Tensor
slot_mapping: torch.Tensor
num_prefill_tokens: int = 0
num_decode_tokens: int = 0
num_prefills: int = 0
num_decodes: int = 0
# Biases for different attention types.
attn_bias: Optional["AttentionBias"] = None
# Self-attention prefill/decode metadata cache
_cached_prefill_metadata: Optional["XFormersAttentionMetadata"] = None
_cached_decode_metadata: Optional["XFormersAttentionMetadata"] = None
@property
def prefill_metadata(self) -> Optional["XFormersAttentionMetadata"]:
if self.num_prefills == 0:
return None
if self._cached_prefill_metadata is not None:
# Recover cached prefill-phase attention
# metadata structure
return self._cached_prefill_metadata
q_start_loc = self.query_start_loc[self.num_decodes :]
q_seqlens = torch.diff(q_start_loc)
kv_seqlens = self.seq_lens[self.num_decodes :]
# Construct & cache prefill-phase attention metadata structure
self._cached_prefill_metadata = XFormersAttentionMetadata(
num_actual_tokens=self.num_prefill_tokens,
max_query_len=int(q_seqlens.max().item()),
query_start_loc=q_start_loc - q_start_loc[0],
max_seq_len=int(kv_seqlens.max().item()),
seq_lens=kv_seqlens,
block_table=self.block_table[self.num_decodes :],
slot_mapping=self.slot_mapping[self.num_decode_tokens :],
)
return self._cached_prefill_metadata
@property
def decode_metadata(self) -> Optional["XFormersAttentionMetadata"]:
if self.num_decode_tokens == 0:
return None
if self._cached_decode_metadata is not None:
# Recover cached decode-phase attention
# metadata structure
return self._cached_decode_metadata
q_start_loc = self.query_start_loc
q_seqlens = torch.diff(q_start_loc)
decode_kv_seqlens = self.seq_lens[: self.num_decodes]
# Construct & cache decode-phase attention metadata structure
self._cached_decode_metadata = XFormersAttentionMetadata(
num_actual_tokens=self.num_decode_tokens,
max_query_len=int(q_seqlens[: self.num_decodes].max().item()),
query_start_loc=q_start_loc[: self.num_decodes + 1],
max_seq_len=int(decode_kv_seqlens.max().item()),
seq_lens=decode_kv_seqlens,
block_table=self.block_table[: self.num_decodes],
slot_mapping=self.slot_mapping[: self.num_decode_tokens],
attn_bias=self.attn_bias,
)
return self._cached_decode_metadata
class XFormersAttentionMetadataBuilder(
AttentionMetadataBuilder[XFormersAttentionMetadata]
):
reorder_batch_threshold: int = 1
def __init__(
self,
kv_cache_spec: AttentionSpec,
layer_names: list[str],
vllm_config: VllmConfig,
device: torch.device,
):
super().__init__(kv_cache_spec, layer_names, vllm_config, device)
assert XFORMERS_AVAILABLE
self.block_size = kv_cache_spec.block_size
self._num_decodes = 0
self._num_decode_tokens = 0
def build(
self,
common_prefix_len: int,
common_attn_metadata: CommonAttentionMetadata,
fast_build: bool = False,
) -> XFormersAttentionMetadata:
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = (
split_decodes_and_prefills(
common_attn_metadata, decode_threshold=self.reorder_batch_threshold
)
)
num_actual_tokens = common_attn_metadata.num_actual_tokens
q_start_loc = common_attn_metadata.query_start_loc
q_seqlens = torch.diff(q_start_loc)
max_query_len = common_attn_metadata.max_query_len
kv_seqlens = common_attn_metadata.seq_lens
max_seq_len = common_attn_metadata.max_seq_len
block_table = common_attn_metadata.block_table_tensor
slot_mapping = common_attn_metadata.slot_mapping
bias = None
if num_decodes > 0:
# Construct the decoder bias.
decode_q_seqlens = q_seqlens[:num_decodes]
decode_kv_seqlens = kv_seqlens[:num_decodes]
bias = PagedBlockDiagonalCausalWithOffsetPaddedKeysMask.from_seqlens(
q_seqlen=decode_q_seqlens.tolist(),
kv_seqlen=decode_kv_seqlens.tolist(),
page_size=self.block_size,
block_tables=block_table[:num_decodes],
device=block_table.device,
)
return XFormersAttentionMetadata(
num_actual_tokens=num_actual_tokens,
num_prefill_tokens=num_prefill_tokens,
num_decode_tokens=num_decode_tokens,
num_prefills=num_prefills,
num_decodes=num_decodes,
max_query_len=max_query_len,
query_start_loc=q_start_loc,
max_seq_len=max_seq_len,
seq_lens=kv_seqlens,
block_table=block_table,
slot_mapping=slot_mapping,
attn_bias=bias,
)
class XFormersAttentionImpl(AttentionImpl):
def __init__(
self,
num_heads: int,
head_size: int,
scale: float,
num_kv_heads: int,
alibi_slopes: list[float] | None,
sliding_window: int | None,
kv_cache_dtype: str,
logits_soft_cap: float | None = None,
attn_type: AttentionType = AttentionType.DECODER,
kv_sharing_target_layer_name: str | None = None,
) -> None:
if kv_sharing_target_layer_name is not None:
raise NotImplementedError("KV sharing is not supported in V0.")
if alibi_slopes is not None:
raise NotImplementedError("XFormers does not support alibi slopes yet.")
self.num_heads = num_heads
self.head_size = head_size
self.scale = float(scale)
self.num_kv_heads = num_kv_heads
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
self.kv_cache_dtype = kv_cache_dtype
self.kv_sharing_target_layer_name = kv_sharing_target_layer_name
if alibi_slopes is not None:
alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32)
self.alibi_slopes = alibi_slopes
if sliding_window is None:
self.sliding_window = (-1, -1)
else:
self.sliding_window = (sliding_window - 1, 0)
if logits_soft_cap is None:
# Setting logits_soft_cap to 0 means no soft cap.
logits_soft_cap = 0
self.logits_soft_cap = logits_soft_cap
if attn_type != AttentionType.DECODER:
raise NotImplementedError(
"Encoder self-attention and "
"encoder/decoder cross-attention "
"are not implemented for "
"XFormersAttentionImpl."
)
def forward(
self,
layer: torch.nn.Module,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: XFormersAttentionMetadata,
output: torch.Tensor | None = None,
output_scale: torch.Tensor | None = None,
output_block_scale: torch.Tensor | None = None,
) -> torch.Tensor:
"""Forward pass with XFormers.
Args:
query: shape = [num_tokens, num_heads, head_size]
key: shape = [num_tokens, num_kv_heads, head_size]
value: shape = [num_tokens, num_kv_heads, head_size]
kv_cache: shape =
[2, num_blocks, block_size, num_kv_heads, head_size]
attn_metadata: Metadata for attention.
Returns:
shape = [num_tokens, num_heads * head_size]
"""
assert output is not None, "Output tensor must be provided."
if output_scale is not None or output_block_scale is not None:
raise NotImplementedError(
"fused output quantization is not yet supported"
" for XFormersAttentionImpl"
)
if attn_metadata is None:
# Profiling run.
return output.fill_(0)
# Cache the input KVs.
key_cache, value_cache = kv_cache.unbind(0)
if self.kv_sharing_target_layer_name is None:
# Reshape the input keys and values and store them in the cache.
# Skip this if sharing KV cache with an earlier attention layer.
# NOTE(woosuk): Here, key and value are padded while slot_mapping is
# not padded. However, we don't need to do key[:num_actual_tokens]
# and value[:num_actual_tokens] because the reshape_and_cache_flash
# op uses the slot_mapping's shape to determine the number of
# actual tokens.
ops.reshape_and_cache_flash(
key,
value,
key_cache,
value_cache,
attn_metadata.slot_mapping,
self.kv_cache_dtype,
layer._k_scale,
layer._v_scale,
)
num_actual_tokens = attn_metadata.num_actual_tokens
num_decode_tokens = attn_metadata.num_decode_tokens
if prefill_meta := attn_metadata.prefill_metadata:
descale_shape = (prefill_meta.query_start_loc.shape[0] - 1, key.shape[1])
unified_attention(
q=query[num_decode_tokens:num_actual_tokens],
k=key_cache,
v=value_cache,
out=output[num_decode_tokens:num_actual_tokens],
cu_seqlens_q=prefill_meta.query_start_loc,
max_seqlen_q=prefill_meta.max_query_len,
seqused_k=prefill_meta.seq_lens,
max_seqlen_k=prefill_meta.max_seq_len,
softmax_scale=self.scale,
causal=True,
alibi_slopes=self.alibi_slopes,
window_size=self.sliding_window,
block_table=prefill_meta.block_table,
softcap=self.logits_soft_cap,
q_descale=None, # Not supported
k_descale=layer._k_scale.expand(descale_shape),
v_descale=layer._v_scale.expand(descale_shape),
)
if decode_meta := attn_metadata.decode_metadata:
# Query for decode. KV is not needed because it is already cached.
decode_query = query[:num_decode_tokens]
# Reshape query to [1, B_T, G, H, D].
q = decode_query.view(
1, -1, self.num_kv_heads, self.num_queries_per_kv, self.head_size
)
# Reshape the k and v caches to [1, Bkv_T, G, H, D]
cache_k = key_cache.view(
1, -1, self.num_kv_heads, 1, self.head_size
).expand(
1,
-1,
self.num_kv_heads,
self.num_queries_per_kv,
self.head_size,
)
cache_v = value_cache.view(
1, -1, self.num_kv_heads, 1, self.head_size
).expand(
1,
-1,
self.num_kv_heads,
self.num_queries_per_kv,
self.head_size,
)
attn_bias = decode_meta.attn_bias
output[:num_decode_tokens] = xops.memory_efficient_attention_forward(
q,
cache_k,
cache_v,
attn_bias=attn_bias,
p=0.0,
scale=self.scale,
).view(decode_query.shape)
# Reshape the output tensor.
return output