From cd86fff38feed579a22768ac9f9464360a0819fe Mon Sep 17 00:00:00 2001 From: JartX Date: Sun, 1 Feb 2026 14:36:25 +0100 Subject: [PATCH] [BUGFIX] Fix hipErrorIllegalState in Qwen3-Omni during startup profiling allow inference Omni on ROCM (#33077) Signed-off-by: JartX --- .../models/qwen3_omni_moe_thinker.py | 38 +++++++++++++++---- 1 file changed, 31 insertions(+), 7 deletions(-) diff --git a/vllm/model_executor/models/qwen3_omni_moe_thinker.py b/vllm/model_executor/models/qwen3_omni_moe_thinker.py index d9a0b9923..4d797528f 100755 --- a/vllm/model_executor/models/qwen3_omni_moe_thinker.py +++ b/vllm/model_executor/models/qwen3_omni_moe_thinker.py @@ -907,13 +907,37 @@ class Qwen3Omni_VisionTransformer(nn.Module): hidden_states = hidden_states + pos_embeds rotary_pos_emb_cos, rotary_pos_emb_sin = self.rot_pos_emb(grid_thw) - cu_seqlens = torch.repeat_interleave( - grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0] - ).cumsum( - dim=0, - dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32, - ) - cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0) + # RDNA3 (gfx11) specific bug workaround: torch.repeat_interleave triggers + # kernel crashes. We attempt the operation and catch the RuntimeError + # to switch to a vectorized cumsum + searchsorted approach. + try: + cu_seqlens = torch.repeat_interleave( + grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0] + ).cumsum( + dim=0, + dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32, + ) + cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0) + except RuntimeError: + logger.warning( + "torch.repeat_interleave not executable, " + "switching to vectorized searchsorted implementation." + ) + repeat_counts = grid_thw[:, 0] + values = grid_thw[:, 1] * grid_thw[:, 2] + repeat_cumsum = repeat_counts.cumsum(0) + total_items = repeat_cumsum[-1].item() + + indices = torch.searchsorted( + repeat_cumsum, + torch.arange(total_items, device=grid_thw.device), + right=True, + ) + cu_seqlens = values[indices].cumsum( + dim=0, + dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32, + ) + cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0) hidden_states = hidden_states.unsqueeze(1) rotary_pos_emb_cos = rotary_pos_emb_cos.to(hidden_states.device)