diff --git a/tests/kernels/attention/test_mha_attn.py b/tests/kernels/attention/test_mha_attn.py index 25fb5c926..d76c57f9e 100644 --- a/tests/kernels/attention/test_mha_attn.py +++ b/tests/kernels/attention/test_mha_attn.py @@ -17,7 +17,7 @@ from vllm.platforms import current_platform from vllm.platforms.cpu import CpuPlatform from vllm.platforms.cuda import CudaPlatform from vllm.platforms.rocm import RocmPlatform -from vllm.utils.torch_utils import set_random_seed +from vllm.utils.torch_utils import set_default_torch_dtype, set_random_seed from vllm.v1.attention.backends.registry import AttentionBackendEnum from vllm.v1.attention.selector import _cached_get_attn_backend @@ -71,6 +71,15 @@ def test_mha_attn_platform(default_vllm_config, device: str): attn = MMEncoderAttention(16, 72, scale=1) assert attn.attn_backend == AttentionBackendEnum.FLASH_ATTN + # Test CUDA with head_size=72 (not divisible by 32) + # - should use vLLM's FlashAttention + with ( + patch("vllm.model_executor.models.vision.current_platform", CudaPlatform()), + set_default_torch_dtype(torch.float32), + ): + attn = MMEncoderAttention(16, 72, scale=1) + assert attn.attn_backend == AttentionBackendEnum.TRITON_ATTN + def ref_attention( query: torch.Tensor, @@ -153,7 +162,12 @@ def test_mha_attn_forward( v, scale=scale, ).reshape(batch_size, seq_len, num_heads * head_size) - torch.testing.assert_close(output, ref_output) + tol_kwargs = ( + dict(rtol=1e-3, atol=1e-3) + if attn.attn_backend == AttentionBackendEnum.TRITON_ATTN + else {} + ) + torch.testing.assert_close(output, ref_output, **tol_kwargs) @pytest.mark.parametrize("var_seq_len", VAR_SEQ_LENS) diff --git a/vllm/model_executor/layers/attention/mm_encoder_attention.py b/vllm/model_executor/layers/attention/mm_encoder_attention.py index f26d89f40..1e9c714ea 100644 --- a/vllm/model_executor/layers/attention/mm_encoder_attention.py +++ b/vllm/model_executor/layers/attention/mm_encoder_attention.py @@ -12,6 +12,7 @@ from vllm.v1.attention.backends.registry import AttentionBackendEnum from vllm.v1.attention.ops.vit_attn_wrappers import ( vit_flash_attn_wrapper, vit_torch_sdpa_wrapper, + vit_triton_attn_wrapper, ) logger = init_logger(__name__) @@ -165,6 +166,41 @@ class MMEncoderAttention(CustomOp): output = output.reshape(bsz, q_len, -1) return output + def _forward_triton( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + cu_seqlens: torch.Tensor | None = None, + max_seqlen: torch.Tensor | None = None, # Only used for Flash Attention + ) -> torch.Tensor: + """Input shape: + (batch_size x seq_len x hidden_size) or + (batch_size x seq_len x num_heads x head_size) + """ + assert (cu_seqlens is not None and max_seqlen is not None) or ( + cu_seqlens is None and max_seqlen is None + ), "cu_seqlens and max_seqlen should be both set or both None." + + bsz, q_len = query.size()[:2] + kv_len = key.size(1) + is_reshaped = query.dim() != 4 + + query, key, value = self.view_qkv_to_4d(query, key, value, bsz, q_len, kv_len) + + output = vit_triton_attn_wrapper( + q=query, + k=key, + v=value, + batch_size=bsz, + scale=self.scale, + cu_seqlens=cu_seqlens, + max_seqlen=max_seqlen, + ) + if is_reshaped: + output = output.reshape(bsz, q_len, -1) + return output + def forward_native( self, query: torch.Tensor, @@ -185,6 +221,8 @@ class MMEncoderAttention(CustomOp): ) -> torch.Tensor: if self.is_flash_attn_backend: return self._forward_fa(query, key, value, cu_seqlens, max_seqlen) + elif self.attn_backend == AttentionBackendEnum.TRITON_ATTN: + return self._forward_triton(query, key, value, cu_seqlens, max_seqlen) elif self.attn_backend == AttentionBackendEnum.TORCH_SDPA: return self._forward_sdpa(query, key, value, cu_seqlens) else: diff --git a/vllm/model_executor/models/dots_ocr.py b/vllm/model_executor/models/dots_ocr.py index 0d2fefb73..4d8acb082 100644 --- a/vllm/model_executor/models/dots_ocr.py +++ b/vllm/model_executor/models/dots_ocr.py @@ -573,10 +573,11 @@ class DotsVisionTransformer(nn.Module): def compute_attn_mask_seqlen(self, cu_seqlens: torch.Tensor) -> int | None: max_seqlen = None - if ( - self.attn_backend == AttentionBackendEnum.FLASH_ATTN - or self.attn_backend == AttentionBackendEnum.ROCM_AITER_FA - ): + if self.attn_backend in { + AttentionBackendEnum.FLASH_ATTN, + AttentionBackendEnum.ROCM_AITER_FA, + AttentionBackendEnum.TRITON_ATTN, + }: max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max() return max_seqlen diff --git a/vllm/model_executor/models/ernie45_vl.py b/vllm/model_executor/models/ernie45_vl.py index 37e95b261..ab1386e08 100644 --- a/vllm/model_executor/models/ernie45_vl.py +++ b/vllm/model_executor/models/ernie45_vl.py @@ -446,10 +446,11 @@ class Ernie4_5_VisionTransformer(nn.Module): def compute_attn_mask_seqlen(self, cu_seqlens: torch.Tensor) -> torch.Tensor | None: max_seqlen = None - if ( - self.attn_backend == AttentionBackendEnum.FLASH_ATTN - or self.attn_backend == AttentionBackendEnum.ROCM_AITER_FA - ): + if self.attn_backend in { + AttentionBackendEnum.FLASH_ATTN, + AttentionBackendEnum.ROCM_AITER_FA, + AttentionBackendEnum.TRITON_ATTN, + }: max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max() return max_seqlen diff --git a/vllm/model_executor/models/glm4_1v.py b/vllm/model_executor/models/glm4_1v.py index 23f27db3c..a85d5e6f9 100644 --- a/vllm/model_executor/models/glm4_1v.py +++ b/vllm/model_executor/models/glm4_1v.py @@ -723,10 +723,11 @@ class Glm4vVisionTransformer(nn.Module): cu_seqlens: torch.Tensor, ) -> torch.Tensor | None: max_seqlen = None - if ( - self.attn_backend == AttentionBackendEnum.FLASH_ATTN - or self.attn_backend == AttentionBackendEnum.ROCM_AITER_FA - ): + if self.attn_backend in { + AttentionBackendEnum.FLASH_ATTN, + AttentionBackendEnum.ROCM_AITER_FA, + AttentionBackendEnum.TRITON_ATTN, + }: max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max() return max_seqlen diff --git a/vllm/model_executor/models/paddleocr_vl.py b/vllm/model_executor/models/paddleocr_vl.py index 021f24e11..2bbe7e850 100644 --- a/vllm/model_executor/models/paddleocr_vl.py +++ b/vllm/model_executor/models/paddleocr_vl.py @@ -730,14 +730,7 @@ class SiglipEncoder(nn.Module): head_size=head_dim, dtype=torch.get_default_dtype(), ) - if self.attn_backend not in { - AttentionBackendEnum.FLASH_ATTN, - AttentionBackendEnum.TORCH_SDPA, - AttentionBackendEnum.ROCM_AITER_FA, - }: - raise RuntimeError( - f"PaddleOCR-VL does not support {self.attn_backend} backend now." - ) + self.layers = nn.ModuleList( [ SiglipEncoderLayer( @@ -805,6 +798,7 @@ class SiglipEncoder(nn.Module): if self.attn_backend in { AttentionBackendEnum.FLASH_ATTN, AttentionBackendEnum.ROCM_AITER_FA, + AttentionBackendEnum.TRITON_ATTN, }: max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max() diff --git a/vllm/model_executor/models/qwen2_5_vl.py b/vllm/model_executor/models/qwen2_5_vl.py index c2c52fa66..9e5f1175a 100644 --- a/vllm/model_executor/models/qwen2_5_vl.py +++ b/vllm/model_executor/models/qwen2_5_vl.py @@ -607,15 +607,6 @@ class Qwen2_5_VisionTransformer(nn.Module): dtype=torch.get_default_dtype(), ) - if self.attn_backend not in { - AttentionBackendEnum.FLASH_ATTN, - AttentionBackendEnum.TORCH_SDPA, - AttentionBackendEnum.ROCM_AITER_FA, - }: - raise RuntimeError( - f"Qwen2.5-VL does not support {self.attn_backend} backend now." - ) - with set_model_tag("Qwen2_5_VisionBlock", is_encoder=True): self.blocks = nn.ModuleList( [ @@ -761,6 +752,7 @@ class Qwen2_5_VisionTransformer(nn.Module): if self.attn_backend in { AttentionBackendEnum.FLASH_ATTN, AttentionBackendEnum.ROCM_AITER_FA, + AttentionBackendEnum.TRITON_ATTN, }: max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max() return max_seqlen diff --git a/vllm/model_executor/models/qwen2_vl.py b/vllm/model_executor/models/qwen2_vl.py index 1c568bdff..c530493b1 100644 --- a/vllm/model_executor/models/qwen2_vl.py +++ b/vllm/model_executor/models/qwen2_vl.py @@ -642,6 +642,7 @@ class Qwen2VisionTransformer(nn.Module): if self.attn_backend in { AttentionBackendEnum.FLASH_ATTN, AttentionBackendEnum.ROCM_AITER_FA, + AttentionBackendEnum.TRITON_ATTN, }: max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max() return max_seqlen diff --git a/vllm/model_executor/models/qwen3_omni_moe_thinker.py b/vllm/model_executor/models/qwen3_omni_moe_thinker.py index 50fbb8be1..2943a319f 100755 --- a/vllm/model_executor/models/qwen3_omni_moe_thinker.py +++ b/vllm/model_executor/models/qwen3_omni_moe_thinker.py @@ -391,6 +391,7 @@ class Qwen3OmniMoeAudioEncoder(nn.Module): if self.attn_backend in { AttentionBackendEnum.FLASH_ATTN, AttentionBackendEnum.ROCM_AITER_FA, + AttentionBackendEnum.TRITON_ATTN, }: max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max() return max_seqlen @@ -919,6 +920,7 @@ class Qwen3Omni_VisionTransformer(nn.Module): if self.attn_backend in { AttentionBackendEnum.FLASH_ATTN, AttentionBackendEnum.ROCM_AITER_FA, + AttentionBackendEnum.TRITON_ATTN, }: max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max() return max_seqlen diff --git a/vllm/model_executor/models/qwen3_vl.py b/vllm/model_executor/models/qwen3_vl.py index c18fc77f7..abb38a648 100644 --- a/vllm/model_executor/models/qwen3_vl.py +++ b/vllm/model_executor/models/qwen3_vl.py @@ -385,14 +385,6 @@ class Qwen3_VisionTransformer(nn.Module): dtype=torch.get_default_dtype(), ) - if self.attn_backend not in { - AttentionBackendEnum.FLASH_ATTN, - AttentionBackendEnum.TORCH_SDPA, - AttentionBackendEnum.ROCM_AITER_FA, - }: - raise RuntimeError( - f"Qwen3-VL does not support {self.attn_backend} backend now." - ) self.blocks = nn.ModuleList( [ Qwen3_VisionBlock( @@ -526,9 +518,10 @@ class Qwen3_VisionTransformer(nn.Module): cu_seqlens: torch.Tensor, ) -> torch.Tensor: max_seqlen = torch.zeros([], device=cu_seqlens.device) - if ( - self.attn_backend == AttentionBackendEnum.FLASH_ATTN - or self.attn_backend == AttentionBackendEnum.ROCM_AITER_FA + if self.attn_backend in ( + AttentionBackendEnum.FLASH_ATTN, + AttentionBackendEnum.ROCM_AITER_FA, + AttentionBackendEnum.TRITON_ATTN, ): max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max() return max_seqlen diff --git a/vllm/model_executor/models/vision.py b/vllm/model_executor/models/vision.py index a2b78753a..8882754b3 100644 --- a/vllm/model_executor/models/vision.py +++ b/vllm/model_executor/models/vision.py @@ -108,7 +108,7 @@ def get_vit_attn_backend( multimodal_config: MultiModalConfig | None = ( model_config.multimodal_config if model_config is not None else None ) - except AssertionError: + except (AssertionError, AttributeError): multimodal_config = None attn_backend_override = ( @@ -134,7 +134,7 @@ def is_vit_use_data_parallel(): multimodal_config: MultiModalConfig | None = ( model_config.multimodal_config if model_config is not None else None ) - except AssertionError: + except (AssertionError, AttributeError): multimodal_config = None mm_encoder_tp_mode = ( diff --git a/vllm/platforms/cuda.py b/vllm/platforms/cuda.py index b7efe24dc..c2fcde4ab 100644 --- a/vllm/platforms/cuda.py +++ b/vllm/platforms/cuda.py @@ -411,8 +411,9 @@ class CudaPlatformBase(Platform): @classmethod def get_supported_vit_attn_backends(cls) -> list["AttentionBackendEnum"]: return [ - AttentionBackendEnum.TORCH_SDPA, AttentionBackendEnum.FLASH_ATTN, + AttentionBackendEnum.TRITON_ATTN, + AttentionBackendEnum.TORCH_SDPA, ] @classmethod @@ -430,14 +431,25 @@ class CudaPlatformBase(Platform): logger.info_once(f"Using backend {backend} for vit attention") return backend - # Try FlashAttention first - if (cc := cls.get_device_capability()) and cc.major >= 8: + cc = cls.get_device_capability() + for vit_attn_backend in cls.get_supported_vit_attn_backends(): + if vit_attn_backend == AttentionBackendEnum.TORCH_SDPA: + continue try: - backend_class = AttentionBackendEnum.FLASH_ATTN.get_class() - if backend_class.supports_head_size( + backend_class = vit_attn_backend.get_class() + is_backend_supported = backend_class.supports_head_size( head_size - ) and backend_class.supports_dtype(dtype): - return AttentionBackendEnum.FLASH_ATTN + ) and backend_class.supports_dtype(dtype) + if cc is not None: + is_backend_supported = ( + is_backend_supported + and backend_class.supports_compute_capability(cc) + ) + if is_backend_supported: + logger.info_once( + f"Using backend {vit_attn_backend} for vit attention" + ) + return vit_attn_backend except ImportError: pass diff --git a/vllm/platforms/rocm.py b/vllm/platforms/rocm.py index 808d21400..2fedd7c67 100644 --- a/vllm/platforms/rocm.py +++ b/vllm/platforms/rocm.py @@ -384,6 +384,7 @@ class RocmPlatform(Platform): return [ AttentionBackendEnum.FLASH_ATTN, AttentionBackendEnum.ROCM_AITER_FA, + AttentionBackendEnum.TRITON_ATTN, AttentionBackendEnum.TORCH_SDPA, ] diff --git a/vllm/v1/attention/ops/vit_attn_wrappers.py b/vllm/v1/attention/ops/vit_attn_wrappers.py index 32fcb3511..f5c748fbc 100644 --- a/vllm/v1/attention/ops/vit_attn_wrappers.py +++ b/vllm/v1/attention/ops/vit_attn_wrappers.py @@ -110,6 +110,83 @@ def vit_flash_attn_wrapper( ) +def triton_attn_wrapper( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + batch_size: int, + scale: float | None = None, + cu_seqlens: torch.Tensor | None = None, + max_seqlen: torch.Tensor | None = None, +) -> torch.Tensor: + from vllm.v1.attention.ops.triton_prefill_attention import context_attention_fwd + + q_len = q.size(1) + if cu_seqlens is None: + cu_seqlens = torch.arange( + 0, (batch_size + 1) * q_len, step=q_len, dtype=torch.int32, device=q.device + ) + max_seqlen = q_len if max_seqlen is None else max_seqlen.item() + + q, k, v = (einops.rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v]) + output = torch.empty_like(q) + context_attention_fwd( + q, + k, + v, + output, + b_start_loc=cu_seqlens[:-1], + b_seq_len=cu_seqlens[1:] - cu_seqlens[:-1], + max_input_len=max_seqlen, + is_causal=False, + sliding_window_q=None, + sliding_window_k=None, + softmax_scale=scale, + ) + + context_layer = einops.rearrange(output, "(b s) h d -> b s h d", b=batch_size) + return context_layer + + +def triton_attn_wrapper_fake( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + batch_size: int, + scale: float | None = None, + cu_seqlens: torch.Tensor | None = None, + max_seqlen: torch.Tensor | None = None, +) -> torch.Tensor: + return torch.empty_like(q) + + +direct_register_custom_op( + op_name="triton_attn_wrapper", + op_func=triton_attn_wrapper, + fake_impl=triton_attn_wrapper_fake, +) + + +def vit_triton_attn_wrapper( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + batch_size: int, + scale: float | None = None, + cu_seqlens: torch.Tensor | None = None, + max_seqlen: torch.Tensor | None = None, +) -> torch.Tensor: + return torch.ops.vllm.triton_attn_wrapper( + q, + k, + v, + batch_size, + scale, + cu_seqlens, + max_seqlen, + ) + + def apply_sdpa( q: torch.Tensor, k: torch.Tensor,