[Bugfix] Fix missing sequence_lengths in qwen3_omni_moe_thinker (#35741)
Signed-off-by: Ye (Charlotte) Qi <yeq@meta.com>
This commit is contained in:
committed by
GitHub
parent
cad21918e3
commit
fa6a6be519
@@ -648,6 +648,7 @@ class Qwen3_VisionBlock(nn.Module):
|
|||||||
rotary_pos_emb_cos: torch.Tensor,
|
rotary_pos_emb_cos: torch.Tensor,
|
||||||
rotary_pos_emb_sin: torch.Tensor,
|
rotary_pos_emb_sin: torch.Tensor,
|
||||||
max_seqlen: torch.Tensor | None, # Only used for Flash Attention
|
max_seqlen: torch.Tensor | None, # Only used for Flash Attention
|
||||||
|
sequence_lengths: torch.Tensor | None, # Only used for FlashInfer CuDNN backend
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
x = x + self.attn(
|
x = x + self.attn(
|
||||||
self.norm1(x),
|
self.norm1(x),
|
||||||
@@ -655,6 +656,7 @@ class Qwen3_VisionBlock(nn.Module):
|
|||||||
rotary_pos_emb_cos=rotary_pos_emb_cos,
|
rotary_pos_emb_cos=rotary_pos_emb_cos,
|
||||||
rotary_pos_emb_sin=rotary_pos_emb_sin,
|
rotary_pos_emb_sin=rotary_pos_emb_sin,
|
||||||
max_seqlen=max_seqlen,
|
max_seqlen=max_seqlen,
|
||||||
|
sequence_lengths=sequence_lengths,
|
||||||
)
|
)
|
||||||
|
|
||||||
x = x + self.mlp(self.norm2(x))
|
x = x + self.mlp(self.norm2(x))
|
||||||
@@ -975,6 +977,20 @@ class Qwen3Omni_VisionTransformer(nn.Module):
|
|||||||
rotary_pos_emb_sin = rotary_pos_emb_sin.to(hidden_states.device)
|
rotary_pos_emb_sin = rotary_pos_emb_sin.to(hidden_states.device)
|
||||||
max_seqlen = self.compute_attn_mask_seqlen(cu_seqlens)
|
max_seqlen = self.compute_attn_mask_seqlen(cu_seqlens)
|
||||||
|
|
||||||
|
# Recompute cu_seqlens in numpy from grid_thw to avoid GPU->CPU sync
|
||||||
|
grid_thw_np = grid_thw.cpu().numpy()
|
||||||
|
cu_seqlens_np = np.repeat(
|
||||||
|
grid_thw_np[:, 1] * grid_thw_np[:, 2], grid_thw_np[:, 0]
|
||||||
|
).cumsum(axis=0, dtype=np.int32)
|
||||||
|
cu_seqlens_np = np.concatenate([np.zeros(1, dtype=np.int32), cu_seqlens_np])
|
||||||
|
sequence_lengths = MMEncoderAttention.maybe_compute_sequence_lengths(
|
||||||
|
self.attn_backend, cu_seqlens_np
|
||||||
|
)
|
||||||
|
if sequence_lengths is not None:
|
||||||
|
sequence_lengths = torch.from_numpy(sequence_lengths).to(
|
||||||
|
self.device, non_blocking=True
|
||||||
|
)
|
||||||
|
|
||||||
hidden_states_list = []
|
hidden_states_list = []
|
||||||
deepstack_visual_indexes = self.deepstack_visual_indexes
|
deepstack_visual_indexes = self.deepstack_visual_indexes
|
||||||
|
|
||||||
@@ -985,6 +1001,7 @@ class Qwen3Omni_VisionTransformer(nn.Module):
|
|||||||
rotary_pos_emb_cos=rotary_pos_emb_cos,
|
rotary_pos_emb_cos=rotary_pos_emb_cos,
|
||||||
rotary_pos_emb_sin=rotary_pos_emb_sin,
|
rotary_pos_emb_sin=rotary_pos_emb_sin,
|
||||||
max_seqlen=max_seqlen,
|
max_seqlen=max_seqlen,
|
||||||
|
sequence_lengths=sequence_lengths,
|
||||||
)
|
)
|
||||||
if (
|
if (
|
||||||
deepstack_visual_indexes is not None
|
deepstack_visual_indexes is not None
|
||||||
|
|||||||
Reference in New Issue
Block a user