diff --git a/tests/kernels/attention/test_mha_attn.py b/tests/kernels/attention/test_mha_attn.py index d76c57f9e..bc99ed576 100644 --- a/tests/kernels/attention/test_mha_attn.py +++ b/tests/kernels/attention/test_mha_attn.py @@ -9,9 +9,12 @@ Test: import itertools from unittest.mock import patch +import numpy as np import pytest import torch +from vllm.config import get_current_vllm_config +from vllm.config.multimodal import MultiModalConfig from vllm.model_executor.layers.attention import MMEncoderAttention from vllm.platforms import current_platform from vllm.platforms.cpu import CpuPlatform @@ -224,3 +227,110 @@ def test_mha_attn_varlen_forward( ref_output.append(output_i) ref_output = torch.cat(ref_output, dim=1) torch.testing.assert_close(output, ref_output, atol=1e-2, rtol=1e-2) + + +@pytest.mark.parametrize("var_seq_len", VAR_SEQ_LENS) +@pytest.mark.parametrize( + "dtype", + [torch.bfloat16, torch.half], +) +@pytest.mark.parametrize("device", CUDA_DEVICES) +def test_mha_attn_varlen_forward_flashinfer( + default_vllm_config, + var_seq_len: list[int], + dtype: torch.dtype, + device: str, +): + """Test MMEncoderAttention varlen forward with FLASHINFER backend (head_size=72). + + Exercises the path that uses --mm-encoder-attn-backend=FLASHINFER with + recomputed cu_seqlens, max_seqlen, and sequence_lengths as in qwen3_vl + vision encoder. + """ + pytest.importorskip("flashinfer") + + num_heads = 16 + head_size = 72 + set_random_seed(0) + torch.set_default_device(device) + torch.set_default_dtype(dtype) + + # Override vllm config so get_vit_attn_backend returns FLASHINFER (simulates + # --mm-encoder-attn-backend=FLASHINFER). + vllm_config = get_current_vllm_config() + old_model_config = getattr(vllm_config, "model_config", None) + minimal_model_config = type( + "MinimalModelConfig", + (), + { + "multimodal_config": MultiModalConfig( + mm_encoder_attn_backend=AttentionBackendEnum.FLASHINFER + ), + }, + )() + vllm_config.model_config = minimal_model_config + try: + total_len = sum(var_seq_len) + # Stride of second dim = 3 * num_heads * head_size (same as qwen2_5_vl + # after qkv rearrange and unbind: qkv shape (b, s, 3, head, head_dim)). + qkv = torch.randn(1, total_len, 3, num_heads, head_size) + q, k, v = qkv.unbind(dim=2) + + cu_seqlens_np = np.array( + [0] + list(itertools.accumulate(var_seq_len)), dtype=np.int32 + ) + hidden_size = num_heads * head_size + tp_size = 1 + + sequence_lengths_np = MMEncoderAttention.maybe_compute_sequence_lengths( + AttentionBackendEnum.FLASHINFER, cu_seqlens_np + ) + sequence_lengths = torch.from_numpy(sequence_lengths_np).to( + device, dtype=torch.int32, non_blocking=True + ) + + max_seqlen_val = MMEncoderAttention.compute_max_seqlen( + AttentionBackendEnum.FLASHINFER, cu_seqlens_np + ) + max_seqlen = torch.tensor(max_seqlen_val, device=device, dtype=torch.int32) + + cu_seqlens_np = MMEncoderAttention.maybe_recompute_cu_seqlens( + AttentionBackendEnum.FLASHINFER, + cu_seqlens_np, + hidden_size, + tp_size, + ) + cu_seqlens = torch.from_numpy(cu_seqlens_np).to( + device, dtype=torch.int32, non_blocking=True + ) + + scale = 1.0 / head_size**0.5 + attn = MMEncoderAttention( + num_heads, + head_size, + scale=scale, + num_kv_heads=num_heads, + ) + assert attn.attn_backend == AttentionBackendEnum.FLASHINFER + + output = attn( + q, + k, + v, + cu_seqlens=cu_seqlens, + max_seqlen=max_seqlen, + sequence_lengths=sequence_lengths, + ) + + ref_output = [] + for q_i, k_i, v_i in zip( + torch.split(q, var_seq_len, dim=1), + torch.split(k, var_seq_len, dim=1), + torch.split(v, var_seq_len, dim=1), + ): + output_i = ref_attention(q_i, k_i, v_i, scale=scale) + ref_output.append(output_i) + ref_output = torch.cat(ref_output, dim=1) + torch.testing.assert_close(output, ref_output, atol=1e-2, rtol=1e-2) + finally: + vllm_config.model_config = old_model_config diff --git a/vllm/model_executor/layers/attention/mm_encoder_attention.py b/vllm/model_executor/layers/attention/mm_encoder_attention.py index e59806abb..d89366bbd 100644 --- a/vllm/model_executor/layers/attention/mm_encoder_attention.py +++ b/vllm/model_executor/layers/attention/mm_encoder_attention.py @@ -2,21 +2,93 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import numpy as np import torch from vllm.logger import init_logger from vllm.model_executor.custom_op import CustomOp from vllm.model_executor.models.vision import get_vit_attn_backend +from vllm.utils.math_utils import round_up from vllm.v1.attention.backends.fa_utils import get_flash_attn_version from vllm.v1.attention.backends.registry import AttentionBackendEnum from vllm.v1.attention.ops.vit_attn_wrappers import ( vit_flash_attn_wrapper, + vit_flashinfer_wrapper, vit_torch_sdpa_wrapper, vit_triton_attn_wrapper, ) logger = init_logger(__name__) +# Batch buckets for cuDNN graph caching. +# Graphs use batch size and max sequence length as cache key. +# This avoids creating a new graph for each unique set of +# batch size and max sequence length at runtime. +# From the cuDNN team's performance measurements, there +# is no significant kernel performance difference between padding +# to a smaller batch size/seq length and padding to larger +# ones. The bucketing here is solely used to avoid memory +# operation overhead, which won't be needed if we have CUDA +# graph support in the future. +# TODO: Remove buckets after issue #34763 +# (cuda graph support) is addressed. +FLASHINFER_BATCH_BUCKETS = [8, 16, 32, 64] +FLASHINFER_MAX_SEQLEN_BUCKETS = [ + 1 * 1024, + 2 * 1024, + 4 * 1024, + 8 * 1024, + 16 * 1024, + 32 * 1024, + 64 * 1024, + 128 * 1024, +] + +# Workspace buffer for FlashInfer CuDNN backend +FLASHINFER_CUDNN_WORKSPACE_SIZE_BYTES = 128 * 1024 * 1024 +_flashinfer_workspace_buffer: torch.Tensor | None = None + + +def _get_flashinfer_workspace_buffer() -> torch.Tensor: + global _flashinfer_workspace_buffer + if _flashinfer_workspace_buffer is None: + _flashinfer_workspace_buffer = torch.zeros( + FLASHINFER_CUDNN_WORKSPACE_SIZE_BYTES, + dtype=torch.uint8, + device="cuda", + ) + return _flashinfer_workspace_buffer + + +def add_padding_to_seqlens( + seq: np.ndarray, + batch_size: int, + padding_value: int, +) -> np.ndarray: + batch_size_padded = next( + (b for b in FLASHINFER_BATCH_BUCKETS if b >= batch_size), + round_up(batch_size, FLASHINFER_BATCH_BUCKETS[0]), + ) + if batch_size_padded == batch_size: + return seq + return np.concatenate( + [ + seq, + np.full((batch_size_padded - batch_size,), padding_value, dtype=seq.dtype), + ] + ) + + +def bucket_flashinfer_max_seqlen( + real_max_seqlen: int, +) -> int: + if real_max_seqlen <= 0: + return FLASHINFER_MAX_SEQLEN_BUCKETS[0] + return next( + (s for s in FLASHINFER_MAX_SEQLEN_BUCKETS if s >= real_max_seqlen), + round_up(real_max_seqlen, FLASHINFER_MAX_SEQLEN_BUCKETS[-1]), + ) + # --8<-- [start:mm_encoder_attn] @CustomOp.register("mm_encoder_attn") @@ -24,6 +96,67 @@ class MMEncoderAttention(CustomOp): """Multi-headed attention without any cache, used for multimodal encoder.""" # --8<-- [end:mm_encoder_attn] + @classmethod + def compute_max_seqlen( + cls, + attn_backend: AttentionBackendEnum, + cu_seqlens: np.ndarray, + ) -> int: + max_seqlen = 0 + if ( + attn_backend + in ( + AttentionBackendEnum.FLASH_ATTN, + AttentionBackendEnum.ROCM_AITER_FA, + AttentionBackendEnum.TRITON_ATTN, + AttentionBackendEnum.FLASHINFER, + ) + and len(cu_seqlens) >= 2 + ): + max_seqlen = int((cu_seqlens[1:] - cu_seqlens[:-1]).max()) + if attn_backend == AttentionBackendEnum.FLASHINFER: + max_seqlen = bucket_flashinfer_max_seqlen(max_seqlen) + return max_seqlen + + @classmethod + def maybe_compute_sequence_lengths( + cls, + attn_backend: AttentionBackendEnum, + cu_seqlens: np.ndarray, + ) -> np.ndarray | None: + if attn_backend != AttentionBackendEnum.FLASHINFER: + return None + sequence_lengths = cu_seqlens[1:] - cu_seqlens[:-1] + sequence_lengths = add_padding_to_seqlens( + sequence_lengths, len(sequence_lengths), 0 + ) + return sequence_lengths + + @classmethod + def maybe_recompute_cu_seqlens( + cls, + attn_backend: AttentionBackendEnum, + cu_seqlens: np.ndarray, + hidden_size: int, + tp_size: int, + ) -> np.ndarray: + if attn_backend != AttentionBackendEnum.FLASHINFER: + return cu_seqlens + + batch_size = len(cu_seqlens) - 1 + scale = hidden_size // tp_size + cu_seqlens = cu_seqlens * scale + + cu_seqlens_qko = cu_seqlens + cu_seqlens_v = cu_seqlens * 3 + + cu_seqlens_qko = add_padding_to_seqlens( + cu_seqlens_qko, batch_size, cu_seqlens_qko[-1] + ) + cu_seqlens_v = add_padding_to_seqlens( + cu_seqlens_v, batch_size, cu_seqlens_v[-1] + ) + return np.concatenate([cu_seqlens_qko, cu_seqlens_v]) def __init__( self, @@ -46,10 +179,9 @@ class MMEncoderAttention(CustomOp): self.num_heads = num_heads self.head_size = head_size - self.scale = scale + self.scale = 1.0 / (head_size**0.5) if scale is None else scale self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads self.layer_name = prefix - assert self.num_heads % self.num_kv_heads == 0, ( f"num_heads ({self.num_heads}) is not " f"divisible by num_kv_heads ({self.num_kv_heads})" @@ -75,6 +207,9 @@ class MMEncoderAttention(CustomOp): get_flash_attn_version() if self.is_flash_attn_backend else None ) + if self.attn_backend == AttentionBackendEnum.FLASHINFER: + _get_flashinfer_workspace_buffer() + logger.info_once(f"Using {self.attn_backend} for MMEncoderAttention.") @classmethod @@ -201,6 +336,27 @@ class MMEncoderAttention(CustomOp): output = output.reshape(bsz, q_len, -1) return output + def _forward_flashinfer( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + cu_seqlens: torch.Tensor | None = None, + max_seqlen: torch.Tensor | None = None, + sequence_lengths: torch.Tensor + | None = None, # Only used for FlashInfer CuDNN backend + ) -> torch.Tensor: + return vit_flashinfer_wrapper( + q=query, + k=key, + v=value, + scale=self.scale, + workspace_buffer=_get_flashinfer_workspace_buffer(), + cu_seqlens=cu_seqlens, + max_seqlen=max_seqlen, + sequence_lengths=sequence_lengths, + ) + def forward_native( self, query: torch.Tensor, @@ -208,6 +364,8 @@ class MMEncoderAttention(CustomOp): value: torch.Tensor, cu_seqlens: torch.Tensor | None = None, max_seqlen: torch.Tensor | None = None, # Only used for Flash Attention + sequence_lengths: torch.Tensor + | None = None, # Only used for FlashInfer CuDNN backend ) -> torch.Tensor: return self._forward_sdpa(query, key, value, cu_seqlens) @@ -218,11 +376,17 @@ class MMEncoderAttention(CustomOp): value: torch.Tensor, cu_seqlens: torch.Tensor | None = None, max_seqlen: torch.Tensor | None = None, # Only used for Flash Attention + sequence_lengths: torch.Tensor + | None = None, # Only used for FlashInfer CuDNN backend ) -> torch.Tensor: if self.is_flash_attn_backend: return self._forward_fa(query, key, value, cu_seqlens, max_seqlen) elif self.attn_backend == AttentionBackendEnum.TRITON_ATTN: return self._forward_triton(query, key, value, cu_seqlens, max_seqlen) + elif self.attn_backend == AttentionBackendEnum.FLASHINFER: + return self._forward_flashinfer( + query, key, value, cu_seqlens, max_seqlen, sequence_lengths + ) elif self.attn_backend == AttentionBackendEnum.TORCH_SDPA: return self._forward_sdpa(query, key, value, cu_seqlens) else: @@ -238,6 +402,8 @@ class MMEncoderAttention(CustomOp): value: torch.Tensor, cu_seqlens: torch.Tensor | None = None, max_seqlen: torch.Tensor | None = None, # Only used for Flash Attention + sequence_lengths: torch.Tensor + | None = None, # Only used for FlashInfer CuDNN backend ) -> torch.Tensor: return self._forward_sdpa(query, key, value, cu_seqlens) @@ -248,6 +414,8 @@ class MMEncoderAttention(CustomOp): value: torch.Tensor, cu_seqlens: torch.Tensor | None = None, max_seqlen: torch.Tensor | None = None, # Only used for Flash Attention + sequence_lengths: torch.Tensor + | None = None, # Only used for FlashInfer CuDNN backend ) -> torch.Tensor: if self.attn_backend == AttentionBackendEnum.FLASH_ATTN: return self._forward_fa(query, key, value, cu_seqlens, max_seqlen) diff --git a/vllm/model_executor/models/qwen2_5_vl.py b/vllm/model_executor/models/qwen2_5_vl.py index 9e5f1175a..3eeefbb3f 100644 --- a/vllm/model_executor/models/qwen2_5_vl.py +++ b/vllm/model_executor/models/qwen2_5_vl.py @@ -357,6 +357,7 @@ class Qwen2_5_VisionAttention(nn.Module): rotary_pos_emb_cos: torch.Tensor, rotary_pos_emb_sin: torch.Tensor, max_seqlen: torch.Tensor, # Only used for Flash Attention + sequence_lengths: torch.Tensor, # Only used for FlashInfer CuDNN backend ) -> torch.Tensor: # [s, b, c] --> [s, b, head * 3 * head_dim] x, _ = self.qkv(x) @@ -398,6 +399,7 @@ class Qwen2_5_VisionAttention(nn.Module): value=v, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen, + sequence_lengths=sequence_lengths, ) context_layer = einops.rearrange( @@ -463,6 +465,7 @@ class Qwen2_5_VisionBlock(nn.Module): rotary_pos_emb_cos=rotary_pos_emb_cos, rotary_pos_emb_sin=rotary_pos_emb_sin, max_seqlen=max_seqlen, + sequence_lengths=None, ) x_fused_norm, residual = self.norm2(x, residual=x_attn) x = residual + self.mlp(x_fused_norm) diff --git a/vllm/model_executor/models/qwen3_vl.py b/vllm/model_executor/models/qwen3_vl.py index 304553ed3..e5bdbd802 100644 --- a/vllm/model_executor/models/qwen3_vl.py +++ b/vllm/model_executor/models/qwen3_vl.py @@ -51,9 +51,12 @@ from transformers.video_utils import VideoMetadata from vllm.compilation.decorators import support_torch_compile from vllm.config import VllmConfig from vllm.config.multimodal import BaseDummyOptions, VideoDummyOptions -from vllm.distributed import get_pp_group +from vllm.distributed import get_pp_group, parallel_state from vllm.logger import init_logger from vllm.model_executor.layers.activation import _ACTIVATION_REGISTRY +from vllm.model_executor.layers.attention.mm_encoder_attention import ( + MMEncoderAttention, +) from vllm.model_executor.layers.conv import Conv3dLayer from vllm.model_executor.layers.linear import ( ColumnParallelLinear, @@ -92,7 +95,6 @@ from vllm.multimodal.processing import ( from vllm.sequence import IntermediateTensors from vllm.utils.collection_utils import is_list_of from vllm.utils.math_utils import round_up -from vllm.v1.attention.backends.registry import AttentionBackendEnum from .interfaces import ( MultiModalEmbeddings, @@ -244,6 +246,7 @@ class Qwen3_VisionBlock(nn.Module): rotary_pos_emb_cos: torch.Tensor, rotary_pos_emb_sin: torch.Tensor, max_seqlen: torch.Tensor, # Only used for Flash Attention + sequence_lengths: torch.Tensor, # Only used for FlashInfer CuDNN backend ) -> torch.Tensor: x = x + self.attn( self.norm1(x), @@ -251,6 +254,7 @@ class Qwen3_VisionBlock(nn.Module): rotary_pos_emb_cos=rotary_pos_emb_cos, rotary_pos_emb_sin=rotary_pos_emb_sin, max_seqlen=max_seqlen, + sequence_lengths=sequence_lengths, ) x = x + self.mlp(self.norm2(x)) @@ -332,6 +336,13 @@ class Qwen3_VisionTransformer(nn.Module): ) self.num_grid_per_side = int(self.num_position_embeddings**0.5) + use_data_parallel = is_vit_use_data_parallel() + self.tp_size = ( + 1 + if use_data_parallel + else parallel_state.get_tensor_model_parallel_world_size() + ) + # NOTE: This is used for creating empty tensor for all_gather for # DP ViT. Here out_hidden_size is enlarged due to deepstack self.out_hidden_size = vision_config.out_hidden_size * ( @@ -513,19 +524,6 @@ class Qwen3_VisionTransformer(nn.Module): return torch.cat(outputs, dim=0) - def compute_attn_mask_seqlen( - self, - cu_seqlens: torch.Tensor, - ) -> torch.Tensor: - max_seqlen = torch.zeros([], device=cu_seqlens.device) - if self.attn_backend in ( - AttentionBackendEnum.FLASH_ATTN, - AttentionBackendEnum.ROCM_AITER_FA, - AttentionBackendEnum.TRITON_ATTN, - ): - max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max() - return max_seqlen - def forward( self, x: torch.Tensor, @@ -549,11 +547,26 @@ class Qwen3_VisionTransformer(nn.Module): axis=0, dtype=np.int32 ) cu_seqlens = np.concatenate([np.zeros(1, dtype=np.int32), cu_seqlens]) - cu_seqlens = torch.from_numpy(cu_seqlens) - + sequence_lengths = MMEncoderAttention.maybe_compute_sequence_lengths( + self.attn_backend, cu_seqlens + ) + if sequence_lengths is not None: + sequence_lengths = torch.from_numpy(sequence_lengths).to( + self.device, non_blocking=True + ) + max_seqlen = torch.tensor( + MMEncoderAttention.compute_max_seqlen(self.attn_backend, cu_seqlens), + dtype=torch.int32, + device=self.device, + ) + cu_seqlens = MMEncoderAttention.maybe_recompute_cu_seqlens( + self.attn_backend, + cu_seqlens, + self.hidden_size, + self.tp_size, + ) + cu_seqlens = torch.from_numpy(cu_seqlens).to(self.device, non_blocking=True) hidden_states = hidden_states.unsqueeze(1) - max_seqlen = self.compute_attn_mask_seqlen(cu_seqlens) - cu_seqlens = cu_seqlens.to(self.device, non_blocking=True) deepstack_feature_lists = [] for layer_num, blk in enumerate(self.blocks): @@ -563,6 +576,7 @@ class Qwen3_VisionTransformer(nn.Module): rotary_pos_emb_cos=rotary_pos_emb_cos, rotary_pos_emb_sin=rotary_pos_emb_sin, max_seqlen=max_seqlen, + sequence_lengths=sequence_lengths, ) if layer_num in self.deepstack_visual_indexes: deepstack_merger_idx = self.deepstack_visual_indexes.index(layer_num) diff --git a/vllm/platforms/cuda.py b/vllm/platforms/cuda.py index ddd4df418..d3312fe15 100644 --- a/vllm/platforms/cuda.py +++ b/vllm/platforms/cuda.py @@ -414,6 +414,7 @@ class CudaPlatformBase(Platform): AttentionBackendEnum.FLASH_ATTN, AttentionBackendEnum.TRITON_ATTN, AttentionBackendEnum.TORCH_SDPA, + AttentionBackendEnum.FLASHINFER, ] @classmethod diff --git a/vllm/v1/attention/ops/vit_attn_wrappers.py b/vllm/v1/attention/ops/vit_attn_wrappers.py index f5c748fbc..6ffe110ad 100644 --- a/vllm/v1/attention/ops/vit_attn_wrappers.py +++ b/vllm/v1/attention/ops/vit_attn_wrappers.py @@ -268,3 +268,91 @@ def vit_torch_sdpa_wrapper( return torch.ops.vllm.torch_sdpa_wrapper( q, k, v, scale, cu_seqlens, enable_gqa=enable_gqa ) + + +def flashinfer_wrapper( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + scale: float, + workspace_buffer: torch.Tensor, + cu_seqlens: torch.Tensor | None = None, + max_seqlen: torch.Tensor | None = None, + sequence_lengths: torch.Tensor | None = None, +) -> torch.Tensor: + from flashinfer.prefill import cudnn_batch_prefill_with_kv_cache + + is_reshaped = q.dim() == 4 + + if is_reshaped: + reshape_batch_size = q.shape[0] + q, k, v = (einops.rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v]) + # cuDNN <= 9.10.2.21 requires q, k to be contiguous + # this comes with no cost for ViTs with RoPE because + # RoPE has already made q and k contiguous. + q, k = q.contiguous(), k.contiguous() + + assert len(cu_seqlens) % 2 == 0, "cu_seqlens must be divisible by 2" + cu_seqlength = len(cu_seqlens) // 2 + batch_offsets_qko = cu_seqlens[:cu_seqlength].view(-1, 1, 1, 1) + batch_offsets_v = cu_seqlens[cu_seqlength:].view(-1, 1, 1, 1) + sequence_lengths = sequence_lengths.view(-1, 1, 1, 1) + max_seqlen = max_seqlen.item() + + output, _ = cudnn_batch_prefill_with_kv_cache( + q, + k, + v, + scale, + workspace_buffer, + max_token_per_sequence=max_seqlen, + max_sequence_kv=max_seqlen, + actual_seq_lens_q=sequence_lengths, + actual_seq_lens_kv=sequence_lengths, + causal=False, + return_lse=False, + batch_offsets_q=batch_offsets_qko, + batch_offsets_k=batch_offsets_qko, + batch_offsets_v=batch_offsets_v, + batch_offsets_o=batch_offsets_qko, + ) + + if is_reshaped: + output = einops.rearrange(output, "(b s) h d -> b s h d", b=reshape_batch_size) + + return output + + +def vit_flashinfer_wrapper_fake( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + scale: float, + workspace_buffer: torch.Tensor, + cu_seqlens: torch.Tensor | None = None, + max_seqlen: torch.Tensor | None = None, + sequence_lengths: torch.Tensor | None = None, +) -> torch.Tensor: + return torch.empty_like(q) + + +direct_register_custom_op( + op_name="flashinfer_wrapper", + op_func=flashinfer_wrapper, + fake_impl=vit_flashinfer_wrapper_fake, +) + + +def vit_flashinfer_wrapper( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + scale: float, + workspace_buffer: torch.Tensor, + cu_seqlens: torch.Tensor | None = None, + max_seqlen: torch.Tensor | None = None, + sequence_lengths: torch.Tensor | None = None, +) -> torch.Tensor: + return torch.ops.vllm.flashinfer_wrapper( + q, k, v, scale, workspace_buffer, cu_seqlens, max_seqlen, sequence_lengths + )