[Bugfix][Qwen][Multimodal] Move Qwen2_5_vl sdpa to custom op and reenable compile (#27764)

Signed-off-by: Lucas Kabela <lucaskabela@meta.com>
This commit is contained in:
Lucas Kabela
2025-11-03 11:12:15 -08:00
committed by GitHub
parent a4398fbb5e
commit 55011aef24
2 changed files with 69 additions and 28 deletions

View File

@@ -46,6 +46,7 @@ from vllm.attention.backends.registry import _Backend
from vllm.attention.layer import maybe_get_vit_flash_attn_backend
from vllm.attention.ops.vit_attn_wrappers import (
vit_flash_attn_wrapper,
vit_torch_sdpa_wrapper,
vit_xformers_attn_wrapper,
)
from vllm.compilation.decorators import support_torch_compile
@@ -442,23 +443,12 @@ class Qwen2_5_VisionAttention(nn.Module):
q = q.contiguous()
k = k.contiguous()
v = v.contiguous()
outputs = []
for i in range(1, len(cu_seqlens)):
start_idx = cu_seqlens[i - 1]
end_idx = cu_seqlens[i]
q_i = q[:, start_idx:end_idx]
k_i = k[:, start_idx:end_idx]
v_i = v[:, start_idx:end_idx]
q_i, k_i, v_i = (
einops.rearrange(x, "b s h d -> b h s d") for x in [q_i, k_i, v_i]
)
output_i = F.scaled_dot_product_attention(q_i, k_i, v_i, dropout_p=0.0)
output_i = einops.rearrange(output_i, "b h s d -> b s h d ")
outputs.append(output_i)
context_layer = torch.cat(outputs, dim=1)
context_layer = einops.rearrange(
context_layer, "b s h d -> s b (h d)"
).contiguous()
context_layer = vit_torch_sdpa_wrapper(
q,
k,
v,
cu_seqlens,
)
elif self.attn_backend == _Backend.XFORMERS:
context_layer = vit_xformers_attn_wrapper(q, k, v, seqlens)
@@ -466,17 +456,15 @@ class Qwen2_5_VisionAttention(nn.Module):
return output
# (FIXME): Enable this after dynamic slicing is fixed
# See https://github.com/vllm-project/vllm/pull/27760
# @support_torch_compile(
# dynamic_arg_dims={
# "x": 0,
# "cu_seqlens": 0,
# "rotary_pos_emb": 0,
# "seqlens": 0,
# },
# mark_unbacked_dims={"seqlens": 0},
# )
@support_torch_compile(
dynamic_arg_dims={
"x": 0,
"cu_seqlens": 0,
"rotary_pos_emb": 0,
"seqlens": 0,
},
mark_unbacked_dims={"seqlens": 0},
)
class Qwen2_5_VisionBlock(nn.Module):
def __init__(
self,