[Models][Qwen] Replace pad with cat for better performance (#26486)

Signed-off-by: Lukas Geiger <lukas.geiger94@gmail.com>
This commit is contained in:
Lukas Geiger
2025-10-09 15:51:26 +01:00
committed by GitHub
parent e246ad6f0c
commit 2c1c7dfb35
4 changed files with 6 additions and 5 deletions

View File

@@ -539,7 +539,7 @@ class Qwen3_VisionTransformer(nn.Module):
dim=0,
dtype=grid_thw_tensor.dtype if torch.jit.is_tracing() else torch.int32,
)
cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0)
cu_seqlens = torch.cat([cu_seqlens.new_zeros(1), cu_seqlens])
hidden_states = hidden_states.unsqueeze(1)
rotary_pos_emb = rotary_pos_emb.to(hidden_states.device)