[MM][OOT] Support CPU seq_lens for OOT MMEncoderAttention kernels (#36605)

Signed-off-by: shen-shanshan <467638484@qq.com>
Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
Co-authored-by: Isotr0py <mozf@mail2.sysu.edu.cn>
This commit is contained in:
Shanshan Shen
2026-03-12 18:28:23 +08:00
committed by GitHub
parent 57431d8231
commit f0d3658c0f
5 changed files with 52 additions and 40 deletions

View File

@@ -983,13 +983,11 @@ class Qwen3Omni_VisionTransformer(nn.Module):
grid_thw_np[:, 1] * grid_thw_np[:, 2], grid_thw_np[:, 0]
).cumsum(axis=0, dtype=np.int32)
cu_seqlens_np = np.concatenate([np.zeros(1, dtype=np.int32), cu_seqlens_np])
sequence_lengths = MMEncoderAttention.maybe_compute_sequence_lengths(
self.attn_backend, cu_seqlens_np
sequence_lengths = MMEncoderAttention.maybe_compute_seq_lens(
self.attn_backend,
cu_seqlens_np,
self.device,
)
if sequence_lengths is not None:
sequence_lengths = torch.from_numpy(sequence_lengths).to(
self.device, non_blocking=True
)
hidden_states_list = []
deepstack_visual_indexes = self.deepstack_visual_indexes

View File

@@ -550,13 +550,9 @@ class Qwen3_VisionTransformer(nn.Module):
axis=0, dtype=np.int32
)
cu_seqlens = np.concatenate([np.zeros(1, dtype=np.int32), cu_seqlens])
sequence_lengths = MMEncoderAttention.maybe_compute_sequence_lengths(
self.attn_backend, cu_seqlens
sequence_lengths = MMEncoderAttention.maybe_compute_seq_lens(
self.attn_backend, cu_seqlens, self.device
)
if sequence_lengths is not None:
sequence_lengths = torch.from_numpy(sequence_lengths).to(
self.device, non_blocking=True
)
max_seqlen = torch.tensor(
MMEncoderAttention.compute_max_seqlen(self.attn_backend, cu_seqlens),
dtype=torch.int32,
@@ -567,8 +563,8 @@ class Qwen3_VisionTransformer(nn.Module):
cu_seqlens,
self.hidden_size,
self.tp_size,
self.device,
)
cu_seqlens = torch.from_numpy(cu_seqlens).to(self.device, non_blocking=True)
hidden_states = hidden_states.unsqueeze(1)
deepstack_feature_lists = []