[MM][OOT] Support CPU seq_lens for OOT MMEncoderAttention kernels (#36605)
Signed-off-by: shen-shanshan <467638484@qq.com> Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn> Co-authored-by: Isotr0py <mozf@mail2.sysu.edu.cn>
This commit is contained in:
@@ -983,13 +983,11 @@ class Qwen3Omni_VisionTransformer(nn.Module):
|
||||
grid_thw_np[:, 1] * grid_thw_np[:, 2], grid_thw_np[:, 0]
|
||||
).cumsum(axis=0, dtype=np.int32)
|
||||
cu_seqlens_np = np.concatenate([np.zeros(1, dtype=np.int32), cu_seqlens_np])
|
||||
sequence_lengths = MMEncoderAttention.maybe_compute_sequence_lengths(
|
||||
self.attn_backend, cu_seqlens_np
|
||||
sequence_lengths = MMEncoderAttention.maybe_compute_seq_lens(
|
||||
self.attn_backend,
|
||||
cu_seqlens_np,
|
||||
self.device,
|
||||
)
|
||||
if sequence_lengths is not None:
|
||||
sequence_lengths = torch.from_numpy(sequence_lengths).to(
|
||||
self.device, non_blocking=True
|
||||
)
|
||||
|
||||
hidden_states_list = []
|
||||
deepstack_visual_indexes = self.deepstack_visual_indexes
|
||||
|
||||
@@ -550,13 +550,9 @@ class Qwen3_VisionTransformer(nn.Module):
|
||||
axis=0, dtype=np.int32
|
||||
)
|
||||
cu_seqlens = np.concatenate([np.zeros(1, dtype=np.int32), cu_seqlens])
|
||||
sequence_lengths = MMEncoderAttention.maybe_compute_sequence_lengths(
|
||||
self.attn_backend, cu_seqlens
|
||||
sequence_lengths = MMEncoderAttention.maybe_compute_seq_lens(
|
||||
self.attn_backend, cu_seqlens, self.device
|
||||
)
|
||||
if sequence_lengths is not None:
|
||||
sequence_lengths = torch.from_numpy(sequence_lengths).to(
|
||||
self.device, non_blocking=True
|
||||
)
|
||||
max_seqlen = torch.tensor(
|
||||
MMEncoderAttention.compute_max_seqlen(self.attn_backend, cu_seqlens),
|
||||
dtype=torch.int32,
|
||||
@@ -567,8 +563,8 @@ class Qwen3_VisionTransformer(nn.Module):
|
||||
cu_seqlens,
|
||||
self.hidden_size,
|
||||
self.tp_size,
|
||||
self.device,
|
||||
)
|
||||
cu_seqlens = torch.from_numpy(cu_seqlens).to(self.device, non_blocking=True)
|
||||
hidden_states = hidden_states.unsqueeze(1)
|
||||
|
||||
deepstack_feature_lists = []
|
||||
|
||||
Reference in New Issue
Block a user