[Models][Qwen] Replace pad with cat for better performance (#26486)
Signed-off-by: Lukas Geiger <lukas.geiger94@gmail.com>
This commit is contained in:
@@ -680,7 +680,7 @@ class DotsVisionTransformer(nn.Module):
|
||||
dim=0,
|
||||
dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32,
|
||||
)
|
||||
cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0)
|
||||
cu_seqlens = torch.cat([cu_seqlens.new_zeros(1), cu_seqlens])
|
||||
|
||||
max_seqlen, seqlens = self.compute_attn_mask_seqlen(cu_seqlens)
|
||||
for blk in self.blocks:
|
||||
|
||||
Reference in New Issue
Block a user