[BUGFIX] Fix hipErrorIllegalState in Qwen3-Omni during startup profiling allow inference Omni on ROCM (#33077)
Signed-off-by: JartX <sagformas@epdcenter.es>
This commit is contained in:
@@ -907,13 +907,37 @@ class Qwen3Omni_VisionTransformer(nn.Module):
|
|||||||
hidden_states = hidden_states + pos_embeds
|
hidden_states = hidden_states + pos_embeds
|
||||||
rotary_pos_emb_cos, rotary_pos_emb_sin = self.rot_pos_emb(grid_thw)
|
rotary_pos_emb_cos, rotary_pos_emb_sin = self.rot_pos_emb(grid_thw)
|
||||||
|
|
||||||
cu_seqlens = torch.repeat_interleave(
|
# RDNA3 (gfx11) specific bug workaround: torch.repeat_interleave triggers
|
||||||
grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]
|
# kernel crashes. We attempt the operation and catch the RuntimeError
|
||||||
).cumsum(
|
# to switch to a vectorized cumsum + searchsorted approach.
|
||||||
dim=0,
|
try:
|
||||||
dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32,
|
cu_seqlens = torch.repeat_interleave(
|
||||||
)
|
grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]
|
||||||
cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0)
|
).cumsum(
|
||||||
|
dim=0,
|
||||||
|
dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32,
|
||||||
|
)
|
||||||
|
cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0)
|
||||||
|
except RuntimeError:
|
||||||
|
logger.warning(
|
||||||
|
"torch.repeat_interleave not executable, "
|
||||||
|
"switching to vectorized searchsorted implementation."
|
||||||
|
)
|
||||||
|
repeat_counts = grid_thw[:, 0]
|
||||||
|
values = grid_thw[:, 1] * grid_thw[:, 2]
|
||||||
|
repeat_cumsum = repeat_counts.cumsum(0)
|
||||||
|
total_items = repeat_cumsum[-1].item()
|
||||||
|
|
||||||
|
indices = torch.searchsorted(
|
||||||
|
repeat_cumsum,
|
||||||
|
torch.arange(total_items, device=grid_thw.device),
|
||||||
|
right=True,
|
||||||
|
)
|
||||||
|
cu_seqlens = values[indices].cumsum(
|
||||||
|
dim=0,
|
||||||
|
dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32,
|
||||||
|
)
|
||||||
|
cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0)
|
||||||
|
|
||||||
hidden_states = hidden_states.unsqueeze(1)
|
hidden_states = hidden_states.unsqueeze(1)
|
||||||
rotary_pos_emb_cos = rotary_pos_emb_cos.to(hidden_states.device)
|
rotary_pos_emb_cos = rotary_pos_emb_cos.to(hidden_states.device)
|
||||||
|
|||||||
Reference in New Issue
Block a user