[MM Encoder] Add Triton ViT attention backend (#32183)
Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
This commit is contained in:
@@ -17,7 +17,7 @@ from vllm.platforms import current_platform
|
||||
from vllm.platforms.cpu import CpuPlatform
|
||||
from vllm.platforms.cuda import CudaPlatform
|
||||
from vllm.platforms.rocm import RocmPlatform
|
||||
from vllm.utils.torch_utils import set_random_seed
|
||||
from vllm.utils.torch_utils import set_default_torch_dtype, set_random_seed
|
||||
from vllm.v1.attention.backends.registry import AttentionBackendEnum
|
||||
from vllm.v1.attention.selector import _cached_get_attn_backend
|
||||
|
||||
@@ -71,6 +71,15 @@ def test_mha_attn_platform(default_vllm_config, device: str):
|
||||
attn = MMEncoderAttention(16, 72, scale=1)
|
||||
assert attn.attn_backend == AttentionBackendEnum.FLASH_ATTN
|
||||
|
||||
# Test CUDA with head_size=72 (not divisible by 32)
|
||||
# - should use vLLM's FlashAttention
|
||||
with (
|
||||
patch("vllm.model_executor.models.vision.current_platform", CudaPlatform()),
|
||||
set_default_torch_dtype(torch.float32),
|
||||
):
|
||||
attn = MMEncoderAttention(16, 72, scale=1)
|
||||
assert attn.attn_backend == AttentionBackendEnum.TRITON_ATTN
|
||||
|
||||
|
||||
def ref_attention(
|
||||
query: torch.Tensor,
|
||||
@@ -153,7 +162,12 @@ def test_mha_attn_forward(
|
||||
v,
|
||||
scale=scale,
|
||||
).reshape(batch_size, seq_len, num_heads * head_size)
|
||||
torch.testing.assert_close(output, ref_output)
|
||||
tol_kwargs = (
|
||||
dict(rtol=1e-3, atol=1e-3)
|
||||
if attn.attn_backend == AttentionBackendEnum.TRITON_ATTN
|
||||
else {}
|
||||
)
|
||||
torch.testing.assert_close(output, ref_output, **tol_kwargs)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("var_seq_len", VAR_SEQ_LENS)
|
||||
|
||||
@@ -12,6 +12,7 @@ from vllm.v1.attention.backends.registry import AttentionBackendEnum
|
||||
from vllm.v1.attention.ops.vit_attn_wrappers import (
|
||||
vit_flash_attn_wrapper,
|
||||
vit_torch_sdpa_wrapper,
|
||||
vit_triton_attn_wrapper,
|
||||
)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
@@ -165,6 +166,41 @@ class MMEncoderAttention(CustomOp):
|
||||
output = output.reshape(bsz, q_len, -1)
|
||||
return output
|
||||
|
||||
def _forward_triton(
|
||||
self,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
cu_seqlens: torch.Tensor | None = None,
|
||||
max_seqlen: torch.Tensor | None = None, # Only used for Flash Attention
|
||||
) -> torch.Tensor:
|
||||
"""Input shape:
|
||||
(batch_size x seq_len x hidden_size) or
|
||||
(batch_size x seq_len x num_heads x head_size)
|
||||
"""
|
||||
assert (cu_seqlens is not None and max_seqlen is not None) or (
|
||||
cu_seqlens is None and max_seqlen is None
|
||||
), "cu_seqlens and max_seqlen should be both set or both None."
|
||||
|
||||
bsz, q_len = query.size()[:2]
|
||||
kv_len = key.size(1)
|
||||
is_reshaped = query.dim() != 4
|
||||
|
||||
query, key, value = self.view_qkv_to_4d(query, key, value, bsz, q_len, kv_len)
|
||||
|
||||
output = vit_triton_attn_wrapper(
|
||||
q=query,
|
||||
k=key,
|
||||
v=value,
|
||||
batch_size=bsz,
|
||||
scale=self.scale,
|
||||
cu_seqlens=cu_seqlens,
|
||||
max_seqlen=max_seqlen,
|
||||
)
|
||||
if is_reshaped:
|
||||
output = output.reshape(bsz, q_len, -1)
|
||||
return output
|
||||
|
||||
def forward_native(
|
||||
self,
|
||||
query: torch.Tensor,
|
||||
@@ -185,6 +221,8 @@ class MMEncoderAttention(CustomOp):
|
||||
) -> torch.Tensor:
|
||||
if self.is_flash_attn_backend:
|
||||
return self._forward_fa(query, key, value, cu_seqlens, max_seqlen)
|
||||
elif self.attn_backend == AttentionBackendEnum.TRITON_ATTN:
|
||||
return self._forward_triton(query, key, value, cu_seqlens, max_seqlen)
|
||||
elif self.attn_backend == AttentionBackendEnum.TORCH_SDPA:
|
||||
return self._forward_sdpa(query, key, value, cu_seqlens)
|
||||
else:
|
||||
|
||||
@@ -573,10 +573,11 @@ class DotsVisionTransformer(nn.Module):
|
||||
|
||||
def compute_attn_mask_seqlen(self, cu_seqlens: torch.Tensor) -> int | None:
|
||||
max_seqlen = None
|
||||
if (
|
||||
self.attn_backend == AttentionBackendEnum.FLASH_ATTN
|
||||
or self.attn_backend == AttentionBackendEnum.ROCM_AITER_FA
|
||||
):
|
||||
if self.attn_backend in {
|
||||
AttentionBackendEnum.FLASH_ATTN,
|
||||
AttentionBackendEnum.ROCM_AITER_FA,
|
||||
AttentionBackendEnum.TRITON_ATTN,
|
||||
}:
|
||||
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max()
|
||||
return max_seqlen
|
||||
|
||||
|
||||
@@ -446,10 +446,11 @@ class Ernie4_5_VisionTransformer(nn.Module):
|
||||
|
||||
def compute_attn_mask_seqlen(self, cu_seqlens: torch.Tensor) -> torch.Tensor | None:
|
||||
max_seqlen = None
|
||||
if (
|
||||
self.attn_backend == AttentionBackendEnum.FLASH_ATTN
|
||||
or self.attn_backend == AttentionBackendEnum.ROCM_AITER_FA
|
||||
):
|
||||
if self.attn_backend in {
|
||||
AttentionBackendEnum.FLASH_ATTN,
|
||||
AttentionBackendEnum.ROCM_AITER_FA,
|
||||
AttentionBackendEnum.TRITON_ATTN,
|
||||
}:
|
||||
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max()
|
||||
return max_seqlen
|
||||
|
||||
|
||||
@@ -723,10 +723,11 @@ class Glm4vVisionTransformer(nn.Module):
|
||||
cu_seqlens: torch.Tensor,
|
||||
) -> torch.Tensor | None:
|
||||
max_seqlen = None
|
||||
if (
|
||||
self.attn_backend == AttentionBackendEnum.FLASH_ATTN
|
||||
or self.attn_backend == AttentionBackendEnum.ROCM_AITER_FA
|
||||
):
|
||||
if self.attn_backend in {
|
||||
AttentionBackendEnum.FLASH_ATTN,
|
||||
AttentionBackendEnum.ROCM_AITER_FA,
|
||||
AttentionBackendEnum.TRITON_ATTN,
|
||||
}:
|
||||
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max()
|
||||
return max_seqlen
|
||||
|
||||
|
||||
@@ -730,14 +730,7 @@ class SiglipEncoder(nn.Module):
|
||||
head_size=head_dim,
|
||||
dtype=torch.get_default_dtype(),
|
||||
)
|
||||
if self.attn_backend not in {
|
||||
AttentionBackendEnum.FLASH_ATTN,
|
||||
AttentionBackendEnum.TORCH_SDPA,
|
||||
AttentionBackendEnum.ROCM_AITER_FA,
|
||||
}:
|
||||
raise RuntimeError(
|
||||
f"PaddleOCR-VL does not support {self.attn_backend} backend now."
|
||||
)
|
||||
|
||||
self.layers = nn.ModuleList(
|
||||
[
|
||||
SiglipEncoderLayer(
|
||||
@@ -805,6 +798,7 @@ class SiglipEncoder(nn.Module):
|
||||
if self.attn_backend in {
|
||||
AttentionBackendEnum.FLASH_ATTN,
|
||||
AttentionBackendEnum.ROCM_AITER_FA,
|
||||
AttentionBackendEnum.TRITON_ATTN,
|
||||
}:
|
||||
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max()
|
||||
|
||||
|
||||
@@ -607,15 +607,6 @@ class Qwen2_5_VisionTransformer(nn.Module):
|
||||
dtype=torch.get_default_dtype(),
|
||||
)
|
||||
|
||||
if self.attn_backend not in {
|
||||
AttentionBackendEnum.FLASH_ATTN,
|
||||
AttentionBackendEnum.TORCH_SDPA,
|
||||
AttentionBackendEnum.ROCM_AITER_FA,
|
||||
}:
|
||||
raise RuntimeError(
|
||||
f"Qwen2.5-VL does not support {self.attn_backend} backend now."
|
||||
)
|
||||
|
||||
with set_model_tag("Qwen2_5_VisionBlock", is_encoder=True):
|
||||
self.blocks = nn.ModuleList(
|
||||
[
|
||||
@@ -761,6 +752,7 @@ class Qwen2_5_VisionTransformer(nn.Module):
|
||||
if self.attn_backend in {
|
||||
AttentionBackendEnum.FLASH_ATTN,
|
||||
AttentionBackendEnum.ROCM_AITER_FA,
|
||||
AttentionBackendEnum.TRITON_ATTN,
|
||||
}:
|
||||
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max()
|
||||
return max_seqlen
|
||||
|
||||
@@ -642,6 +642,7 @@ class Qwen2VisionTransformer(nn.Module):
|
||||
if self.attn_backend in {
|
||||
AttentionBackendEnum.FLASH_ATTN,
|
||||
AttentionBackendEnum.ROCM_AITER_FA,
|
||||
AttentionBackendEnum.TRITON_ATTN,
|
||||
}:
|
||||
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max()
|
||||
return max_seqlen
|
||||
|
||||
@@ -391,6 +391,7 @@ class Qwen3OmniMoeAudioEncoder(nn.Module):
|
||||
if self.attn_backend in {
|
||||
AttentionBackendEnum.FLASH_ATTN,
|
||||
AttentionBackendEnum.ROCM_AITER_FA,
|
||||
AttentionBackendEnum.TRITON_ATTN,
|
||||
}:
|
||||
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max()
|
||||
return max_seqlen
|
||||
@@ -919,6 +920,7 @@ class Qwen3Omni_VisionTransformer(nn.Module):
|
||||
if self.attn_backend in {
|
||||
AttentionBackendEnum.FLASH_ATTN,
|
||||
AttentionBackendEnum.ROCM_AITER_FA,
|
||||
AttentionBackendEnum.TRITON_ATTN,
|
||||
}:
|
||||
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max()
|
||||
return max_seqlen
|
||||
|
||||
@@ -385,14 +385,6 @@ class Qwen3_VisionTransformer(nn.Module):
|
||||
dtype=torch.get_default_dtype(),
|
||||
)
|
||||
|
||||
if self.attn_backend not in {
|
||||
AttentionBackendEnum.FLASH_ATTN,
|
||||
AttentionBackendEnum.TORCH_SDPA,
|
||||
AttentionBackendEnum.ROCM_AITER_FA,
|
||||
}:
|
||||
raise RuntimeError(
|
||||
f"Qwen3-VL does not support {self.attn_backend} backend now."
|
||||
)
|
||||
self.blocks = nn.ModuleList(
|
||||
[
|
||||
Qwen3_VisionBlock(
|
||||
@@ -526,9 +518,10 @@ class Qwen3_VisionTransformer(nn.Module):
|
||||
cu_seqlens: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
max_seqlen = torch.zeros([], device=cu_seqlens.device)
|
||||
if (
|
||||
self.attn_backend == AttentionBackendEnum.FLASH_ATTN
|
||||
or self.attn_backend == AttentionBackendEnum.ROCM_AITER_FA
|
||||
if self.attn_backend in (
|
||||
AttentionBackendEnum.FLASH_ATTN,
|
||||
AttentionBackendEnum.ROCM_AITER_FA,
|
||||
AttentionBackendEnum.TRITON_ATTN,
|
||||
):
|
||||
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max()
|
||||
return max_seqlen
|
||||
|
||||
@@ -108,7 +108,7 @@ def get_vit_attn_backend(
|
||||
multimodal_config: MultiModalConfig | None = (
|
||||
model_config.multimodal_config if model_config is not None else None
|
||||
)
|
||||
except AssertionError:
|
||||
except (AssertionError, AttributeError):
|
||||
multimodal_config = None
|
||||
|
||||
attn_backend_override = (
|
||||
@@ -134,7 +134,7 @@ def is_vit_use_data_parallel():
|
||||
multimodal_config: MultiModalConfig | None = (
|
||||
model_config.multimodal_config if model_config is not None else None
|
||||
)
|
||||
except AssertionError:
|
||||
except (AssertionError, AttributeError):
|
||||
multimodal_config = None
|
||||
|
||||
mm_encoder_tp_mode = (
|
||||
|
||||
@@ -411,8 +411,9 @@ class CudaPlatformBase(Platform):
|
||||
@classmethod
|
||||
def get_supported_vit_attn_backends(cls) -> list["AttentionBackendEnum"]:
|
||||
return [
|
||||
AttentionBackendEnum.TORCH_SDPA,
|
||||
AttentionBackendEnum.FLASH_ATTN,
|
||||
AttentionBackendEnum.TRITON_ATTN,
|
||||
AttentionBackendEnum.TORCH_SDPA,
|
||||
]
|
||||
|
||||
@classmethod
|
||||
@@ -430,14 +431,25 @@ class CudaPlatformBase(Platform):
|
||||
logger.info_once(f"Using backend {backend} for vit attention")
|
||||
return backend
|
||||
|
||||
# Try FlashAttention first
|
||||
if (cc := cls.get_device_capability()) and cc.major >= 8:
|
||||
cc = cls.get_device_capability()
|
||||
for vit_attn_backend in cls.get_supported_vit_attn_backends():
|
||||
if vit_attn_backend == AttentionBackendEnum.TORCH_SDPA:
|
||||
continue
|
||||
try:
|
||||
backend_class = AttentionBackendEnum.FLASH_ATTN.get_class()
|
||||
if backend_class.supports_head_size(
|
||||
backend_class = vit_attn_backend.get_class()
|
||||
is_backend_supported = backend_class.supports_head_size(
|
||||
head_size
|
||||
) and backend_class.supports_dtype(dtype):
|
||||
return AttentionBackendEnum.FLASH_ATTN
|
||||
) and backend_class.supports_dtype(dtype)
|
||||
if cc is not None:
|
||||
is_backend_supported = (
|
||||
is_backend_supported
|
||||
and backend_class.supports_compute_capability(cc)
|
||||
)
|
||||
if is_backend_supported:
|
||||
logger.info_once(
|
||||
f"Using backend {vit_attn_backend} for vit attention"
|
||||
)
|
||||
return vit_attn_backend
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
|
||||
@@ -384,6 +384,7 @@ class RocmPlatform(Platform):
|
||||
return [
|
||||
AttentionBackendEnum.FLASH_ATTN,
|
||||
AttentionBackendEnum.ROCM_AITER_FA,
|
||||
AttentionBackendEnum.TRITON_ATTN,
|
||||
AttentionBackendEnum.TORCH_SDPA,
|
||||
]
|
||||
|
||||
|
||||
@@ -110,6 +110,83 @@ def vit_flash_attn_wrapper(
|
||||
)
|
||||
|
||||
|
||||
def triton_attn_wrapper(
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
batch_size: int,
|
||||
scale: float | None = None,
|
||||
cu_seqlens: torch.Tensor | None = None,
|
||||
max_seqlen: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
from vllm.v1.attention.ops.triton_prefill_attention import context_attention_fwd
|
||||
|
||||
q_len = q.size(1)
|
||||
if cu_seqlens is None:
|
||||
cu_seqlens = torch.arange(
|
||||
0, (batch_size + 1) * q_len, step=q_len, dtype=torch.int32, device=q.device
|
||||
)
|
||||
max_seqlen = q_len if max_seqlen is None else max_seqlen.item()
|
||||
|
||||
q, k, v = (einops.rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v])
|
||||
output = torch.empty_like(q)
|
||||
context_attention_fwd(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
output,
|
||||
b_start_loc=cu_seqlens[:-1],
|
||||
b_seq_len=cu_seqlens[1:] - cu_seqlens[:-1],
|
||||
max_input_len=max_seqlen,
|
||||
is_causal=False,
|
||||
sliding_window_q=None,
|
||||
sliding_window_k=None,
|
||||
softmax_scale=scale,
|
||||
)
|
||||
|
||||
context_layer = einops.rearrange(output, "(b s) h d -> b s h d", b=batch_size)
|
||||
return context_layer
|
||||
|
||||
|
||||
def triton_attn_wrapper_fake(
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
batch_size: int,
|
||||
scale: float | None = None,
|
||||
cu_seqlens: torch.Tensor | None = None,
|
||||
max_seqlen: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
return torch.empty_like(q)
|
||||
|
||||
|
||||
direct_register_custom_op(
|
||||
op_name="triton_attn_wrapper",
|
||||
op_func=triton_attn_wrapper,
|
||||
fake_impl=triton_attn_wrapper_fake,
|
||||
)
|
||||
|
||||
|
||||
def vit_triton_attn_wrapper(
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
batch_size: int,
|
||||
scale: float | None = None,
|
||||
cu_seqlens: torch.Tensor | None = None,
|
||||
max_seqlen: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
return torch.ops.vllm.triton_attn_wrapper(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
batch_size,
|
||||
scale,
|
||||
cu_seqlens,
|
||||
max_seqlen,
|
||||
)
|
||||
|
||||
|
||||
def apply_sdpa(
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
|
||||
Reference in New Issue
Block a user