[Core] Deprecate xformers (#29262)
Signed-off-by: Roger Wang <hey@rogerw.io>
This commit is contained in:
@@ -46,7 +46,6 @@ from vllm.attention.layer import maybe_get_vit_flash_attn_backend
|
||||
from vllm.attention.ops.vit_attn_wrappers import (
|
||||
vit_flash_attn_wrapper,
|
||||
vit_torch_sdpa_wrapper,
|
||||
vit_xformers_attn_wrapper,
|
||||
)
|
||||
from vllm.compilation.decorators import support_torch_compile
|
||||
from vllm.config import VllmConfig
|
||||
@@ -375,7 +374,6 @@ class Qwen2_5_VisionAttention(nn.Module):
|
||||
rotary_pos_emb_cos: torch.Tensor,
|
||||
rotary_pos_emb_sin: torch.Tensor,
|
||||
max_seqlen: torch.Tensor, # Only used for Flash Attention
|
||||
seqlens: torch.Tensor, # Only used for xFormers
|
||||
) -> torch.Tensor:
|
||||
# [s, b, c] --> [s, b, head * 3 * head_dim]
|
||||
x, _ = self.qkv(x)
|
||||
@@ -435,8 +433,6 @@ class Qwen2_5_VisionAttention(nn.Module):
|
||||
v,
|
||||
cu_seqlens,
|
||||
)
|
||||
elif self.attn_backend == AttentionBackendEnum.XFORMERS:
|
||||
context_layer = vit_xformers_attn_wrapper(q, k, v, seqlens)
|
||||
|
||||
output, _ = self.proj(context_layer)
|
||||
return output
|
||||
@@ -448,9 +444,7 @@ class Qwen2_5_VisionAttention(nn.Module):
|
||||
"cu_seqlens": 0,
|
||||
"rotary_pos_emb_cos": 0,
|
||||
"rotary_pos_emb_sin": 0,
|
||||
"seqlens": 0,
|
||||
},
|
||||
mark_unbacked_dims={"seqlens": 0},
|
||||
enable_if=should_torch_compile_mm_vit,
|
||||
)
|
||||
class Qwen2_5_VisionBlock(nn.Module):
|
||||
@@ -501,7 +495,6 @@ class Qwen2_5_VisionBlock(nn.Module):
|
||||
rotary_pos_emb_cos: torch.Tensor,
|
||||
rotary_pos_emb_sin: torch.Tensor,
|
||||
max_seqlen: torch.Tensor, # Only used for Flash Attention
|
||||
seqlens: torch.Tensor, # Only used for xFormers
|
||||
) -> torch.Tensor:
|
||||
x_attn = self.attn(
|
||||
self.norm1(x),
|
||||
@@ -509,7 +502,6 @@ class Qwen2_5_VisionBlock(nn.Module):
|
||||
rotary_pos_emb_cos=rotary_pos_emb_cos,
|
||||
rotary_pos_emb_sin=rotary_pos_emb_sin,
|
||||
max_seqlen=max_seqlen,
|
||||
seqlens=seqlens,
|
||||
)
|
||||
x_fused_norm, residual = self.norm2(x, residual=x_attn)
|
||||
x = residual + self.mlp(x_fused_norm)
|
||||
@@ -670,7 +662,6 @@ class Qwen2_5_VisionTransformer(nn.Module):
|
||||
if self.attn_backend not in {
|
||||
AttentionBackendEnum.FLASH_ATTN,
|
||||
AttentionBackendEnum.TORCH_SDPA,
|
||||
AttentionBackendEnum.XFORMERS,
|
||||
AttentionBackendEnum.ROCM_AITER_FA,
|
||||
}:
|
||||
raise RuntimeError(
|
||||
@@ -822,17 +813,14 @@ class Qwen2_5_VisionTransformer(nn.Module):
|
||||
def compute_attn_mask_seqlen(
|
||||
self,
|
||||
cu_seqlens: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
) -> torch.Tensor:
|
||||
max_seqlen = torch.zeros([], device=cu_seqlens.device)
|
||||
seqlens = torch.zeros(1, device=cu_seqlens.device)
|
||||
if self.attn_backend in {
|
||||
AttentionBackendEnum.FLASH_ATTN,
|
||||
AttentionBackendEnum.ROCM_AITER_FA,
|
||||
}:
|
||||
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max()
|
||||
elif self.attn_backend == AttentionBackendEnum.XFORMERS:
|
||||
seqlens = cu_seqlens[1:] - cu_seqlens[:-1]
|
||||
return max_seqlen, seqlens
|
||||
return max_seqlen
|
||||
|
||||
@staticmethod
|
||||
def invert_permutation(perm: torch.Tensor) -> torch.Tensor:
|
||||
@@ -897,10 +885,8 @@ class Qwen2_5_VisionTransformer(nn.Module):
|
||||
|
||||
# transformers
|
||||
# pre-compute seqlens for window/full attn to reduce cuMemcpy operations
|
||||
max_seqlen_full, seqlens_full = self.compute_attn_mask_seqlen(cu_seqlens)
|
||||
max_seqlen_window, seqlens_window = self.compute_attn_mask_seqlen(
|
||||
cu_window_seqlens
|
||||
)
|
||||
max_seqlen_full = self.compute_attn_mask_seqlen(cu_seqlens)
|
||||
max_seqlen_window = self.compute_attn_mask_seqlen(cu_window_seqlens)
|
||||
|
||||
cu_seqlens = cu_seqlens.to(device=self.device, non_blocking=True)
|
||||
cu_window_seqlens = cu_window_seqlens.to(device=self.device, non_blocking=True)
|
||||
@@ -927,11 +913,9 @@ class Qwen2_5_VisionTransformer(nn.Module):
|
||||
if layer_num in self.fullatt_block_indexes:
|
||||
cu_seqlens_now = cu_seqlens
|
||||
max_seqlen_now = max_seqlen_full
|
||||
seqlens_now = seqlens_full
|
||||
else:
|
||||
cu_seqlens_now = cu_window_seqlens
|
||||
max_seqlen_now = max_seqlen_window
|
||||
seqlens_now = seqlens_window
|
||||
|
||||
hidden_states = blk(
|
||||
hidden_states,
|
||||
@@ -939,7 +923,6 @@ class Qwen2_5_VisionTransformer(nn.Module):
|
||||
rotary_pos_emb_cos=rotary_pos_emb_cos,
|
||||
rotary_pos_emb_sin=rotary_pos_emb_sin,
|
||||
max_seqlen=max_seqlen_now,
|
||||
seqlens=seqlens_now,
|
||||
)
|
||||
|
||||
# For Qwen2.5-VL-3B, float16 will overflow at last block
|
||||
|
||||
Reference in New Issue
Block a user