[Misc][qwen2_5_vl][torch.compile] Enable supports_torch_compile on generic nn.Module and demonstrate speedup on Qwen Vision model (#23207)
Signed-off-by: Lucas Kabela <lucaskabela@meta.com> Signed-off-by: Lucas Kabela <lucasakabela@gmail.com>
This commit is contained in:
@@ -31,10 +31,10 @@ from collections.abc import Callable, Iterable, Mapping, Sequence
|
||||
from functools import lru_cache, partial
|
||||
from typing import Annotated, Any, Literal, TypeAlias
|
||||
|
||||
import einops
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from einops import rearrange
|
||||
from transformers import BatchFeature, PretrainedConfig
|
||||
from transformers.models.qwen2_5_vl import Qwen2_5_VLProcessor
|
||||
from transformers.models.qwen2_5_vl.configuration_qwen2_5_vl import (
|
||||
@@ -47,9 +47,15 @@ from vllm.attention.layer import (
|
||||
check_upstream_fa_availability,
|
||||
maybe_get_vit_flash_attn_backend,
|
||||
)
|
||||
from vllm.attention.ops.vit_attn_wrappers import (
|
||||
vit_flash_attn_wrapper,
|
||||
vit_xformers_attn_wrapper,
|
||||
)
|
||||
from vllm.compilation.decorators import support_torch_compile
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.distributed import parallel_state
|
||||
from vllm.distributed import utils as dist_utils
|
||||
from vllm.forward_context import set_forward_context
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.activation import get_act_and_mul_fn
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
@@ -392,8 +398,8 @@ class Qwen2_5_VisionAttention(nn.Module):
|
||||
x: torch.Tensor,
|
||||
cu_seqlens: torch.Tensor,
|
||||
rotary_pos_emb: torch.Tensor,
|
||||
max_seqlen: int | None = None, # Only used for Flash Attention
|
||||
seqlens: list[int] | None = None, # Only used for xFormers
|
||||
max_seqlen: torch.Tensor, # Only used for Flash Attention
|
||||
seqlens: torch.Tensor, # Only used for xFormers
|
||||
) -> torch.Tensor:
|
||||
# [s, b, c] --> [s, b, head * 3 * head_dim]
|
||||
x, _ = self.qkv(x)
|
||||
@@ -402,7 +408,7 @@ class Qwen2_5_VisionAttention(nn.Module):
|
||||
q, k, v = self.split_qkv(x)
|
||||
batch_size = q.shape[1]
|
||||
|
||||
q, k, v = (rearrange(x, "s b ... -> b s ...") for x in (q, k, v))
|
||||
q, k, v = (einops.rearrange(x, "s b ... -> b s ...") for x in (q, k, v))
|
||||
if rotary_pos_emb is not None:
|
||||
# [2 * b, s, heads, head_dim]
|
||||
qk_concat = torch.cat([q, k], dim=0)
|
||||
@@ -410,31 +416,18 @@ class Qwen2_5_VisionAttention(nn.Module):
|
||||
q, k = torch.chunk(qk_rotated, 2, dim=0)
|
||||
|
||||
if self.is_flash_attn_backend:
|
||||
q, k, v = (rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v])
|
||||
|
||||
output = self.flash_attn_varlen_func(
|
||||
context_layer = vit_flash_attn_wrapper(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
cu_seqlens_q=cu_seqlens,
|
||||
cu_seqlens_k=cu_seqlens,
|
||||
max_seqlen_q=max_seqlen,
|
||||
max_seqlen_k=max_seqlen,
|
||||
dropout_p=0.0,
|
||||
causal=False,
|
||||
cu_seqlens,
|
||||
max_seqlen,
|
||||
batch_size,
|
||||
self.attn_backend == _Backend.ROCM_AITER_FA,
|
||||
self.use_upstream_fa,
|
||||
)
|
||||
|
||||
context_layer = rearrange(
|
||||
output, "(b s) h d -> s b (h d)", b=batch_size
|
||||
).contiguous()
|
||||
elif self.attn_backend == _Backend.TORCH_SDPA:
|
||||
# Execute attention entry by entry for speed & less VRAM.
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
if current_platform.is_rocm():
|
||||
q = q.contiguous()
|
||||
k = k.contiguous()
|
||||
v = v.contiguous()
|
||||
outputs = []
|
||||
for i in range(1, len(cu_seqlens)):
|
||||
start_idx = cu_seqlens[i - 1]
|
||||
@@ -443,34 +436,31 @@ class Qwen2_5_VisionAttention(nn.Module):
|
||||
k_i = k[:, start_idx:end_idx]
|
||||
v_i = v[:, start_idx:end_idx]
|
||||
q_i, k_i, v_i = (
|
||||
rearrange(x, "b s h d -> b h s d") for x in [q_i, k_i, v_i]
|
||||
einops.rearrange(x, "b s h d -> b h s d") for x in [q_i, k_i, v_i]
|
||||
)
|
||||
output_i = F.scaled_dot_product_attention(q_i, k_i, v_i, dropout_p=0.0)
|
||||
output_i = rearrange(output_i, "b h s d -> b s h d ")
|
||||
output_i = einops.rearrange(output_i, "b h s d -> b s h d ")
|
||||
outputs.append(output_i)
|
||||
context_layer = torch.cat(outputs, dim=1)
|
||||
context_layer = rearrange(
|
||||
context_layer = einops.rearrange(
|
||||
context_layer, "b s h d -> s b (h d)"
|
||||
).contiguous()
|
||||
elif self.attn_backend == _Backend.XFORMERS:
|
||||
from xformers import ops as xops
|
||||
from xformers.ops.fmha.attn_bias import BlockDiagonalMask
|
||||
|
||||
attn_bias = BlockDiagonalMask.from_seqlens(
|
||||
q_seqlen=seqlens, kv_seqlen=None, device=q.device
|
||||
)
|
||||
|
||||
context_layer = xops.memory_efficient_attention_forward(
|
||||
q, k, v, attn_bias=attn_bias, p=0, scale=None
|
||||
)
|
||||
context_layer = rearrange(
|
||||
context_layer, "b s h d -> s b (h d)"
|
||||
).contiguous()
|
||||
context_layer = vit_xformers_attn_wrapper(q, k, v, seqlens)
|
||||
|
||||
output, _ = self.proj(context_layer)
|
||||
return output
|
||||
|
||||
|
||||
@support_torch_compile(
|
||||
dynamic_arg_dims={
|
||||
"x": 0,
|
||||
"cu_seqlens": 0,
|
||||
"rotary_pos_emb": 0,
|
||||
"seqlens": 0,
|
||||
},
|
||||
mark_unbacked_dims={"seqlens": 0},
|
||||
)
|
||||
class Qwen2_5_VisionBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
@@ -515,8 +505,8 @@ class Qwen2_5_VisionBlock(nn.Module):
|
||||
x: torch.Tensor,
|
||||
cu_seqlens: torch.Tensor,
|
||||
rotary_pos_emb: torch.Tensor,
|
||||
max_seqlen: int | None = None, # Only used for Flash Attention
|
||||
seqlens: list[int] | None = None, # Only used for xFormers
|
||||
max_seqlen: torch.Tensor, # Only used for Flash Attention
|
||||
seqlens: torch.Tensor, # Only used for xFormers
|
||||
) -> torch.Tensor:
|
||||
x_attn = self.attn(
|
||||
self.norm1(x),
|
||||
@@ -530,6 +520,11 @@ class Qwen2_5_VisionBlock(nn.Module):
|
||||
return x
|
||||
|
||||
|
||||
@support_torch_compile(
|
||||
dynamic_arg_dims={
|
||||
"x": 0,
|
||||
}
|
||||
)
|
||||
class Qwen2_5_VisionPatchEmbed(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
@@ -556,6 +551,11 @@ class Qwen2_5_VisionPatchEmbed(nn.Module):
|
||||
return x
|
||||
|
||||
|
||||
@support_torch_compile(
|
||||
dynamic_arg_dims={
|
||||
"x": 0,
|
||||
}
|
||||
)
|
||||
class Qwen2_5_VisionPatchMerger(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
@@ -665,13 +665,18 @@ class Qwen2_5_VisionTransformer(nn.Module):
|
||||
self.spatial_merge_size = vision_config.spatial_merge_size
|
||||
self.fullatt_block_indexes = vision_config.fullatt_block_indexes
|
||||
self.spatial_merge_unit = self.spatial_merge_size**2
|
||||
# TODO[@lucaskabela]: Investigate fixing this usage
|
||||
# see https://github.com/vllm-project/vllm/issues/27044
|
||||
# DO NOT MOVE THIS IMPORT
|
||||
from vllm.compilation.backends import set_model_tag
|
||||
|
||||
self.patch_embed = Qwen2_5_VisionPatchEmbed(
|
||||
patch_size=patch_size,
|
||||
temporal_patch_size=temporal_patch_size,
|
||||
in_channels=in_channels,
|
||||
hidden_size=self.hidden_size,
|
||||
)
|
||||
with set_model_tag("Qwen2_5_VisionPatchEmbed"):
|
||||
self.patch_embed = Qwen2_5_VisionPatchEmbed(
|
||||
patch_size=patch_size,
|
||||
temporal_patch_size=temporal_patch_size,
|
||||
in_channels=in_channels,
|
||||
hidden_size=self.hidden_size,
|
||||
)
|
||||
|
||||
norm_layer = partial(RMSNorm, eps=norm_eps)
|
||||
head_dim = self.hidden_size // self.num_heads
|
||||
@@ -701,32 +706,35 @@ class Qwen2_5_VisionTransformer(nn.Module):
|
||||
f"Qwen2.5-VL does not support {self.attn_backend} backend now."
|
||||
)
|
||||
|
||||
self.blocks = nn.ModuleList(
|
||||
[
|
||||
Qwen2_5_VisionBlock(
|
||||
dim=self.hidden_size,
|
||||
num_heads=self.num_heads,
|
||||
mlp_hidden_dim=vision_config.intermediate_size,
|
||||
act_fn=get_act_and_mul_fn(vision_config.hidden_act),
|
||||
norm_layer=norm_layer,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.blocks.{layer_idx}",
|
||||
use_data_parallel=use_data_parallel,
|
||||
attn_backend=self.attn_backend,
|
||||
use_upstream_fa=use_upstream_fa,
|
||||
)
|
||||
for layer_idx in range(depth)
|
||||
]
|
||||
)
|
||||
self.merger = Qwen2_5_VisionPatchMerger(
|
||||
d_model=vision_config.out_hidden_size,
|
||||
context_dim=self.hidden_size,
|
||||
norm_layer=norm_layer,
|
||||
spatial_merge_size=self.spatial_merge_size,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.merger",
|
||||
use_data_parallel=use_data_parallel,
|
||||
)
|
||||
with set_model_tag("Qwen2_5_VisionBlock"):
|
||||
self.blocks = nn.ModuleList(
|
||||
[
|
||||
Qwen2_5_VisionBlock(
|
||||
dim=self.hidden_size,
|
||||
num_heads=self.num_heads,
|
||||
mlp_hidden_dim=vision_config.intermediate_size,
|
||||
act_fn=get_act_and_mul_fn(vision_config.hidden_act),
|
||||
norm_layer=norm_layer,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.blocks.{layer_idx}",
|
||||
use_data_parallel=use_data_parallel,
|
||||
attn_backend=self.attn_backend,
|
||||
use_upstream_fa=use_upstream_fa,
|
||||
)
|
||||
for layer_idx in range(depth)
|
||||
]
|
||||
)
|
||||
|
||||
with set_model_tag("Qwen2_5_VisionPatchMerger"):
|
||||
self.merger = Qwen2_5_VisionPatchMerger(
|
||||
d_model=vision_config.out_hidden_size,
|
||||
context_dim=self.hidden_size,
|
||||
norm_layer=norm_layer,
|
||||
spatial_merge_size=self.spatial_merge_size,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.merger",
|
||||
use_data_parallel=use_data_parallel,
|
||||
)
|
||||
|
||||
@property
|
||||
def dtype(self) -> torch.dtype:
|
||||
@@ -827,15 +835,18 @@ class Qwen2_5_VisionTransformer(nn.Module):
|
||||
def compute_attn_mask_seqlen(
|
||||
self,
|
||||
cu_seqlens: torch.Tensor,
|
||||
) -> tuple[int | None, list[int] | None]:
|
||||
max_seqlen, seqlens = None, None
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
max_seqlen, seqlens = (
|
||||
torch.zeros(1, device=cu_seqlens.device),
|
||||
torch.zeros(1, device=cu_seqlens.device),
|
||||
)
|
||||
if (
|
||||
self.attn_backend == _Backend.FLASH_ATTN
|
||||
or self.attn_backend == _Backend.ROCM_AITER_FA
|
||||
):
|
||||
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
|
||||
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max()
|
||||
elif self.attn_backend == _Backend.XFORMERS:
|
||||
seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()
|
||||
seqlens = cu_seqlens[1:] - cu_seqlens[:-1]
|
||||
return max_seqlen, seqlens
|
||||
|
||||
@staticmethod
|
||||
@@ -1233,6 +1244,7 @@ class Qwen2_5_VLForConditionalGeneration(
|
||||
|
||||
self.use_data_parallel = multimodal_config.mm_encoder_tp_mode == "data"
|
||||
self.config = config
|
||||
self.vllm_config = vllm_config
|
||||
self.multimodal_config = multimodal_config
|
||||
self.video_pruning_rate = multimodal_config.video_pruning_rate
|
||||
self.is_multimodal_pruning_enabled = (
|
||||
@@ -1248,7 +1260,7 @@ class Qwen2_5_VLForConditionalGeneration(
|
||||
else None
|
||||
)
|
||||
self.visual = Qwen2_5_VisionTransformer(
|
||||
config.vision_config,
|
||||
vision_config=config.vision_config,
|
||||
norm_eps=getattr(config, "rms_norm_eps", 1e-6),
|
||||
quant_config=self.quant_config,
|
||||
prefix=maybe_prefix(prefix, "visual"),
|
||||
@@ -1336,13 +1348,13 @@ class Qwen2_5_VLForConditionalGeneration(
|
||||
image_embeds = image_input["image_embeds"].type(self.visual.dtype)
|
||||
else:
|
||||
pixel_values = image_input["pixel_values"]
|
||||
|
||||
if self.use_data_parallel:
|
||||
return run_dp_sharded_mrope_vision_model(
|
||||
self.visual, pixel_values, grid_thw_list, rope_type="rope_3d"
|
||||
)
|
||||
else:
|
||||
image_embeds = self.visual(pixel_values, grid_thw=grid_thw_list)
|
||||
with set_forward_context(None, self.vllm_config):
|
||||
if self.use_data_parallel:
|
||||
return run_dp_sharded_mrope_vision_model(
|
||||
self.visual, pixel_values, grid_thw_list, rope_type="rope_3d"
|
||||
)
|
||||
else:
|
||||
image_embeds = self.visual(pixel_values, grid_thw=grid_thw_list)
|
||||
|
||||
# Split concatenated embeddings for each image item.
|
||||
# Using prod on grid_thw_list instead of grid_thw.prod avoids CUDA sync
|
||||
@@ -1396,12 +1408,18 @@ class Qwen2_5_VLForConditionalGeneration(
|
||||
video_embeds = video_input["video_embeds"].type(self.visual.dtype)
|
||||
else:
|
||||
pixel_values_videos = video_input["pixel_values_videos"]
|
||||
if self.use_data_parallel:
|
||||
return run_dp_sharded_mrope_vision_model(
|
||||
self.visual, pixel_values_videos, grid_thw_list, rope_type="rope_3d"
|
||||
)
|
||||
else:
|
||||
video_embeds = self.visual(pixel_values_videos, grid_thw=grid_thw_list)
|
||||
with set_forward_context(None, self.vllm_config):
|
||||
if self.use_data_parallel:
|
||||
return run_dp_sharded_mrope_vision_model(
|
||||
self.visual,
|
||||
pixel_values_videos,
|
||||
grid_thw_list,
|
||||
rope_type="rope_3d",
|
||||
)
|
||||
else:
|
||||
video_embeds = self.visual(
|
||||
pixel_values_videos, grid_thw=grid_thw_list
|
||||
)
|
||||
|
||||
# Split concatenated embeddings for each video item.
|
||||
merge_size = self.visual.spatial_merge_size
|
||||
|
||||
Reference in New Issue
Block a user