Flashinfer cuDNN backend for Qwen3 VL ViT attention (#34580)
Signed-off-by: Max Hu <maxhu@nvidia.com> Signed-off-by: Max Hu <hyoung2991@gmail.com> Co-authored-by: Max Hu <maxhu@nvidia.com> Co-authored-by: Shang Wang <shangw@nvidia.com>
This commit is contained in:
@@ -51,9 +51,12 @@ from transformers.video_utils import VideoMetadata
|
||||
from vllm.compilation.decorators import support_torch_compile
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.config.multimodal import BaseDummyOptions, VideoDummyOptions
|
||||
from vllm.distributed import get_pp_group
|
||||
from vllm.distributed import get_pp_group, parallel_state
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.activation import _ACTIVATION_REGISTRY
|
||||
from vllm.model_executor.layers.attention.mm_encoder_attention import (
|
||||
MMEncoderAttention,
|
||||
)
|
||||
from vllm.model_executor.layers.conv import Conv3dLayer
|
||||
from vllm.model_executor.layers.linear import (
|
||||
ColumnParallelLinear,
|
||||
@@ -92,7 +95,6 @@ from vllm.multimodal.processing import (
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.utils.collection_utils import is_list_of
|
||||
from vllm.utils.math_utils import round_up
|
||||
from vllm.v1.attention.backends.registry import AttentionBackendEnum
|
||||
|
||||
from .interfaces import (
|
||||
MultiModalEmbeddings,
|
||||
@@ -244,6 +246,7 @@ class Qwen3_VisionBlock(nn.Module):
|
||||
rotary_pos_emb_cos: torch.Tensor,
|
||||
rotary_pos_emb_sin: torch.Tensor,
|
||||
max_seqlen: torch.Tensor, # Only used for Flash Attention
|
||||
sequence_lengths: torch.Tensor, # Only used for FlashInfer CuDNN backend
|
||||
) -> torch.Tensor:
|
||||
x = x + self.attn(
|
||||
self.norm1(x),
|
||||
@@ -251,6 +254,7 @@ class Qwen3_VisionBlock(nn.Module):
|
||||
rotary_pos_emb_cos=rotary_pos_emb_cos,
|
||||
rotary_pos_emb_sin=rotary_pos_emb_sin,
|
||||
max_seqlen=max_seqlen,
|
||||
sequence_lengths=sequence_lengths,
|
||||
)
|
||||
|
||||
x = x + self.mlp(self.norm2(x))
|
||||
@@ -332,6 +336,13 @@ class Qwen3_VisionTransformer(nn.Module):
|
||||
)
|
||||
self.num_grid_per_side = int(self.num_position_embeddings**0.5)
|
||||
|
||||
use_data_parallel = is_vit_use_data_parallel()
|
||||
self.tp_size = (
|
||||
1
|
||||
if use_data_parallel
|
||||
else parallel_state.get_tensor_model_parallel_world_size()
|
||||
)
|
||||
|
||||
# NOTE: This is used for creating empty tensor for all_gather for
|
||||
# DP ViT. Here out_hidden_size is enlarged due to deepstack
|
||||
self.out_hidden_size = vision_config.out_hidden_size * (
|
||||
@@ -513,19 +524,6 @@ class Qwen3_VisionTransformer(nn.Module):
|
||||
|
||||
return torch.cat(outputs, dim=0)
|
||||
|
||||
def compute_attn_mask_seqlen(
|
||||
self,
|
||||
cu_seqlens: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
max_seqlen = torch.zeros([], device=cu_seqlens.device)
|
||||
if self.attn_backend in (
|
||||
AttentionBackendEnum.FLASH_ATTN,
|
||||
AttentionBackendEnum.ROCM_AITER_FA,
|
||||
AttentionBackendEnum.TRITON_ATTN,
|
||||
):
|
||||
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max()
|
||||
return max_seqlen
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
@@ -549,11 +547,26 @@ class Qwen3_VisionTransformer(nn.Module):
|
||||
axis=0, dtype=np.int32
|
||||
)
|
||||
cu_seqlens = np.concatenate([np.zeros(1, dtype=np.int32), cu_seqlens])
|
||||
cu_seqlens = torch.from_numpy(cu_seqlens)
|
||||
|
||||
sequence_lengths = MMEncoderAttention.maybe_compute_sequence_lengths(
|
||||
self.attn_backend, cu_seqlens
|
||||
)
|
||||
if sequence_lengths is not None:
|
||||
sequence_lengths = torch.from_numpy(sequence_lengths).to(
|
||||
self.device, non_blocking=True
|
||||
)
|
||||
max_seqlen = torch.tensor(
|
||||
MMEncoderAttention.compute_max_seqlen(self.attn_backend, cu_seqlens),
|
||||
dtype=torch.int32,
|
||||
device=self.device,
|
||||
)
|
||||
cu_seqlens = MMEncoderAttention.maybe_recompute_cu_seqlens(
|
||||
self.attn_backend,
|
||||
cu_seqlens,
|
||||
self.hidden_size,
|
||||
self.tp_size,
|
||||
)
|
||||
cu_seqlens = torch.from_numpy(cu_seqlens).to(self.device, non_blocking=True)
|
||||
hidden_states = hidden_states.unsqueeze(1)
|
||||
max_seqlen = self.compute_attn_mask_seqlen(cu_seqlens)
|
||||
cu_seqlens = cu_seqlens.to(self.device, non_blocking=True)
|
||||
|
||||
deepstack_feature_lists = []
|
||||
for layer_num, blk in enumerate(self.blocks):
|
||||
@@ -563,6 +576,7 @@ class Qwen3_VisionTransformer(nn.Module):
|
||||
rotary_pos_emb_cos=rotary_pos_emb_cos,
|
||||
rotary_pos_emb_sin=rotary_pos_emb_sin,
|
||||
max_seqlen=max_seqlen,
|
||||
sequence_lengths=sequence_lengths,
|
||||
)
|
||||
if layer_num in self.deepstack_visual_indexes:
|
||||
deepstack_merger_idx = self.deepstack_visual_indexes.index(layer_num)
|
||||
|
||||
Reference in New Issue
Block a user