Flashinfer cuDNN backend for Qwen3 VL ViT attention (#34580)

Signed-off-by: Max Hu <maxhu@nvidia.com>
Signed-off-by: Max Hu <hyoung2991@gmail.com>
Co-authored-by: Max Hu <maxhu@nvidia.com>
Co-authored-by: Shang Wang <shangw@nvidia.com>
This commit is contained in:
Max Hu
2026-02-27 20:20:23 +08:00
committed by GitHub
parent b66a74649e
commit 9c3fe9936b
6 changed files with 405 additions and 21 deletions

View File

@@ -2,21 +2,93 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import numpy as np
import torch
from vllm.logger import init_logger
from vllm.model_executor.custom_op import CustomOp
from vllm.model_executor.models.vision import get_vit_attn_backend
from vllm.utils.math_utils import round_up
from vllm.v1.attention.backends.fa_utils import get_flash_attn_version
from vllm.v1.attention.backends.registry import AttentionBackendEnum
from vllm.v1.attention.ops.vit_attn_wrappers import (
vit_flash_attn_wrapper,
vit_flashinfer_wrapper,
vit_torch_sdpa_wrapper,
vit_triton_attn_wrapper,
)
logger = init_logger(__name__)
# Batch buckets for cuDNN graph caching.
# Graphs use batch size and max sequence length as cache key.
# This avoids creating a new graph for each unique set of
# batch size and max sequence length at runtime.
# From the cuDNN team's performance measurements, there
# is no significant kernel performance difference between padding
# to a smaller batch size/seq length and padding to larger
# ones. The bucketing here is solely used to avoid memory
# operation overhead, which won't be needed if we have CUDA
# graph support in the future.
# TODO: Remove buckets after issue #34763
# (cuda graph support) is addressed.
FLASHINFER_BATCH_BUCKETS = [8, 16, 32, 64]
FLASHINFER_MAX_SEQLEN_BUCKETS = [
1 * 1024,
2 * 1024,
4 * 1024,
8 * 1024,
16 * 1024,
32 * 1024,
64 * 1024,
128 * 1024,
]
# Workspace buffer for FlashInfer CuDNN backend
FLASHINFER_CUDNN_WORKSPACE_SIZE_BYTES = 128 * 1024 * 1024
_flashinfer_workspace_buffer: torch.Tensor | None = None
def _get_flashinfer_workspace_buffer() -> torch.Tensor:
global _flashinfer_workspace_buffer
if _flashinfer_workspace_buffer is None:
_flashinfer_workspace_buffer = torch.zeros(
FLASHINFER_CUDNN_WORKSPACE_SIZE_BYTES,
dtype=torch.uint8,
device="cuda",
)
return _flashinfer_workspace_buffer
def add_padding_to_seqlens(
seq: np.ndarray,
batch_size: int,
padding_value: int,
) -> np.ndarray:
batch_size_padded = next(
(b for b in FLASHINFER_BATCH_BUCKETS if b >= batch_size),
round_up(batch_size, FLASHINFER_BATCH_BUCKETS[0]),
)
if batch_size_padded == batch_size:
return seq
return np.concatenate(
[
seq,
np.full((batch_size_padded - batch_size,), padding_value, dtype=seq.dtype),
]
)
def bucket_flashinfer_max_seqlen(
real_max_seqlen: int,
) -> int:
if real_max_seqlen <= 0:
return FLASHINFER_MAX_SEQLEN_BUCKETS[0]
return next(
(s for s in FLASHINFER_MAX_SEQLEN_BUCKETS if s >= real_max_seqlen),
round_up(real_max_seqlen, FLASHINFER_MAX_SEQLEN_BUCKETS[-1]),
)
# --8<-- [start:mm_encoder_attn]
@CustomOp.register("mm_encoder_attn")
@@ -24,6 +96,67 @@ class MMEncoderAttention(CustomOp):
"""Multi-headed attention without any cache, used for multimodal encoder."""
# --8<-- [end:mm_encoder_attn]
@classmethod
def compute_max_seqlen(
cls,
attn_backend: AttentionBackendEnum,
cu_seqlens: np.ndarray,
) -> int:
max_seqlen = 0
if (
attn_backend
in (
AttentionBackendEnum.FLASH_ATTN,
AttentionBackendEnum.ROCM_AITER_FA,
AttentionBackendEnum.TRITON_ATTN,
AttentionBackendEnum.FLASHINFER,
)
and len(cu_seqlens) >= 2
):
max_seqlen = int((cu_seqlens[1:] - cu_seqlens[:-1]).max())
if attn_backend == AttentionBackendEnum.FLASHINFER:
max_seqlen = bucket_flashinfer_max_seqlen(max_seqlen)
return max_seqlen
@classmethod
def maybe_compute_sequence_lengths(
cls,
attn_backend: AttentionBackendEnum,
cu_seqlens: np.ndarray,
) -> np.ndarray | None:
if attn_backend != AttentionBackendEnum.FLASHINFER:
return None
sequence_lengths = cu_seqlens[1:] - cu_seqlens[:-1]
sequence_lengths = add_padding_to_seqlens(
sequence_lengths, len(sequence_lengths), 0
)
return sequence_lengths
@classmethod
def maybe_recompute_cu_seqlens(
cls,
attn_backend: AttentionBackendEnum,
cu_seqlens: np.ndarray,
hidden_size: int,
tp_size: int,
) -> np.ndarray:
if attn_backend != AttentionBackendEnum.FLASHINFER:
return cu_seqlens
batch_size = len(cu_seqlens) - 1
scale = hidden_size // tp_size
cu_seqlens = cu_seqlens * scale
cu_seqlens_qko = cu_seqlens
cu_seqlens_v = cu_seqlens * 3
cu_seqlens_qko = add_padding_to_seqlens(
cu_seqlens_qko, batch_size, cu_seqlens_qko[-1]
)
cu_seqlens_v = add_padding_to_seqlens(
cu_seqlens_v, batch_size, cu_seqlens_v[-1]
)
return np.concatenate([cu_seqlens_qko, cu_seqlens_v])
def __init__(
self,
@@ -46,10 +179,9 @@ class MMEncoderAttention(CustomOp):
self.num_heads = num_heads
self.head_size = head_size
self.scale = scale
self.scale = 1.0 / (head_size**0.5) if scale is None else scale
self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads
self.layer_name = prefix
assert self.num_heads % self.num_kv_heads == 0, (
f"num_heads ({self.num_heads}) is not "
f"divisible by num_kv_heads ({self.num_kv_heads})"
@@ -75,6 +207,9 @@ class MMEncoderAttention(CustomOp):
get_flash_attn_version() if self.is_flash_attn_backend else None
)
if self.attn_backend == AttentionBackendEnum.FLASHINFER:
_get_flashinfer_workspace_buffer()
logger.info_once(f"Using {self.attn_backend} for MMEncoderAttention.")
@classmethod
@@ -201,6 +336,27 @@ class MMEncoderAttention(CustomOp):
output = output.reshape(bsz, q_len, -1)
return output
def _forward_flashinfer(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
cu_seqlens: torch.Tensor | None = None,
max_seqlen: torch.Tensor | None = None,
sequence_lengths: torch.Tensor
| None = None, # Only used for FlashInfer CuDNN backend
) -> torch.Tensor:
return vit_flashinfer_wrapper(
q=query,
k=key,
v=value,
scale=self.scale,
workspace_buffer=_get_flashinfer_workspace_buffer(),
cu_seqlens=cu_seqlens,
max_seqlen=max_seqlen,
sequence_lengths=sequence_lengths,
)
def forward_native(
self,
query: torch.Tensor,
@@ -208,6 +364,8 @@ class MMEncoderAttention(CustomOp):
value: torch.Tensor,
cu_seqlens: torch.Tensor | None = None,
max_seqlen: torch.Tensor | None = None, # Only used for Flash Attention
sequence_lengths: torch.Tensor
| None = None, # Only used for FlashInfer CuDNN backend
) -> torch.Tensor:
return self._forward_sdpa(query, key, value, cu_seqlens)
@@ -218,11 +376,17 @@ class MMEncoderAttention(CustomOp):
value: torch.Tensor,
cu_seqlens: torch.Tensor | None = None,
max_seqlen: torch.Tensor | None = None, # Only used for Flash Attention
sequence_lengths: torch.Tensor
| None = None, # Only used for FlashInfer CuDNN backend
) -> torch.Tensor:
if self.is_flash_attn_backend:
return self._forward_fa(query, key, value, cu_seqlens, max_seqlen)
elif self.attn_backend == AttentionBackendEnum.TRITON_ATTN:
return self._forward_triton(query, key, value, cu_seqlens, max_seqlen)
elif self.attn_backend == AttentionBackendEnum.FLASHINFER:
return self._forward_flashinfer(
query, key, value, cu_seqlens, max_seqlen, sequence_lengths
)
elif self.attn_backend == AttentionBackendEnum.TORCH_SDPA:
return self._forward_sdpa(query, key, value, cu_seqlens)
else:
@@ -238,6 +402,8 @@ class MMEncoderAttention(CustomOp):
value: torch.Tensor,
cu_seqlens: torch.Tensor | None = None,
max_seqlen: torch.Tensor | None = None, # Only used for Flash Attention
sequence_lengths: torch.Tensor
| None = None, # Only used for FlashInfer CuDNN backend
) -> torch.Tensor:
return self._forward_sdpa(query, key, value, cu_seqlens)
@@ -248,6 +414,8 @@ class MMEncoderAttention(CustomOp):
value: torch.Tensor,
cu_seqlens: torch.Tensor | None = None,
max_seqlen: torch.Tensor | None = None, # Only used for Flash Attention
sequence_lengths: torch.Tensor
| None = None, # Only used for FlashInfer CuDNN backend
) -> torch.Tensor:
if self.attn_backend == AttentionBackendEnum.FLASH_ATTN:
return self._forward_fa(query, key, value, cu_seqlens, max_seqlen)

View File

@@ -357,6 +357,7 @@ class Qwen2_5_VisionAttention(nn.Module):
rotary_pos_emb_cos: torch.Tensor,
rotary_pos_emb_sin: torch.Tensor,
max_seqlen: torch.Tensor, # Only used for Flash Attention
sequence_lengths: torch.Tensor, # Only used for FlashInfer CuDNN backend
) -> torch.Tensor:
# [s, b, c] --> [s, b, head * 3 * head_dim]
x, _ = self.qkv(x)
@@ -398,6 +399,7 @@ class Qwen2_5_VisionAttention(nn.Module):
value=v,
cu_seqlens=cu_seqlens,
max_seqlen=max_seqlen,
sequence_lengths=sequence_lengths,
)
context_layer = einops.rearrange(
@@ -463,6 +465,7 @@ class Qwen2_5_VisionBlock(nn.Module):
rotary_pos_emb_cos=rotary_pos_emb_cos,
rotary_pos_emb_sin=rotary_pos_emb_sin,
max_seqlen=max_seqlen,
sequence_lengths=None,
)
x_fused_norm, residual = self.norm2(x, residual=x_attn)
x = residual + self.mlp(x_fused_norm)

View File

@@ -51,9 +51,12 @@ from transformers.video_utils import VideoMetadata
from vllm.compilation.decorators import support_torch_compile
from vllm.config import VllmConfig
from vllm.config.multimodal import BaseDummyOptions, VideoDummyOptions
from vllm.distributed import get_pp_group
from vllm.distributed import get_pp_group, parallel_state
from vllm.logger import init_logger
from vllm.model_executor.layers.activation import _ACTIVATION_REGISTRY
from vllm.model_executor.layers.attention.mm_encoder_attention import (
MMEncoderAttention,
)
from vllm.model_executor.layers.conv import Conv3dLayer
from vllm.model_executor.layers.linear import (
ColumnParallelLinear,
@@ -92,7 +95,6 @@ from vllm.multimodal.processing import (
from vllm.sequence import IntermediateTensors
from vllm.utils.collection_utils import is_list_of
from vllm.utils.math_utils import round_up
from vllm.v1.attention.backends.registry import AttentionBackendEnum
from .interfaces import (
MultiModalEmbeddings,
@@ -244,6 +246,7 @@ class Qwen3_VisionBlock(nn.Module):
rotary_pos_emb_cos: torch.Tensor,
rotary_pos_emb_sin: torch.Tensor,
max_seqlen: torch.Tensor, # Only used for Flash Attention
sequence_lengths: torch.Tensor, # Only used for FlashInfer CuDNN backend
) -> torch.Tensor:
x = x + self.attn(
self.norm1(x),
@@ -251,6 +254,7 @@ class Qwen3_VisionBlock(nn.Module):
rotary_pos_emb_cos=rotary_pos_emb_cos,
rotary_pos_emb_sin=rotary_pos_emb_sin,
max_seqlen=max_seqlen,
sequence_lengths=sequence_lengths,
)
x = x + self.mlp(self.norm2(x))
@@ -332,6 +336,13 @@ class Qwen3_VisionTransformer(nn.Module):
)
self.num_grid_per_side = int(self.num_position_embeddings**0.5)
use_data_parallel = is_vit_use_data_parallel()
self.tp_size = (
1
if use_data_parallel
else parallel_state.get_tensor_model_parallel_world_size()
)
# NOTE: This is used for creating empty tensor for all_gather for
# DP ViT. Here out_hidden_size is enlarged due to deepstack
self.out_hidden_size = vision_config.out_hidden_size * (
@@ -513,19 +524,6 @@ class Qwen3_VisionTransformer(nn.Module):
return torch.cat(outputs, dim=0)
def compute_attn_mask_seqlen(
self,
cu_seqlens: torch.Tensor,
) -> torch.Tensor:
max_seqlen = torch.zeros([], device=cu_seqlens.device)
if self.attn_backend in (
AttentionBackendEnum.FLASH_ATTN,
AttentionBackendEnum.ROCM_AITER_FA,
AttentionBackendEnum.TRITON_ATTN,
):
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max()
return max_seqlen
def forward(
self,
x: torch.Tensor,
@@ -549,11 +547,26 @@ class Qwen3_VisionTransformer(nn.Module):
axis=0, dtype=np.int32
)
cu_seqlens = np.concatenate([np.zeros(1, dtype=np.int32), cu_seqlens])
cu_seqlens = torch.from_numpy(cu_seqlens)
sequence_lengths = MMEncoderAttention.maybe_compute_sequence_lengths(
self.attn_backend, cu_seqlens
)
if sequence_lengths is not None:
sequence_lengths = torch.from_numpy(sequence_lengths).to(
self.device, non_blocking=True
)
max_seqlen = torch.tensor(
MMEncoderAttention.compute_max_seqlen(self.attn_backend, cu_seqlens),
dtype=torch.int32,
device=self.device,
)
cu_seqlens = MMEncoderAttention.maybe_recompute_cu_seqlens(
self.attn_backend,
cu_seqlens,
self.hidden_size,
self.tp_size,
)
cu_seqlens = torch.from_numpy(cu_seqlens).to(self.device, non_blocking=True)
hidden_states = hidden_states.unsqueeze(1)
max_seqlen = self.compute_attn_mask_seqlen(cu_seqlens)
cu_seqlens = cu_seqlens.to(self.device, non_blocking=True)
deepstack_feature_lists = []
for layer_num, blk in enumerate(self.blocks):
@@ -563,6 +576,7 @@ class Qwen3_VisionTransformer(nn.Module):
rotary_pos_emb_cos=rotary_pos_emb_cos,
rotary_pos_emb_sin=rotary_pos_emb_sin,
max_seqlen=max_seqlen,
sequence_lengths=sequence_lengths,
)
if layer_num in self.deepstack_visual_indexes:
deepstack_merger_idx = self.deepstack_visual_indexes.index(layer_num)