[CustomOp][MM] Extract MMEncoderAttention as CustomOp and replace the backend of QwenVisionAttention with it. (#30125)

Signed-off-by: shen-shanshan <467638484@qq.com>
Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
Signed-off-by: tjtanaa <tunjian.tan@embeddedllm.com>
Co-authored-by: Isotr0py <mozf@mail2.sysu.edu.cn>
Co-authored-by: tjtanaa <tunjian.tan@embeddedllm.com>
This commit is contained in:
Shanshan Shen
2025-12-15 11:13:32 +08:00
committed by GitHub
parent 84e23d103d
commit 87b4d1557d
24 changed files with 1262 additions and 851 deletions

View File

@@ -42,13 +42,9 @@ from transformers.models.qwen2_5_vl.configuration_qwen2_5_vl import (
)
from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.attention.layer import maybe_get_vit_flash_attn_backend
from vllm.attention.ops.vit_attn_wrappers import (
vit_flash_attn_wrapper,
vit_torch_sdpa_wrapper,
)
from vllm.attention.layers.mm_encoder_attention import MMEncoderAttention
from vllm.compilation.decorators import support_torch_compile
from vllm.config import VllmConfig
from vllm.config import MultiModalConfig, VllmConfig
from vllm.distributed import parallel_state
from vllm.distributed import utils as dist_utils
from vllm.forward_context import set_forward_context
@@ -267,10 +263,15 @@ class Qwen2_5_VisionMLP(nn.Module):
bias: bool = False,
act_fn: Callable[[torch.Tensor], torch.Tensor] = F.silu,
quant_config: QuantizationConfig | None = None,
multimodal_config: MultiModalConfig | None = None,
prefix: str = "",
use_data_parallel: bool = False,
):
super().__init__()
use_data_parallel = (
multimodal_config.mm_encoder_tp_mode == "data"
if multimodal_config
else False
)
self.gate_up_proj = MergedColumnParallelLinear(
input_size=in_features,
output_sizes=[hidden_features] * 2, # [gate_proj, up_proj]
@@ -304,13 +305,16 @@ class Qwen2_5_VisionAttention(nn.Module):
num_heads: int,
projection_size: int,
quant_config: QuantizationConfig | None = None,
multimodal_config: MultiModalConfig | None = None,
prefix: str = "",
use_data_parallel: bool = False,
attn_backend: AttentionBackendEnum = AttentionBackendEnum.TORCH_SDPA,
attn_backend_override: AttentionBackendEnum | None = None,
) -> None:
super().__init__()
# Per attention head and per partition values.
use_data_parallel = (
multimodal_config.mm_encoder_tp_mode == "data"
if multimodal_config
else False
)
self.tp_size = (
1
if use_data_parallel
@@ -342,18 +346,12 @@ class Qwen2_5_VisionAttention(nn.Module):
prefix=f"{prefix}.proj",
disable_tp=use_data_parallel,
)
self.attn_backend = attn_backend
self.attn_backend, self.flash_attn_varlen_func = (
maybe_get_vit_flash_attn_backend(
self.attn_backend,
attn_backend_override=attn_backend_override,
)
)
self.is_flash_attn_backend = self.attn_backend in {
AttentionBackendEnum.FLASH_ATTN,
AttentionBackendEnum.ROCM_AITER_FA,
}
self.attn = MMEncoderAttention(
num_heads=self.num_attention_heads_per_partition,
head_size=self.hidden_size_per_attention_head,
multimodal_config=multimodal_config,
)
def forward(
self,
@@ -394,32 +392,17 @@ class Qwen2_5_VisionAttention(nn.Module):
else:
q, k, v = qkv.unbind(dim=2)
if self.is_flash_attn_backend:
context_layer = vit_flash_attn_wrapper(
q,
k,
v,
cu_seqlens,
max_seqlen,
batch_size,
self.attn_backend == AttentionBackendEnum.ROCM_AITER_FA,
)
elif self.attn_backend == AttentionBackendEnum.TORCH_SDPA:
# Execute attention entry by entry for speed & less VRAM.
from vllm.platforms import current_platform
context_layer = self.attn(
query=q,
key=k,
value=v,
cu_seqlens=cu_seqlens,
max_seqlen=max_seqlen,
)
# Never remove the next contiguous logic
# Without it, hallucinations occur with the backend
if current_platform.is_rocm():
q = q.contiguous()
k = k.contiguous()
v = v.contiguous()
context_layer = vit_torch_sdpa_wrapper(
q,
k,
v,
cu_seqlens,
)
context_layer = einops.rearrange(
context_layer, "b s h d -> s b (h d)", b=batch_size
).contiguous()
output, _ = self.proj(context_layer)
return output
@@ -443,10 +426,8 @@ class Qwen2_5_VisionBlock(nn.Module):
act_fn: Callable[[torch.Tensor], torch.Tensor] = F.silu,
norm_layer: Callable[[int], nn.Module] | None = None,
quant_config: QuantizationConfig | None = None,
multimodal_config: MultiModalConfig | None = None,
prefix: str = "",
use_data_parallel: bool = False,
attn_backend: AttentionBackendEnum = AttentionBackendEnum.TORCH_SDPA,
attn_backend_override: AttentionBackendEnum | None = None,
) -> None:
super().__init__()
if norm_layer is None:
@@ -458,10 +439,8 @@ class Qwen2_5_VisionBlock(nn.Module):
num_heads=num_heads,
projection_size=dim,
quant_config=quant_config,
multimodal_config=multimodal_config,
prefix=f"{prefix}.attn",
use_data_parallel=use_data_parallel,
attn_backend=attn_backend,
attn_backend_override=attn_backend_override,
)
self.mlp = Qwen2_5_VisionMLP(
dim,
@@ -469,8 +448,8 @@ class Qwen2_5_VisionBlock(nn.Module):
act_fn=act_fn,
bias=True,
quant_config=quant_config,
multimodal_config=multimodal_config,
prefix=f"{prefix}.mlp",
use_data_parallel=use_data_parallel,
)
def forward(
@@ -542,10 +521,15 @@ class Qwen2_5_VisionPatchMerger(nn.Module):
norm_layer: Callable[[int], nn.Module] | None = None,
spatial_merge_size: int = 2,
quant_config: QuantizationConfig | None = None,
multimodal_config: MultiModalConfig | None = None,
prefix: str = "",
use_data_parallel: bool = False,
) -> None:
super().__init__()
use_data_parallel = (
multimodal_config.mm_encoder_tp_mode == "data"
if multimodal_config
else False
)
self.hidden_size = context_dim * (spatial_merge_size**2)
if norm_layer is None:
norm_layer = partial(nn.LayerNorm, eps=1e-6)
@@ -586,9 +570,8 @@ class Qwen2_5_VisionTransformer(nn.Module):
vision_config: Qwen2_5_VLVisionConfig,
norm_eps: float = 1e-6,
quant_config: QuantizationConfig | None = None,
multimodal_config: MultiModalConfig | None = None,
prefix: str = "",
use_data_parallel: bool = False,
attn_backend_override: AttentionBackendEnum | None = None,
) -> None:
super().__init__()
@@ -598,7 +581,6 @@ class Qwen2_5_VisionTransformer(nn.Module):
depth = vision_config.depth
self.hidden_size = vision_config.hidden_size
self.num_heads = vision_config.num_heads
self.use_data_parallel = use_data_parallel
self.out_hidden_size = vision_config.out_hidden_size
# args for get_window_index_thw
@@ -629,19 +611,17 @@ class Qwen2_5_VisionTransformer(nn.Module):
rope_parameters={"partial_rotary_factor": 0.5},
)
attn_backend_override = (
multimodal_config.mm_encoder_attn_backend
if multimodal_config is not None
else None
)
self.attn_backend = get_vit_attn_backend(
head_size=head_dim,
dtype=torch.get_default_dtype(),
attn_backend_override=attn_backend_override,
)
self.attn_backend, self.flash_attn_varlen_func = (
maybe_get_vit_flash_attn_backend(
self.attn_backend,
attn_backend_override=attn_backend_override,
)
)
if self.attn_backend not in {
AttentionBackendEnum.FLASH_ATTN,
AttentionBackendEnum.TORCH_SDPA,
@@ -661,10 +641,8 @@ class Qwen2_5_VisionTransformer(nn.Module):
act_fn=get_act_and_mul_fn(vision_config.hidden_act),
norm_layer=norm_layer,
quant_config=quant_config,
multimodal_config=multimodal_config,
prefix=f"{prefix}.blocks.{layer_idx}",
use_data_parallel=use_data_parallel,
attn_backend=self.attn_backend,
attn_backend_override=attn_backend_override,
)
for layer_idx in range(depth)
]
@@ -677,8 +655,8 @@ class Qwen2_5_VisionTransformer(nn.Module):
norm_layer=norm_layer,
spatial_merge_size=self.spatial_merge_size,
quant_config=quant_config,
multimodal_config=multimodal_config,
prefix=f"{prefix}.merger",
use_data_parallel=use_data_parallel,
)
@property
@@ -1200,18 +1178,12 @@ class Qwen2_5_VLForConditionalGeneration(
if multimodal_config.get_limit_per_prompt(
"image"
) or multimodal_config.get_limit_per_prompt("video"):
attn_backend_override = (
multimodal_config.mm_encoder_attn_backend
if multimodal_config is not None
else None
)
self.visual = Qwen2_5_VisionTransformer(
vision_config=config.vision_config,
norm_eps=getattr(config, "rms_norm_eps", 1e-6),
quant_config=self.quant_config,
prefix=maybe_prefix(prefix, "visual"),
use_data_parallel=self.use_data_parallel,
attn_backend_override=attn_backend_override,
multimodal_config=multimodal_config,
)
else:
self.visual = None