diff --git a/tests/compile/test_multimodal_compile.py b/tests/compile/test_multimodal_compile.py new file mode 100644 index 000000000..6c195dd93 --- /dev/null +++ b/tests/compile/test_multimodal_compile.py @@ -0,0 +1,36 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import pytest + +from vllm.compilation.counter import compilation_counter +from vllm.config.compilation import CompilationMode + + +# forked needed to workaround https://github.com/vllm-project/vllm/issues/21073 +@pytest.mark.forked +def test_qwen2_5_vl_compilation(vllm_runner, monkeypatch): + """Test that Qwen2.5-VL vision submodules are compiled. + + This test verifies that the 3 vision submodules (Qwen2_5_VisionPatchEmbed, + Qwen2_5_VisionBlock, and Qwen2_5_VisionPatchMerger) are properly tagged + for compilation by checking that num_models_seen increases by at least 3. + """ + # Disable multiprocessing so that the counter is in the same process + monkeypatch.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", "0") + + with ( + # NOTE: Qwen2.5-VL has 35 models in total - the LLM backend + # Vision Patch Embed, Vision Patch Merger, and then 32 Vision Blocks + # (one for each layer) - in the future, we should fix vLLM compilation + # logic to handle this case and only compile the Vision submodules once + # and reuse the compiled code for all layers + # See https://github.com/vllm-project/vllm/issues/27590 + compilation_counter.expect(num_models_seen=35), + vllm_runner( + "Qwen/Qwen2.5-VL-3B-Instruct", + max_model_len=2048, + gpu_memory_utilization=0.7, + compilation_config={"mode": CompilationMode.VLLM_COMPILE}, + ) as _, + ): + pass diff --git a/vllm/attention/ops/vit_attn_wrappers.py b/vllm/attention/ops/vit_attn_wrappers.py new file mode 100644 index 000000000..f71f49a1a --- /dev/null +++ b/vllm/attention/ops/vit_attn_wrappers.py @@ -0,0 +1,125 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +This file contains ops for ViT attention to be compatible with torch.compile +as there are operations here not supported by torch.compile (for instance, +`to_list` in xformers attn, or `.item()` in flash attention) + +Using these ops and wrapping vision blocks with `torch.compile` can speed up +throughput in vision models by ~5% relative on H100, and improve token +latencies by ~7% (see qwen2_5_vl for example usage) + +To use these ops, you must have a recent version of PyTorch installed (>= 2.4.0) +""" + +import einops +import torch + +from vllm.utils.torch_utils import direct_register_custom_op + + +def xformers_attn_seqlens_wrapper( + q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, seqlens: torch.Tensor +) -> torch.Tensor: + from xformers import ops as xops + from xformers.ops.fmha.attn_bias import BlockDiagonalMask + + attn_bias = BlockDiagonalMask.from_seqlens( + q_seqlen=seqlens.tolist(), kv_seqlen=None, device=q.device + ) + context_layer = xops.memory_efficient_attention_forward( + q, k, v, attn_bias=attn_bias, p=0, scale=None + ) + context_layer = einops.rearrange(context_layer, "b s h d -> s b (h d)").contiguous() + return context_layer + + +def xformers_attn_seqlens_wrapper_fake( + q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, seqlens: torch.Tensor +) -> torch.Tensor: + b, s, h, d = q.shape + return torch.empty((s, b, h * d), dtype=q.dtype, device=q.device) + + +direct_register_custom_op( + op_name="xformers_attn_seqlens_wrapper", + op_func=xformers_attn_seqlens_wrapper, + fake_impl=xformers_attn_seqlens_wrapper_fake, +) + + +def vit_xformers_attn_wrapper( + q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, seqlens: torch.Tensor +) -> torch.Tensor: + return torch.ops.vllm.xformers_attn_seqlens_wrapper(q, k, v, seqlens) + + +def flash_attn_maxseqlen_wrapper( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + cu_seqlens: torch.Tensor, + max_seqlen: torch.Tensor, + batch_size: int, + is_rocm_aiter: bool, + use_upstream_fa: bool, +) -> torch.Tensor: + if is_rocm_aiter: + from aiter import flash_attn_varlen_func + else: + if use_upstream_fa: + from flash_attn import flash_attn_varlen_func + else: + from vllm.vllm_flash_attn import flash_attn_varlen_func + q, k, v = (einops.rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v]) + output = flash_attn_varlen_func( + q, + k, + v, + cu_seqlens_q=cu_seqlens, + cu_seqlens_k=cu_seqlens, + max_seqlen_q=max_seqlen.item(), + max_seqlen_k=max_seqlen.item(), + dropout_p=0.0, + causal=False, + ) + context_layer = einops.rearrange( + output, "(b s) h d -> s b (h d)", b=batch_size + ).contiguous() + return context_layer + + +def flash_attn_maxseqlen_wrapper_fake( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + cu_seqlens: torch.Tensor, + max_seqlen: torch.Tensor, + batch_size: int, + is_rocm_aiter: bool, + use_upstream_fa: bool, +) -> torch.Tensor: + b, s, h, d = q.shape + return torch.empty((s, b, h * d), dtype=q.dtype, device=q.device) + + +direct_register_custom_op( + op_name="flash_attn_maxseqlen_wrapper", + op_func=flash_attn_maxseqlen_wrapper, + fake_impl=flash_attn_maxseqlen_wrapper_fake, +) + + +def vit_flash_attn_wrapper( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + cu_seqlens: torch.Tensor, + max_seqlen: torch.Tensor, + batch_size: int, + is_rocm_aiter: bool, + use_upstream_fa: bool, +) -> torch.Tensor: + return torch.ops.vllm.flash_attn_maxseqlen_wrapper( + q, k, v, cu_seqlens, max_seqlen, batch_size, is_rocm_aiter, use_upstream_fa + ) diff --git a/vllm/compilation/decorators.py b/vllm/compilation/decorators.py index 69fb93601..0946fa691 100644 --- a/vllm/compilation/decorators.py +++ b/vllm/compilation/decorators.py @@ -18,7 +18,12 @@ from torch._dynamo.symbolic_convert import InliningInstructionTranslator import vllm.envs as envs from vllm.compilation.counter import compilation_counter from vllm.compilation.wrapper import TorchCompileWrapperWithCustomDispatcher -from vllm.config import CompilationMode, VllmConfig, set_current_vllm_config +from vllm.config import ( + CompilationMode, + VllmConfig, + get_current_vllm_config, + set_current_vllm_config, +) from vllm.logger import init_logger from vllm.sequence import IntermediateTensors from vllm.utils.import_utils import resolve_obj_by_qualname @@ -74,6 +79,21 @@ def support_torch_compile( ) -> Callable[[_T], _T]: ... +@overload +def support_torch_compile( + *, + mark_unbacked_dims: dict[str, int | list[int]] | None, +) -> Callable[[_T], _T]: ... + + +@overload +def support_torch_compile( + *, + dynamic_arg_dims: dict[str, int | list[int]] | None, + mark_unbacked_dims: dict[str, int | list[int]] | None, +) -> Callable[[_T], _T]: ... + + @overload def support_torch_compile(cls: _T) -> _T: ... @@ -82,6 +102,7 @@ def support_torch_compile( cls: _T | None = None, *, dynamic_arg_dims: dict[str, int | list[int]] | None = None, + mark_unbacked_dims: dict[str, int | list[int]] | None = None, enable_if: Callable[[VllmConfig], bool] | None = None, ) -> Callable[[_T], _T] | _T: """ @@ -135,6 +156,11 @@ def support_torch_compile( returns a boolean value indicating whether to compile the model or not. This is useful if you want to compile the model only when certain conditions are met. + + `mark_unbacked_dims` is a dictionary that maps argument names with a dynamic + dim to be decorated with `mark_unbacked`. This is useful if we would like to + enforce that dynamo do not specialize on 0/1 values in the case of dummy input + such as for vision model compilation """ def cls_decorator_helper(cls: _T) -> _T: @@ -172,7 +198,9 @@ def support_torch_compile( raise ValueError( f"Argument {k} not found in the forward method of {cls}" ) - return _support_torch_compile(cls, inferred_dynamic_arg_dims, enable_if) + return _support_torch_compile( + cls, inferred_dynamic_arg_dims, mark_unbacked_dims, enable_if + ) if cls is not None: # use `support_torch_compile` as a decorator without arguments @@ -212,6 +240,7 @@ def _verify_source_unchanged(source_info, vllm_config) -> None: def _support_torch_compile( cls: _T, dynamic_arg_dims: dict[str, int | list[int]], + mark_unbacked_dims: dict[str, int | list[int]] | None = None, enable_if: Callable[[VllmConfig], bool] | None = None, ) -> _T: """ @@ -230,8 +259,22 @@ def _support_torch_compile( setattr(cls, IGNORE_COMPILE_KEY, False) - def __init__(self, *, vllm_config: VllmConfig, prefix: str = "", **kwargs): - old_init(self, vllm_config=vllm_config, prefix=prefix, **kwargs) + def __init__( + self, *, vllm_config: VllmConfig | None = None, prefix: str = "", **kwargs + ): + if vllm_config is None: + vllm_config = get_current_vllm_config() + + # NOTE: to support multimodal models (such as encoder), + # we may not have vllm_config so we may need to patch + # it + sig = inspect.signature(old_init) + if "vllm_config" in sig.parameters: + kwargs["vllm_config"] = vllm_config + if "prefix" in sig.parameters: + kwargs["prefix"] = prefix + old_init(self, **kwargs) + self.vllm_config = vllm_config enable_compile = enable_if is None or enable_if(vllm_config) # for CompilationMode.STOCK_TORCH_COMPILE , the upper level model runner @@ -344,6 +387,15 @@ def _support_torch_compile( "Unsupported dynamic dimensions" f" {dims} for argument {k} with type {type(arg)}." ) + if mark_unbacked_dims: + for k, dims in mark_unbacked_dims.items(): + arg = bound_args.arguments.get(k) + if arg is not None: + dims = [dims] if isinstance(dims, int) else dims + if isinstance(arg, torch.Tensor): + # In case dims is specified with negative indexing + dims = [arg.ndim + dim if dim < 0 else dim for dim in dims] + torch._dynamo.decorators.mark_unbacked(arg, dims) # here, it is the starting point of the `torch.compile` process start_monitoring_torch_compile(self.vllm_config) logger.debug("Start compiling function %s", self.original_code_object) diff --git a/vllm/config/compilation.py b/vllm/config/compilation.py index c24a94091..f3ed78779 100644 --- a/vllm/config/compilation.py +++ b/vllm/config/compilation.py @@ -684,6 +684,8 @@ class CompilationConfig: from vllm.compilation.backends import VllmBackend + # TODO[@lucaskabela]: See if we can forward prefix + # https://github.com/vllm-project/vllm/issues/27045 return VllmBackend(vllm_config) def post_init_cudagraph_sizes(self) -> None: diff --git a/vllm/model_executor/models/qwen2_5_omni_thinker.py b/vllm/model_executor/models/qwen2_5_omni_thinker.py index 6338ea93b..677d34dea 100644 --- a/vllm/model_executor/models/qwen2_5_omni_thinker.py +++ b/vllm/model_executor/models/qwen2_5_omni_thinker.py @@ -45,6 +45,7 @@ from transformers.models.whisper import WhisperFeatureExtractor from vllm.config import VllmConfig from vllm.config.multimodal import BaseDummyOptions +from vllm.forward_context import set_forward_context from vllm.logger import init_logger from vllm.model_executor.models.module_mapping import MultiModelKeys from vllm.model_executor.models.qwen2_5_vl import ( @@ -759,7 +760,8 @@ class Qwen2_5OmniConditionalGenerationMixin: assert grid_thw.ndim == 2 pixel_values = image_input["pixel_values"].type(self.visual.dtype) - image_embeds = self.visual(pixel_values, grid_thw=grid_thw) + with set_forward_context(None, self.vllm_config): + image_embeds = self.visual(pixel_values, grid_thw=grid_thw) # Split concatenated embeddings for each image item. merge_size = self.visual.spatial_merge_size sizes = grid_thw.prod(-1) // merge_size // merge_size @@ -779,7 +781,8 @@ class Qwen2_5OmniConditionalGenerationMixin: assert grid_thw.ndim == 2 pixel_values_videos = video_input["pixel_values_videos"].type(self.visual.dtype) - video_embeds = self.visual(pixel_values_videos, grid_thw=grid_thw) + with set_forward_context(None, self.vllm_config): + video_embeds = self.visual(pixel_values_videos, grid_thw=grid_thw) # Split concatenated embeddings for each video item. merge_size = self.visual.spatial_merge_size sizes = grid_thw.prod(-1) // merge_size // merge_size @@ -839,6 +842,7 @@ class Qwen2_5OmniThinkerForConditionalGeneration( def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() + self.vllm_config = vllm_config thinker_config: Qwen2_5OmniThinkerConfig = ( vllm_config.model_config.hf_config.thinker_config ) diff --git a/vllm/model_executor/models/qwen2_5_vl.py b/vllm/model_executor/models/qwen2_5_vl.py index b622021e2..30e3d2dff 100644 --- a/vllm/model_executor/models/qwen2_5_vl.py +++ b/vllm/model_executor/models/qwen2_5_vl.py @@ -31,10 +31,10 @@ from collections.abc import Callable, Iterable, Mapping, Sequence from functools import lru_cache, partial from typing import Annotated, Any, Literal, TypeAlias +import einops import torch import torch.nn as nn import torch.nn.functional as F -from einops import rearrange from transformers import BatchFeature, PretrainedConfig from transformers.models.qwen2_5_vl import Qwen2_5_VLProcessor from transformers.models.qwen2_5_vl.configuration_qwen2_5_vl import ( @@ -47,9 +47,15 @@ from vllm.attention.layer import ( check_upstream_fa_availability, maybe_get_vit_flash_attn_backend, ) +from vllm.attention.ops.vit_attn_wrappers import ( + vit_flash_attn_wrapper, + vit_xformers_attn_wrapper, +) +from vllm.compilation.decorators import support_torch_compile from vllm.config import VllmConfig from vllm.distributed import parallel_state from vllm.distributed import utils as dist_utils +from vllm.forward_context import set_forward_context from vllm.logger import init_logger from vllm.model_executor.layers.activation import get_act_and_mul_fn from vllm.model_executor.layers.layernorm import RMSNorm @@ -392,8 +398,8 @@ class Qwen2_5_VisionAttention(nn.Module): x: torch.Tensor, cu_seqlens: torch.Tensor, rotary_pos_emb: torch.Tensor, - max_seqlen: int | None = None, # Only used for Flash Attention - seqlens: list[int] | None = None, # Only used for xFormers + max_seqlen: torch.Tensor, # Only used for Flash Attention + seqlens: torch.Tensor, # Only used for xFormers ) -> torch.Tensor: # [s, b, c] --> [s, b, head * 3 * head_dim] x, _ = self.qkv(x) @@ -402,7 +408,7 @@ class Qwen2_5_VisionAttention(nn.Module): q, k, v = self.split_qkv(x) batch_size = q.shape[1] - q, k, v = (rearrange(x, "s b ... -> b s ...") for x in (q, k, v)) + q, k, v = (einops.rearrange(x, "s b ... -> b s ...") for x in (q, k, v)) if rotary_pos_emb is not None: # [2 * b, s, heads, head_dim] qk_concat = torch.cat([q, k], dim=0) @@ -410,31 +416,18 @@ class Qwen2_5_VisionAttention(nn.Module): q, k = torch.chunk(qk_rotated, 2, dim=0) if self.is_flash_attn_backend: - q, k, v = (rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v]) - - output = self.flash_attn_varlen_func( + context_layer = vit_flash_attn_wrapper( q, k, v, - cu_seqlens_q=cu_seqlens, - cu_seqlens_k=cu_seqlens, - max_seqlen_q=max_seqlen, - max_seqlen_k=max_seqlen, - dropout_p=0.0, - causal=False, + cu_seqlens, + max_seqlen, + batch_size, + self.attn_backend == _Backend.ROCM_AITER_FA, + self.use_upstream_fa, ) - - context_layer = rearrange( - output, "(b s) h d -> s b (h d)", b=batch_size - ).contiguous() elif self.attn_backend == _Backend.TORCH_SDPA: # Execute attention entry by entry for speed & less VRAM. - from vllm.platforms import current_platform - - if current_platform.is_rocm(): - q = q.contiguous() - k = k.contiguous() - v = v.contiguous() outputs = [] for i in range(1, len(cu_seqlens)): start_idx = cu_seqlens[i - 1] @@ -443,34 +436,31 @@ class Qwen2_5_VisionAttention(nn.Module): k_i = k[:, start_idx:end_idx] v_i = v[:, start_idx:end_idx] q_i, k_i, v_i = ( - rearrange(x, "b s h d -> b h s d") for x in [q_i, k_i, v_i] + einops.rearrange(x, "b s h d -> b h s d") for x in [q_i, k_i, v_i] ) output_i = F.scaled_dot_product_attention(q_i, k_i, v_i, dropout_p=0.0) - output_i = rearrange(output_i, "b h s d -> b s h d ") + output_i = einops.rearrange(output_i, "b h s d -> b s h d ") outputs.append(output_i) context_layer = torch.cat(outputs, dim=1) - context_layer = rearrange( + context_layer = einops.rearrange( context_layer, "b s h d -> s b (h d)" ).contiguous() elif self.attn_backend == _Backend.XFORMERS: - from xformers import ops as xops - from xformers.ops.fmha.attn_bias import BlockDiagonalMask - - attn_bias = BlockDiagonalMask.from_seqlens( - q_seqlen=seqlens, kv_seqlen=None, device=q.device - ) - - context_layer = xops.memory_efficient_attention_forward( - q, k, v, attn_bias=attn_bias, p=0, scale=None - ) - context_layer = rearrange( - context_layer, "b s h d -> s b (h d)" - ).contiguous() + context_layer = vit_xformers_attn_wrapper(q, k, v, seqlens) output, _ = self.proj(context_layer) return output +@support_torch_compile( + dynamic_arg_dims={ + "x": 0, + "cu_seqlens": 0, + "rotary_pos_emb": 0, + "seqlens": 0, + }, + mark_unbacked_dims={"seqlens": 0}, +) class Qwen2_5_VisionBlock(nn.Module): def __init__( self, @@ -515,8 +505,8 @@ class Qwen2_5_VisionBlock(nn.Module): x: torch.Tensor, cu_seqlens: torch.Tensor, rotary_pos_emb: torch.Tensor, - max_seqlen: int | None = None, # Only used for Flash Attention - seqlens: list[int] | None = None, # Only used for xFormers + max_seqlen: torch.Tensor, # Only used for Flash Attention + seqlens: torch.Tensor, # Only used for xFormers ) -> torch.Tensor: x_attn = self.attn( self.norm1(x), @@ -530,6 +520,11 @@ class Qwen2_5_VisionBlock(nn.Module): return x +@support_torch_compile( + dynamic_arg_dims={ + "x": 0, + } +) class Qwen2_5_VisionPatchEmbed(nn.Module): def __init__( self, @@ -556,6 +551,11 @@ class Qwen2_5_VisionPatchEmbed(nn.Module): return x +@support_torch_compile( + dynamic_arg_dims={ + "x": 0, + } +) class Qwen2_5_VisionPatchMerger(nn.Module): def __init__( self, @@ -665,13 +665,18 @@ class Qwen2_5_VisionTransformer(nn.Module): self.spatial_merge_size = vision_config.spatial_merge_size self.fullatt_block_indexes = vision_config.fullatt_block_indexes self.spatial_merge_unit = self.spatial_merge_size**2 + # TODO[@lucaskabela]: Investigate fixing this usage + # see https://github.com/vllm-project/vllm/issues/27044 + # DO NOT MOVE THIS IMPORT + from vllm.compilation.backends import set_model_tag - self.patch_embed = Qwen2_5_VisionPatchEmbed( - patch_size=patch_size, - temporal_patch_size=temporal_patch_size, - in_channels=in_channels, - hidden_size=self.hidden_size, - ) + with set_model_tag("Qwen2_5_VisionPatchEmbed"): + self.patch_embed = Qwen2_5_VisionPatchEmbed( + patch_size=patch_size, + temporal_patch_size=temporal_patch_size, + in_channels=in_channels, + hidden_size=self.hidden_size, + ) norm_layer = partial(RMSNorm, eps=norm_eps) head_dim = self.hidden_size // self.num_heads @@ -701,32 +706,35 @@ class Qwen2_5_VisionTransformer(nn.Module): f"Qwen2.5-VL does not support {self.attn_backend} backend now." ) - self.blocks = nn.ModuleList( - [ - Qwen2_5_VisionBlock( - dim=self.hidden_size, - num_heads=self.num_heads, - mlp_hidden_dim=vision_config.intermediate_size, - act_fn=get_act_and_mul_fn(vision_config.hidden_act), - norm_layer=norm_layer, - quant_config=quant_config, - prefix=f"{prefix}.blocks.{layer_idx}", - use_data_parallel=use_data_parallel, - attn_backend=self.attn_backend, - use_upstream_fa=use_upstream_fa, - ) - for layer_idx in range(depth) - ] - ) - self.merger = Qwen2_5_VisionPatchMerger( - d_model=vision_config.out_hidden_size, - context_dim=self.hidden_size, - norm_layer=norm_layer, - spatial_merge_size=self.spatial_merge_size, - quant_config=quant_config, - prefix=f"{prefix}.merger", - use_data_parallel=use_data_parallel, - ) + with set_model_tag("Qwen2_5_VisionBlock"): + self.blocks = nn.ModuleList( + [ + Qwen2_5_VisionBlock( + dim=self.hidden_size, + num_heads=self.num_heads, + mlp_hidden_dim=vision_config.intermediate_size, + act_fn=get_act_and_mul_fn(vision_config.hidden_act), + norm_layer=norm_layer, + quant_config=quant_config, + prefix=f"{prefix}.blocks.{layer_idx}", + use_data_parallel=use_data_parallel, + attn_backend=self.attn_backend, + use_upstream_fa=use_upstream_fa, + ) + for layer_idx in range(depth) + ] + ) + + with set_model_tag("Qwen2_5_VisionPatchMerger"): + self.merger = Qwen2_5_VisionPatchMerger( + d_model=vision_config.out_hidden_size, + context_dim=self.hidden_size, + norm_layer=norm_layer, + spatial_merge_size=self.spatial_merge_size, + quant_config=quant_config, + prefix=f"{prefix}.merger", + use_data_parallel=use_data_parallel, + ) @property def dtype(self) -> torch.dtype: @@ -827,15 +835,18 @@ class Qwen2_5_VisionTransformer(nn.Module): def compute_attn_mask_seqlen( self, cu_seqlens: torch.Tensor, - ) -> tuple[int | None, list[int] | None]: - max_seqlen, seqlens = None, None + ) -> tuple[torch.Tensor, torch.Tensor]: + max_seqlen, seqlens = ( + torch.zeros(1, device=cu_seqlens.device), + torch.zeros(1, device=cu_seqlens.device), + ) if ( self.attn_backend == _Backend.FLASH_ATTN or self.attn_backend == _Backend.ROCM_AITER_FA ): - max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() + max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max() elif self.attn_backend == _Backend.XFORMERS: - seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist() + seqlens = cu_seqlens[1:] - cu_seqlens[:-1] return max_seqlen, seqlens @staticmethod @@ -1233,6 +1244,7 @@ class Qwen2_5_VLForConditionalGeneration( self.use_data_parallel = multimodal_config.mm_encoder_tp_mode == "data" self.config = config + self.vllm_config = vllm_config self.multimodal_config = multimodal_config self.video_pruning_rate = multimodal_config.video_pruning_rate self.is_multimodal_pruning_enabled = ( @@ -1248,7 +1260,7 @@ class Qwen2_5_VLForConditionalGeneration( else None ) self.visual = Qwen2_5_VisionTransformer( - config.vision_config, + vision_config=config.vision_config, norm_eps=getattr(config, "rms_norm_eps", 1e-6), quant_config=self.quant_config, prefix=maybe_prefix(prefix, "visual"), @@ -1336,13 +1348,13 @@ class Qwen2_5_VLForConditionalGeneration( image_embeds = image_input["image_embeds"].type(self.visual.dtype) else: pixel_values = image_input["pixel_values"] - - if self.use_data_parallel: - return run_dp_sharded_mrope_vision_model( - self.visual, pixel_values, grid_thw_list, rope_type="rope_3d" - ) - else: - image_embeds = self.visual(pixel_values, grid_thw=grid_thw_list) + with set_forward_context(None, self.vllm_config): + if self.use_data_parallel: + return run_dp_sharded_mrope_vision_model( + self.visual, pixel_values, grid_thw_list, rope_type="rope_3d" + ) + else: + image_embeds = self.visual(pixel_values, grid_thw=grid_thw_list) # Split concatenated embeddings for each image item. # Using prod on grid_thw_list instead of grid_thw.prod avoids CUDA sync @@ -1396,12 +1408,18 @@ class Qwen2_5_VLForConditionalGeneration( video_embeds = video_input["video_embeds"].type(self.visual.dtype) else: pixel_values_videos = video_input["pixel_values_videos"] - if self.use_data_parallel: - return run_dp_sharded_mrope_vision_model( - self.visual, pixel_values_videos, grid_thw_list, rope_type="rope_3d" - ) - else: - video_embeds = self.visual(pixel_values_videos, grid_thw=grid_thw_list) + with set_forward_context(None, self.vllm_config): + if self.use_data_parallel: + return run_dp_sharded_mrope_vision_model( + self.visual, + pixel_values_videos, + grid_thw_list, + rope_type="rope_3d", + ) + else: + video_embeds = self.visual( + pixel_values_videos, grid_thw=grid_thw_list + ) # Split concatenated embeddings for each video item. merge_size = self.visual.spatial_merge_size