[CustomOp][MM] Extract MMEncoderAttention as CustomOp and replace the backend of QwenVisionAttention with it. (#30125)
Signed-off-by: shen-shanshan <467638484@qq.com> Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn> Signed-off-by: tjtanaa <tunjian.tan@embeddedllm.com> Co-authored-by: Isotr0py <mozf@mail2.sysu.edu.cn> Co-authored-by: tjtanaa <tunjian.tan@embeddedllm.com>
This commit is contained in:
@@ -42,13 +42,9 @@ from transformers.models.qwen2_5_vl.configuration_qwen2_5_vl import (
|
||||
)
|
||||
|
||||
from vllm.attention.backends.registry import AttentionBackendEnum
|
||||
from vllm.attention.layer import maybe_get_vit_flash_attn_backend
|
||||
from vllm.attention.ops.vit_attn_wrappers import (
|
||||
vit_flash_attn_wrapper,
|
||||
vit_torch_sdpa_wrapper,
|
||||
)
|
||||
from vllm.attention.layers.mm_encoder_attention import MMEncoderAttention
|
||||
from vllm.compilation.decorators import support_torch_compile
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.config import MultiModalConfig, VllmConfig
|
||||
from vllm.distributed import parallel_state
|
||||
from vllm.distributed import utils as dist_utils
|
||||
from vllm.forward_context import set_forward_context
|
||||
@@ -267,10 +263,15 @@ class Qwen2_5_VisionMLP(nn.Module):
|
||||
bias: bool = False,
|
||||
act_fn: Callable[[torch.Tensor], torch.Tensor] = F.silu,
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
multimodal_config: MultiModalConfig | None = None,
|
||||
prefix: str = "",
|
||||
use_data_parallel: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
use_data_parallel = (
|
||||
multimodal_config.mm_encoder_tp_mode == "data"
|
||||
if multimodal_config
|
||||
else False
|
||||
)
|
||||
self.gate_up_proj = MergedColumnParallelLinear(
|
||||
input_size=in_features,
|
||||
output_sizes=[hidden_features] * 2, # [gate_proj, up_proj]
|
||||
@@ -304,13 +305,16 @@ class Qwen2_5_VisionAttention(nn.Module):
|
||||
num_heads: int,
|
||||
projection_size: int,
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
multimodal_config: MultiModalConfig | None = None,
|
||||
prefix: str = "",
|
||||
use_data_parallel: bool = False,
|
||||
attn_backend: AttentionBackendEnum = AttentionBackendEnum.TORCH_SDPA,
|
||||
attn_backend_override: AttentionBackendEnum | None = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
# Per attention head and per partition values.
|
||||
use_data_parallel = (
|
||||
multimodal_config.mm_encoder_tp_mode == "data"
|
||||
if multimodal_config
|
||||
else False
|
||||
)
|
||||
self.tp_size = (
|
||||
1
|
||||
if use_data_parallel
|
||||
@@ -342,18 +346,12 @@ class Qwen2_5_VisionAttention(nn.Module):
|
||||
prefix=f"{prefix}.proj",
|
||||
disable_tp=use_data_parallel,
|
||||
)
|
||||
self.attn_backend = attn_backend
|
||||
self.attn_backend, self.flash_attn_varlen_func = (
|
||||
maybe_get_vit_flash_attn_backend(
|
||||
self.attn_backend,
|
||||
attn_backend_override=attn_backend_override,
|
||||
)
|
||||
)
|
||||
|
||||
self.is_flash_attn_backend = self.attn_backend in {
|
||||
AttentionBackendEnum.FLASH_ATTN,
|
||||
AttentionBackendEnum.ROCM_AITER_FA,
|
||||
}
|
||||
self.attn = MMEncoderAttention(
|
||||
num_heads=self.num_attention_heads_per_partition,
|
||||
head_size=self.hidden_size_per_attention_head,
|
||||
multimodal_config=multimodal_config,
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@@ -394,32 +392,17 @@ class Qwen2_5_VisionAttention(nn.Module):
|
||||
else:
|
||||
q, k, v = qkv.unbind(dim=2)
|
||||
|
||||
if self.is_flash_attn_backend:
|
||||
context_layer = vit_flash_attn_wrapper(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
cu_seqlens,
|
||||
max_seqlen,
|
||||
batch_size,
|
||||
self.attn_backend == AttentionBackendEnum.ROCM_AITER_FA,
|
||||
)
|
||||
elif self.attn_backend == AttentionBackendEnum.TORCH_SDPA:
|
||||
# Execute attention entry by entry for speed & less VRAM.
|
||||
from vllm.platforms import current_platform
|
||||
context_layer = self.attn(
|
||||
query=q,
|
||||
key=k,
|
||||
value=v,
|
||||
cu_seqlens=cu_seqlens,
|
||||
max_seqlen=max_seqlen,
|
||||
)
|
||||
|
||||
# Never remove the next contiguous logic
|
||||
# Without it, hallucinations occur with the backend
|
||||
if current_platform.is_rocm():
|
||||
q = q.contiguous()
|
||||
k = k.contiguous()
|
||||
v = v.contiguous()
|
||||
context_layer = vit_torch_sdpa_wrapper(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
cu_seqlens,
|
||||
)
|
||||
context_layer = einops.rearrange(
|
||||
context_layer, "b s h d -> s b (h d)", b=batch_size
|
||||
).contiguous()
|
||||
|
||||
output, _ = self.proj(context_layer)
|
||||
return output
|
||||
@@ -443,10 +426,8 @@ class Qwen2_5_VisionBlock(nn.Module):
|
||||
act_fn: Callable[[torch.Tensor], torch.Tensor] = F.silu,
|
||||
norm_layer: Callable[[int], nn.Module] | None = None,
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
multimodal_config: MultiModalConfig | None = None,
|
||||
prefix: str = "",
|
||||
use_data_parallel: bool = False,
|
||||
attn_backend: AttentionBackendEnum = AttentionBackendEnum.TORCH_SDPA,
|
||||
attn_backend_override: AttentionBackendEnum | None = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
if norm_layer is None:
|
||||
@@ -458,10 +439,8 @@ class Qwen2_5_VisionBlock(nn.Module):
|
||||
num_heads=num_heads,
|
||||
projection_size=dim,
|
||||
quant_config=quant_config,
|
||||
multimodal_config=multimodal_config,
|
||||
prefix=f"{prefix}.attn",
|
||||
use_data_parallel=use_data_parallel,
|
||||
attn_backend=attn_backend,
|
||||
attn_backend_override=attn_backend_override,
|
||||
)
|
||||
self.mlp = Qwen2_5_VisionMLP(
|
||||
dim,
|
||||
@@ -469,8 +448,8 @@ class Qwen2_5_VisionBlock(nn.Module):
|
||||
act_fn=act_fn,
|
||||
bias=True,
|
||||
quant_config=quant_config,
|
||||
multimodal_config=multimodal_config,
|
||||
prefix=f"{prefix}.mlp",
|
||||
use_data_parallel=use_data_parallel,
|
||||
)
|
||||
|
||||
def forward(
|
||||
@@ -542,10 +521,15 @@ class Qwen2_5_VisionPatchMerger(nn.Module):
|
||||
norm_layer: Callable[[int], nn.Module] | None = None,
|
||||
spatial_merge_size: int = 2,
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
multimodal_config: MultiModalConfig | None = None,
|
||||
prefix: str = "",
|
||||
use_data_parallel: bool = False,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
use_data_parallel = (
|
||||
multimodal_config.mm_encoder_tp_mode == "data"
|
||||
if multimodal_config
|
||||
else False
|
||||
)
|
||||
self.hidden_size = context_dim * (spatial_merge_size**2)
|
||||
if norm_layer is None:
|
||||
norm_layer = partial(nn.LayerNorm, eps=1e-6)
|
||||
@@ -586,9 +570,8 @@ class Qwen2_5_VisionTransformer(nn.Module):
|
||||
vision_config: Qwen2_5_VLVisionConfig,
|
||||
norm_eps: float = 1e-6,
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
multimodal_config: MultiModalConfig | None = None,
|
||||
prefix: str = "",
|
||||
use_data_parallel: bool = False,
|
||||
attn_backend_override: AttentionBackendEnum | None = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
@@ -598,7 +581,6 @@ class Qwen2_5_VisionTransformer(nn.Module):
|
||||
depth = vision_config.depth
|
||||
self.hidden_size = vision_config.hidden_size
|
||||
self.num_heads = vision_config.num_heads
|
||||
self.use_data_parallel = use_data_parallel
|
||||
self.out_hidden_size = vision_config.out_hidden_size
|
||||
|
||||
# args for get_window_index_thw
|
||||
@@ -629,19 +611,17 @@ class Qwen2_5_VisionTransformer(nn.Module):
|
||||
rope_parameters={"partial_rotary_factor": 0.5},
|
||||
)
|
||||
|
||||
attn_backend_override = (
|
||||
multimodal_config.mm_encoder_attn_backend
|
||||
if multimodal_config is not None
|
||||
else None
|
||||
)
|
||||
self.attn_backend = get_vit_attn_backend(
|
||||
head_size=head_dim,
|
||||
dtype=torch.get_default_dtype(),
|
||||
attn_backend_override=attn_backend_override,
|
||||
)
|
||||
|
||||
self.attn_backend, self.flash_attn_varlen_func = (
|
||||
maybe_get_vit_flash_attn_backend(
|
||||
self.attn_backend,
|
||||
attn_backend_override=attn_backend_override,
|
||||
)
|
||||
)
|
||||
|
||||
if self.attn_backend not in {
|
||||
AttentionBackendEnum.FLASH_ATTN,
|
||||
AttentionBackendEnum.TORCH_SDPA,
|
||||
@@ -661,10 +641,8 @@ class Qwen2_5_VisionTransformer(nn.Module):
|
||||
act_fn=get_act_and_mul_fn(vision_config.hidden_act),
|
||||
norm_layer=norm_layer,
|
||||
quant_config=quant_config,
|
||||
multimodal_config=multimodal_config,
|
||||
prefix=f"{prefix}.blocks.{layer_idx}",
|
||||
use_data_parallel=use_data_parallel,
|
||||
attn_backend=self.attn_backend,
|
||||
attn_backend_override=attn_backend_override,
|
||||
)
|
||||
for layer_idx in range(depth)
|
||||
]
|
||||
@@ -677,8 +655,8 @@ class Qwen2_5_VisionTransformer(nn.Module):
|
||||
norm_layer=norm_layer,
|
||||
spatial_merge_size=self.spatial_merge_size,
|
||||
quant_config=quant_config,
|
||||
multimodal_config=multimodal_config,
|
||||
prefix=f"{prefix}.merger",
|
||||
use_data_parallel=use_data_parallel,
|
||||
)
|
||||
|
||||
@property
|
||||
@@ -1200,18 +1178,12 @@ class Qwen2_5_VLForConditionalGeneration(
|
||||
if multimodal_config.get_limit_per_prompt(
|
||||
"image"
|
||||
) or multimodal_config.get_limit_per_prompt("video"):
|
||||
attn_backend_override = (
|
||||
multimodal_config.mm_encoder_attn_backend
|
||||
if multimodal_config is not None
|
||||
else None
|
||||
)
|
||||
self.visual = Qwen2_5_VisionTransformer(
|
||||
vision_config=config.vision_config,
|
||||
norm_eps=getattr(config, "rms_norm_eps", 1e-6),
|
||||
quant_config=self.quant_config,
|
||||
prefix=maybe_prefix(prefix, "visual"),
|
||||
use_data_parallel=self.use_data_parallel,
|
||||
attn_backend_override=attn_backend_override,
|
||||
multimodal_config=multimodal_config,
|
||||
)
|
||||
else:
|
||||
self.visual = None
|
||||
|
||||
Reference in New Issue
Block a user