[Core] Deprecate xformers (#29262)

Signed-off-by: Roger Wang <hey@rogerw.io>
This commit is contained in:
Roger Wang
2025-11-23 20:18:55 -08:00
committed by GitHub
parent 5253f4276f
commit 0ff70821c9
31 changed files with 77 additions and 963 deletions

View File

@@ -309,7 +309,6 @@ class Glm4vVisionAttention(nn.Module):
if self.attn_backend not in {
AttentionBackendEnum.FLASH_ATTN,
AttentionBackendEnum.TORCH_SDPA,
AttentionBackendEnum.XFORMERS,
AttentionBackendEnum.ROCM_AITER_FA,
}:
raise RuntimeError(
@@ -345,7 +344,6 @@ class Glm4vVisionAttention(nn.Module):
rotary_pos_emb_cos: torch.Tensor,
rotary_pos_emb_sin: torch.Tensor,
max_seqlen: int | None = None, # Only used for Flash Attention
seqlens: list[int] | None = None, # Only used for xFormers
) -> torch.Tensor:
# [s, b, c] --> [s, b, head * 3 * head_dim]
x, _ = self.qkv(x)
@@ -400,20 +398,6 @@ class Glm4vVisionAttention(nn.Module):
context_layer = rearrange(
context_layer, "b s h d -> s b (h d)"
).contiguous()
elif self.attn_backend == AttentionBackendEnum.XFORMERS:
from xformers import ops as xops
from xformers.ops.fmha.attn_bias import BlockDiagonalMask
attn_bias = BlockDiagonalMask.from_seqlens(
q_seqlen=seqlens, kv_seqlen=None, device=q.device
)
context_layer = xops.memory_efficient_attention_forward(
q, k, v, attn_bias=attn_bias, p=0, scale=None
)
context_layer = rearrange(
context_layer, "b s h d -> s b (h d)"
).contiguous()
output, _ = self.proj(context_layer)
return output
@@ -461,7 +445,6 @@ class Glm4vVisionBlock(nn.Module):
rotary_pos_emb_cos: torch.Tensor,
rotary_pos_emb_sin: torch.Tensor,
max_seqlen: int | None = None, # Only used for Flash Attention
seqlens: list[int] | None = None, # Only used for xFormers
) -> torch.Tensor:
x_attn = self.attn(
self.norm1(x),
@@ -469,7 +452,6 @@ class Glm4vVisionBlock(nn.Module):
rotary_pos_emb_cos=rotary_pos_emb_cos,
rotary_pos_emb_sin=rotary_pos_emb_sin,
max_seqlen=max_seqlen,
seqlens=seqlens,
)
x_fused_norm, residual = self.norm2(x, residual=x_attn)
x = residual + self.mlp(x_fused_norm)
@@ -803,15 +785,14 @@ class Glm4vVisionTransformer(nn.Module):
def compute_attn_mask_seqlen(
self,
cu_seqlens: torch.Tensor,
) -> tuple[int | None, list[int] | None]:
max_seqlen, seqlens = None, None
seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()
) -> int | None:
max_seqlen = None
if (
self.attn_backend == AttentionBackendEnum.FLASH_ATTN
or self.attn_backend == AttentionBackendEnum.ROCM_AITER_FA
):
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
return max_seqlen, seqlens
return max_seqlen
def forward(
self,
@@ -836,8 +817,9 @@ class Glm4vVisionTransformer(nn.Module):
).cumsum(dim=0, dtype=torch.int32)
cu_seqlens = F.pad(cu_seqlens, (1, 0), "constant", 0)
# pre-compute seqlens for attn mask to reduce cuMemcpy operations
max_seqlen, seqlens = self.compute_attn_mask_seqlen(cu_seqlens)
# pre-compute max_seqlen for attn mask to reduce cuMemcpy operations
max_seqlen = self.compute_attn_mask_seqlen(cu_seqlens)
seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()
x = self.embeddings(
x, seqlens, grid_thw, image_type_ids[:, 0], image_type_ids[:, 1]
)
@@ -851,7 +833,6 @@ class Glm4vVisionTransformer(nn.Module):
rotary_pos_emb_cos=rotary_pos_emb_cos,
rotary_pos_emb_sin=rotary_pos_emb_sin,
max_seqlen=max_seqlen,
seqlens=seqlens,
)
# adapter