[CustomOp][MM] Extract MMEncoderAttention as CustomOp and replace the backend of QwenVisionAttention with it. (#30125)
Signed-off-by: shen-shanshan <467638484@qq.com> Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn> Signed-off-by: tjtanaa <tunjian.tan@embeddedllm.com> Co-authored-by: Isotr0py <mozf@mail2.sysu.edu.cn> Co-authored-by: tjtanaa <tunjian.tan@embeddedllm.com>
This commit is contained in:
@@ -50,7 +50,7 @@ from transformers.video_utils import VideoMetadata
|
||||
|
||||
from vllm.attention.backends.registry import AttentionBackendEnum
|
||||
from vllm.compilation.decorators import support_torch_compile
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.config import MultiModalConfig, VllmConfig
|
||||
from vllm.config.multimodal import BaseDummyOptions, VideoDummyOptions
|
||||
from vllm.distributed import get_pp_group
|
||||
from vllm.logger import init_logger
|
||||
@@ -169,10 +169,15 @@ class Qwen3_VisionMLP(nn.Module):
|
||||
bias: bool = False,
|
||||
act_fn: Callable[[torch.Tensor], torch.Tensor] = F.silu,
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
multimodal_config: MultiModalConfig | None = None,
|
||||
prefix: str = "",
|
||||
use_data_parallel: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
use_data_parallel = (
|
||||
multimodal_config.mm_encoder_tp_mode == "data"
|
||||
if multimodal_config
|
||||
else False
|
||||
)
|
||||
self.linear_fc1 = ColumnParallelLinear(
|
||||
in_features,
|
||||
hidden_features,
|
||||
@@ -206,10 +211,9 @@ class Qwen3_VisionBlock(nn.Module):
|
||||
mlp_hidden_dim: int,
|
||||
act_fn: Callable[[torch.Tensor], torch.Tensor] = F.silu,
|
||||
norm_layer: Callable[[int], nn.Module] | None = None,
|
||||
multimodal_config: MultiModalConfig | None = None,
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
prefix: str = "",
|
||||
use_data_parallel: bool = False,
|
||||
attn_backend: AttentionBackendEnum = AttentionBackendEnum.TORCH_SDPA,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
if norm_layer is None:
|
||||
@@ -221,9 +225,8 @@ class Qwen3_VisionBlock(nn.Module):
|
||||
num_heads=num_heads,
|
||||
projection_size=dim,
|
||||
quant_config=quant_config,
|
||||
multimodal_config=multimodal_config,
|
||||
prefix=f"{prefix}.attn",
|
||||
use_data_parallel=use_data_parallel,
|
||||
attn_backend=attn_backend,
|
||||
)
|
||||
self.mlp = Qwen3_VisionMLP(
|
||||
dim,
|
||||
@@ -231,8 +234,8 @@ class Qwen3_VisionBlock(nn.Module):
|
||||
act_fn=act_fn,
|
||||
bias=True,
|
||||
quant_config=quant_config,
|
||||
multimodal_config=multimodal_config,
|
||||
prefix=f"{prefix}.mlp",
|
||||
use_data_parallel=use_data_parallel,
|
||||
)
|
||||
|
||||
def forward(
|
||||
@@ -264,10 +267,15 @@ class Qwen3_VisionPatchMerger(nn.Module):
|
||||
spatial_merge_size: int = 2,
|
||||
use_postshuffle_norm: bool = False,
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
multimodal_config: MultiModalConfig | None = None,
|
||||
prefix: str = "",
|
||||
use_data_parallel: bool = False,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
use_data_parallel = (
|
||||
multimodal_config.mm_encoder_tp_mode == "data"
|
||||
if multimodal_config
|
||||
else False
|
||||
)
|
||||
self.hidden_size = context_dim * (spatial_merge_size**2)
|
||||
|
||||
self.use_postshuffle_norm = use_postshuffle_norm
|
||||
@@ -313,9 +321,8 @@ class Qwen3_VisionTransformer(nn.Module):
|
||||
vision_config: Qwen3VLVisionConfig,
|
||||
norm_eps: float = 1e-6,
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
multimodal_config: MultiModalConfig | None = None,
|
||||
prefix: str = "",
|
||||
use_data_parallel: bool = False,
|
||||
attn_backend_override: AttentionBackendEnum | None = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.hidden_size = vision_config.hidden_size
|
||||
@@ -326,7 +333,6 @@ class Qwen3_VisionTransformer(nn.Module):
|
||||
self.spatial_merge_unit = self.spatial_merge_size**2
|
||||
self.temporal_patch_size = vision_config.temporal_patch_size
|
||||
self.deepstack_visual_indexes = vision_config.deepstack_visual_indexes
|
||||
self.use_data_parallel = use_data_parallel
|
||||
self.num_grid_per_side = int(self.num_position_embeddings**0.5)
|
||||
|
||||
# NOTE: This is used for creating empty tensor for all_gather for
|
||||
@@ -359,8 +365,8 @@ class Qwen3_VisionTransformer(nn.Module):
|
||||
norm_layer=norm_layer,
|
||||
spatial_merge_size=self.spatial_merge_size,
|
||||
quant_config=quant_config,
|
||||
multimodal_config=multimodal_config,
|
||||
prefix=f"{prefix}.merger",
|
||||
use_data_parallel=use_data_parallel,
|
||||
)
|
||||
|
||||
self.deepstack_merger_list = nn.ModuleList(
|
||||
@@ -372,13 +378,16 @@ class Qwen3_VisionTransformer(nn.Module):
|
||||
use_postshuffle_norm=True,
|
||||
norm_layer=norm_layer,
|
||||
quant_config=quant_config,
|
||||
multimodal_config=multimodal_config,
|
||||
prefix=f"{prefix}.deepstack_merger_list.{layer_idx}",
|
||||
use_data_parallel=use_data_parallel,
|
||||
)
|
||||
for layer_idx in range(len(self.deepstack_visual_indexes))
|
||||
]
|
||||
)
|
||||
|
||||
attn_backend_override = (
|
||||
multimodal_config.mm_encoder_attn_backend if multimodal_config else None
|
||||
)
|
||||
self.attn_backend = get_vit_attn_backend(
|
||||
head_size=head_dim,
|
||||
dtype=torch.get_default_dtype(),
|
||||
@@ -402,9 +411,8 @@ class Qwen3_VisionTransformer(nn.Module):
|
||||
act_fn=_ACTIVATION_REGISTRY[vision_config.hidden_act],
|
||||
norm_layer=norm_layer,
|
||||
quant_config=quant_config,
|
||||
multimodal_config=multimodal_config,
|
||||
prefix=f"{prefix}.blocks.{layer_idx}",
|
||||
use_data_parallel=use_data_parallel,
|
||||
attn_backend=self.attn_backend,
|
||||
)
|
||||
for layer_idx in range(vision_config.depth)
|
||||
]
|
||||
@@ -1277,18 +1285,12 @@ class Qwen3VLForConditionalGeneration(
|
||||
) and not multimodal_config.get_limit_per_prompt("video"):
|
||||
self.visual = None
|
||||
else:
|
||||
attn_backend_override = (
|
||||
multimodal_config.mm_encoder_attn_backend
|
||||
if multimodal_config is not None
|
||||
else None
|
||||
)
|
||||
self.visual = Qwen3_VisionTransformer(
|
||||
config.vision_config,
|
||||
norm_eps=getattr(config, "rms_norm_eps", 1e-6),
|
||||
quant_config=quant_config,
|
||||
multimodal_config=multimodal_config,
|
||||
prefix=maybe_prefix(prefix, "visual"),
|
||||
use_data_parallel=self.use_data_parallel,
|
||||
attn_backend_override=attn_backend_override,
|
||||
)
|
||||
|
||||
self.language_model = Qwen3LLMForCausalLM(
|
||||
|
||||
Reference in New Issue
Block a user