[Bugfix] Fix missing sequence_lengths in qwen3_omni_moe_thinker (#35741)

Signed-off-by: Ye (Charlotte) Qi <yeq@meta.com>
This commit is contained in:
Ye (Charlotte) Qi
2026-03-02 13:11:56 -08:00
committed by GitHub
parent cad21918e3
commit fa6a6be519

View File

@@ -648,6 +648,7 @@ class Qwen3_VisionBlock(nn.Module):
rotary_pos_emb_cos: torch.Tensor,
rotary_pos_emb_sin: torch.Tensor,
max_seqlen: torch.Tensor | None, # Only used for Flash Attention
sequence_lengths: torch.Tensor | None, # Only used for FlashInfer CuDNN backend
) -> torch.Tensor:
x = x + self.attn(
self.norm1(x),
@@ -655,6 +656,7 @@ class Qwen3_VisionBlock(nn.Module):
rotary_pos_emb_cos=rotary_pos_emb_cos,
rotary_pos_emb_sin=rotary_pos_emb_sin,
max_seqlen=max_seqlen,
sequence_lengths=sequence_lengths,
)
x = x + self.mlp(self.norm2(x))
@@ -975,6 +977,20 @@ class Qwen3Omni_VisionTransformer(nn.Module):
rotary_pos_emb_sin = rotary_pos_emb_sin.to(hidden_states.device)
max_seqlen = self.compute_attn_mask_seqlen(cu_seqlens)
# Recompute cu_seqlens in numpy from grid_thw to avoid GPU->CPU sync
grid_thw_np = grid_thw.cpu().numpy()
cu_seqlens_np = np.repeat(
grid_thw_np[:, 1] * grid_thw_np[:, 2], grid_thw_np[:, 0]
).cumsum(axis=0, dtype=np.int32)
cu_seqlens_np = np.concatenate([np.zeros(1, dtype=np.int32), cu_seqlens_np])
sequence_lengths = MMEncoderAttention.maybe_compute_sequence_lengths(
self.attn_backend, cu_seqlens_np
)
if sequence_lengths is not None:
sequence_lengths = torch.from_numpy(sequence_lengths).to(
self.device, non_blocking=True
)
hidden_states_list = []
deepstack_visual_indexes = self.deepstack_visual_indexes
@@ -985,6 +1001,7 @@ class Qwen3Omni_VisionTransformer(nn.Module):
rotary_pos_emb_cos=rotary_pos_emb_cos,
rotary_pos_emb_sin=rotary_pos_emb_sin,
max_seqlen=max_seqlen,
sequence_lengths=sequence_lengths,
)
if (
deepstack_visual_indexes is not None