Flashinfer cuDNN backend for Qwen3 VL ViT attention (#34580)
Signed-off-by: Max Hu <maxhu@nvidia.com> Signed-off-by: Max Hu <hyoung2991@gmail.com> Co-authored-by: Max Hu <maxhu@nvidia.com> Co-authored-by: Shang Wang <shangw@nvidia.com>
This commit is contained in:
@@ -9,9 +9,12 @@ Test:
|
|||||||
import itertools
|
import itertools
|
||||||
from unittest.mock import patch
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
from vllm.config import get_current_vllm_config
|
||||||
|
from vllm.config.multimodal import MultiModalConfig
|
||||||
from vllm.model_executor.layers.attention import MMEncoderAttention
|
from vllm.model_executor.layers.attention import MMEncoderAttention
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
from vllm.platforms.cpu import CpuPlatform
|
from vllm.platforms.cpu import CpuPlatform
|
||||||
@@ -224,3 +227,110 @@ def test_mha_attn_varlen_forward(
|
|||||||
ref_output.append(output_i)
|
ref_output.append(output_i)
|
||||||
ref_output = torch.cat(ref_output, dim=1)
|
ref_output = torch.cat(ref_output, dim=1)
|
||||||
torch.testing.assert_close(output, ref_output, atol=1e-2, rtol=1e-2)
|
torch.testing.assert_close(output, ref_output, atol=1e-2, rtol=1e-2)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("var_seq_len", VAR_SEQ_LENS)
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"dtype",
|
||||||
|
[torch.bfloat16, torch.half],
|
||||||
|
)
|
||||||
|
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||||
|
def test_mha_attn_varlen_forward_flashinfer(
|
||||||
|
default_vllm_config,
|
||||||
|
var_seq_len: list[int],
|
||||||
|
dtype: torch.dtype,
|
||||||
|
device: str,
|
||||||
|
):
|
||||||
|
"""Test MMEncoderAttention varlen forward with FLASHINFER backend (head_size=72).
|
||||||
|
|
||||||
|
Exercises the path that uses --mm-encoder-attn-backend=FLASHINFER with
|
||||||
|
recomputed cu_seqlens, max_seqlen, and sequence_lengths as in qwen3_vl
|
||||||
|
vision encoder.
|
||||||
|
"""
|
||||||
|
pytest.importorskip("flashinfer")
|
||||||
|
|
||||||
|
num_heads = 16
|
||||||
|
head_size = 72
|
||||||
|
set_random_seed(0)
|
||||||
|
torch.set_default_device(device)
|
||||||
|
torch.set_default_dtype(dtype)
|
||||||
|
|
||||||
|
# Override vllm config so get_vit_attn_backend returns FLASHINFER (simulates
|
||||||
|
# --mm-encoder-attn-backend=FLASHINFER).
|
||||||
|
vllm_config = get_current_vllm_config()
|
||||||
|
old_model_config = getattr(vllm_config, "model_config", None)
|
||||||
|
minimal_model_config = type(
|
||||||
|
"MinimalModelConfig",
|
||||||
|
(),
|
||||||
|
{
|
||||||
|
"multimodal_config": MultiModalConfig(
|
||||||
|
mm_encoder_attn_backend=AttentionBackendEnum.FLASHINFER
|
||||||
|
),
|
||||||
|
},
|
||||||
|
)()
|
||||||
|
vllm_config.model_config = minimal_model_config
|
||||||
|
try:
|
||||||
|
total_len = sum(var_seq_len)
|
||||||
|
# Stride of second dim = 3 * num_heads * head_size (same as qwen2_5_vl
|
||||||
|
# after qkv rearrange and unbind: qkv shape (b, s, 3, head, head_dim)).
|
||||||
|
qkv = torch.randn(1, total_len, 3, num_heads, head_size)
|
||||||
|
q, k, v = qkv.unbind(dim=2)
|
||||||
|
|
||||||
|
cu_seqlens_np = np.array(
|
||||||
|
[0] + list(itertools.accumulate(var_seq_len)), dtype=np.int32
|
||||||
|
)
|
||||||
|
hidden_size = num_heads * head_size
|
||||||
|
tp_size = 1
|
||||||
|
|
||||||
|
sequence_lengths_np = MMEncoderAttention.maybe_compute_sequence_lengths(
|
||||||
|
AttentionBackendEnum.FLASHINFER, cu_seqlens_np
|
||||||
|
)
|
||||||
|
sequence_lengths = torch.from_numpy(sequence_lengths_np).to(
|
||||||
|
device, dtype=torch.int32, non_blocking=True
|
||||||
|
)
|
||||||
|
|
||||||
|
max_seqlen_val = MMEncoderAttention.compute_max_seqlen(
|
||||||
|
AttentionBackendEnum.FLASHINFER, cu_seqlens_np
|
||||||
|
)
|
||||||
|
max_seqlen = torch.tensor(max_seqlen_val, device=device, dtype=torch.int32)
|
||||||
|
|
||||||
|
cu_seqlens_np = MMEncoderAttention.maybe_recompute_cu_seqlens(
|
||||||
|
AttentionBackendEnum.FLASHINFER,
|
||||||
|
cu_seqlens_np,
|
||||||
|
hidden_size,
|
||||||
|
tp_size,
|
||||||
|
)
|
||||||
|
cu_seqlens = torch.from_numpy(cu_seqlens_np).to(
|
||||||
|
device, dtype=torch.int32, non_blocking=True
|
||||||
|
)
|
||||||
|
|
||||||
|
scale = 1.0 / head_size**0.5
|
||||||
|
attn = MMEncoderAttention(
|
||||||
|
num_heads,
|
||||||
|
head_size,
|
||||||
|
scale=scale,
|
||||||
|
num_kv_heads=num_heads,
|
||||||
|
)
|
||||||
|
assert attn.attn_backend == AttentionBackendEnum.FLASHINFER
|
||||||
|
|
||||||
|
output = attn(
|
||||||
|
q,
|
||||||
|
k,
|
||||||
|
v,
|
||||||
|
cu_seqlens=cu_seqlens,
|
||||||
|
max_seqlen=max_seqlen,
|
||||||
|
sequence_lengths=sequence_lengths,
|
||||||
|
)
|
||||||
|
|
||||||
|
ref_output = []
|
||||||
|
for q_i, k_i, v_i in zip(
|
||||||
|
torch.split(q, var_seq_len, dim=1),
|
||||||
|
torch.split(k, var_seq_len, dim=1),
|
||||||
|
torch.split(v, var_seq_len, dim=1),
|
||||||
|
):
|
||||||
|
output_i = ref_attention(q_i, k_i, v_i, scale=scale)
|
||||||
|
ref_output.append(output_i)
|
||||||
|
ref_output = torch.cat(ref_output, dim=1)
|
||||||
|
torch.testing.assert_close(output, ref_output, atol=1e-2, rtol=1e-2)
|
||||||
|
finally:
|
||||||
|
vllm_config.model_config = old_model_config
|
||||||
|
|||||||
@@ -2,21 +2,93 @@
|
|||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.model_executor.custom_op import CustomOp
|
from vllm.model_executor.custom_op import CustomOp
|
||||||
from vllm.model_executor.models.vision import get_vit_attn_backend
|
from vllm.model_executor.models.vision import get_vit_attn_backend
|
||||||
|
from vllm.utils.math_utils import round_up
|
||||||
from vllm.v1.attention.backends.fa_utils import get_flash_attn_version
|
from vllm.v1.attention.backends.fa_utils import get_flash_attn_version
|
||||||
from vllm.v1.attention.backends.registry import AttentionBackendEnum
|
from vllm.v1.attention.backends.registry import AttentionBackendEnum
|
||||||
from vllm.v1.attention.ops.vit_attn_wrappers import (
|
from vllm.v1.attention.ops.vit_attn_wrappers import (
|
||||||
vit_flash_attn_wrapper,
|
vit_flash_attn_wrapper,
|
||||||
|
vit_flashinfer_wrapper,
|
||||||
vit_torch_sdpa_wrapper,
|
vit_torch_sdpa_wrapper,
|
||||||
vit_triton_attn_wrapper,
|
vit_triton_attn_wrapper,
|
||||||
)
|
)
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
# Batch buckets for cuDNN graph caching.
|
||||||
|
# Graphs use batch size and max sequence length as cache key.
|
||||||
|
# This avoids creating a new graph for each unique set of
|
||||||
|
# batch size and max sequence length at runtime.
|
||||||
|
# From the cuDNN team's performance measurements, there
|
||||||
|
# is no significant kernel performance difference between padding
|
||||||
|
# to a smaller batch size/seq length and padding to larger
|
||||||
|
# ones. The bucketing here is solely used to avoid memory
|
||||||
|
# operation overhead, which won't be needed if we have CUDA
|
||||||
|
# graph support in the future.
|
||||||
|
# TODO: Remove buckets after issue #34763
|
||||||
|
# (cuda graph support) is addressed.
|
||||||
|
FLASHINFER_BATCH_BUCKETS = [8, 16, 32, 64]
|
||||||
|
FLASHINFER_MAX_SEQLEN_BUCKETS = [
|
||||||
|
1 * 1024,
|
||||||
|
2 * 1024,
|
||||||
|
4 * 1024,
|
||||||
|
8 * 1024,
|
||||||
|
16 * 1024,
|
||||||
|
32 * 1024,
|
||||||
|
64 * 1024,
|
||||||
|
128 * 1024,
|
||||||
|
]
|
||||||
|
|
||||||
|
# Workspace buffer for FlashInfer CuDNN backend
|
||||||
|
FLASHINFER_CUDNN_WORKSPACE_SIZE_BYTES = 128 * 1024 * 1024
|
||||||
|
_flashinfer_workspace_buffer: torch.Tensor | None = None
|
||||||
|
|
||||||
|
|
||||||
|
def _get_flashinfer_workspace_buffer() -> torch.Tensor:
|
||||||
|
global _flashinfer_workspace_buffer
|
||||||
|
if _flashinfer_workspace_buffer is None:
|
||||||
|
_flashinfer_workspace_buffer = torch.zeros(
|
||||||
|
FLASHINFER_CUDNN_WORKSPACE_SIZE_BYTES,
|
||||||
|
dtype=torch.uint8,
|
||||||
|
device="cuda",
|
||||||
|
)
|
||||||
|
return _flashinfer_workspace_buffer
|
||||||
|
|
||||||
|
|
||||||
|
def add_padding_to_seqlens(
|
||||||
|
seq: np.ndarray,
|
||||||
|
batch_size: int,
|
||||||
|
padding_value: int,
|
||||||
|
) -> np.ndarray:
|
||||||
|
batch_size_padded = next(
|
||||||
|
(b for b in FLASHINFER_BATCH_BUCKETS if b >= batch_size),
|
||||||
|
round_up(batch_size, FLASHINFER_BATCH_BUCKETS[0]),
|
||||||
|
)
|
||||||
|
if batch_size_padded == batch_size:
|
||||||
|
return seq
|
||||||
|
return np.concatenate(
|
||||||
|
[
|
||||||
|
seq,
|
||||||
|
np.full((batch_size_padded - batch_size,), padding_value, dtype=seq.dtype),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def bucket_flashinfer_max_seqlen(
|
||||||
|
real_max_seqlen: int,
|
||||||
|
) -> int:
|
||||||
|
if real_max_seqlen <= 0:
|
||||||
|
return FLASHINFER_MAX_SEQLEN_BUCKETS[0]
|
||||||
|
return next(
|
||||||
|
(s for s in FLASHINFER_MAX_SEQLEN_BUCKETS if s >= real_max_seqlen),
|
||||||
|
round_up(real_max_seqlen, FLASHINFER_MAX_SEQLEN_BUCKETS[-1]),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
# --8<-- [start:mm_encoder_attn]
|
# --8<-- [start:mm_encoder_attn]
|
||||||
@CustomOp.register("mm_encoder_attn")
|
@CustomOp.register("mm_encoder_attn")
|
||||||
@@ -24,6 +96,67 @@ class MMEncoderAttention(CustomOp):
|
|||||||
"""Multi-headed attention without any cache, used for multimodal encoder."""
|
"""Multi-headed attention without any cache, used for multimodal encoder."""
|
||||||
|
|
||||||
# --8<-- [end:mm_encoder_attn]
|
# --8<-- [end:mm_encoder_attn]
|
||||||
|
@classmethod
|
||||||
|
def compute_max_seqlen(
|
||||||
|
cls,
|
||||||
|
attn_backend: AttentionBackendEnum,
|
||||||
|
cu_seqlens: np.ndarray,
|
||||||
|
) -> int:
|
||||||
|
max_seqlen = 0
|
||||||
|
if (
|
||||||
|
attn_backend
|
||||||
|
in (
|
||||||
|
AttentionBackendEnum.FLASH_ATTN,
|
||||||
|
AttentionBackendEnum.ROCM_AITER_FA,
|
||||||
|
AttentionBackendEnum.TRITON_ATTN,
|
||||||
|
AttentionBackendEnum.FLASHINFER,
|
||||||
|
)
|
||||||
|
and len(cu_seqlens) >= 2
|
||||||
|
):
|
||||||
|
max_seqlen = int((cu_seqlens[1:] - cu_seqlens[:-1]).max())
|
||||||
|
if attn_backend == AttentionBackendEnum.FLASHINFER:
|
||||||
|
max_seqlen = bucket_flashinfer_max_seqlen(max_seqlen)
|
||||||
|
return max_seqlen
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def maybe_compute_sequence_lengths(
|
||||||
|
cls,
|
||||||
|
attn_backend: AttentionBackendEnum,
|
||||||
|
cu_seqlens: np.ndarray,
|
||||||
|
) -> np.ndarray | None:
|
||||||
|
if attn_backend != AttentionBackendEnum.FLASHINFER:
|
||||||
|
return None
|
||||||
|
sequence_lengths = cu_seqlens[1:] - cu_seqlens[:-1]
|
||||||
|
sequence_lengths = add_padding_to_seqlens(
|
||||||
|
sequence_lengths, len(sequence_lengths), 0
|
||||||
|
)
|
||||||
|
return sequence_lengths
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def maybe_recompute_cu_seqlens(
|
||||||
|
cls,
|
||||||
|
attn_backend: AttentionBackendEnum,
|
||||||
|
cu_seqlens: np.ndarray,
|
||||||
|
hidden_size: int,
|
||||||
|
tp_size: int,
|
||||||
|
) -> np.ndarray:
|
||||||
|
if attn_backend != AttentionBackendEnum.FLASHINFER:
|
||||||
|
return cu_seqlens
|
||||||
|
|
||||||
|
batch_size = len(cu_seqlens) - 1
|
||||||
|
scale = hidden_size // tp_size
|
||||||
|
cu_seqlens = cu_seqlens * scale
|
||||||
|
|
||||||
|
cu_seqlens_qko = cu_seqlens
|
||||||
|
cu_seqlens_v = cu_seqlens * 3
|
||||||
|
|
||||||
|
cu_seqlens_qko = add_padding_to_seqlens(
|
||||||
|
cu_seqlens_qko, batch_size, cu_seqlens_qko[-1]
|
||||||
|
)
|
||||||
|
cu_seqlens_v = add_padding_to_seqlens(
|
||||||
|
cu_seqlens_v, batch_size, cu_seqlens_v[-1]
|
||||||
|
)
|
||||||
|
return np.concatenate([cu_seqlens_qko, cu_seqlens_v])
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@@ -46,10 +179,9 @@ class MMEncoderAttention(CustomOp):
|
|||||||
|
|
||||||
self.num_heads = num_heads
|
self.num_heads = num_heads
|
||||||
self.head_size = head_size
|
self.head_size = head_size
|
||||||
self.scale = scale
|
self.scale = 1.0 / (head_size**0.5) if scale is None else scale
|
||||||
self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads
|
self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads
|
||||||
self.layer_name = prefix
|
self.layer_name = prefix
|
||||||
|
|
||||||
assert self.num_heads % self.num_kv_heads == 0, (
|
assert self.num_heads % self.num_kv_heads == 0, (
|
||||||
f"num_heads ({self.num_heads}) is not "
|
f"num_heads ({self.num_heads}) is not "
|
||||||
f"divisible by num_kv_heads ({self.num_kv_heads})"
|
f"divisible by num_kv_heads ({self.num_kv_heads})"
|
||||||
@@ -75,6 +207,9 @@ class MMEncoderAttention(CustomOp):
|
|||||||
get_flash_attn_version() if self.is_flash_attn_backend else None
|
get_flash_attn_version() if self.is_flash_attn_backend else None
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if self.attn_backend == AttentionBackendEnum.FLASHINFER:
|
||||||
|
_get_flashinfer_workspace_buffer()
|
||||||
|
|
||||||
logger.info_once(f"Using {self.attn_backend} for MMEncoderAttention.")
|
logger.info_once(f"Using {self.attn_backend} for MMEncoderAttention.")
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@@ -201,6 +336,27 @@ class MMEncoderAttention(CustomOp):
|
|||||||
output = output.reshape(bsz, q_len, -1)
|
output = output.reshape(bsz, q_len, -1)
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
def _forward_flashinfer(
|
||||||
|
self,
|
||||||
|
query: torch.Tensor,
|
||||||
|
key: torch.Tensor,
|
||||||
|
value: torch.Tensor,
|
||||||
|
cu_seqlens: torch.Tensor | None = None,
|
||||||
|
max_seqlen: torch.Tensor | None = None,
|
||||||
|
sequence_lengths: torch.Tensor
|
||||||
|
| None = None, # Only used for FlashInfer CuDNN backend
|
||||||
|
) -> torch.Tensor:
|
||||||
|
return vit_flashinfer_wrapper(
|
||||||
|
q=query,
|
||||||
|
k=key,
|
||||||
|
v=value,
|
||||||
|
scale=self.scale,
|
||||||
|
workspace_buffer=_get_flashinfer_workspace_buffer(),
|
||||||
|
cu_seqlens=cu_seqlens,
|
||||||
|
max_seqlen=max_seqlen,
|
||||||
|
sequence_lengths=sequence_lengths,
|
||||||
|
)
|
||||||
|
|
||||||
def forward_native(
|
def forward_native(
|
||||||
self,
|
self,
|
||||||
query: torch.Tensor,
|
query: torch.Tensor,
|
||||||
@@ -208,6 +364,8 @@ class MMEncoderAttention(CustomOp):
|
|||||||
value: torch.Tensor,
|
value: torch.Tensor,
|
||||||
cu_seqlens: torch.Tensor | None = None,
|
cu_seqlens: torch.Tensor | None = None,
|
||||||
max_seqlen: torch.Tensor | None = None, # Only used for Flash Attention
|
max_seqlen: torch.Tensor | None = None, # Only used for Flash Attention
|
||||||
|
sequence_lengths: torch.Tensor
|
||||||
|
| None = None, # Only used for FlashInfer CuDNN backend
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
return self._forward_sdpa(query, key, value, cu_seqlens)
|
return self._forward_sdpa(query, key, value, cu_seqlens)
|
||||||
|
|
||||||
@@ -218,11 +376,17 @@ class MMEncoderAttention(CustomOp):
|
|||||||
value: torch.Tensor,
|
value: torch.Tensor,
|
||||||
cu_seqlens: torch.Tensor | None = None,
|
cu_seqlens: torch.Tensor | None = None,
|
||||||
max_seqlen: torch.Tensor | None = None, # Only used for Flash Attention
|
max_seqlen: torch.Tensor | None = None, # Only used for Flash Attention
|
||||||
|
sequence_lengths: torch.Tensor
|
||||||
|
| None = None, # Only used for FlashInfer CuDNN backend
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
if self.is_flash_attn_backend:
|
if self.is_flash_attn_backend:
|
||||||
return self._forward_fa(query, key, value, cu_seqlens, max_seqlen)
|
return self._forward_fa(query, key, value, cu_seqlens, max_seqlen)
|
||||||
elif self.attn_backend == AttentionBackendEnum.TRITON_ATTN:
|
elif self.attn_backend == AttentionBackendEnum.TRITON_ATTN:
|
||||||
return self._forward_triton(query, key, value, cu_seqlens, max_seqlen)
|
return self._forward_triton(query, key, value, cu_seqlens, max_seqlen)
|
||||||
|
elif self.attn_backend == AttentionBackendEnum.FLASHINFER:
|
||||||
|
return self._forward_flashinfer(
|
||||||
|
query, key, value, cu_seqlens, max_seqlen, sequence_lengths
|
||||||
|
)
|
||||||
elif self.attn_backend == AttentionBackendEnum.TORCH_SDPA:
|
elif self.attn_backend == AttentionBackendEnum.TORCH_SDPA:
|
||||||
return self._forward_sdpa(query, key, value, cu_seqlens)
|
return self._forward_sdpa(query, key, value, cu_seqlens)
|
||||||
else:
|
else:
|
||||||
@@ -238,6 +402,8 @@ class MMEncoderAttention(CustomOp):
|
|||||||
value: torch.Tensor,
|
value: torch.Tensor,
|
||||||
cu_seqlens: torch.Tensor | None = None,
|
cu_seqlens: torch.Tensor | None = None,
|
||||||
max_seqlen: torch.Tensor | None = None, # Only used for Flash Attention
|
max_seqlen: torch.Tensor | None = None, # Only used for Flash Attention
|
||||||
|
sequence_lengths: torch.Tensor
|
||||||
|
| None = None, # Only used for FlashInfer CuDNN backend
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
return self._forward_sdpa(query, key, value, cu_seqlens)
|
return self._forward_sdpa(query, key, value, cu_seqlens)
|
||||||
|
|
||||||
@@ -248,6 +414,8 @@ class MMEncoderAttention(CustomOp):
|
|||||||
value: torch.Tensor,
|
value: torch.Tensor,
|
||||||
cu_seqlens: torch.Tensor | None = None,
|
cu_seqlens: torch.Tensor | None = None,
|
||||||
max_seqlen: torch.Tensor | None = None, # Only used for Flash Attention
|
max_seqlen: torch.Tensor | None = None, # Only used for Flash Attention
|
||||||
|
sequence_lengths: torch.Tensor
|
||||||
|
| None = None, # Only used for FlashInfer CuDNN backend
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
if self.attn_backend == AttentionBackendEnum.FLASH_ATTN:
|
if self.attn_backend == AttentionBackendEnum.FLASH_ATTN:
|
||||||
return self._forward_fa(query, key, value, cu_seqlens, max_seqlen)
|
return self._forward_fa(query, key, value, cu_seqlens, max_seqlen)
|
||||||
|
|||||||
@@ -357,6 +357,7 @@ class Qwen2_5_VisionAttention(nn.Module):
|
|||||||
rotary_pos_emb_cos: torch.Tensor,
|
rotary_pos_emb_cos: torch.Tensor,
|
||||||
rotary_pos_emb_sin: torch.Tensor,
|
rotary_pos_emb_sin: torch.Tensor,
|
||||||
max_seqlen: torch.Tensor, # Only used for Flash Attention
|
max_seqlen: torch.Tensor, # Only used for Flash Attention
|
||||||
|
sequence_lengths: torch.Tensor, # Only used for FlashInfer CuDNN backend
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
# [s, b, c] --> [s, b, head * 3 * head_dim]
|
# [s, b, c] --> [s, b, head * 3 * head_dim]
|
||||||
x, _ = self.qkv(x)
|
x, _ = self.qkv(x)
|
||||||
@@ -398,6 +399,7 @@ class Qwen2_5_VisionAttention(nn.Module):
|
|||||||
value=v,
|
value=v,
|
||||||
cu_seqlens=cu_seqlens,
|
cu_seqlens=cu_seqlens,
|
||||||
max_seqlen=max_seqlen,
|
max_seqlen=max_seqlen,
|
||||||
|
sequence_lengths=sequence_lengths,
|
||||||
)
|
)
|
||||||
|
|
||||||
context_layer = einops.rearrange(
|
context_layer = einops.rearrange(
|
||||||
@@ -463,6 +465,7 @@ class Qwen2_5_VisionBlock(nn.Module):
|
|||||||
rotary_pos_emb_cos=rotary_pos_emb_cos,
|
rotary_pos_emb_cos=rotary_pos_emb_cos,
|
||||||
rotary_pos_emb_sin=rotary_pos_emb_sin,
|
rotary_pos_emb_sin=rotary_pos_emb_sin,
|
||||||
max_seqlen=max_seqlen,
|
max_seqlen=max_seqlen,
|
||||||
|
sequence_lengths=None,
|
||||||
)
|
)
|
||||||
x_fused_norm, residual = self.norm2(x, residual=x_attn)
|
x_fused_norm, residual = self.norm2(x, residual=x_attn)
|
||||||
x = residual + self.mlp(x_fused_norm)
|
x = residual + self.mlp(x_fused_norm)
|
||||||
|
|||||||
@@ -51,9 +51,12 @@ from transformers.video_utils import VideoMetadata
|
|||||||
from vllm.compilation.decorators import support_torch_compile
|
from vllm.compilation.decorators import support_torch_compile
|
||||||
from vllm.config import VllmConfig
|
from vllm.config import VllmConfig
|
||||||
from vllm.config.multimodal import BaseDummyOptions, VideoDummyOptions
|
from vllm.config.multimodal import BaseDummyOptions, VideoDummyOptions
|
||||||
from vllm.distributed import get_pp_group
|
from vllm.distributed import get_pp_group, parallel_state
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.model_executor.layers.activation import _ACTIVATION_REGISTRY
|
from vllm.model_executor.layers.activation import _ACTIVATION_REGISTRY
|
||||||
|
from vllm.model_executor.layers.attention.mm_encoder_attention import (
|
||||||
|
MMEncoderAttention,
|
||||||
|
)
|
||||||
from vllm.model_executor.layers.conv import Conv3dLayer
|
from vllm.model_executor.layers.conv import Conv3dLayer
|
||||||
from vllm.model_executor.layers.linear import (
|
from vllm.model_executor.layers.linear import (
|
||||||
ColumnParallelLinear,
|
ColumnParallelLinear,
|
||||||
@@ -92,7 +95,6 @@ from vllm.multimodal.processing import (
|
|||||||
from vllm.sequence import IntermediateTensors
|
from vllm.sequence import IntermediateTensors
|
||||||
from vllm.utils.collection_utils import is_list_of
|
from vllm.utils.collection_utils import is_list_of
|
||||||
from vllm.utils.math_utils import round_up
|
from vllm.utils.math_utils import round_up
|
||||||
from vllm.v1.attention.backends.registry import AttentionBackendEnum
|
|
||||||
|
|
||||||
from .interfaces import (
|
from .interfaces import (
|
||||||
MultiModalEmbeddings,
|
MultiModalEmbeddings,
|
||||||
@@ -244,6 +246,7 @@ class Qwen3_VisionBlock(nn.Module):
|
|||||||
rotary_pos_emb_cos: torch.Tensor,
|
rotary_pos_emb_cos: torch.Tensor,
|
||||||
rotary_pos_emb_sin: torch.Tensor,
|
rotary_pos_emb_sin: torch.Tensor,
|
||||||
max_seqlen: torch.Tensor, # Only used for Flash Attention
|
max_seqlen: torch.Tensor, # Only used for Flash Attention
|
||||||
|
sequence_lengths: torch.Tensor, # Only used for FlashInfer CuDNN backend
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
x = x + self.attn(
|
x = x + self.attn(
|
||||||
self.norm1(x),
|
self.norm1(x),
|
||||||
@@ -251,6 +254,7 @@ class Qwen3_VisionBlock(nn.Module):
|
|||||||
rotary_pos_emb_cos=rotary_pos_emb_cos,
|
rotary_pos_emb_cos=rotary_pos_emb_cos,
|
||||||
rotary_pos_emb_sin=rotary_pos_emb_sin,
|
rotary_pos_emb_sin=rotary_pos_emb_sin,
|
||||||
max_seqlen=max_seqlen,
|
max_seqlen=max_seqlen,
|
||||||
|
sequence_lengths=sequence_lengths,
|
||||||
)
|
)
|
||||||
|
|
||||||
x = x + self.mlp(self.norm2(x))
|
x = x + self.mlp(self.norm2(x))
|
||||||
@@ -332,6 +336,13 @@ class Qwen3_VisionTransformer(nn.Module):
|
|||||||
)
|
)
|
||||||
self.num_grid_per_side = int(self.num_position_embeddings**0.5)
|
self.num_grid_per_side = int(self.num_position_embeddings**0.5)
|
||||||
|
|
||||||
|
use_data_parallel = is_vit_use_data_parallel()
|
||||||
|
self.tp_size = (
|
||||||
|
1
|
||||||
|
if use_data_parallel
|
||||||
|
else parallel_state.get_tensor_model_parallel_world_size()
|
||||||
|
)
|
||||||
|
|
||||||
# NOTE: This is used for creating empty tensor for all_gather for
|
# NOTE: This is used for creating empty tensor for all_gather for
|
||||||
# DP ViT. Here out_hidden_size is enlarged due to deepstack
|
# DP ViT. Here out_hidden_size is enlarged due to deepstack
|
||||||
self.out_hidden_size = vision_config.out_hidden_size * (
|
self.out_hidden_size = vision_config.out_hidden_size * (
|
||||||
@@ -513,19 +524,6 @@ class Qwen3_VisionTransformer(nn.Module):
|
|||||||
|
|
||||||
return torch.cat(outputs, dim=0)
|
return torch.cat(outputs, dim=0)
|
||||||
|
|
||||||
def compute_attn_mask_seqlen(
|
|
||||||
self,
|
|
||||||
cu_seqlens: torch.Tensor,
|
|
||||||
) -> torch.Tensor:
|
|
||||||
max_seqlen = torch.zeros([], device=cu_seqlens.device)
|
|
||||||
if self.attn_backend in (
|
|
||||||
AttentionBackendEnum.FLASH_ATTN,
|
|
||||||
AttentionBackendEnum.ROCM_AITER_FA,
|
|
||||||
AttentionBackendEnum.TRITON_ATTN,
|
|
||||||
):
|
|
||||||
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max()
|
|
||||||
return max_seqlen
|
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
@@ -549,11 +547,26 @@ class Qwen3_VisionTransformer(nn.Module):
|
|||||||
axis=0, dtype=np.int32
|
axis=0, dtype=np.int32
|
||||||
)
|
)
|
||||||
cu_seqlens = np.concatenate([np.zeros(1, dtype=np.int32), cu_seqlens])
|
cu_seqlens = np.concatenate([np.zeros(1, dtype=np.int32), cu_seqlens])
|
||||||
cu_seqlens = torch.from_numpy(cu_seqlens)
|
sequence_lengths = MMEncoderAttention.maybe_compute_sequence_lengths(
|
||||||
|
self.attn_backend, cu_seqlens
|
||||||
|
)
|
||||||
|
if sequence_lengths is not None:
|
||||||
|
sequence_lengths = torch.from_numpy(sequence_lengths).to(
|
||||||
|
self.device, non_blocking=True
|
||||||
|
)
|
||||||
|
max_seqlen = torch.tensor(
|
||||||
|
MMEncoderAttention.compute_max_seqlen(self.attn_backend, cu_seqlens),
|
||||||
|
dtype=torch.int32,
|
||||||
|
device=self.device,
|
||||||
|
)
|
||||||
|
cu_seqlens = MMEncoderAttention.maybe_recompute_cu_seqlens(
|
||||||
|
self.attn_backend,
|
||||||
|
cu_seqlens,
|
||||||
|
self.hidden_size,
|
||||||
|
self.tp_size,
|
||||||
|
)
|
||||||
|
cu_seqlens = torch.from_numpy(cu_seqlens).to(self.device, non_blocking=True)
|
||||||
hidden_states = hidden_states.unsqueeze(1)
|
hidden_states = hidden_states.unsqueeze(1)
|
||||||
max_seqlen = self.compute_attn_mask_seqlen(cu_seqlens)
|
|
||||||
cu_seqlens = cu_seqlens.to(self.device, non_blocking=True)
|
|
||||||
|
|
||||||
deepstack_feature_lists = []
|
deepstack_feature_lists = []
|
||||||
for layer_num, blk in enumerate(self.blocks):
|
for layer_num, blk in enumerate(self.blocks):
|
||||||
@@ -563,6 +576,7 @@ class Qwen3_VisionTransformer(nn.Module):
|
|||||||
rotary_pos_emb_cos=rotary_pos_emb_cos,
|
rotary_pos_emb_cos=rotary_pos_emb_cos,
|
||||||
rotary_pos_emb_sin=rotary_pos_emb_sin,
|
rotary_pos_emb_sin=rotary_pos_emb_sin,
|
||||||
max_seqlen=max_seqlen,
|
max_seqlen=max_seqlen,
|
||||||
|
sequence_lengths=sequence_lengths,
|
||||||
)
|
)
|
||||||
if layer_num in self.deepstack_visual_indexes:
|
if layer_num in self.deepstack_visual_indexes:
|
||||||
deepstack_merger_idx = self.deepstack_visual_indexes.index(layer_num)
|
deepstack_merger_idx = self.deepstack_visual_indexes.index(layer_num)
|
||||||
|
|||||||
@@ -414,6 +414,7 @@ class CudaPlatformBase(Platform):
|
|||||||
AttentionBackendEnum.FLASH_ATTN,
|
AttentionBackendEnum.FLASH_ATTN,
|
||||||
AttentionBackendEnum.TRITON_ATTN,
|
AttentionBackendEnum.TRITON_ATTN,
|
||||||
AttentionBackendEnum.TORCH_SDPA,
|
AttentionBackendEnum.TORCH_SDPA,
|
||||||
|
AttentionBackendEnum.FLASHINFER,
|
||||||
]
|
]
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
|||||||
@@ -268,3 +268,91 @@ def vit_torch_sdpa_wrapper(
|
|||||||
return torch.ops.vllm.torch_sdpa_wrapper(
|
return torch.ops.vllm.torch_sdpa_wrapper(
|
||||||
q, k, v, scale, cu_seqlens, enable_gqa=enable_gqa
|
q, k, v, scale, cu_seqlens, enable_gqa=enable_gqa
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def flashinfer_wrapper(
|
||||||
|
q: torch.Tensor,
|
||||||
|
k: torch.Tensor,
|
||||||
|
v: torch.Tensor,
|
||||||
|
scale: float,
|
||||||
|
workspace_buffer: torch.Tensor,
|
||||||
|
cu_seqlens: torch.Tensor | None = None,
|
||||||
|
max_seqlen: torch.Tensor | None = None,
|
||||||
|
sequence_lengths: torch.Tensor | None = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
from flashinfer.prefill import cudnn_batch_prefill_with_kv_cache
|
||||||
|
|
||||||
|
is_reshaped = q.dim() == 4
|
||||||
|
|
||||||
|
if is_reshaped:
|
||||||
|
reshape_batch_size = q.shape[0]
|
||||||
|
q, k, v = (einops.rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v])
|
||||||
|
# cuDNN <= 9.10.2.21 requires q, k to be contiguous
|
||||||
|
# this comes with no cost for ViTs with RoPE because
|
||||||
|
# RoPE has already made q and k contiguous.
|
||||||
|
q, k = q.contiguous(), k.contiguous()
|
||||||
|
|
||||||
|
assert len(cu_seqlens) % 2 == 0, "cu_seqlens must be divisible by 2"
|
||||||
|
cu_seqlength = len(cu_seqlens) // 2
|
||||||
|
batch_offsets_qko = cu_seqlens[:cu_seqlength].view(-1, 1, 1, 1)
|
||||||
|
batch_offsets_v = cu_seqlens[cu_seqlength:].view(-1, 1, 1, 1)
|
||||||
|
sequence_lengths = sequence_lengths.view(-1, 1, 1, 1)
|
||||||
|
max_seqlen = max_seqlen.item()
|
||||||
|
|
||||||
|
output, _ = cudnn_batch_prefill_with_kv_cache(
|
||||||
|
q,
|
||||||
|
k,
|
||||||
|
v,
|
||||||
|
scale,
|
||||||
|
workspace_buffer,
|
||||||
|
max_token_per_sequence=max_seqlen,
|
||||||
|
max_sequence_kv=max_seqlen,
|
||||||
|
actual_seq_lens_q=sequence_lengths,
|
||||||
|
actual_seq_lens_kv=sequence_lengths,
|
||||||
|
causal=False,
|
||||||
|
return_lse=False,
|
||||||
|
batch_offsets_q=batch_offsets_qko,
|
||||||
|
batch_offsets_k=batch_offsets_qko,
|
||||||
|
batch_offsets_v=batch_offsets_v,
|
||||||
|
batch_offsets_o=batch_offsets_qko,
|
||||||
|
)
|
||||||
|
|
||||||
|
if is_reshaped:
|
||||||
|
output = einops.rearrange(output, "(b s) h d -> b s h d", b=reshape_batch_size)
|
||||||
|
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
def vit_flashinfer_wrapper_fake(
|
||||||
|
q: torch.Tensor,
|
||||||
|
k: torch.Tensor,
|
||||||
|
v: torch.Tensor,
|
||||||
|
scale: float,
|
||||||
|
workspace_buffer: torch.Tensor,
|
||||||
|
cu_seqlens: torch.Tensor | None = None,
|
||||||
|
max_seqlen: torch.Tensor | None = None,
|
||||||
|
sequence_lengths: torch.Tensor | None = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
return torch.empty_like(q)
|
||||||
|
|
||||||
|
|
||||||
|
direct_register_custom_op(
|
||||||
|
op_name="flashinfer_wrapper",
|
||||||
|
op_func=flashinfer_wrapper,
|
||||||
|
fake_impl=vit_flashinfer_wrapper_fake,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def vit_flashinfer_wrapper(
|
||||||
|
q: torch.Tensor,
|
||||||
|
k: torch.Tensor,
|
||||||
|
v: torch.Tensor,
|
||||||
|
scale: float,
|
||||||
|
workspace_buffer: torch.Tensor,
|
||||||
|
cu_seqlens: torch.Tensor | None = None,
|
||||||
|
max_seqlen: torch.Tensor | None = None,
|
||||||
|
sequence_lengths: torch.Tensor | None = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
return torch.ops.vllm.flashinfer_wrapper(
|
||||||
|
q, k, v, scale, workspace_buffer, cu_seqlens, max_seqlen, sequence_lengths
|
||||||
|
)
|
||||||
|
|||||||
Reference in New Issue
Block a user