[ROCm] [VL] [Bugfix] Fix vit flash attn dispatcher logic for ROCm (#26104)

Signed-off-by: tjtanaa <tunjian.tan@embeddedllm.com>
This commit is contained in:
TJian
2025-10-02 22:34:53 -07:00
committed by GitHub
parent 27edd2aeb4
commit 9c5ee91b2a
9 changed files with 154 additions and 141 deletions

View File

@@ -1,7 +1,7 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Attention layer.""" """Attention layer."""
from typing import List, Optional from typing import Callable, List, Optional
import torch import torch
import torch.nn as nn import torch.nn as nn
@@ -68,9 +68,39 @@ def check_upstream_fa_availability(dtype: torch.dtype):
) and current_platform.has_device_capability(80): ) and current_platform.has_device_capability(80):
from transformers.utils import is_flash_attn_2_available from transformers.utils import is_flash_attn_2_available
return is_flash_attn_2_available() return is_flash_attn_2_available()
if current_platform.is_rocm():
from importlib.util import find_spec
return find_spec("flash_attn") is not None
return False return False
def maybe_get_vit_flash_attn_backend(
attn_backend: _Backend,
use_upstream_fa: bool) -> tuple[_Backend, Callable]:
if attn_backend != _Backend.FLASH_ATTN and \
attn_backend != _Backend.ROCM_AITER_FA and \
check_upstream_fa_availability(torch.get_default_dtype()):
attn_backend = _Backend.FLASH_ATTN
use_upstream_fa = True
if current_platform.is_rocm() and \
attn_backend == _Backend.FLASH_ATTN:
use_upstream_fa = True
if (attn_backend in {_Backend.FLASH_ATTN, _Backend.ROCM_AITER_FA}):
if attn_backend == _Backend.ROCM_AITER_FA:
from aiter import flash_attn_varlen_func
else:
if use_upstream_fa:
from flash_attn import flash_attn_varlen_func
else:
from vllm.vllm_flash_attn import flash_attn_varlen_func
else:
flash_attn_varlen_func = None
return attn_backend, flash_attn_varlen_func
class Attention(nn.Module, AttentionLayerBase): class Attention(nn.Module, AttentionLayerBase):
"""Attention layer. """Attention layer.
@@ -410,13 +440,9 @@ class MultiHeadAttention(nn.Module):
# to upstream flash attention if available. # to upstream flash attention if available.
# If vllm native fa is selected, we use it directly. # If vllm native fa is selected, we use it directly.
use_upstream_fa = False use_upstream_fa = False
if backend != _Backend.FLASH_ATTN and check_upstream_fa_availability(
dtype):
backend = _Backend.FLASH_ATTN
use_upstream_fa = True
if current_platform.is_rocm() or current_platform.is_xpu(): if current_platform.is_xpu():
# currently, only torch_sdpa is supported on rocm/xpu # currently, only torch_sdpa is supported on xpu
self.attn_backend = _Backend.TORCH_SDPA self.attn_backend = _Backend.TORCH_SDPA
else: else:
@@ -428,17 +454,25 @@ class MultiHeadAttention(nn.Module):
_Backend.FLASH_ATTN, _Backend.FLASH_ATTN,
} else _Backend.TORCH_SDPA } else _Backend.TORCH_SDPA
self.attn_backend, self._flash_attn_varlen_func \
= maybe_get_vit_flash_attn_backend(
self.attn_backend,
use_upstream_fa,
)
if (self.attn_backend == _Backend.XFORMERS if (self.attn_backend == _Backend.XFORMERS
and not check_xformers_availability()): and not check_xformers_availability()):
self.attn_backend = _Backend.TORCH_SDPA self.attn_backend = _Backend.TORCH_SDPA
if self.attn_backend == _Backend.FLASH_ATTN: self.is_flash_attn_backend = self.attn_backend in {
if use_upstream_fa: _Backend.FLASH_ATTN, _Backend.ROCM_AITER_FA
from flash_attn import flash_attn_varlen_func }
self._flash_attn_varlen_func = flash_attn_varlen_func
else: # this condition is just to make sure that the
from vllm.vllm_flash_attn import flash_attn_varlen_func # use_upstream_fa in the log is correct
self._flash_attn_varlen_func = flash_attn_varlen_func if current_platform.is_rocm() \
and self.attn_backend == _Backend.FLASH_ATTN:
use_upstream_fa = True
logger.info_once( logger.info_once(
f"MultiHeadAttention attn_backend: {self.attn_backend}, " f"MultiHeadAttention attn_backend: {self.attn_backend}, "
@@ -466,7 +500,7 @@ class MultiHeadAttention(nn.Module):
key = torch.repeat_interleave(key, num_repeat, dim=2) key = torch.repeat_interleave(key, num_repeat, dim=2)
value = torch.repeat_interleave(value, num_repeat, dim=2) value = torch.repeat_interleave(value, num_repeat, dim=2)
if self.attn_backend == _Backend.FLASH_ATTN: if self.is_flash_attn_backend:
cu_seqlens_q = torch.arange(0, (bsz + 1) * q_len, cu_seqlens_q = torch.arange(0, (bsz + 1) * q_len,
step=q_len, step=q_len,
dtype=torch.int32, dtype=torch.int32,
@@ -507,14 +541,6 @@ class MultiHeadAttention(nn.Module):
from torch_xla.experimental.custom_kernel import flash_attention from torch_xla.experimental.custom_kernel import flash_attention
out = flash_attention(query, key, value, sm_scale=self.scale) out = flash_attention(query, key, value, sm_scale=self.scale)
out = out.transpose(1, 2) out = out.transpose(1, 2)
elif self.attn_backend == _Backend.ROCM_AITER_FA:
from aiter import flash_attn_varlen_func
# ROCm Flash Attention expects (batch, seq, heads, head_dim)
out = flash_attn_varlen_func(query,
key,
value,
softmax_scale=self.scale)
else: else:
# ViT attention hasn't supported this backend yet # ViT attention hasn't supported this backend yet
raise NotImplementedError( raise NotImplementedError(

View File

@@ -10,7 +10,8 @@ from torch.nn import LayerNorm
from transformers.models.qwen2_vl import Qwen2VLProcessor from transformers.models.qwen2_vl import Qwen2VLProcessor
from vllm.attention.backends.registry import _Backend from vllm.attention.backends.registry import _Backend
from vllm.attention.layer import check_upstream_fa_availability from vllm.attention.layer import (check_upstream_fa_availability,
maybe_get_vit_flash_attn_backend)
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.distributed import utils as dist_utils from vllm.distributed import utils as dist_utils
from vllm.distributed.parallel_state import ( from vllm.distributed.parallel_state import (
@@ -267,10 +268,12 @@ class DotsVisionAttention(nn.Module):
self.attn_backend = get_vit_attn_backend( self.attn_backend = get_vit_attn_backend(
self.hidden_size_per_attention_head, torch.get_default_dtype()) self.hidden_size_per_attention_head, torch.get_default_dtype())
self.use_upstream_fa = False self.use_upstream_fa = False
if self.attn_backend != _Backend.FLASH_ATTN and \
check_upstream_fa_availability(torch.get_default_dtype()): self.attn_backend, self.flash_attn_varlen_func \
self.attn_backend = _Backend.FLASH_ATTN = maybe_get_vit_flash_attn_backend(
self.use_upstream_fa = True self.attn_backend,
self.use_upstream_fa,
)
if self.attn_backend not in { if self.attn_backend not in {
_Backend.FLASH_ATTN, _Backend.TORCH_SDPA, _Backend.XFORMERS, _Backend.FLASH_ATTN, _Backend.TORCH_SDPA, _Backend.XFORMERS,
_Backend.ROCM_AITER_FA _Backend.ROCM_AITER_FA
@@ -306,25 +309,18 @@ class DotsVisionAttention(nn.Module):
q, k = torch.chunk(qk_rotated, 2, dim=0) q, k = torch.chunk(qk_rotated, 2, dim=0)
if self.is_flash_attn_backend: if self.is_flash_attn_backend:
if self.attn_backend == _Backend.ROCM_AITER_FA:
from aiter import flash_attn_varlen_func
else:
if self.use_upstream_fa:
from flash_attn import flash_attn_varlen_func
else:
from vllm.vllm_flash_attn import flash_attn_varlen_func
q_ = q.reshape(bs * q.shape[1], q.shape[2], q.shape[3]) q_ = q.reshape(bs * q.shape[1], q.shape[2], q.shape[3])
k_ = k.reshape(bs * k.shape[1], k.shape[2], k.shape[3]) k_ = k.reshape(bs * k.shape[1], k.shape[2], k.shape[3])
v_ = v.reshape(bs * v.shape[1], v.shape[2], v.shape[3]) v_ = v.reshape(bs * v.shape[1], v.shape[2], v.shape[3])
output = flash_attn_varlen_func(q_, output = self.flash_attn_varlen_func(q_,
k_, k_,
v_, v_,
cu_seqlens_q=cu_seqlens, cu_seqlens_q=cu_seqlens,
cu_seqlens_k=cu_seqlens, cu_seqlens_k=cu_seqlens,
max_seqlen_q=max_seqlen, max_seqlen_q=max_seqlen,
max_seqlen_k=max_seqlen, max_seqlen_k=max_seqlen,
dropout_p=0.0, dropout_p=0.0,
causal=False) causal=False)
context_layer = output.view(bs, -1, context_layer = output.view(bs, -1,
self.num_attention_heads_per_partition, self.num_attention_heads_per_partition,
self.hidden_size_per_attention_head) self.hidden_size_per_attention_head)
@@ -611,7 +607,8 @@ class DotsVisionTransformer(nn.Module):
self, cu_seqlens: torch.Tensor self, cu_seqlens: torch.Tensor
) -> tuple[Optional[int], Optional[list[int]]]: ) -> tuple[Optional[int], Optional[list[int]]]:
max_seqlen, seqlens = None, None max_seqlen, seqlens = None, None
if self.attn_backend == _Backend.FLASH_ATTN: if (self.attn_backend == _Backend.FLASH_ATTN
or self.attn_backend == _Backend.ROCM_AITER_FA):
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
elif self.attn_backend == _Backend.XFORMERS: elif self.attn_backend == _Backend.XFORMERS:
seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist() seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()

View File

@@ -35,7 +35,8 @@ from einops import rearrange, repeat
from transformers import BatchFeature from transformers import BatchFeature
from vllm.attention.backends.registry import _Backend from vllm.attention.backends.registry import _Backend
from vllm.attention.layer import check_upstream_fa_availability from vllm.attention.layer import (check_upstream_fa_availability,
maybe_get_vit_flash_attn_backend)
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.distributed import parallel_state from vllm.distributed import parallel_state
from vllm.distributed import utils as dist_utils from vllm.distributed import utils as dist_utils
@@ -176,14 +177,18 @@ class Ernie4_5_VisionAttention(nn.Module):
dtype=torch.get_default_dtype()) dtype=torch.get_default_dtype())
self.use_upstream_fa = False self.use_upstream_fa = False
if self.attn_backend != _Backend.FLASH_ATTN and \
check_upstream_fa_availability(torch.get_default_dtype()): self.attn_backend, self.flash_attn_varlen_func \
self.attn_backend = _Backend.FLASH_ATTN = maybe_get_vit_flash_attn_backend(
self.use_upstream_fa = True self.attn_backend,
self.use_upstream_fa,
)
if self.attn_backend not in { if self.attn_backend not in {
_Backend.FLASH_ATTN, _Backend.TORCH_SDPA, _Backend.XFORMERS, _Backend.FLASH_ATTN,
_Backend.ROCM_AITER_FA _Backend.TORCH_SDPA,
_Backend.XFORMERS,
_Backend.ROCM_AITER_FA,
}: }:
raise RuntimeError( raise RuntimeError(
f"Ernie45-VL does not support {self.attn_backend} backend now." f"Ernie45-VL does not support {self.attn_backend} backend now."
@@ -239,27 +244,18 @@ class Ernie4_5_VisionAttention(nn.Module):
q, k = torch.chunk(qk_rotated, 2, dim=0) q, k = torch.chunk(qk_rotated, 2, dim=0)
if self.is_flash_attn_backend: if self.is_flash_attn_backend:
# from vllm_flash_attn.flash_attn_interface import (
# flash_attn_varlen_func)
if self.attn_backend == _Backend.ROCM_AITER_FA:
from aiter import flash_attn_varlen_func
else:
if self.use_upstream_fa:
from flash_attn import flash_attn_varlen_func
else:
from vllm.vllm_flash_attn import flash_attn_varlen_func
q, k, v = (rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v]) q, k, v = (rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v])
output = flash_attn_varlen_func(q, output = self.flash_attn_varlen_func(q,
k, k,
v, v,
cu_seqlens_q=cu_seqlens, cu_seqlens_q=cu_seqlens,
cu_seqlens_k=cu_seqlens, cu_seqlens_k=cu_seqlens,
max_seqlen_q=max_seqlen, max_seqlen_q=max_seqlen,
max_seqlen_k=max_seqlen, max_seqlen_k=max_seqlen,
dropout_p=0.0, dropout_p=0.0,
causal=False) causal=False)
context_layer = rearrange(output, context_layer = rearrange(output,
"(b s) h d -> s b (h d)", "(b s) h d -> s b (h d)",
@@ -516,7 +512,8 @@ class Ernie4_5_VisionTransformer(nn.Module):
self, cu_seqlens: torch.Tensor self, cu_seqlens: torch.Tensor
) -> tuple[Optional[int], Optional[list[int]]]: ) -> tuple[Optional[int], Optional[list[int]]]:
max_seqlen, seqlens = None, None max_seqlen, seqlens = None, None
if self.attn_backend == _Backend.FLASH_ATTN: if (self.attn_backend == _Backend.FLASH_ATTN
or self.attn_backend == _Backend.ROCM_AITER_FA):
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
elif self.attn_backend == _Backend.XFORMERS: elif self.attn_backend == _Backend.XFORMERS:
seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist() seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()

View File

@@ -47,7 +47,8 @@ from transformers.models.glm4v.video_processing_glm4v import (
from transformers.video_utils import VideoMetadata from transformers.video_utils import VideoMetadata
from vllm.attention.backends.registry import _Backend from vllm.attention.backends.registry import _Backend
from vllm.attention.layer import check_upstream_fa_availability from vllm.attention.layer import (check_upstream_fa_availability,
maybe_get_vit_flash_attn_backend)
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.distributed import (get_tensor_model_parallel_world_size, from vllm.distributed import (get_tensor_model_parallel_world_size,
parallel_state) parallel_state)
@@ -263,19 +264,26 @@ class Glm4vVisionAttention(nn.Module):
head_size=self.hidden_size_per_attention_head, head_size=self.hidden_size_per_attention_head,
dtype=torch.get_default_dtype()) dtype=torch.get_default_dtype())
self.use_upstream_fa = False self.use_upstream_fa = False
if self.attn_backend != _Backend.FLASH_ATTN and \
check_upstream_fa_availability(torch.get_default_dtype()): self.attn_backend, self.flash_attn_varlen_func \
self.attn_backend = _Backend.FLASH_ATTN = maybe_get_vit_flash_attn_backend(
self.use_upstream_fa = True self.attn_backend,
self.use_upstream_fa,
)
if self.attn_backend not in { if self.attn_backend not in {
_Backend.FLASH_ATTN, _Backend.FLASH_ATTN,
_Backend.TORCH_SDPA, _Backend.TORCH_SDPA,
_Backend.XFORMERS, _Backend.XFORMERS,
_Backend.ROCM_AITER_FA,
}: }:
raise RuntimeError( raise RuntimeError(
f"GLM-4V does not support {self.attn_backend} backend now.") f"GLM-4V does not support {self.attn_backend} backend now.")
self.is_flash_attn_backend = self.attn_backend in {
_Backend.FLASH_ATTN, _Backend.ROCM_AITER_FA
}
def split_qkv(self, qkv: torch.Tensor) -> tuple[torch.Tensor, ...]: def split_qkv(self, qkv: torch.Tensor) -> tuple[torch.Tensor, ...]:
# [s, b, 3 * head * head_dim] # [s, b, 3 * head * head_dim]
seq_len, bs, _ = qkv.shape seq_len, bs, _ = qkv.shape
@@ -316,17 +324,11 @@ class Glm4vVisionAttention(nn.Module):
qk_rotated = apply_rotary_pos_emb_vision(qk_concat, rotary_pos_emb) qk_rotated = apply_rotary_pos_emb_vision(qk_concat, rotary_pos_emb)
q, k = torch.chunk(qk_rotated, 2, dim=0) q, k = torch.chunk(qk_rotated, 2, dim=0)
if self.attn_backend == _Backend.FLASH_ATTN: if self.is_flash_attn_backend:
# from vllm_flash_attn.flash_attn_interface import (
# flash_attn_varlen_func)
if self.use_upstream_fa:
from flash_attn import flash_attn_varlen_func
else:
from vllm.vllm_flash_attn import flash_attn_varlen_func
q, k, v = (rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v]) q, k, v = (rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v])
output = flash_attn_varlen_func( output = self.flash_attn_varlen_func(
q, q,
k, k,
v, v,
@@ -774,7 +776,8 @@ class Glm4vVisionTransformer(nn.Module):
) -> tuple[Optional[int], Optional[list[int]]]: ) -> tuple[Optional[int], Optional[list[int]]]:
max_seqlen, seqlens = None, None max_seqlen, seqlens = None, None
seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist() seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()
if self.attn_backend == _Backend.FLASH_ATTN: if (self.attn_backend == _Backend.FLASH_ATTN
or self.attn_backend == _Backend.ROCM_AITER_FA):
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
return max_seqlen, seqlens return max_seqlen, seqlens

View File

@@ -39,7 +39,8 @@ from transformers.models.qwen2_5_vl.configuration_qwen2_5_vl import (
Qwen2_5_VLConfig, Qwen2_5_VLVisionConfig) Qwen2_5_VLConfig, Qwen2_5_VLVisionConfig)
from vllm.attention.backends.registry import _Backend from vllm.attention.backends.registry import _Backend
from vllm.attention.layer import check_upstream_fa_availability from vllm.attention.layer import (check_upstream_fa_availability,
maybe_get_vit_flash_attn_backend)
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.distributed import parallel_state from vllm.distributed import parallel_state
from vllm.distributed import utils as dist_utils from vllm.distributed import utils as dist_utils
@@ -302,6 +303,11 @@ class Qwen2_5_VisionAttention(nn.Module):
disable_tp=use_data_parallel) disable_tp=use_data_parallel)
self.attn_backend = attn_backend self.attn_backend = attn_backend
self.use_upstream_fa = use_upstream_fa self.use_upstream_fa = use_upstream_fa
self.attn_backend, self.flash_attn_varlen_func \
= maybe_get_vit_flash_attn_backend(
self.attn_backend,
self.use_upstream_fa,
)
self.is_flash_attn_backend = self.attn_backend in { self.is_flash_attn_backend = self.attn_backend in {
_Backend.FLASH_ATTN, _Backend.ROCM_AITER_FA _Backend.FLASH_ATTN, _Backend.ROCM_AITER_FA
} }
@@ -354,25 +360,18 @@ class Qwen2_5_VisionAttention(nn.Module):
q, k = torch.chunk(qk_rotated, 2, dim=0) q, k = torch.chunk(qk_rotated, 2, dim=0)
if self.is_flash_attn_backend: if self.is_flash_attn_backend:
if self.attn_backend == _Backend.ROCM_AITER_FA:
from aiter import flash_attn_varlen_func
else:
if self.use_upstream_fa:
from flash_attn import flash_attn_varlen_func
else:
from vllm.vllm_flash_attn import flash_attn_varlen_func
q, k, v = (rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v]) q, k, v = (rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v])
output = flash_attn_varlen_func(q, output = self.flash_attn_varlen_func(q,
k, k,
v, v,
cu_seqlens_q=cu_seqlens, cu_seqlens_q=cu_seqlens,
cu_seqlens_k=cu_seqlens, cu_seqlens_k=cu_seqlens,
max_seqlen_q=max_seqlen, max_seqlen_q=max_seqlen,
max_seqlen_k=max_seqlen, max_seqlen_k=max_seqlen,
dropout_p=0.0, dropout_p=0.0,
causal=False) causal=False)
context_layer = rearrange(output, context_layer = rearrange(output,
"(b s) h d -> s b (h d)", "(b s) h d -> s b (h d)",
@@ -618,6 +617,7 @@ class Qwen2_5_VisionTransformer(nn.Module):
self.attn_backend = get_vit_attn_backend( self.attn_backend = get_vit_attn_backend(
head_size=head_dim, dtype=torch.get_default_dtype()) head_size=head_dim, dtype=torch.get_default_dtype())
if self.attn_backend != _Backend.FLASH_ATTN and \ if self.attn_backend != _Backend.FLASH_ATTN and \
self.attn_backend != _Backend.ROCM_AITER_FA and \
check_upstream_fa_availability( check_upstream_fa_availability(
torch.get_default_dtype()): torch.get_default_dtype()):
self.attn_backend = _Backend.FLASH_ATTN self.attn_backend = _Backend.FLASH_ATTN

View File

@@ -42,7 +42,8 @@ from transformers.models.qwen2_vl.video_processing_qwen2_vl import (
Qwen2VLVideoProcessor) Qwen2VLVideoProcessor)
from vllm.attention.backends.registry import _Backend from vllm.attention.backends.registry import _Backend
from vllm.attention.layer import check_upstream_fa_availability from vllm.attention.layer import (check_upstream_fa_availability,
maybe_get_vit_flash_attn_backend)
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.distributed import parallel_state, tensor_model_parallel_all_gather from vllm.distributed import parallel_state, tensor_model_parallel_all_gather
from vllm.distributed import utils as dist_utils from vllm.distributed import utils as dist_utils
@@ -319,11 +320,12 @@ class Qwen2VisionAttention(nn.Module):
head_size=self.hidden_size_per_attention_head, head_size=self.hidden_size_per_attention_head,
dtype=torch.get_default_dtype()) dtype=torch.get_default_dtype())
self.use_upstream_fa = False self.use_upstream_fa = False
if self.attn_backend != _Backend.FLASH_ATTN and \
check_upstream_fa_availability( self.attn_backend, self.flash_attn_varlen_func \
torch.get_default_dtype()): = maybe_get_vit_flash_attn_backend(
self.attn_backend = _Backend.FLASH_ATTN self.attn_backend,
self.use_upstream_fa = True self.use_upstream_fa,
)
if self.attn_backend not in { if self.attn_backend not in {
_Backend.FLASH_ATTN, _Backend.TORCH_SDPA, _Backend.XFORMERS, _Backend.FLASH_ATTN, _Backend.TORCH_SDPA, _Backend.XFORMERS,
@@ -331,6 +333,7 @@ class Qwen2VisionAttention(nn.Module):
}: }:
raise RuntimeError( raise RuntimeError(
f"Qwen2-VL does not support {self.attn_backend} backend now.") f"Qwen2-VL does not support {self.attn_backend} backend now.")
self.is_flash_attn_backend = self.attn_backend in { self.is_flash_attn_backend = self.attn_backend in {
_Backend.FLASH_ATTN, _Backend.ROCM_AITER_FA _Backend.FLASH_ATTN, _Backend.ROCM_AITER_FA
} }
@@ -383,25 +386,18 @@ class Qwen2VisionAttention(nn.Module):
q, k = torch.chunk(qk_rotated, 2, dim=0) q, k = torch.chunk(qk_rotated, 2, dim=0)
if self.is_flash_attn_backend: if self.is_flash_attn_backend:
if self.attn_backend == _Backend.ROCM_AITER_FA:
from aiter import flash_attn_varlen_func
else:
if self.use_upstream_fa:
from flash_attn import flash_attn_varlen_func
else:
from vllm.vllm_flash_attn import flash_attn_varlen_func
q, k, v = (rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v]) q, k, v = (rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v])
output = flash_attn_varlen_func(q, output = self.flash_attn_varlen_func(q,
k, k,
v, v,
cu_seqlens_q=cu_seqlens, cu_seqlens_q=cu_seqlens,
cu_seqlens_k=cu_seqlens, cu_seqlens_k=cu_seqlens,
max_seqlen_q=max_seqlen, max_seqlen_q=max_seqlen,
max_seqlen_k=max_seqlen, max_seqlen_k=max_seqlen,
dropout_p=0.0, dropout_p=0.0,
causal=False) causal=False)
context_layer = rearrange(output, context_layer = rearrange(output,
"(b s) h d -> s b (h d)", "(b s) h d -> s b (h d)",

View File

@@ -323,6 +323,7 @@ class Qwen3_VisionTransformer(nn.Module):
head_size=head_dim, dtype=torch.get_default_dtype()) head_size=head_dim, dtype=torch.get_default_dtype())
use_upstream_fa = False use_upstream_fa = False
if self.attn_backend != _Backend.FLASH_ATTN and \ if self.attn_backend != _Backend.FLASH_ATTN and \
self.attn_backend != _Backend.ROCM_AITER_FA and \
check_upstream_fa_availability( check_upstream_fa_availability(
torch.get_default_dtype()): torch.get_default_dtype()):
self.attn_backend = _Backend.FLASH_ATTN self.attn_backend = _Backend.FLASH_ATTN
@@ -476,7 +477,8 @@ class Qwen3_VisionTransformer(nn.Module):
cu_seqlens: torch.Tensor, cu_seqlens: torch.Tensor,
) -> tuple[Optional[int], Optional[list[int]]]: ) -> tuple[Optional[int], Optional[list[int]]]:
max_seqlen, seqlens = None, None max_seqlen, seqlens = None, None
if self.attn_backend == _Backend.FLASH_ATTN: if (self.attn_backend == _Backend.FLASH_ATTN
or self.attn_backend == _Backend.ROCM_AITER_FA):
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
elif self.attn_backend == _Backend.XFORMERS: elif self.attn_backend == _Backend.XFORMERS:
seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist() seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()

View File

@@ -14,7 +14,7 @@ from transformers import Siglip2VisionConfig
from transformers.configuration_utils import PretrainedConfig from transformers.configuration_utils import PretrainedConfig
from vllm.attention.backends.registry import _Backend from vllm.attention.backends.registry import _Backend
from vllm.attention.layer import check_upstream_fa_availability from vllm.attention.layer import maybe_get_vit_flash_attn_backend
from vllm.distributed import divide, get_tensor_model_parallel_world_size from vllm.distributed import divide, get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.linear import (ColumnParallelLinear, from vllm.model_executor.layers.linear import (ColumnParallelLinear,
@@ -240,11 +240,12 @@ class Siglip2Attention(nn.Module):
self.attn_backend = get_vit_attn_backend( self.attn_backend = get_vit_attn_backend(
head_size=self.head_dim, dtype=torch.get_default_dtype()) head_size=self.head_dim, dtype=torch.get_default_dtype())
self.use_upstream_fa = False self.use_upstream_fa = False
if self.attn_backend != _Backend.FLASH_ATTN and \
check_upstream_fa_availability( self.attn_backend, self.flash_attn_varlen_func \
torch.get_default_dtype()): = maybe_get_vit_flash_attn_backend(
self.attn_backend = _Backend.FLASH_ATTN self.attn_backend,
self.use_upstream_fa = True self.use_upstream_fa,
)
if self.attn_backend not in { if self.attn_backend not in {
_Backend.FLASH_ATTN, _Backend.TORCH_SDPA, _Backend.FLASH_ATTN, _Backend.TORCH_SDPA,
@@ -286,14 +287,7 @@ class Siglip2Attention(nn.Module):
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
if self.is_flash_attn_backend: if self.is_flash_attn_backend:
if self.attn_backend == _Backend.ROCM_AITER_FA: attn_output = self.flash_attn_varlen_func(
from aiter import flash_attn_varlen_func
else:
if self.use_upstream_fa:
from flash_attn import flash_attn_varlen_func
else:
from vllm.vllm_flash_attn import flash_attn_varlen_func
attn_output = flash_attn_varlen_func(
queries, keys, values, cu_seqlens, cu_seqlens, max_seqlen, queries, keys, values, cu_seqlens, cu_seqlens, max_seqlen,
max_seqlen).reshape(seq_length, -1) max_seqlen).reshape(seq_length, -1)
elif self.attn_backend == _Backend.TORCH_SDPA: elif self.attn_backend == _Backend.TORCH_SDPA:

View File

@@ -189,8 +189,6 @@ class RocmPlatform(Platform):
from vllm.attention.backends.registry import _Backend from vllm.attention.backends.registry import _Backend
if (envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_MHA if (envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_MHA
and on_gfx9()): and on_gfx9()):
# Note: AITER FA is only supported for Qwen-VL models.
# TODO: Add support for other VL models in their model class.
return _Backend.ROCM_AITER_FA return _Backend.ROCM_AITER_FA
if on_gfx9(): if on_gfx9():
return _Backend.FLASH_ATTN return _Backend.FLASH_ATTN