[MM Encoder] Add Triton ViT attention backend (#32183)

Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
This commit is contained in:
Isotr0py
2026-02-15 22:32:47 +08:00
committed by GitHub
parent 19fab44152
commit 71cd89264f
14 changed files with 178 additions and 51 deletions

View File

@@ -17,7 +17,7 @@ from vllm.platforms import current_platform
from vllm.platforms.cpu import CpuPlatform
from vllm.platforms.cuda import CudaPlatform
from vllm.platforms.rocm import RocmPlatform
from vllm.utils.torch_utils import set_random_seed
from vllm.utils.torch_utils import set_default_torch_dtype, set_random_seed
from vllm.v1.attention.backends.registry import AttentionBackendEnum
from vllm.v1.attention.selector import _cached_get_attn_backend
@@ -71,6 +71,15 @@ def test_mha_attn_platform(default_vllm_config, device: str):
attn = MMEncoderAttention(16, 72, scale=1)
assert attn.attn_backend == AttentionBackendEnum.FLASH_ATTN
# Test CUDA with head_size=72 (not divisible by 32)
# - should use vLLM's FlashAttention
with (
patch("vllm.model_executor.models.vision.current_platform", CudaPlatform()),
set_default_torch_dtype(torch.float32),
):
attn = MMEncoderAttention(16, 72, scale=1)
assert attn.attn_backend == AttentionBackendEnum.TRITON_ATTN
def ref_attention(
query: torch.Tensor,
@@ -153,7 +162,12 @@ def test_mha_attn_forward(
v,
scale=scale,
).reshape(batch_size, seq_len, num_heads * head_size)
torch.testing.assert_close(output, ref_output)
tol_kwargs = (
dict(rtol=1e-3, atol=1e-3)
if attn.attn_backend == AttentionBackendEnum.TRITON_ATTN
else {}
)
torch.testing.assert_close(output, ref_output, **tol_kwargs)
@pytest.mark.parametrize("var_seq_len", VAR_SEQ_LENS)

View File

@@ -12,6 +12,7 @@ from vllm.v1.attention.backends.registry import AttentionBackendEnum
from vllm.v1.attention.ops.vit_attn_wrappers import (
vit_flash_attn_wrapper,
vit_torch_sdpa_wrapper,
vit_triton_attn_wrapper,
)
logger = init_logger(__name__)
@@ -165,6 +166,41 @@ class MMEncoderAttention(CustomOp):
output = output.reshape(bsz, q_len, -1)
return output
def _forward_triton(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
cu_seqlens: torch.Tensor | None = None,
max_seqlen: torch.Tensor | None = None, # Only used for Flash Attention
) -> torch.Tensor:
"""Input shape:
(batch_size x seq_len x hidden_size) or
(batch_size x seq_len x num_heads x head_size)
"""
assert (cu_seqlens is not None and max_seqlen is not None) or (
cu_seqlens is None and max_seqlen is None
), "cu_seqlens and max_seqlen should be both set or both None."
bsz, q_len = query.size()[:2]
kv_len = key.size(1)
is_reshaped = query.dim() != 4
query, key, value = self.view_qkv_to_4d(query, key, value, bsz, q_len, kv_len)
output = vit_triton_attn_wrapper(
q=query,
k=key,
v=value,
batch_size=bsz,
scale=self.scale,
cu_seqlens=cu_seqlens,
max_seqlen=max_seqlen,
)
if is_reshaped:
output = output.reshape(bsz, q_len, -1)
return output
def forward_native(
self,
query: torch.Tensor,
@@ -185,6 +221,8 @@ class MMEncoderAttention(CustomOp):
) -> torch.Tensor:
if self.is_flash_attn_backend:
return self._forward_fa(query, key, value, cu_seqlens, max_seqlen)
elif self.attn_backend == AttentionBackendEnum.TRITON_ATTN:
return self._forward_triton(query, key, value, cu_seqlens, max_seqlen)
elif self.attn_backend == AttentionBackendEnum.TORCH_SDPA:
return self._forward_sdpa(query, key, value, cu_seqlens)
else:

View File

@@ -573,10 +573,11 @@ class DotsVisionTransformer(nn.Module):
def compute_attn_mask_seqlen(self, cu_seqlens: torch.Tensor) -> int | None:
max_seqlen = None
if (
self.attn_backend == AttentionBackendEnum.FLASH_ATTN
or self.attn_backend == AttentionBackendEnum.ROCM_AITER_FA
):
if self.attn_backend in {
AttentionBackendEnum.FLASH_ATTN,
AttentionBackendEnum.ROCM_AITER_FA,
AttentionBackendEnum.TRITON_ATTN,
}:
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max()
return max_seqlen

View File

@@ -446,10 +446,11 @@ class Ernie4_5_VisionTransformer(nn.Module):
def compute_attn_mask_seqlen(self, cu_seqlens: torch.Tensor) -> torch.Tensor | None:
max_seqlen = None
if (
self.attn_backend == AttentionBackendEnum.FLASH_ATTN
or self.attn_backend == AttentionBackendEnum.ROCM_AITER_FA
):
if self.attn_backend in {
AttentionBackendEnum.FLASH_ATTN,
AttentionBackendEnum.ROCM_AITER_FA,
AttentionBackendEnum.TRITON_ATTN,
}:
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max()
return max_seqlen

View File

@@ -723,10 +723,11 @@ class Glm4vVisionTransformer(nn.Module):
cu_seqlens: torch.Tensor,
) -> torch.Tensor | None:
max_seqlen = None
if (
self.attn_backend == AttentionBackendEnum.FLASH_ATTN
or self.attn_backend == AttentionBackendEnum.ROCM_AITER_FA
):
if self.attn_backend in {
AttentionBackendEnum.FLASH_ATTN,
AttentionBackendEnum.ROCM_AITER_FA,
AttentionBackendEnum.TRITON_ATTN,
}:
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max()
return max_seqlen

View File

@@ -730,14 +730,7 @@ class SiglipEncoder(nn.Module):
head_size=head_dim,
dtype=torch.get_default_dtype(),
)
if self.attn_backend not in {
AttentionBackendEnum.FLASH_ATTN,
AttentionBackendEnum.TORCH_SDPA,
AttentionBackendEnum.ROCM_AITER_FA,
}:
raise RuntimeError(
f"PaddleOCR-VL does not support {self.attn_backend} backend now."
)
self.layers = nn.ModuleList(
[
SiglipEncoderLayer(
@@ -805,6 +798,7 @@ class SiglipEncoder(nn.Module):
if self.attn_backend in {
AttentionBackendEnum.FLASH_ATTN,
AttentionBackendEnum.ROCM_AITER_FA,
AttentionBackendEnum.TRITON_ATTN,
}:
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max()

View File

@@ -607,15 +607,6 @@ class Qwen2_5_VisionTransformer(nn.Module):
dtype=torch.get_default_dtype(),
)
if self.attn_backend not in {
AttentionBackendEnum.FLASH_ATTN,
AttentionBackendEnum.TORCH_SDPA,
AttentionBackendEnum.ROCM_AITER_FA,
}:
raise RuntimeError(
f"Qwen2.5-VL does not support {self.attn_backend} backend now."
)
with set_model_tag("Qwen2_5_VisionBlock", is_encoder=True):
self.blocks = nn.ModuleList(
[
@@ -761,6 +752,7 @@ class Qwen2_5_VisionTransformer(nn.Module):
if self.attn_backend in {
AttentionBackendEnum.FLASH_ATTN,
AttentionBackendEnum.ROCM_AITER_FA,
AttentionBackendEnum.TRITON_ATTN,
}:
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max()
return max_seqlen

View File

@@ -642,6 +642,7 @@ class Qwen2VisionTransformer(nn.Module):
if self.attn_backend in {
AttentionBackendEnum.FLASH_ATTN,
AttentionBackendEnum.ROCM_AITER_FA,
AttentionBackendEnum.TRITON_ATTN,
}:
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max()
return max_seqlen

View File

@@ -391,6 +391,7 @@ class Qwen3OmniMoeAudioEncoder(nn.Module):
if self.attn_backend in {
AttentionBackendEnum.FLASH_ATTN,
AttentionBackendEnum.ROCM_AITER_FA,
AttentionBackendEnum.TRITON_ATTN,
}:
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max()
return max_seqlen
@@ -919,6 +920,7 @@ class Qwen3Omni_VisionTransformer(nn.Module):
if self.attn_backend in {
AttentionBackendEnum.FLASH_ATTN,
AttentionBackendEnum.ROCM_AITER_FA,
AttentionBackendEnum.TRITON_ATTN,
}:
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max()
return max_seqlen

View File

@@ -385,14 +385,6 @@ class Qwen3_VisionTransformer(nn.Module):
dtype=torch.get_default_dtype(),
)
if self.attn_backend not in {
AttentionBackendEnum.FLASH_ATTN,
AttentionBackendEnum.TORCH_SDPA,
AttentionBackendEnum.ROCM_AITER_FA,
}:
raise RuntimeError(
f"Qwen3-VL does not support {self.attn_backend} backend now."
)
self.blocks = nn.ModuleList(
[
Qwen3_VisionBlock(
@@ -526,9 +518,10 @@ class Qwen3_VisionTransformer(nn.Module):
cu_seqlens: torch.Tensor,
) -> torch.Tensor:
max_seqlen = torch.zeros([], device=cu_seqlens.device)
if (
self.attn_backend == AttentionBackendEnum.FLASH_ATTN
or self.attn_backend == AttentionBackendEnum.ROCM_AITER_FA
if self.attn_backend in (
AttentionBackendEnum.FLASH_ATTN,
AttentionBackendEnum.ROCM_AITER_FA,
AttentionBackendEnum.TRITON_ATTN,
):
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max()
return max_seqlen

View File

@@ -108,7 +108,7 @@ def get_vit_attn_backend(
multimodal_config: MultiModalConfig | None = (
model_config.multimodal_config if model_config is not None else None
)
except AssertionError:
except (AssertionError, AttributeError):
multimodal_config = None
attn_backend_override = (
@@ -134,7 +134,7 @@ def is_vit_use_data_parallel():
multimodal_config: MultiModalConfig | None = (
model_config.multimodal_config if model_config is not None else None
)
except AssertionError:
except (AssertionError, AttributeError):
multimodal_config = None
mm_encoder_tp_mode = (

View File

@@ -411,8 +411,9 @@ class CudaPlatformBase(Platform):
@classmethod
def get_supported_vit_attn_backends(cls) -> list["AttentionBackendEnum"]:
return [
AttentionBackendEnum.TORCH_SDPA,
AttentionBackendEnum.FLASH_ATTN,
AttentionBackendEnum.TRITON_ATTN,
AttentionBackendEnum.TORCH_SDPA,
]
@classmethod
@@ -430,14 +431,25 @@ class CudaPlatformBase(Platform):
logger.info_once(f"Using backend {backend} for vit attention")
return backend
# Try FlashAttention first
if (cc := cls.get_device_capability()) and cc.major >= 8:
cc = cls.get_device_capability()
for vit_attn_backend in cls.get_supported_vit_attn_backends():
if vit_attn_backend == AttentionBackendEnum.TORCH_SDPA:
continue
try:
backend_class = AttentionBackendEnum.FLASH_ATTN.get_class()
if backend_class.supports_head_size(
backend_class = vit_attn_backend.get_class()
is_backend_supported = backend_class.supports_head_size(
head_size
) and backend_class.supports_dtype(dtype):
return AttentionBackendEnum.FLASH_ATTN
) and backend_class.supports_dtype(dtype)
if cc is not None:
is_backend_supported = (
is_backend_supported
and backend_class.supports_compute_capability(cc)
)
if is_backend_supported:
logger.info_once(
f"Using backend {vit_attn_backend} for vit attention"
)
return vit_attn_backend
except ImportError:
pass

View File

@@ -384,6 +384,7 @@ class RocmPlatform(Platform):
return [
AttentionBackendEnum.FLASH_ATTN,
AttentionBackendEnum.ROCM_AITER_FA,
AttentionBackendEnum.TRITON_ATTN,
AttentionBackendEnum.TORCH_SDPA,
]

View File

@@ -110,6 +110,83 @@ def vit_flash_attn_wrapper(
)
def triton_attn_wrapper(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
batch_size: int,
scale: float | None = None,
cu_seqlens: torch.Tensor | None = None,
max_seqlen: torch.Tensor | None = None,
) -> torch.Tensor:
from vllm.v1.attention.ops.triton_prefill_attention import context_attention_fwd
q_len = q.size(1)
if cu_seqlens is None:
cu_seqlens = torch.arange(
0, (batch_size + 1) * q_len, step=q_len, dtype=torch.int32, device=q.device
)
max_seqlen = q_len if max_seqlen is None else max_seqlen.item()
q, k, v = (einops.rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v])
output = torch.empty_like(q)
context_attention_fwd(
q,
k,
v,
output,
b_start_loc=cu_seqlens[:-1],
b_seq_len=cu_seqlens[1:] - cu_seqlens[:-1],
max_input_len=max_seqlen,
is_causal=False,
sliding_window_q=None,
sliding_window_k=None,
softmax_scale=scale,
)
context_layer = einops.rearrange(output, "(b s) h d -> b s h d", b=batch_size)
return context_layer
def triton_attn_wrapper_fake(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
batch_size: int,
scale: float | None = None,
cu_seqlens: torch.Tensor | None = None,
max_seqlen: torch.Tensor | None = None,
) -> torch.Tensor:
return torch.empty_like(q)
direct_register_custom_op(
op_name="triton_attn_wrapper",
op_func=triton_attn_wrapper,
fake_impl=triton_attn_wrapper_fake,
)
def vit_triton_attn_wrapper(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
batch_size: int,
scale: float | None = None,
cu_seqlens: torch.Tensor | None = None,
max_seqlen: torch.Tensor | None = None,
) -> torch.Tensor:
return torch.ops.vllm.triton_attn_wrapper(
q,
k,
v,
batch_size,
scale,
cu_seqlens,
max_seqlen,
)
def apply_sdpa(
q: torch.Tensor,
k: torch.Tensor,