[Core] Deprecate xformers (#29262)
Signed-off-by: Roger Wang <hey@rogerw.io>
This commit is contained in:
@@ -309,7 +309,6 @@ class Glm4vVisionAttention(nn.Module):
|
||||
if self.attn_backend not in {
|
||||
AttentionBackendEnum.FLASH_ATTN,
|
||||
AttentionBackendEnum.TORCH_SDPA,
|
||||
AttentionBackendEnum.XFORMERS,
|
||||
AttentionBackendEnum.ROCM_AITER_FA,
|
||||
}:
|
||||
raise RuntimeError(
|
||||
@@ -345,7 +344,6 @@ class Glm4vVisionAttention(nn.Module):
|
||||
rotary_pos_emb_cos: torch.Tensor,
|
||||
rotary_pos_emb_sin: torch.Tensor,
|
||||
max_seqlen: int | None = None, # Only used for Flash Attention
|
||||
seqlens: list[int] | None = None, # Only used for xFormers
|
||||
) -> torch.Tensor:
|
||||
# [s, b, c] --> [s, b, head * 3 * head_dim]
|
||||
x, _ = self.qkv(x)
|
||||
@@ -400,20 +398,6 @@ class Glm4vVisionAttention(nn.Module):
|
||||
context_layer = rearrange(
|
||||
context_layer, "b s h d -> s b (h d)"
|
||||
).contiguous()
|
||||
elif self.attn_backend == AttentionBackendEnum.XFORMERS:
|
||||
from xformers import ops as xops
|
||||
from xformers.ops.fmha.attn_bias import BlockDiagonalMask
|
||||
|
||||
attn_bias = BlockDiagonalMask.from_seqlens(
|
||||
q_seqlen=seqlens, kv_seqlen=None, device=q.device
|
||||
)
|
||||
|
||||
context_layer = xops.memory_efficient_attention_forward(
|
||||
q, k, v, attn_bias=attn_bias, p=0, scale=None
|
||||
)
|
||||
context_layer = rearrange(
|
||||
context_layer, "b s h d -> s b (h d)"
|
||||
).contiguous()
|
||||
|
||||
output, _ = self.proj(context_layer)
|
||||
return output
|
||||
@@ -461,7 +445,6 @@ class Glm4vVisionBlock(nn.Module):
|
||||
rotary_pos_emb_cos: torch.Tensor,
|
||||
rotary_pos_emb_sin: torch.Tensor,
|
||||
max_seqlen: int | None = None, # Only used for Flash Attention
|
||||
seqlens: list[int] | None = None, # Only used for xFormers
|
||||
) -> torch.Tensor:
|
||||
x_attn = self.attn(
|
||||
self.norm1(x),
|
||||
@@ -469,7 +452,6 @@ class Glm4vVisionBlock(nn.Module):
|
||||
rotary_pos_emb_cos=rotary_pos_emb_cos,
|
||||
rotary_pos_emb_sin=rotary_pos_emb_sin,
|
||||
max_seqlen=max_seqlen,
|
||||
seqlens=seqlens,
|
||||
)
|
||||
x_fused_norm, residual = self.norm2(x, residual=x_attn)
|
||||
x = residual + self.mlp(x_fused_norm)
|
||||
@@ -803,15 +785,14 @@ class Glm4vVisionTransformer(nn.Module):
|
||||
def compute_attn_mask_seqlen(
|
||||
self,
|
||||
cu_seqlens: torch.Tensor,
|
||||
) -> tuple[int | None, list[int] | None]:
|
||||
max_seqlen, seqlens = None, None
|
||||
seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()
|
||||
) -> int | None:
|
||||
max_seqlen = None
|
||||
if (
|
||||
self.attn_backend == AttentionBackendEnum.FLASH_ATTN
|
||||
or self.attn_backend == AttentionBackendEnum.ROCM_AITER_FA
|
||||
):
|
||||
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
|
||||
return max_seqlen, seqlens
|
||||
return max_seqlen
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@@ -836,8 +817,9 @@ class Glm4vVisionTransformer(nn.Module):
|
||||
).cumsum(dim=0, dtype=torch.int32)
|
||||
cu_seqlens = F.pad(cu_seqlens, (1, 0), "constant", 0)
|
||||
|
||||
# pre-compute seqlens for attn mask to reduce cuMemcpy operations
|
||||
max_seqlen, seqlens = self.compute_attn_mask_seqlen(cu_seqlens)
|
||||
# pre-compute max_seqlen for attn mask to reduce cuMemcpy operations
|
||||
max_seqlen = self.compute_attn_mask_seqlen(cu_seqlens)
|
||||
seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()
|
||||
x = self.embeddings(
|
||||
x, seqlens, grid_thw, image_type_ids[:, 0], image_type_ids[:, 1]
|
||||
)
|
||||
@@ -851,7 +833,6 @@ class Glm4vVisionTransformer(nn.Module):
|
||||
rotary_pos_emb_cos=rotary_pos_emb_cos,
|
||||
rotary_pos_emb_sin=rotary_pos_emb_sin,
|
||||
max_seqlen=max_seqlen,
|
||||
seqlens=seqlens,
|
||||
)
|
||||
|
||||
# adapter
|
||||
|
||||
Reference in New Issue
Block a user