Flashinfer cuDNN backend for Qwen3 VL ViT attention (#34580)

Signed-off-by: Max Hu <maxhu@nvidia.com>
Signed-off-by: Max Hu <hyoung2991@gmail.com>
Co-authored-by: Max Hu <maxhu@nvidia.com>
Co-authored-by: Shang Wang <shangw@nvidia.com>
This commit is contained in:
Max Hu
2026-02-27 20:20:23 +08:00
committed by GitHub
parent b66a74649e
commit 9c3fe9936b
6 changed files with 405 additions and 21 deletions

View File

@@ -51,9 +51,12 @@ from transformers.video_utils import VideoMetadata
from vllm.compilation.decorators import support_torch_compile
from vllm.config import VllmConfig
from vllm.config.multimodal import BaseDummyOptions, VideoDummyOptions
from vllm.distributed import get_pp_group
from vllm.distributed import get_pp_group, parallel_state
from vllm.logger import init_logger
from vllm.model_executor.layers.activation import _ACTIVATION_REGISTRY
from vllm.model_executor.layers.attention.mm_encoder_attention import (
MMEncoderAttention,
)
from vllm.model_executor.layers.conv import Conv3dLayer
from vllm.model_executor.layers.linear import (
ColumnParallelLinear,
@@ -92,7 +95,6 @@ from vllm.multimodal.processing import (
from vllm.sequence import IntermediateTensors
from vllm.utils.collection_utils import is_list_of
from vllm.utils.math_utils import round_up
from vllm.v1.attention.backends.registry import AttentionBackendEnum
from .interfaces import (
MultiModalEmbeddings,
@@ -244,6 +246,7 @@ class Qwen3_VisionBlock(nn.Module):
rotary_pos_emb_cos: torch.Tensor,
rotary_pos_emb_sin: torch.Tensor,
max_seqlen: torch.Tensor, # Only used for Flash Attention
sequence_lengths: torch.Tensor, # Only used for FlashInfer CuDNN backend
) -> torch.Tensor:
x = x + self.attn(
self.norm1(x),
@@ -251,6 +254,7 @@ class Qwen3_VisionBlock(nn.Module):
rotary_pos_emb_cos=rotary_pos_emb_cos,
rotary_pos_emb_sin=rotary_pos_emb_sin,
max_seqlen=max_seqlen,
sequence_lengths=sequence_lengths,
)
x = x + self.mlp(self.norm2(x))
@@ -332,6 +336,13 @@ class Qwen3_VisionTransformer(nn.Module):
)
self.num_grid_per_side = int(self.num_position_embeddings**0.5)
use_data_parallel = is_vit_use_data_parallel()
self.tp_size = (
1
if use_data_parallel
else parallel_state.get_tensor_model_parallel_world_size()
)
# NOTE: This is used for creating empty tensor for all_gather for
# DP ViT. Here out_hidden_size is enlarged due to deepstack
self.out_hidden_size = vision_config.out_hidden_size * (
@@ -513,19 +524,6 @@ class Qwen3_VisionTransformer(nn.Module):
return torch.cat(outputs, dim=0)
def compute_attn_mask_seqlen(
self,
cu_seqlens: torch.Tensor,
) -> torch.Tensor:
max_seqlen = torch.zeros([], device=cu_seqlens.device)
if self.attn_backend in (
AttentionBackendEnum.FLASH_ATTN,
AttentionBackendEnum.ROCM_AITER_FA,
AttentionBackendEnum.TRITON_ATTN,
):
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max()
return max_seqlen
def forward(
self,
x: torch.Tensor,
@@ -549,11 +547,26 @@ class Qwen3_VisionTransformer(nn.Module):
axis=0, dtype=np.int32
)
cu_seqlens = np.concatenate([np.zeros(1, dtype=np.int32), cu_seqlens])
cu_seqlens = torch.from_numpy(cu_seqlens)
sequence_lengths = MMEncoderAttention.maybe_compute_sequence_lengths(
self.attn_backend, cu_seqlens
)
if sequence_lengths is not None:
sequence_lengths = torch.from_numpy(sequence_lengths).to(
self.device, non_blocking=True
)
max_seqlen = torch.tensor(
MMEncoderAttention.compute_max_seqlen(self.attn_backend, cu_seqlens),
dtype=torch.int32,
device=self.device,
)
cu_seqlens = MMEncoderAttention.maybe_recompute_cu_seqlens(
self.attn_backend,
cu_seqlens,
self.hidden_size,
self.tp_size,
)
cu_seqlens = torch.from_numpy(cu_seqlens).to(self.device, non_blocking=True)
hidden_states = hidden_states.unsqueeze(1)
max_seqlen = self.compute_attn_mask_seqlen(cu_seqlens)
cu_seqlens = cu_seqlens.to(self.device, non_blocking=True)
deepstack_feature_lists = []
for layer_num, blk in enumerate(self.blocks):
@@ -563,6 +576,7 @@ class Qwen3_VisionTransformer(nn.Module):
rotary_pos_emb_cos=rotary_pos_emb_cos,
rotary_pos_emb_sin=rotary_pos_emb_sin,
max_seqlen=max_seqlen,
sequence_lengths=sequence_lengths,
)
if layer_num in self.deepstack_visual_indexes:
deepstack_merger_idx = self.deepstack_visual_indexes.index(layer_num)