[Misc][qwen2_5_vl][torch.compile] Enable supports_torch_compile on generic nn.Module and demonstrate speedup on Qwen Vision model (#23207)

Signed-off-by: Lucas Kabela <lucaskabela@meta.com>
Signed-off-by: Lucas Kabela <lucasakabela@gmail.com>
This commit is contained in:
Lucas Kabela
2025-10-28 15:36:43 -07:00
committed by GitHub
parent 4fe5895361
commit 94666612a9
6 changed files with 334 additions and 97 deletions

View File

@@ -0,0 +1,36 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import pytest
from vllm.compilation.counter import compilation_counter
from vllm.config.compilation import CompilationMode
# forked needed to workaround https://github.com/vllm-project/vllm/issues/21073
@pytest.mark.forked
def test_qwen2_5_vl_compilation(vllm_runner, monkeypatch):
"""Test that Qwen2.5-VL vision submodules are compiled.
This test verifies that the 3 vision submodules (Qwen2_5_VisionPatchEmbed,
Qwen2_5_VisionBlock, and Qwen2_5_VisionPatchMerger) are properly tagged
for compilation by checking that num_models_seen increases by at least 3.
"""
# Disable multiprocessing so that the counter is in the same process
monkeypatch.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", "0")
with (
# NOTE: Qwen2.5-VL has 35 models in total - the LLM backend
# Vision Patch Embed, Vision Patch Merger, and then 32 Vision Blocks
# (one for each layer) - in the future, we should fix vLLM compilation
# logic to handle this case and only compile the Vision submodules once
# and reuse the compiled code for all layers
# See https://github.com/vllm-project/vllm/issues/27590
compilation_counter.expect(num_models_seen=35),
vllm_runner(
"Qwen/Qwen2.5-VL-3B-Instruct",
max_model_len=2048,
gpu_memory_utilization=0.7,
compilation_config={"mode": CompilationMode.VLLM_COMPILE},
) as _,
):
pass

View File

@@ -0,0 +1,125 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
This file contains ops for ViT attention to be compatible with torch.compile
as there are operations here not supported by torch.compile (for instance,
`to_list` in xformers attn, or `.item()` in flash attention)
Using these ops and wrapping vision blocks with `torch.compile` can speed up
throughput in vision models by ~5% relative on H100, and improve token
latencies by ~7% (see qwen2_5_vl for example usage)
To use these ops, you must have a recent version of PyTorch installed (>= 2.4.0)
"""
import einops
import torch
from vllm.utils.torch_utils import direct_register_custom_op
def xformers_attn_seqlens_wrapper(
q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, seqlens: torch.Tensor
) -> torch.Tensor:
from xformers import ops as xops
from xformers.ops.fmha.attn_bias import BlockDiagonalMask
attn_bias = BlockDiagonalMask.from_seqlens(
q_seqlen=seqlens.tolist(), kv_seqlen=None, device=q.device
)
context_layer = xops.memory_efficient_attention_forward(
q, k, v, attn_bias=attn_bias, p=0, scale=None
)
context_layer = einops.rearrange(context_layer, "b s h d -> s b (h d)").contiguous()
return context_layer
def xformers_attn_seqlens_wrapper_fake(
q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, seqlens: torch.Tensor
) -> torch.Tensor:
b, s, h, d = q.shape
return torch.empty((s, b, h * d), dtype=q.dtype, device=q.device)
direct_register_custom_op(
op_name="xformers_attn_seqlens_wrapper",
op_func=xformers_attn_seqlens_wrapper,
fake_impl=xformers_attn_seqlens_wrapper_fake,
)
def vit_xformers_attn_wrapper(
q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, seqlens: torch.Tensor
) -> torch.Tensor:
return torch.ops.vllm.xformers_attn_seqlens_wrapper(q, k, v, seqlens)
def flash_attn_maxseqlen_wrapper(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
cu_seqlens: torch.Tensor,
max_seqlen: torch.Tensor,
batch_size: int,
is_rocm_aiter: bool,
use_upstream_fa: bool,
) -> torch.Tensor:
if is_rocm_aiter:
from aiter import flash_attn_varlen_func
else:
if use_upstream_fa:
from flash_attn import flash_attn_varlen_func
else:
from vllm.vllm_flash_attn import flash_attn_varlen_func
q, k, v = (einops.rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v])
output = flash_attn_varlen_func(
q,
k,
v,
cu_seqlens_q=cu_seqlens,
cu_seqlens_k=cu_seqlens,
max_seqlen_q=max_seqlen.item(),
max_seqlen_k=max_seqlen.item(),
dropout_p=0.0,
causal=False,
)
context_layer = einops.rearrange(
output, "(b s) h d -> s b (h d)", b=batch_size
).contiguous()
return context_layer
def flash_attn_maxseqlen_wrapper_fake(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
cu_seqlens: torch.Tensor,
max_seqlen: torch.Tensor,
batch_size: int,
is_rocm_aiter: bool,
use_upstream_fa: bool,
) -> torch.Tensor:
b, s, h, d = q.shape
return torch.empty((s, b, h * d), dtype=q.dtype, device=q.device)
direct_register_custom_op(
op_name="flash_attn_maxseqlen_wrapper",
op_func=flash_attn_maxseqlen_wrapper,
fake_impl=flash_attn_maxseqlen_wrapper_fake,
)
def vit_flash_attn_wrapper(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
cu_seqlens: torch.Tensor,
max_seqlen: torch.Tensor,
batch_size: int,
is_rocm_aiter: bool,
use_upstream_fa: bool,
) -> torch.Tensor:
return torch.ops.vllm.flash_attn_maxseqlen_wrapper(
q, k, v, cu_seqlens, max_seqlen, batch_size, is_rocm_aiter, use_upstream_fa
)

View File

@@ -18,7 +18,12 @@ from torch._dynamo.symbolic_convert import InliningInstructionTranslator
import vllm.envs as envs import vllm.envs as envs
from vllm.compilation.counter import compilation_counter from vllm.compilation.counter import compilation_counter
from vllm.compilation.wrapper import TorchCompileWrapperWithCustomDispatcher from vllm.compilation.wrapper import TorchCompileWrapperWithCustomDispatcher
from vllm.config import CompilationMode, VllmConfig, set_current_vllm_config from vllm.config import (
CompilationMode,
VllmConfig,
get_current_vllm_config,
set_current_vllm_config,
)
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.utils.import_utils import resolve_obj_by_qualname from vllm.utils.import_utils import resolve_obj_by_qualname
@@ -74,6 +79,21 @@ def support_torch_compile(
) -> Callable[[_T], _T]: ... ) -> Callable[[_T], _T]: ...
@overload
def support_torch_compile(
*,
mark_unbacked_dims: dict[str, int | list[int]] | None,
) -> Callable[[_T], _T]: ...
@overload
def support_torch_compile(
*,
dynamic_arg_dims: dict[str, int | list[int]] | None,
mark_unbacked_dims: dict[str, int | list[int]] | None,
) -> Callable[[_T], _T]: ...
@overload @overload
def support_torch_compile(cls: _T) -> _T: ... def support_torch_compile(cls: _T) -> _T: ...
@@ -82,6 +102,7 @@ def support_torch_compile(
cls: _T | None = None, cls: _T | None = None,
*, *,
dynamic_arg_dims: dict[str, int | list[int]] | None = None, dynamic_arg_dims: dict[str, int | list[int]] | None = None,
mark_unbacked_dims: dict[str, int | list[int]] | None = None,
enable_if: Callable[[VllmConfig], bool] | None = None, enable_if: Callable[[VllmConfig], bool] | None = None,
) -> Callable[[_T], _T] | _T: ) -> Callable[[_T], _T] | _T:
""" """
@@ -135,6 +156,11 @@ def support_torch_compile(
returns a boolean value indicating whether to compile the model or not. returns a boolean value indicating whether to compile the model or not.
This is useful if you want to compile the model only when certain This is useful if you want to compile the model only when certain
conditions are met. conditions are met.
`mark_unbacked_dims` is a dictionary that maps argument names with a dynamic
dim to be decorated with `mark_unbacked`. This is useful if we would like to
enforce that dynamo do not specialize on 0/1 values in the case of dummy input
such as for vision model compilation
""" """
def cls_decorator_helper(cls: _T) -> _T: def cls_decorator_helper(cls: _T) -> _T:
@@ -172,7 +198,9 @@ def support_torch_compile(
raise ValueError( raise ValueError(
f"Argument {k} not found in the forward method of {cls}" f"Argument {k} not found in the forward method of {cls}"
) )
return _support_torch_compile(cls, inferred_dynamic_arg_dims, enable_if) return _support_torch_compile(
cls, inferred_dynamic_arg_dims, mark_unbacked_dims, enable_if
)
if cls is not None: if cls is not None:
# use `support_torch_compile` as a decorator without arguments # use `support_torch_compile` as a decorator without arguments
@@ -212,6 +240,7 @@ def _verify_source_unchanged(source_info, vllm_config) -> None:
def _support_torch_compile( def _support_torch_compile(
cls: _T, cls: _T,
dynamic_arg_dims: dict[str, int | list[int]], dynamic_arg_dims: dict[str, int | list[int]],
mark_unbacked_dims: dict[str, int | list[int]] | None = None,
enable_if: Callable[[VllmConfig], bool] | None = None, enable_if: Callable[[VllmConfig], bool] | None = None,
) -> _T: ) -> _T:
""" """
@@ -230,8 +259,22 @@ def _support_torch_compile(
setattr(cls, IGNORE_COMPILE_KEY, False) setattr(cls, IGNORE_COMPILE_KEY, False)
def __init__(self, *, vllm_config: VllmConfig, prefix: str = "", **kwargs): def __init__(
old_init(self, vllm_config=vllm_config, prefix=prefix, **kwargs) self, *, vllm_config: VllmConfig | None = None, prefix: str = "", **kwargs
):
if vllm_config is None:
vllm_config = get_current_vllm_config()
# NOTE: to support multimodal models (such as encoder),
# we may not have vllm_config so we may need to patch
# it
sig = inspect.signature(old_init)
if "vllm_config" in sig.parameters:
kwargs["vllm_config"] = vllm_config
if "prefix" in sig.parameters:
kwargs["prefix"] = prefix
old_init(self, **kwargs)
self.vllm_config = vllm_config self.vllm_config = vllm_config
enable_compile = enable_if is None or enable_if(vllm_config) enable_compile = enable_if is None or enable_if(vllm_config)
# for CompilationMode.STOCK_TORCH_COMPILE , the upper level model runner # for CompilationMode.STOCK_TORCH_COMPILE , the upper level model runner
@@ -344,6 +387,15 @@ def _support_torch_compile(
"Unsupported dynamic dimensions" "Unsupported dynamic dimensions"
f" {dims} for argument {k} with type {type(arg)}." f" {dims} for argument {k} with type {type(arg)}."
) )
if mark_unbacked_dims:
for k, dims in mark_unbacked_dims.items():
arg = bound_args.arguments.get(k)
if arg is not None:
dims = [dims] if isinstance(dims, int) else dims
if isinstance(arg, torch.Tensor):
# In case dims is specified with negative indexing
dims = [arg.ndim + dim if dim < 0 else dim for dim in dims]
torch._dynamo.decorators.mark_unbacked(arg, dims)
# here, it is the starting point of the `torch.compile` process # here, it is the starting point of the `torch.compile` process
start_monitoring_torch_compile(self.vllm_config) start_monitoring_torch_compile(self.vllm_config)
logger.debug("Start compiling function %s", self.original_code_object) logger.debug("Start compiling function %s", self.original_code_object)

View File

@@ -684,6 +684,8 @@ class CompilationConfig:
from vllm.compilation.backends import VllmBackend from vllm.compilation.backends import VllmBackend
# TODO[@lucaskabela]: See if we can forward prefix
# https://github.com/vllm-project/vllm/issues/27045
return VllmBackend(vllm_config) return VllmBackend(vllm_config)
def post_init_cudagraph_sizes(self) -> None: def post_init_cudagraph_sizes(self) -> None:

View File

@@ -45,6 +45,7 @@ from transformers.models.whisper import WhisperFeatureExtractor
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.config.multimodal import BaseDummyOptions from vllm.config.multimodal import BaseDummyOptions
from vllm.forward_context import set_forward_context
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.models.module_mapping import MultiModelKeys from vllm.model_executor.models.module_mapping import MultiModelKeys
from vllm.model_executor.models.qwen2_5_vl import ( from vllm.model_executor.models.qwen2_5_vl import (
@@ -759,7 +760,8 @@ class Qwen2_5OmniConditionalGenerationMixin:
assert grid_thw.ndim == 2 assert grid_thw.ndim == 2
pixel_values = image_input["pixel_values"].type(self.visual.dtype) pixel_values = image_input["pixel_values"].type(self.visual.dtype)
image_embeds = self.visual(pixel_values, grid_thw=grid_thw) with set_forward_context(None, self.vllm_config):
image_embeds = self.visual(pixel_values, grid_thw=grid_thw)
# Split concatenated embeddings for each image item. # Split concatenated embeddings for each image item.
merge_size = self.visual.spatial_merge_size merge_size = self.visual.spatial_merge_size
sizes = grid_thw.prod(-1) // merge_size // merge_size sizes = grid_thw.prod(-1) // merge_size // merge_size
@@ -779,7 +781,8 @@ class Qwen2_5OmniConditionalGenerationMixin:
assert grid_thw.ndim == 2 assert grid_thw.ndim == 2
pixel_values_videos = video_input["pixel_values_videos"].type(self.visual.dtype) pixel_values_videos = video_input["pixel_values_videos"].type(self.visual.dtype)
video_embeds = self.visual(pixel_values_videos, grid_thw=grid_thw) with set_forward_context(None, self.vllm_config):
video_embeds = self.visual(pixel_values_videos, grid_thw=grid_thw)
# Split concatenated embeddings for each video item. # Split concatenated embeddings for each video item.
merge_size = self.visual.spatial_merge_size merge_size = self.visual.spatial_merge_size
sizes = grid_thw.prod(-1) // merge_size // merge_size sizes = grid_thw.prod(-1) // merge_size // merge_size
@@ -839,6 +842,7 @@ class Qwen2_5OmniThinkerForConditionalGeneration(
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__() super().__init__()
self.vllm_config = vllm_config
thinker_config: Qwen2_5OmniThinkerConfig = ( thinker_config: Qwen2_5OmniThinkerConfig = (
vllm_config.model_config.hf_config.thinker_config vllm_config.model_config.hf_config.thinker_config
) )

View File

@@ -31,10 +31,10 @@ from collections.abc import Callable, Iterable, Mapping, Sequence
from functools import lru_cache, partial from functools import lru_cache, partial
from typing import Annotated, Any, Literal, TypeAlias from typing import Annotated, Any, Literal, TypeAlias
import einops
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from einops import rearrange
from transformers import BatchFeature, PretrainedConfig from transformers import BatchFeature, PretrainedConfig
from transformers.models.qwen2_5_vl import Qwen2_5_VLProcessor from transformers.models.qwen2_5_vl import Qwen2_5_VLProcessor
from transformers.models.qwen2_5_vl.configuration_qwen2_5_vl import ( from transformers.models.qwen2_5_vl.configuration_qwen2_5_vl import (
@@ -47,9 +47,15 @@ from vllm.attention.layer import (
check_upstream_fa_availability, check_upstream_fa_availability,
maybe_get_vit_flash_attn_backend, maybe_get_vit_flash_attn_backend,
) )
from vllm.attention.ops.vit_attn_wrappers import (
vit_flash_attn_wrapper,
vit_xformers_attn_wrapper,
)
from vllm.compilation.decorators import support_torch_compile
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.distributed import parallel_state from vllm.distributed import parallel_state
from vllm.distributed import utils as dist_utils from vllm.distributed import utils as dist_utils
from vllm.forward_context import set_forward_context
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.activation import get_act_and_mul_fn from vllm.model_executor.layers.activation import get_act_and_mul_fn
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
@@ -392,8 +398,8 @@ class Qwen2_5_VisionAttention(nn.Module):
x: torch.Tensor, x: torch.Tensor,
cu_seqlens: torch.Tensor, cu_seqlens: torch.Tensor,
rotary_pos_emb: torch.Tensor, rotary_pos_emb: torch.Tensor,
max_seqlen: int | None = None, # Only used for Flash Attention max_seqlen: torch.Tensor, # Only used for Flash Attention
seqlens: list[int] | None = None, # Only used for xFormers seqlens: torch.Tensor, # Only used for xFormers
) -> torch.Tensor: ) -> torch.Tensor:
# [s, b, c] --> [s, b, head * 3 * head_dim] # [s, b, c] --> [s, b, head * 3 * head_dim]
x, _ = self.qkv(x) x, _ = self.qkv(x)
@@ -402,7 +408,7 @@ class Qwen2_5_VisionAttention(nn.Module):
q, k, v = self.split_qkv(x) q, k, v = self.split_qkv(x)
batch_size = q.shape[1] batch_size = q.shape[1]
q, k, v = (rearrange(x, "s b ... -> b s ...") for x in (q, k, v)) q, k, v = (einops.rearrange(x, "s b ... -> b s ...") for x in (q, k, v))
if rotary_pos_emb is not None: if rotary_pos_emb is not None:
# [2 * b, s, heads, head_dim] # [2 * b, s, heads, head_dim]
qk_concat = torch.cat([q, k], dim=0) qk_concat = torch.cat([q, k], dim=0)
@@ -410,31 +416,18 @@ class Qwen2_5_VisionAttention(nn.Module):
q, k = torch.chunk(qk_rotated, 2, dim=0) q, k = torch.chunk(qk_rotated, 2, dim=0)
if self.is_flash_attn_backend: if self.is_flash_attn_backend:
q, k, v = (rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v]) context_layer = vit_flash_attn_wrapper(
output = self.flash_attn_varlen_func(
q, q,
k, k,
v, v,
cu_seqlens_q=cu_seqlens, cu_seqlens,
cu_seqlens_k=cu_seqlens, max_seqlen,
max_seqlen_q=max_seqlen, batch_size,
max_seqlen_k=max_seqlen, self.attn_backend == _Backend.ROCM_AITER_FA,
dropout_p=0.0, self.use_upstream_fa,
causal=False,
) )
context_layer = rearrange(
output, "(b s) h d -> s b (h d)", b=batch_size
).contiguous()
elif self.attn_backend == _Backend.TORCH_SDPA: elif self.attn_backend == _Backend.TORCH_SDPA:
# Execute attention entry by entry for speed & less VRAM. # Execute attention entry by entry for speed & less VRAM.
from vllm.platforms import current_platform
if current_platform.is_rocm():
q = q.contiguous()
k = k.contiguous()
v = v.contiguous()
outputs = [] outputs = []
for i in range(1, len(cu_seqlens)): for i in range(1, len(cu_seqlens)):
start_idx = cu_seqlens[i - 1] start_idx = cu_seqlens[i - 1]
@@ -443,34 +436,31 @@ class Qwen2_5_VisionAttention(nn.Module):
k_i = k[:, start_idx:end_idx] k_i = k[:, start_idx:end_idx]
v_i = v[:, start_idx:end_idx] v_i = v[:, start_idx:end_idx]
q_i, k_i, v_i = ( q_i, k_i, v_i = (
rearrange(x, "b s h d -> b h s d") for x in [q_i, k_i, v_i] einops.rearrange(x, "b s h d -> b h s d") for x in [q_i, k_i, v_i]
) )
output_i = F.scaled_dot_product_attention(q_i, k_i, v_i, dropout_p=0.0) output_i = F.scaled_dot_product_attention(q_i, k_i, v_i, dropout_p=0.0)
output_i = rearrange(output_i, "b h s d -> b s h d ") output_i = einops.rearrange(output_i, "b h s d -> b s h d ")
outputs.append(output_i) outputs.append(output_i)
context_layer = torch.cat(outputs, dim=1) context_layer = torch.cat(outputs, dim=1)
context_layer = rearrange( context_layer = einops.rearrange(
context_layer, "b s h d -> s b (h d)" context_layer, "b s h d -> s b (h d)"
).contiguous() ).contiguous()
elif self.attn_backend == _Backend.XFORMERS: elif self.attn_backend == _Backend.XFORMERS:
from xformers import ops as xops context_layer = vit_xformers_attn_wrapper(q, k, v, seqlens)
from xformers.ops.fmha.attn_bias import BlockDiagonalMask
attn_bias = BlockDiagonalMask.from_seqlens(
q_seqlen=seqlens, kv_seqlen=None, device=q.device
)
context_layer = xops.memory_efficient_attention_forward(
q, k, v, attn_bias=attn_bias, p=0, scale=None
)
context_layer = rearrange(
context_layer, "b s h d -> s b (h d)"
).contiguous()
output, _ = self.proj(context_layer) output, _ = self.proj(context_layer)
return output return output
@support_torch_compile(
dynamic_arg_dims={
"x": 0,
"cu_seqlens": 0,
"rotary_pos_emb": 0,
"seqlens": 0,
},
mark_unbacked_dims={"seqlens": 0},
)
class Qwen2_5_VisionBlock(nn.Module): class Qwen2_5_VisionBlock(nn.Module):
def __init__( def __init__(
self, self,
@@ -515,8 +505,8 @@ class Qwen2_5_VisionBlock(nn.Module):
x: torch.Tensor, x: torch.Tensor,
cu_seqlens: torch.Tensor, cu_seqlens: torch.Tensor,
rotary_pos_emb: torch.Tensor, rotary_pos_emb: torch.Tensor,
max_seqlen: int | None = None, # Only used for Flash Attention max_seqlen: torch.Tensor, # Only used for Flash Attention
seqlens: list[int] | None = None, # Only used for xFormers seqlens: torch.Tensor, # Only used for xFormers
) -> torch.Tensor: ) -> torch.Tensor:
x_attn = self.attn( x_attn = self.attn(
self.norm1(x), self.norm1(x),
@@ -530,6 +520,11 @@ class Qwen2_5_VisionBlock(nn.Module):
return x return x
@support_torch_compile(
dynamic_arg_dims={
"x": 0,
}
)
class Qwen2_5_VisionPatchEmbed(nn.Module): class Qwen2_5_VisionPatchEmbed(nn.Module):
def __init__( def __init__(
self, self,
@@ -556,6 +551,11 @@ class Qwen2_5_VisionPatchEmbed(nn.Module):
return x return x
@support_torch_compile(
dynamic_arg_dims={
"x": 0,
}
)
class Qwen2_5_VisionPatchMerger(nn.Module): class Qwen2_5_VisionPatchMerger(nn.Module):
def __init__( def __init__(
self, self,
@@ -665,13 +665,18 @@ class Qwen2_5_VisionTransformer(nn.Module):
self.spatial_merge_size = vision_config.spatial_merge_size self.spatial_merge_size = vision_config.spatial_merge_size
self.fullatt_block_indexes = vision_config.fullatt_block_indexes self.fullatt_block_indexes = vision_config.fullatt_block_indexes
self.spatial_merge_unit = self.spatial_merge_size**2 self.spatial_merge_unit = self.spatial_merge_size**2
# TODO[@lucaskabela]: Investigate fixing this usage
# see https://github.com/vllm-project/vllm/issues/27044
# DO NOT MOVE THIS IMPORT
from vllm.compilation.backends import set_model_tag
self.patch_embed = Qwen2_5_VisionPatchEmbed( with set_model_tag("Qwen2_5_VisionPatchEmbed"):
patch_size=patch_size, self.patch_embed = Qwen2_5_VisionPatchEmbed(
temporal_patch_size=temporal_patch_size, patch_size=patch_size,
in_channels=in_channels, temporal_patch_size=temporal_patch_size,
hidden_size=self.hidden_size, in_channels=in_channels,
) hidden_size=self.hidden_size,
)
norm_layer = partial(RMSNorm, eps=norm_eps) norm_layer = partial(RMSNorm, eps=norm_eps)
head_dim = self.hidden_size // self.num_heads head_dim = self.hidden_size // self.num_heads
@@ -701,32 +706,35 @@ class Qwen2_5_VisionTransformer(nn.Module):
f"Qwen2.5-VL does not support {self.attn_backend} backend now." f"Qwen2.5-VL does not support {self.attn_backend} backend now."
) )
self.blocks = nn.ModuleList( with set_model_tag("Qwen2_5_VisionBlock"):
[ self.blocks = nn.ModuleList(
Qwen2_5_VisionBlock( [
dim=self.hidden_size, Qwen2_5_VisionBlock(
num_heads=self.num_heads, dim=self.hidden_size,
mlp_hidden_dim=vision_config.intermediate_size, num_heads=self.num_heads,
act_fn=get_act_and_mul_fn(vision_config.hidden_act), mlp_hidden_dim=vision_config.intermediate_size,
norm_layer=norm_layer, act_fn=get_act_and_mul_fn(vision_config.hidden_act),
quant_config=quant_config, norm_layer=norm_layer,
prefix=f"{prefix}.blocks.{layer_idx}", quant_config=quant_config,
use_data_parallel=use_data_parallel, prefix=f"{prefix}.blocks.{layer_idx}",
attn_backend=self.attn_backend, use_data_parallel=use_data_parallel,
use_upstream_fa=use_upstream_fa, attn_backend=self.attn_backend,
) use_upstream_fa=use_upstream_fa,
for layer_idx in range(depth) )
] for layer_idx in range(depth)
) ]
self.merger = Qwen2_5_VisionPatchMerger( )
d_model=vision_config.out_hidden_size,
context_dim=self.hidden_size, with set_model_tag("Qwen2_5_VisionPatchMerger"):
norm_layer=norm_layer, self.merger = Qwen2_5_VisionPatchMerger(
spatial_merge_size=self.spatial_merge_size, d_model=vision_config.out_hidden_size,
quant_config=quant_config, context_dim=self.hidden_size,
prefix=f"{prefix}.merger", norm_layer=norm_layer,
use_data_parallel=use_data_parallel, spatial_merge_size=self.spatial_merge_size,
) quant_config=quant_config,
prefix=f"{prefix}.merger",
use_data_parallel=use_data_parallel,
)
@property @property
def dtype(self) -> torch.dtype: def dtype(self) -> torch.dtype:
@@ -827,15 +835,18 @@ class Qwen2_5_VisionTransformer(nn.Module):
def compute_attn_mask_seqlen( def compute_attn_mask_seqlen(
self, self,
cu_seqlens: torch.Tensor, cu_seqlens: torch.Tensor,
) -> tuple[int | None, list[int] | None]: ) -> tuple[torch.Tensor, torch.Tensor]:
max_seqlen, seqlens = None, None max_seqlen, seqlens = (
torch.zeros(1, device=cu_seqlens.device),
torch.zeros(1, device=cu_seqlens.device),
)
if ( if (
self.attn_backend == _Backend.FLASH_ATTN self.attn_backend == _Backend.FLASH_ATTN
or self.attn_backend == _Backend.ROCM_AITER_FA or self.attn_backend == _Backend.ROCM_AITER_FA
): ):
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max()
elif self.attn_backend == _Backend.XFORMERS: elif self.attn_backend == _Backend.XFORMERS:
seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist() seqlens = cu_seqlens[1:] - cu_seqlens[:-1]
return max_seqlen, seqlens return max_seqlen, seqlens
@staticmethod @staticmethod
@@ -1233,6 +1244,7 @@ class Qwen2_5_VLForConditionalGeneration(
self.use_data_parallel = multimodal_config.mm_encoder_tp_mode == "data" self.use_data_parallel = multimodal_config.mm_encoder_tp_mode == "data"
self.config = config self.config = config
self.vllm_config = vllm_config
self.multimodal_config = multimodal_config self.multimodal_config = multimodal_config
self.video_pruning_rate = multimodal_config.video_pruning_rate self.video_pruning_rate = multimodal_config.video_pruning_rate
self.is_multimodal_pruning_enabled = ( self.is_multimodal_pruning_enabled = (
@@ -1248,7 +1260,7 @@ class Qwen2_5_VLForConditionalGeneration(
else None else None
) )
self.visual = Qwen2_5_VisionTransformer( self.visual = Qwen2_5_VisionTransformer(
config.vision_config, vision_config=config.vision_config,
norm_eps=getattr(config, "rms_norm_eps", 1e-6), norm_eps=getattr(config, "rms_norm_eps", 1e-6),
quant_config=self.quant_config, quant_config=self.quant_config,
prefix=maybe_prefix(prefix, "visual"), prefix=maybe_prefix(prefix, "visual"),
@@ -1336,13 +1348,13 @@ class Qwen2_5_VLForConditionalGeneration(
image_embeds = image_input["image_embeds"].type(self.visual.dtype) image_embeds = image_input["image_embeds"].type(self.visual.dtype)
else: else:
pixel_values = image_input["pixel_values"] pixel_values = image_input["pixel_values"]
with set_forward_context(None, self.vllm_config):
if self.use_data_parallel: if self.use_data_parallel:
return run_dp_sharded_mrope_vision_model( return run_dp_sharded_mrope_vision_model(
self.visual, pixel_values, grid_thw_list, rope_type="rope_3d" self.visual, pixel_values, grid_thw_list, rope_type="rope_3d"
) )
else: else:
image_embeds = self.visual(pixel_values, grid_thw=grid_thw_list) image_embeds = self.visual(pixel_values, grid_thw=grid_thw_list)
# Split concatenated embeddings for each image item. # Split concatenated embeddings for each image item.
# Using prod on grid_thw_list instead of grid_thw.prod avoids CUDA sync # Using prod on grid_thw_list instead of grid_thw.prod avoids CUDA sync
@@ -1396,12 +1408,18 @@ class Qwen2_5_VLForConditionalGeneration(
video_embeds = video_input["video_embeds"].type(self.visual.dtype) video_embeds = video_input["video_embeds"].type(self.visual.dtype)
else: else:
pixel_values_videos = video_input["pixel_values_videos"] pixel_values_videos = video_input["pixel_values_videos"]
if self.use_data_parallel: with set_forward_context(None, self.vllm_config):
return run_dp_sharded_mrope_vision_model( if self.use_data_parallel:
self.visual, pixel_values_videos, grid_thw_list, rope_type="rope_3d" return run_dp_sharded_mrope_vision_model(
) self.visual,
else: pixel_values_videos,
video_embeds = self.visual(pixel_values_videos, grid_thw=grid_thw_list) grid_thw_list,
rope_type="rope_3d",
)
else:
video_embeds = self.visual(
pixel_values_videos, grid_thw=grid_thw_list
)
# Split concatenated embeddings for each video item. # Split concatenated embeddings for each video item.
merge_size = self.visual.spatial_merge_size merge_size = self.visual.spatial_merge_size