[Misc][qwen2_5_vl][torch.compile] Enable supports_torch_compile on generic nn.Module and demonstrate speedup on Qwen Vision model (#23207)
Signed-off-by: Lucas Kabela <lucaskabela@meta.com> Signed-off-by: Lucas Kabela <lucasakabela@gmail.com>
This commit is contained in:
36
tests/compile/test_multimodal_compile.py
Normal file
36
tests/compile/test_multimodal_compile.py
Normal file
@@ -0,0 +1,36 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from vllm.compilation.counter import compilation_counter
|
||||||
|
from vllm.config.compilation import CompilationMode
|
||||||
|
|
||||||
|
|
||||||
|
# forked needed to workaround https://github.com/vllm-project/vllm/issues/21073
|
||||||
|
@pytest.mark.forked
|
||||||
|
def test_qwen2_5_vl_compilation(vllm_runner, monkeypatch):
|
||||||
|
"""Test that Qwen2.5-VL vision submodules are compiled.
|
||||||
|
|
||||||
|
This test verifies that the 3 vision submodules (Qwen2_5_VisionPatchEmbed,
|
||||||
|
Qwen2_5_VisionBlock, and Qwen2_5_VisionPatchMerger) are properly tagged
|
||||||
|
for compilation by checking that num_models_seen increases by at least 3.
|
||||||
|
"""
|
||||||
|
# Disable multiprocessing so that the counter is in the same process
|
||||||
|
monkeypatch.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", "0")
|
||||||
|
|
||||||
|
with (
|
||||||
|
# NOTE: Qwen2.5-VL has 35 models in total - the LLM backend
|
||||||
|
# Vision Patch Embed, Vision Patch Merger, and then 32 Vision Blocks
|
||||||
|
# (one for each layer) - in the future, we should fix vLLM compilation
|
||||||
|
# logic to handle this case and only compile the Vision submodules once
|
||||||
|
# and reuse the compiled code for all layers
|
||||||
|
# See https://github.com/vllm-project/vllm/issues/27590
|
||||||
|
compilation_counter.expect(num_models_seen=35),
|
||||||
|
vllm_runner(
|
||||||
|
"Qwen/Qwen2.5-VL-3B-Instruct",
|
||||||
|
max_model_len=2048,
|
||||||
|
gpu_memory_utilization=0.7,
|
||||||
|
compilation_config={"mode": CompilationMode.VLLM_COMPILE},
|
||||||
|
) as _,
|
||||||
|
):
|
||||||
|
pass
|
||||||
125
vllm/attention/ops/vit_attn_wrappers.py
Normal file
125
vllm/attention/ops/vit_attn_wrappers.py
Normal file
@@ -0,0 +1,125 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
"""
|
||||||
|
This file contains ops for ViT attention to be compatible with torch.compile
|
||||||
|
as there are operations here not supported by torch.compile (for instance,
|
||||||
|
`to_list` in xformers attn, or `.item()` in flash attention)
|
||||||
|
|
||||||
|
Using these ops and wrapping vision blocks with `torch.compile` can speed up
|
||||||
|
throughput in vision models by ~5% relative on H100, and improve token
|
||||||
|
latencies by ~7% (see qwen2_5_vl for example usage)
|
||||||
|
|
||||||
|
To use these ops, you must have a recent version of PyTorch installed (>= 2.4.0)
|
||||||
|
"""
|
||||||
|
|
||||||
|
import einops
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from vllm.utils.torch_utils import direct_register_custom_op
|
||||||
|
|
||||||
|
|
||||||
|
def xformers_attn_seqlens_wrapper(
|
||||||
|
q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, seqlens: torch.Tensor
|
||||||
|
) -> torch.Tensor:
|
||||||
|
from xformers import ops as xops
|
||||||
|
from xformers.ops.fmha.attn_bias import BlockDiagonalMask
|
||||||
|
|
||||||
|
attn_bias = BlockDiagonalMask.from_seqlens(
|
||||||
|
q_seqlen=seqlens.tolist(), kv_seqlen=None, device=q.device
|
||||||
|
)
|
||||||
|
context_layer = xops.memory_efficient_attention_forward(
|
||||||
|
q, k, v, attn_bias=attn_bias, p=0, scale=None
|
||||||
|
)
|
||||||
|
context_layer = einops.rearrange(context_layer, "b s h d -> s b (h d)").contiguous()
|
||||||
|
return context_layer
|
||||||
|
|
||||||
|
|
||||||
|
def xformers_attn_seqlens_wrapper_fake(
|
||||||
|
q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, seqlens: torch.Tensor
|
||||||
|
) -> torch.Tensor:
|
||||||
|
b, s, h, d = q.shape
|
||||||
|
return torch.empty((s, b, h * d), dtype=q.dtype, device=q.device)
|
||||||
|
|
||||||
|
|
||||||
|
direct_register_custom_op(
|
||||||
|
op_name="xformers_attn_seqlens_wrapper",
|
||||||
|
op_func=xformers_attn_seqlens_wrapper,
|
||||||
|
fake_impl=xformers_attn_seqlens_wrapper_fake,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def vit_xformers_attn_wrapper(
|
||||||
|
q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, seqlens: torch.Tensor
|
||||||
|
) -> torch.Tensor:
|
||||||
|
return torch.ops.vllm.xformers_attn_seqlens_wrapper(q, k, v, seqlens)
|
||||||
|
|
||||||
|
|
||||||
|
def flash_attn_maxseqlen_wrapper(
|
||||||
|
q: torch.Tensor,
|
||||||
|
k: torch.Tensor,
|
||||||
|
v: torch.Tensor,
|
||||||
|
cu_seqlens: torch.Tensor,
|
||||||
|
max_seqlen: torch.Tensor,
|
||||||
|
batch_size: int,
|
||||||
|
is_rocm_aiter: bool,
|
||||||
|
use_upstream_fa: bool,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
if is_rocm_aiter:
|
||||||
|
from aiter import flash_attn_varlen_func
|
||||||
|
else:
|
||||||
|
if use_upstream_fa:
|
||||||
|
from flash_attn import flash_attn_varlen_func
|
||||||
|
else:
|
||||||
|
from vllm.vllm_flash_attn import flash_attn_varlen_func
|
||||||
|
q, k, v = (einops.rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v])
|
||||||
|
output = flash_attn_varlen_func(
|
||||||
|
q,
|
||||||
|
k,
|
||||||
|
v,
|
||||||
|
cu_seqlens_q=cu_seqlens,
|
||||||
|
cu_seqlens_k=cu_seqlens,
|
||||||
|
max_seqlen_q=max_seqlen.item(),
|
||||||
|
max_seqlen_k=max_seqlen.item(),
|
||||||
|
dropout_p=0.0,
|
||||||
|
causal=False,
|
||||||
|
)
|
||||||
|
context_layer = einops.rearrange(
|
||||||
|
output, "(b s) h d -> s b (h d)", b=batch_size
|
||||||
|
).contiguous()
|
||||||
|
return context_layer
|
||||||
|
|
||||||
|
|
||||||
|
def flash_attn_maxseqlen_wrapper_fake(
|
||||||
|
q: torch.Tensor,
|
||||||
|
k: torch.Tensor,
|
||||||
|
v: torch.Tensor,
|
||||||
|
cu_seqlens: torch.Tensor,
|
||||||
|
max_seqlen: torch.Tensor,
|
||||||
|
batch_size: int,
|
||||||
|
is_rocm_aiter: bool,
|
||||||
|
use_upstream_fa: bool,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
b, s, h, d = q.shape
|
||||||
|
return torch.empty((s, b, h * d), dtype=q.dtype, device=q.device)
|
||||||
|
|
||||||
|
|
||||||
|
direct_register_custom_op(
|
||||||
|
op_name="flash_attn_maxseqlen_wrapper",
|
||||||
|
op_func=flash_attn_maxseqlen_wrapper,
|
||||||
|
fake_impl=flash_attn_maxseqlen_wrapper_fake,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def vit_flash_attn_wrapper(
|
||||||
|
q: torch.Tensor,
|
||||||
|
k: torch.Tensor,
|
||||||
|
v: torch.Tensor,
|
||||||
|
cu_seqlens: torch.Tensor,
|
||||||
|
max_seqlen: torch.Tensor,
|
||||||
|
batch_size: int,
|
||||||
|
is_rocm_aiter: bool,
|
||||||
|
use_upstream_fa: bool,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
return torch.ops.vllm.flash_attn_maxseqlen_wrapper(
|
||||||
|
q, k, v, cu_seqlens, max_seqlen, batch_size, is_rocm_aiter, use_upstream_fa
|
||||||
|
)
|
||||||
@@ -18,7 +18,12 @@ from torch._dynamo.symbolic_convert import InliningInstructionTranslator
|
|||||||
import vllm.envs as envs
|
import vllm.envs as envs
|
||||||
from vllm.compilation.counter import compilation_counter
|
from vllm.compilation.counter import compilation_counter
|
||||||
from vllm.compilation.wrapper import TorchCompileWrapperWithCustomDispatcher
|
from vllm.compilation.wrapper import TorchCompileWrapperWithCustomDispatcher
|
||||||
from vllm.config import CompilationMode, VllmConfig, set_current_vllm_config
|
from vllm.config import (
|
||||||
|
CompilationMode,
|
||||||
|
VllmConfig,
|
||||||
|
get_current_vllm_config,
|
||||||
|
set_current_vllm_config,
|
||||||
|
)
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.sequence import IntermediateTensors
|
from vllm.sequence import IntermediateTensors
|
||||||
from vllm.utils.import_utils import resolve_obj_by_qualname
|
from vllm.utils.import_utils import resolve_obj_by_qualname
|
||||||
@@ -74,6 +79,21 @@ def support_torch_compile(
|
|||||||
) -> Callable[[_T], _T]: ...
|
) -> Callable[[_T], _T]: ...
|
||||||
|
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def support_torch_compile(
|
||||||
|
*,
|
||||||
|
mark_unbacked_dims: dict[str, int | list[int]] | None,
|
||||||
|
) -> Callable[[_T], _T]: ...
|
||||||
|
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def support_torch_compile(
|
||||||
|
*,
|
||||||
|
dynamic_arg_dims: dict[str, int | list[int]] | None,
|
||||||
|
mark_unbacked_dims: dict[str, int | list[int]] | None,
|
||||||
|
) -> Callable[[_T], _T]: ...
|
||||||
|
|
||||||
|
|
||||||
@overload
|
@overload
|
||||||
def support_torch_compile(cls: _T) -> _T: ...
|
def support_torch_compile(cls: _T) -> _T: ...
|
||||||
|
|
||||||
@@ -82,6 +102,7 @@ def support_torch_compile(
|
|||||||
cls: _T | None = None,
|
cls: _T | None = None,
|
||||||
*,
|
*,
|
||||||
dynamic_arg_dims: dict[str, int | list[int]] | None = None,
|
dynamic_arg_dims: dict[str, int | list[int]] | None = None,
|
||||||
|
mark_unbacked_dims: dict[str, int | list[int]] | None = None,
|
||||||
enable_if: Callable[[VllmConfig], bool] | None = None,
|
enable_if: Callable[[VllmConfig], bool] | None = None,
|
||||||
) -> Callable[[_T], _T] | _T:
|
) -> Callable[[_T], _T] | _T:
|
||||||
"""
|
"""
|
||||||
@@ -135,6 +156,11 @@ def support_torch_compile(
|
|||||||
returns a boolean value indicating whether to compile the model or not.
|
returns a boolean value indicating whether to compile the model or not.
|
||||||
This is useful if you want to compile the model only when certain
|
This is useful if you want to compile the model only when certain
|
||||||
conditions are met.
|
conditions are met.
|
||||||
|
|
||||||
|
`mark_unbacked_dims` is a dictionary that maps argument names with a dynamic
|
||||||
|
dim to be decorated with `mark_unbacked`. This is useful if we would like to
|
||||||
|
enforce that dynamo do not specialize on 0/1 values in the case of dummy input
|
||||||
|
such as for vision model compilation
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def cls_decorator_helper(cls: _T) -> _T:
|
def cls_decorator_helper(cls: _T) -> _T:
|
||||||
@@ -172,7 +198,9 @@ def support_torch_compile(
|
|||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Argument {k} not found in the forward method of {cls}"
|
f"Argument {k} not found in the forward method of {cls}"
|
||||||
)
|
)
|
||||||
return _support_torch_compile(cls, inferred_dynamic_arg_dims, enable_if)
|
return _support_torch_compile(
|
||||||
|
cls, inferred_dynamic_arg_dims, mark_unbacked_dims, enable_if
|
||||||
|
)
|
||||||
|
|
||||||
if cls is not None:
|
if cls is not None:
|
||||||
# use `support_torch_compile` as a decorator without arguments
|
# use `support_torch_compile` as a decorator without arguments
|
||||||
@@ -212,6 +240,7 @@ def _verify_source_unchanged(source_info, vllm_config) -> None:
|
|||||||
def _support_torch_compile(
|
def _support_torch_compile(
|
||||||
cls: _T,
|
cls: _T,
|
||||||
dynamic_arg_dims: dict[str, int | list[int]],
|
dynamic_arg_dims: dict[str, int | list[int]],
|
||||||
|
mark_unbacked_dims: dict[str, int | list[int]] | None = None,
|
||||||
enable_if: Callable[[VllmConfig], bool] | None = None,
|
enable_if: Callable[[VllmConfig], bool] | None = None,
|
||||||
) -> _T:
|
) -> _T:
|
||||||
"""
|
"""
|
||||||
@@ -230,8 +259,22 @@ def _support_torch_compile(
|
|||||||
|
|
||||||
setattr(cls, IGNORE_COMPILE_KEY, False)
|
setattr(cls, IGNORE_COMPILE_KEY, False)
|
||||||
|
|
||||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = "", **kwargs):
|
def __init__(
|
||||||
old_init(self, vllm_config=vllm_config, prefix=prefix, **kwargs)
|
self, *, vllm_config: VllmConfig | None = None, prefix: str = "", **kwargs
|
||||||
|
):
|
||||||
|
if vllm_config is None:
|
||||||
|
vllm_config = get_current_vllm_config()
|
||||||
|
|
||||||
|
# NOTE: to support multimodal models (such as encoder),
|
||||||
|
# we may not have vllm_config so we may need to patch
|
||||||
|
# it
|
||||||
|
sig = inspect.signature(old_init)
|
||||||
|
if "vllm_config" in sig.parameters:
|
||||||
|
kwargs["vllm_config"] = vllm_config
|
||||||
|
if "prefix" in sig.parameters:
|
||||||
|
kwargs["prefix"] = prefix
|
||||||
|
old_init(self, **kwargs)
|
||||||
|
|
||||||
self.vllm_config = vllm_config
|
self.vllm_config = vllm_config
|
||||||
enable_compile = enable_if is None or enable_if(vllm_config)
|
enable_compile = enable_if is None or enable_if(vllm_config)
|
||||||
# for CompilationMode.STOCK_TORCH_COMPILE , the upper level model runner
|
# for CompilationMode.STOCK_TORCH_COMPILE , the upper level model runner
|
||||||
@@ -344,6 +387,15 @@ def _support_torch_compile(
|
|||||||
"Unsupported dynamic dimensions"
|
"Unsupported dynamic dimensions"
|
||||||
f" {dims} for argument {k} with type {type(arg)}."
|
f" {dims} for argument {k} with type {type(arg)}."
|
||||||
)
|
)
|
||||||
|
if mark_unbacked_dims:
|
||||||
|
for k, dims in mark_unbacked_dims.items():
|
||||||
|
arg = bound_args.arguments.get(k)
|
||||||
|
if arg is not None:
|
||||||
|
dims = [dims] if isinstance(dims, int) else dims
|
||||||
|
if isinstance(arg, torch.Tensor):
|
||||||
|
# In case dims is specified with negative indexing
|
||||||
|
dims = [arg.ndim + dim if dim < 0 else dim for dim in dims]
|
||||||
|
torch._dynamo.decorators.mark_unbacked(arg, dims)
|
||||||
# here, it is the starting point of the `torch.compile` process
|
# here, it is the starting point of the `torch.compile` process
|
||||||
start_monitoring_torch_compile(self.vllm_config)
|
start_monitoring_torch_compile(self.vllm_config)
|
||||||
logger.debug("Start compiling function %s", self.original_code_object)
|
logger.debug("Start compiling function %s", self.original_code_object)
|
||||||
|
|||||||
@@ -684,6 +684,8 @@ class CompilationConfig:
|
|||||||
|
|
||||||
from vllm.compilation.backends import VllmBackend
|
from vllm.compilation.backends import VllmBackend
|
||||||
|
|
||||||
|
# TODO[@lucaskabela]: See if we can forward prefix
|
||||||
|
# https://github.com/vllm-project/vllm/issues/27045
|
||||||
return VllmBackend(vllm_config)
|
return VllmBackend(vllm_config)
|
||||||
|
|
||||||
def post_init_cudagraph_sizes(self) -> None:
|
def post_init_cudagraph_sizes(self) -> None:
|
||||||
|
|||||||
@@ -45,6 +45,7 @@ from transformers.models.whisper import WhisperFeatureExtractor
|
|||||||
|
|
||||||
from vllm.config import VllmConfig
|
from vllm.config import VllmConfig
|
||||||
from vllm.config.multimodal import BaseDummyOptions
|
from vllm.config.multimodal import BaseDummyOptions
|
||||||
|
from vllm.forward_context import set_forward_context
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.model_executor.models.module_mapping import MultiModelKeys
|
from vllm.model_executor.models.module_mapping import MultiModelKeys
|
||||||
from vllm.model_executor.models.qwen2_5_vl import (
|
from vllm.model_executor.models.qwen2_5_vl import (
|
||||||
@@ -759,7 +760,8 @@ class Qwen2_5OmniConditionalGenerationMixin:
|
|||||||
assert grid_thw.ndim == 2
|
assert grid_thw.ndim == 2
|
||||||
|
|
||||||
pixel_values = image_input["pixel_values"].type(self.visual.dtype)
|
pixel_values = image_input["pixel_values"].type(self.visual.dtype)
|
||||||
image_embeds = self.visual(pixel_values, grid_thw=grid_thw)
|
with set_forward_context(None, self.vllm_config):
|
||||||
|
image_embeds = self.visual(pixel_values, grid_thw=grid_thw)
|
||||||
# Split concatenated embeddings for each image item.
|
# Split concatenated embeddings for each image item.
|
||||||
merge_size = self.visual.spatial_merge_size
|
merge_size = self.visual.spatial_merge_size
|
||||||
sizes = grid_thw.prod(-1) // merge_size // merge_size
|
sizes = grid_thw.prod(-1) // merge_size // merge_size
|
||||||
@@ -779,7 +781,8 @@ class Qwen2_5OmniConditionalGenerationMixin:
|
|||||||
assert grid_thw.ndim == 2
|
assert grid_thw.ndim == 2
|
||||||
|
|
||||||
pixel_values_videos = video_input["pixel_values_videos"].type(self.visual.dtype)
|
pixel_values_videos = video_input["pixel_values_videos"].type(self.visual.dtype)
|
||||||
video_embeds = self.visual(pixel_values_videos, grid_thw=grid_thw)
|
with set_forward_context(None, self.vllm_config):
|
||||||
|
video_embeds = self.visual(pixel_values_videos, grid_thw=grid_thw)
|
||||||
# Split concatenated embeddings for each video item.
|
# Split concatenated embeddings for each video item.
|
||||||
merge_size = self.visual.spatial_merge_size
|
merge_size = self.visual.spatial_merge_size
|
||||||
sizes = grid_thw.prod(-1) // merge_size // merge_size
|
sizes = grid_thw.prod(-1) // merge_size // merge_size
|
||||||
@@ -839,6 +842,7 @@ class Qwen2_5OmniThinkerForConditionalGeneration(
|
|||||||
|
|
||||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
self.vllm_config = vllm_config
|
||||||
thinker_config: Qwen2_5OmniThinkerConfig = (
|
thinker_config: Qwen2_5OmniThinkerConfig = (
|
||||||
vllm_config.model_config.hf_config.thinker_config
|
vllm_config.model_config.hf_config.thinker_config
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -31,10 +31,10 @@ from collections.abc import Callable, Iterable, Mapping, Sequence
|
|||||||
from functools import lru_cache, partial
|
from functools import lru_cache, partial
|
||||||
from typing import Annotated, Any, Literal, TypeAlias
|
from typing import Annotated, Any, Literal, TypeAlias
|
||||||
|
|
||||||
|
import einops
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from einops import rearrange
|
|
||||||
from transformers import BatchFeature, PretrainedConfig
|
from transformers import BatchFeature, PretrainedConfig
|
||||||
from transformers.models.qwen2_5_vl import Qwen2_5_VLProcessor
|
from transformers.models.qwen2_5_vl import Qwen2_5_VLProcessor
|
||||||
from transformers.models.qwen2_5_vl.configuration_qwen2_5_vl import (
|
from transformers.models.qwen2_5_vl.configuration_qwen2_5_vl import (
|
||||||
@@ -47,9 +47,15 @@ from vllm.attention.layer import (
|
|||||||
check_upstream_fa_availability,
|
check_upstream_fa_availability,
|
||||||
maybe_get_vit_flash_attn_backend,
|
maybe_get_vit_flash_attn_backend,
|
||||||
)
|
)
|
||||||
|
from vllm.attention.ops.vit_attn_wrappers import (
|
||||||
|
vit_flash_attn_wrapper,
|
||||||
|
vit_xformers_attn_wrapper,
|
||||||
|
)
|
||||||
|
from vllm.compilation.decorators import support_torch_compile
|
||||||
from vllm.config import VllmConfig
|
from vllm.config import VllmConfig
|
||||||
from vllm.distributed import parallel_state
|
from vllm.distributed import parallel_state
|
||||||
from vllm.distributed import utils as dist_utils
|
from vllm.distributed import utils as dist_utils
|
||||||
|
from vllm.forward_context import set_forward_context
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.model_executor.layers.activation import get_act_and_mul_fn
|
from vllm.model_executor.layers.activation import get_act_and_mul_fn
|
||||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||||
@@ -392,8 +398,8 @@ class Qwen2_5_VisionAttention(nn.Module):
|
|||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
cu_seqlens: torch.Tensor,
|
cu_seqlens: torch.Tensor,
|
||||||
rotary_pos_emb: torch.Tensor,
|
rotary_pos_emb: torch.Tensor,
|
||||||
max_seqlen: int | None = None, # Only used for Flash Attention
|
max_seqlen: torch.Tensor, # Only used for Flash Attention
|
||||||
seqlens: list[int] | None = None, # Only used for xFormers
|
seqlens: torch.Tensor, # Only used for xFormers
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
# [s, b, c] --> [s, b, head * 3 * head_dim]
|
# [s, b, c] --> [s, b, head * 3 * head_dim]
|
||||||
x, _ = self.qkv(x)
|
x, _ = self.qkv(x)
|
||||||
@@ -402,7 +408,7 @@ class Qwen2_5_VisionAttention(nn.Module):
|
|||||||
q, k, v = self.split_qkv(x)
|
q, k, v = self.split_qkv(x)
|
||||||
batch_size = q.shape[1]
|
batch_size = q.shape[1]
|
||||||
|
|
||||||
q, k, v = (rearrange(x, "s b ... -> b s ...") for x in (q, k, v))
|
q, k, v = (einops.rearrange(x, "s b ... -> b s ...") for x in (q, k, v))
|
||||||
if rotary_pos_emb is not None:
|
if rotary_pos_emb is not None:
|
||||||
# [2 * b, s, heads, head_dim]
|
# [2 * b, s, heads, head_dim]
|
||||||
qk_concat = torch.cat([q, k], dim=0)
|
qk_concat = torch.cat([q, k], dim=0)
|
||||||
@@ -410,31 +416,18 @@ class Qwen2_5_VisionAttention(nn.Module):
|
|||||||
q, k = torch.chunk(qk_rotated, 2, dim=0)
|
q, k = torch.chunk(qk_rotated, 2, dim=0)
|
||||||
|
|
||||||
if self.is_flash_attn_backend:
|
if self.is_flash_attn_backend:
|
||||||
q, k, v = (rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v])
|
context_layer = vit_flash_attn_wrapper(
|
||||||
|
|
||||||
output = self.flash_attn_varlen_func(
|
|
||||||
q,
|
q,
|
||||||
k,
|
k,
|
||||||
v,
|
v,
|
||||||
cu_seqlens_q=cu_seqlens,
|
cu_seqlens,
|
||||||
cu_seqlens_k=cu_seqlens,
|
max_seqlen,
|
||||||
max_seqlen_q=max_seqlen,
|
batch_size,
|
||||||
max_seqlen_k=max_seqlen,
|
self.attn_backend == _Backend.ROCM_AITER_FA,
|
||||||
dropout_p=0.0,
|
self.use_upstream_fa,
|
||||||
causal=False,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
context_layer = rearrange(
|
|
||||||
output, "(b s) h d -> s b (h d)", b=batch_size
|
|
||||||
).contiguous()
|
|
||||||
elif self.attn_backend == _Backend.TORCH_SDPA:
|
elif self.attn_backend == _Backend.TORCH_SDPA:
|
||||||
# Execute attention entry by entry for speed & less VRAM.
|
# Execute attention entry by entry for speed & less VRAM.
|
||||||
from vllm.platforms import current_platform
|
|
||||||
|
|
||||||
if current_platform.is_rocm():
|
|
||||||
q = q.contiguous()
|
|
||||||
k = k.contiguous()
|
|
||||||
v = v.contiguous()
|
|
||||||
outputs = []
|
outputs = []
|
||||||
for i in range(1, len(cu_seqlens)):
|
for i in range(1, len(cu_seqlens)):
|
||||||
start_idx = cu_seqlens[i - 1]
|
start_idx = cu_seqlens[i - 1]
|
||||||
@@ -443,34 +436,31 @@ class Qwen2_5_VisionAttention(nn.Module):
|
|||||||
k_i = k[:, start_idx:end_idx]
|
k_i = k[:, start_idx:end_idx]
|
||||||
v_i = v[:, start_idx:end_idx]
|
v_i = v[:, start_idx:end_idx]
|
||||||
q_i, k_i, v_i = (
|
q_i, k_i, v_i = (
|
||||||
rearrange(x, "b s h d -> b h s d") for x in [q_i, k_i, v_i]
|
einops.rearrange(x, "b s h d -> b h s d") for x in [q_i, k_i, v_i]
|
||||||
)
|
)
|
||||||
output_i = F.scaled_dot_product_attention(q_i, k_i, v_i, dropout_p=0.0)
|
output_i = F.scaled_dot_product_attention(q_i, k_i, v_i, dropout_p=0.0)
|
||||||
output_i = rearrange(output_i, "b h s d -> b s h d ")
|
output_i = einops.rearrange(output_i, "b h s d -> b s h d ")
|
||||||
outputs.append(output_i)
|
outputs.append(output_i)
|
||||||
context_layer = torch.cat(outputs, dim=1)
|
context_layer = torch.cat(outputs, dim=1)
|
||||||
context_layer = rearrange(
|
context_layer = einops.rearrange(
|
||||||
context_layer, "b s h d -> s b (h d)"
|
context_layer, "b s h d -> s b (h d)"
|
||||||
).contiguous()
|
).contiguous()
|
||||||
elif self.attn_backend == _Backend.XFORMERS:
|
elif self.attn_backend == _Backend.XFORMERS:
|
||||||
from xformers import ops as xops
|
context_layer = vit_xformers_attn_wrapper(q, k, v, seqlens)
|
||||||
from xformers.ops.fmha.attn_bias import BlockDiagonalMask
|
|
||||||
|
|
||||||
attn_bias = BlockDiagonalMask.from_seqlens(
|
|
||||||
q_seqlen=seqlens, kv_seqlen=None, device=q.device
|
|
||||||
)
|
|
||||||
|
|
||||||
context_layer = xops.memory_efficient_attention_forward(
|
|
||||||
q, k, v, attn_bias=attn_bias, p=0, scale=None
|
|
||||||
)
|
|
||||||
context_layer = rearrange(
|
|
||||||
context_layer, "b s h d -> s b (h d)"
|
|
||||||
).contiguous()
|
|
||||||
|
|
||||||
output, _ = self.proj(context_layer)
|
output, _ = self.proj(context_layer)
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
@support_torch_compile(
|
||||||
|
dynamic_arg_dims={
|
||||||
|
"x": 0,
|
||||||
|
"cu_seqlens": 0,
|
||||||
|
"rotary_pos_emb": 0,
|
||||||
|
"seqlens": 0,
|
||||||
|
},
|
||||||
|
mark_unbacked_dims={"seqlens": 0},
|
||||||
|
)
|
||||||
class Qwen2_5_VisionBlock(nn.Module):
|
class Qwen2_5_VisionBlock(nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@@ -515,8 +505,8 @@ class Qwen2_5_VisionBlock(nn.Module):
|
|||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
cu_seqlens: torch.Tensor,
|
cu_seqlens: torch.Tensor,
|
||||||
rotary_pos_emb: torch.Tensor,
|
rotary_pos_emb: torch.Tensor,
|
||||||
max_seqlen: int | None = None, # Only used for Flash Attention
|
max_seqlen: torch.Tensor, # Only used for Flash Attention
|
||||||
seqlens: list[int] | None = None, # Only used for xFormers
|
seqlens: torch.Tensor, # Only used for xFormers
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
x_attn = self.attn(
|
x_attn = self.attn(
|
||||||
self.norm1(x),
|
self.norm1(x),
|
||||||
@@ -530,6 +520,11 @@ class Qwen2_5_VisionBlock(nn.Module):
|
|||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
@support_torch_compile(
|
||||||
|
dynamic_arg_dims={
|
||||||
|
"x": 0,
|
||||||
|
}
|
||||||
|
)
|
||||||
class Qwen2_5_VisionPatchEmbed(nn.Module):
|
class Qwen2_5_VisionPatchEmbed(nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@@ -556,6 +551,11 @@ class Qwen2_5_VisionPatchEmbed(nn.Module):
|
|||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
@support_torch_compile(
|
||||||
|
dynamic_arg_dims={
|
||||||
|
"x": 0,
|
||||||
|
}
|
||||||
|
)
|
||||||
class Qwen2_5_VisionPatchMerger(nn.Module):
|
class Qwen2_5_VisionPatchMerger(nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@@ -665,13 +665,18 @@ class Qwen2_5_VisionTransformer(nn.Module):
|
|||||||
self.spatial_merge_size = vision_config.spatial_merge_size
|
self.spatial_merge_size = vision_config.spatial_merge_size
|
||||||
self.fullatt_block_indexes = vision_config.fullatt_block_indexes
|
self.fullatt_block_indexes = vision_config.fullatt_block_indexes
|
||||||
self.spatial_merge_unit = self.spatial_merge_size**2
|
self.spatial_merge_unit = self.spatial_merge_size**2
|
||||||
|
# TODO[@lucaskabela]: Investigate fixing this usage
|
||||||
|
# see https://github.com/vllm-project/vllm/issues/27044
|
||||||
|
# DO NOT MOVE THIS IMPORT
|
||||||
|
from vllm.compilation.backends import set_model_tag
|
||||||
|
|
||||||
self.patch_embed = Qwen2_5_VisionPatchEmbed(
|
with set_model_tag("Qwen2_5_VisionPatchEmbed"):
|
||||||
patch_size=patch_size,
|
self.patch_embed = Qwen2_5_VisionPatchEmbed(
|
||||||
temporal_patch_size=temporal_patch_size,
|
patch_size=patch_size,
|
||||||
in_channels=in_channels,
|
temporal_patch_size=temporal_patch_size,
|
||||||
hidden_size=self.hidden_size,
|
in_channels=in_channels,
|
||||||
)
|
hidden_size=self.hidden_size,
|
||||||
|
)
|
||||||
|
|
||||||
norm_layer = partial(RMSNorm, eps=norm_eps)
|
norm_layer = partial(RMSNorm, eps=norm_eps)
|
||||||
head_dim = self.hidden_size // self.num_heads
|
head_dim = self.hidden_size // self.num_heads
|
||||||
@@ -701,32 +706,35 @@ class Qwen2_5_VisionTransformer(nn.Module):
|
|||||||
f"Qwen2.5-VL does not support {self.attn_backend} backend now."
|
f"Qwen2.5-VL does not support {self.attn_backend} backend now."
|
||||||
)
|
)
|
||||||
|
|
||||||
self.blocks = nn.ModuleList(
|
with set_model_tag("Qwen2_5_VisionBlock"):
|
||||||
[
|
self.blocks = nn.ModuleList(
|
||||||
Qwen2_5_VisionBlock(
|
[
|
||||||
dim=self.hidden_size,
|
Qwen2_5_VisionBlock(
|
||||||
num_heads=self.num_heads,
|
dim=self.hidden_size,
|
||||||
mlp_hidden_dim=vision_config.intermediate_size,
|
num_heads=self.num_heads,
|
||||||
act_fn=get_act_and_mul_fn(vision_config.hidden_act),
|
mlp_hidden_dim=vision_config.intermediate_size,
|
||||||
norm_layer=norm_layer,
|
act_fn=get_act_and_mul_fn(vision_config.hidden_act),
|
||||||
quant_config=quant_config,
|
norm_layer=norm_layer,
|
||||||
prefix=f"{prefix}.blocks.{layer_idx}",
|
quant_config=quant_config,
|
||||||
use_data_parallel=use_data_parallel,
|
prefix=f"{prefix}.blocks.{layer_idx}",
|
||||||
attn_backend=self.attn_backend,
|
use_data_parallel=use_data_parallel,
|
||||||
use_upstream_fa=use_upstream_fa,
|
attn_backend=self.attn_backend,
|
||||||
)
|
use_upstream_fa=use_upstream_fa,
|
||||||
for layer_idx in range(depth)
|
)
|
||||||
]
|
for layer_idx in range(depth)
|
||||||
)
|
]
|
||||||
self.merger = Qwen2_5_VisionPatchMerger(
|
)
|
||||||
d_model=vision_config.out_hidden_size,
|
|
||||||
context_dim=self.hidden_size,
|
with set_model_tag("Qwen2_5_VisionPatchMerger"):
|
||||||
norm_layer=norm_layer,
|
self.merger = Qwen2_5_VisionPatchMerger(
|
||||||
spatial_merge_size=self.spatial_merge_size,
|
d_model=vision_config.out_hidden_size,
|
||||||
quant_config=quant_config,
|
context_dim=self.hidden_size,
|
||||||
prefix=f"{prefix}.merger",
|
norm_layer=norm_layer,
|
||||||
use_data_parallel=use_data_parallel,
|
spatial_merge_size=self.spatial_merge_size,
|
||||||
)
|
quant_config=quant_config,
|
||||||
|
prefix=f"{prefix}.merger",
|
||||||
|
use_data_parallel=use_data_parallel,
|
||||||
|
)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def dtype(self) -> torch.dtype:
|
def dtype(self) -> torch.dtype:
|
||||||
@@ -827,15 +835,18 @@ class Qwen2_5_VisionTransformer(nn.Module):
|
|||||||
def compute_attn_mask_seqlen(
|
def compute_attn_mask_seqlen(
|
||||||
self,
|
self,
|
||||||
cu_seqlens: torch.Tensor,
|
cu_seqlens: torch.Tensor,
|
||||||
) -> tuple[int | None, list[int] | None]:
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||||
max_seqlen, seqlens = None, None
|
max_seqlen, seqlens = (
|
||||||
|
torch.zeros(1, device=cu_seqlens.device),
|
||||||
|
torch.zeros(1, device=cu_seqlens.device),
|
||||||
|
)
|
||||||
if (
|
if (
|
||||||
self.attn_backend == _Backend.FLASH_ATTN
|
self.attn_backend == _Backend.FLASH_ATTN
|
||||||
or self.attn_backend == _Backend.ROCM_AITER_FA
|
or self.attn_backend == _Backend.ROCM_AITER_FA
|
||||||
):
|
):
|
||||||
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
|
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max()
|
||||||
elif self.attn_backend == _Backend.XFORMERS:
|
elif self.attn_backend == _Backend.XFORMERS:
|
||||||
seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()
|
seqlens = cu_seqlens[1:] - cu_seqlens[:-1]
|
||||||
return max_seqlen, seqlens
|
return max_seqlen, seqlens
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@@ -1233,6 +1244,7 @@ class Qwen2_5_VLForConditionalGeneration(
|
|||||||
|
|
||||||
self.use_data_parallel = multimodal_config.mm_encoder_tp_mode == "data"
|
self.use_data_parallel = multimodal_config.mm_encoder_tp_mode == "data"
|
||||||
self.config = config
|
self.config = config
|
||||||
|
self.vllm_config = vllm_config
|
||||||
self.multimodal_config = multimodal_config
|
self.multimodal_config = multimodal_config
|
||||||
self.video_pruning_rate = multimodal_config.video_pruning_rate
|
self.video_pruning_rate = multimodal_config.video_pruning_rate
|
||||||
self.is_multimodal_pruning_enabled = (
|
self.is_multimodal_pruning_enabled = (
|
||||||
@@ -1248,7 +1260,7 @@ class Qwen2_5_VLForConditionalGeneration(
|
|||||||
else None
|
else None
|
||||||
)
|
)
|
||||||
self.visual = Qwen2_5_VisionTransformer(
|
self.visual = Qwen2_5_VisionTransformer(
|
||||||
config.vision_config,
|
vision_config=config.vision_config,
|
||||||
norm_eps=getattr(config, "rms_norm_eps", 1e-6),
|
norm_eps=getattr(config, "rms_norm_eps", 1e-6),
|
||||||
quant_config=self.quant_config,
|
quant_config=self.quant_config,
|
||||||
prefix=maybe_prefix(prefix, "visual"),
|
prefix=maybe_prefix(prefix, "visual"),
|
||||||
@@ -1336,13 +1348,13 @@ class Qwen2_5_VLForConditionalGeneration(
|
|||||||
image_embeds = image_input["image_embeds"].type(self.visual.dtype)
|
image_embeds = image_input["image_embeds"].type(self.visual.dtype)
|
||||||
else:
|
else:
|
||||||
pixel_values = image_input["pixel_values"]
|
pixel_values = image_input["pixel_values"]
|
||||||
|
with set_forward_context(None, self.vllm_config):
|
||||||
if self.use_data_parallel:
|
if self.use_data_parallel:
|
||||||
return run_dp_sharded_mrope_vision_model(
|
return run_dp_sharded_mrope_vision_model(
|
||||||
self.visual, pixel_values, grid_thw_list, rope_type="rope_3d"
|
self.visual, pixel_values, grid_thw_list, rope_type="rope_3d"
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
image_embeds = self.visual(pixel_values, grid_thw=grid_thw_list)
|
image_embeds = self.visual(pixel_values, grid_thw=grid_thw_list)
|
||||||
|
|
||||||
# Split concatenated embeddings for each image item.
|
# Split concatenated embeddings for each image item.
|
||||||
# Using prod on grid_thw_list instead of grid_thw.prod avoids CUDA sync
|
# Using prod on grid_thw_list instead of grid_thw.prod avoids CUDA sync
|
||||||
@@ -1396,12 +1408,18 @@ class Qwen2_5_VLForConditionalGeneration(
|
|||||||
video_embeds = video_input["video_embeds"].type(self.visual.dtype)
|
video_embeds = video_input["video_embeds"].type(self.visual.dtype)
|
||||||
else:
|
else:
|
||||||
pixel_values_videos = video_input["pixel_values_videos"]
|
pixel_values_videos = video_input["pixel_values_videos"]
|
||||||
if self.use_data_parallel:
|
with set_forward_context(None, self.vllm_config):
|
||||||
return run_dp_sharded_mrope_vision_model(
|
if self.use_data_parallel:
|
||||||
self.visual, pixel_values_videos, grid_thw_list, rope_type="rope_3d"
|
return run_dp_sharded_mrope_vision_model(
|
||||||
)
|
self.visual,
|
||||||
else:
|
pixel_values_videos,
|
||||||
video_embeds = self.visual(pixel_values_videos, grid_thw=grid_thw_list)
|
grid_thw_list,
|
||||||
|
rope_type="rope_3d",
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
video_embeds = self.visual(
|
||||||
|
pixel_values_videos, grid_thw=grid_thw_list
|
||||||
|
)
|
||||||
|
|
||||||
# Split concatenated embeddings for each video item.
|
# Split concatenated embeddings for each video item.
|
||||||
merge_size = self.visual.spatial_merge_size
|
merge_size = self.visual.spatial_merge_size
|
||||||
|
|||||||
Reference in New Issue
Block a user