[Core] Deprecate xformers (#29262)
Signed-off-by: Roger Wang <hey@rogerw.io>
This commit is contained in:
@@ -235,7 +235,6 @@ class Qwen3_VisionBlock(nn.Module):
|
||||
rotary_pos_emb_cos: torch.Tensor,
|
||||
rotary_pos_emb_sin: torch.Tensor,
|
||||
max_seqlen: torch.Tensor, # Only used for Flash Attention
|
||||
seqlens: torch.Tensor, # Only used for xFormers
|
||||
) -> torch.Tensor:
|
||||
x = x + self.attn(
|
||||
self.norm1(x),
|
||||
@@ -243,7 +242,6 @@ class Qwen3_VisionBlock(nn.Module):
|
||||
rotary_pos_emb_cos=rotary_pos_emb_cos,
|
||||
rotary_pos_emb_sin=rotary_pos_emb_sin,
|
||||
max_seqlen=max_seqlen,
|
||||
seqlens=seqlens,
|
||||
)
|
||||
|
||||
x = x + self.mlp(self.norm2(x))
|
||||
@@ -391,7 +389,6 @@ class Qwen3_VisionTransformer(nn.Module):
|
||||
if self.attn_backend not in {
|
||||
AttentionBackendEnum.FLASH_ATTN,
|
||||
AttentionBackendEnum.TORCH_SDPA,
|
||||
AttentionBackendEnum.XFORMERS,
|
||||
AttentionBackendEnum.ROCM_AITER_FA,
|
||||
}:
|
||||
raise RuntimeError(
|
||||
@@ -531,17 +528,14 @@ class Qwen3_VisionTransformer(nn.Module):
|
||||
def compute_attn_mask_seqlen(
|
||||
self,
|
||||
cu_seqlens: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
) -> torch.Tensor:
|
||||
max_seqlen = torch.zeros([], device=cu_seqlens.device)
|
||||
seqlens = torch.zeros(1, device=cu_seqlens.device)
|
||||
if (
|
||||
self.attn_backend == AttentionBackendEnum.FLASH_ATTN
|
||||
or self.attn_backend == AttentionBackendEnum.ROCM_AITER_FA
|
||||
):
|
||||
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max()
|
||||
elif self.attn_backend == AttentionBackendEnum.XFORMERS:
|
||||
seqlens = cu_seqlens[1:] - cu_seqlens[:-1]
|
||||
return max_seqlen, seqlens
|
||||
return max_seqlen
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@@ -569,7 +563,7 @@ class Qwen3_VisionTransformer(nn.Module):
|
||||
cu_seqlens = torch.from_numpy(cu_seqlens)
|
||||
|
||||
hidden_states = hidden_states.unsqueeze(1)
|
||||
max_seqlen, seqlens = self.compute_attn_mask_seqlen(cu_seqlens)
|
||||
max_seqlen = self.compute_attn_mask_seqlen(cu_seqlens)
|
||||
cu_seqlens = cu_seqlens.to(self.device, non_blocking=True)
|
||||
|
||||
deepstack_feature_lists = []
|
||||
@@ -580,7 +574,6 @@ class Qwen3_VisionTransformer(nn.Module):
|
||||
rotary_pos_emb_cos=rotary_pos_emb_cos,
|
||||
rotary_pos_emb_sin=rotary_pos_emb_sin,
|
||||
max_seqlen=max_seqlen,
|
||||
seqlens=seqlens,
|
||||
)
|
||||
if layer_num in self.deepstack_visual_indexes:
|
||||
deepstack_merger_idx = self.deepstack_visual_indexes.index(layer_num)
|
||||
|
||||
Reference in New Issue
Block a user