Flashinfer cuDNN backend for Qwen3 VL ViT attention (#34580)
Signed-off-by: Max Hu <maxhu@nvidia.com> Signed-off-by: Max Hu <hyoung2991@gmail.com> Co-authored-by: Max Hu <maxhu@nvidia.com> Co-authored-by: Shang Wang <shangw@nvidia.com>
This commit is contained in:
@@ -9,9 +9,12 @@ Test:
|
||||
import itertools
|
||||
from unittest.mock import patch
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm.config import get_current_vllm_config
|
||||
from vllm.config.multimodal import MultiModalConfig
|
||||
from vllm.model_executor.layers.attention import MMEncoderAttention
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.platforms.cpu import CpuPlatform
|
||||
@@ -224,3 +227,110 @@ def test_mha_attn_varlen_forward(
|
||||
ref_output.append(output_i)
|
||||
ref_output = torch.cat(ref_output, dim=1)
|
||||
torch.testing.assert_close(output, ref_output, atol=1e-2, rtol=1e-2)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("var_seq_len", VAR_SEQ_LENS)
|
||||
@pytest.mark.parametrize(
|
||||
"dtype",
|
||||
[torch.bfloat16, torch.half],
|
||||
)
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
def test_mha_attn_varlen_forward_flashinfer(
|
||||
default_vllm_config,
|
||||
var_seq_len: list[int],
|
||||
dtype: torch.dtype,
|
||||
device: str,
|
||||
):
|
||||
"""Test MMEncoderAttention varlen forward with FLASHINFER backend (head_size=72).
|
||||
|
||||
Exercises the path that uses --mm-encoder-attn-backend=FLASHINFER with
|
||||
recomputed cu_seqlens, max_seqlen, and sequence_lengths as in qwen3_vl
|
||||
vision encoder.
|
||||
"""
|
||||
pytest.importorskip("flashinfer")
|
||||
|
||||
num_heads = 16
|
||||
head_size = 72
|
||||
set_random_seed(0)
|
||||
torch.set_default_device(device)
|
||||
torch.set_default_dtype(dtype)
|
||||
|
||||
# Override vllm config so get_vit_attn_backend returns FLASHINFER (simulates
|
||||
# --mm-encoder-attn-backend=FLASHINFER).
|
||||
vllm_config = get_current_vllm_config()
|
||||
old_model_config = getattr(vllm_config, "model_config", None)
|
||||
minimal_model_config = type(
|
||||
"MinimalModelConfig",
|
||||
(),
|
||||
{
|
||||
"multimodal_config": MultiModalConfig(
|
||||
mm_encoder_attn_backend=AttentionBackendEnum.FLASHINFER
|
||||
),
|
||||
},
|
||||
)()
|
||||
vllm_config.model_config = minimal_model_config
|
||||
try:
|
||||
total_len = sum(var_seq_len)
|
||||
# Stride of second dim = 3 * num_heads * head_size (same as qwen2_5_vl
|
||||
# after qkv rearrange and unbind: qkv shape (b, s, 3, head, head_dim)).
|
||||
qkv = torch.randn(1, total_len, 3, num_heads, head_size)
|
||||
q, k, v = qkv.unbind(dim=2)
|
||||
|
||||
cu_seqlens_np = np.array(
|
||||
[0] + list(itertools.accumulate(var_seq_len)), dtype=np.int32
|
||||
)
|
||||
hidden_size = num_heads * head_size
|
||||
tp_size = 1
|
||||
|
||||
sequence_lengths_np = MMEncoderAttention.maybe_compute_sequence_lengths(
|
||||
AttentionBackendEnum.FLASHINFER, cu_seqlens_np
|
||||
)
|
||||
sequence_lengths = torch.from_numpy(sequence_lengths_np).to(
|
||||
device, dtype=torch.int32, non_blocking=True
|
||||
)
|
||||
|
||||
max_seqlen_val = MMEncoderAttention.compute_max_seqlen(
|
||||
AttentionBackendEnum.FLASHINFER, cu_seqlens_np
|
||||
)
|
||||
max_seqlen = torch.tensor(max_seqlen_val, device=device, dtype=torch.int32)
|
||||
|
||||
cu_seqlens_np = MMEncoderAttention.maybe_recompute_cu_seqlens(
|
||||
AttentionBackendEnum.FLASHINFER,
|
||||
cu_seqlens_np,
|
||||
hidden_size,
|
||||
tp_size,
|
||||
)
|
||||
cu_seqlens = torch.from_numpy(cu_seqlens_np).to(
|
||||
device, dtype=torch.int32, non_blocking=True
|
||||
)
|
||||
|
||||
scale = 1.0 / head_size**0.5
|
||||
attn = MMEncoderAttention(
|
||||
num_heads,
|
||||
head_size,
|
||||
scale=scale,
|
||||
num_kv_heads=num_heads,
|
||||
)
|
||||
assert attn.attn_backend == AttentionBackendEnum.FLASHINFER
|
||||
|
||||
output = attn(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
cu_seqlens=cu_seqlens,
|
||||
max_seqlen=max_seqlen,
|
||||
sequence_lengths=sequence_lengths,
|
||||
)
|
||||
|
||||
ref_output = []
|
||||
for q_i, k_i, v_i in zip(
|
||||
torch.split(q, var_seq_len, dim=1),
|
||||
torch.split(k, var_seq_len, dim=1),
|
||||
torch.split(v, var_seq_len, dim=1),
|
||||
):
|
||||
output_i = ref_attention(q_i, k_i, v_i, scale=scale)
|
||||
ref_output.append(output_i)
|
||||
ref_output = torch.cat(ref_output, dim=1)
|
||||
torch.testing.assert_close(output, ref_output, atol=1e-2, rtol=1e-2)
|
||||
finally:
|
||||
vllm_config.model_config = old_model_config
|
||||
|
||||
Reference in New Issue
Block a user