Flashinfer cuDNN backend for Qwen3 VL ViT attention (#34580)

Signed-off-by: Max Hu <maxhu@nvidia.com>
Signed-off-by: Max Hu <hyoung2991@gmail.com>
Co-authored-by: Max Hu <maxhu@nvidia.com>
Co-authored-by: Shang Wang <shangw@nvidia.com>
This commit is contained in:
Max Hu
2026-02-27 20:20:23 +08:00
committed by GitHub
parent b66a74649e
commit 9c3fe9936b
6 changed files with 405 additions and 21 deletions

View File

@@ -9,9 +9,12 @@ Test:
import itertools
from unittest.mock import patch
import numpy as np
import pytest
import torch
from vllm.config import get_current_vllm_config
from vllm.config.multimodal import MultiModalConfig
from vllm.model_executor.layers.attention import MMEncoderAttention
from vllm.platforms import current_platform
from vllm.platforms.cpu import CpuPlatform
@@ -224,3 +227,110 @@ def test_mha_attn_varlen_forward(
ref_output.append(output_i)
ref_output = torch.cat(ref_output, dim=1)
torch.testing.assert_close(output, ref_output, atol=1e-2, rtol=1e-2)
@pytest.mark.parametrize("var_seq_len", VAR_SEQ_LENS)
@pytest.mark.parametrize(
"dtype",
[torch.bfloat16, torch.half],
)
@pytest.mark.parametrize("device", CUDA_DEVICES)
def test_mha_attn_varlen_forward_flashinfer(
default_vllm_config,
var_seq_len: list[int],
dtype: torch.dtype,
device: str,
):
"""Test MMEncoderAttention varlen forward with FLASHINFER backend (head_size=72).
Exercises the path that uses --mm-encoder-attn-backend=FLASHINFER with
recomputed cu_seqlens, max_seqlen, and sequence_lengths as in qwen3_vl
vision encoder.
"""
pytest.importorskip("flashinfer")
num_heads = 16
head_size = 72
set_random_seed(0)
torch.set_default_device(device)
torch.set_default_dtype(dtype)
# Override vllm config so get_vit_attn_backend returns FLASHINFER (simulates
# --mm-encoder-attn-backend=FLASHINFER).
vllm_config = get_current_vllm_config()
old_model_config = getattr(vllm_config, "model_config", None)
minimal_model_config = type(
"MinimalModelConfig",
(),
{
"multimodal_config": MultiModalConfig(
mm_encoder_attn_backend=AttentionBackendEnum.FLASHINFER
),
},
)()
vllm_config.model_config = minimal_model_config
try:
total_len = sum(var_seq_len)
# Stride of second dim = 3 * num_heads * head_size (same as qwen2_5_vl
# after qkv rearrange and unbind: qkv shape (b, s, 3, head, head_dim)).
qkv = torch.randn(1, total_len, 3, num_heads, head_size)
q, k, v = qkv.unbind(dim=2)
cu_seqlens_np = np.array(
[0] + list(itertools.accumulate(var_seq_len)), dtype=np.int32
)
hidden_size = num_heads * head_size
tp_size = 1
sequence_lengths_np = MMEncoderAttention.maybe_compute_sequence_lengths(
AttentionBackendEnum.FLASHINFER, cu_seqlens_np
)
sequence_lengths = torch.from_numpy(sequence_lengths_np).to(
device, dtype=torch.int32, non_blocking=True
)
max_seqlen_val = MMEncoderAttention.compute_max_seqlen(
AttentionBackendEnum.FLASHINFER, cu_seqlens_np
)
max_seqlen = torch.tensor(max_seqlen_val, device=device, dtype=torch.int32)
cu_seqlens_np = MMEncoderAttention.maybe_recompute_cu_seqlens(
AttentionBackendEnum.FLASHINFER,
cu_seqlens_np,
hidden_size,
tp_size,
)
cu_seqlens = torch.from_numpy(cu_seqlens_np).to(
device, dtype=torch.int32, non_blocking=True
)
scale = 1.0 / head_size**0.5
attn = MMEncoderAttention(
num_heads,
head_size,
scale=scale,
num_kv_heads=num_heads,
)
assert attn.attn_backend == AttentionBackendEnum.FLASHINFER
output = attn(
q,
k,
v,
cu_seqlens=cu_seqlens,
max_seqlen=max_seqlen,
sequence_lengths=sequence_lengths,
)
ref_output = []
for q_i, k_i, v_i in zip(
torch.split(q, var_seq_len, dim=1),
torch.split(k, var_seq_len, dim=1),
torch.split(v, var_seq_len, dim=1),
):
output_i = ref_attention(q_i, k_i, v_i, scale=scale)
ref_output.append(output_i)
ref_output = torch.cat(ref_output, dim=1)
torch.testing.assert_close(output, ref_output, atol=1e-2, rtol=1e-2)
finally:
vllm_config.model_config = old_model_config