[Bugfix] Fix incorrect qwen2.5-vl attention mask pre-computation (#15200)

Signed-off-by: Isotr0py <2037008807@qq.com>
This commit is contained in:
Isotr0py
2025-03-21 10:18:04 +08:00
committed by GitHub
parent 2e0b4cfde0
commit 1e508343e1
3 changed files with 37 additions and 4 deletions

View File

@@ -647,15 +647,17 @@ class Qwen2_5_VisionTransformer(nn.Module):
max_seqlen = None
seqlens = None
if self.attn_backend == _Backend.FLASH_ATTN:
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
elif self.attn_backend == _Backend.XFORMERS:
seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()
for layer_num, blk in enumerate(self.blocks):
if layer_num in self.fullatt_block_indexes:
cu_seqlens_now = cu_seqlens
else:
cu_seqlens_now = cu_window_seqlens
# pre-compute cu_seqlens for window attn
if self.attn_backend == _Backend.FLASH_ATTN:
max_seqlen = (cu_seqlens_now[1:] -
cu_seqlens_now[:-1]).max().item()
elif self.attn_backend == _Backend.XFORMERS:
seqlens = (cu_seqlens_now[1:] - cu_seqlens_now[:-1]).tolist()
hidden_states = blk(
hidden_states,
cu_seqlens=cu_seqlens_now,