[ROCm] [VL] [Bugfix] Fix vit flash attn dispatcher logic for ROCm (#26104)

Signed-off-by: tjtanaa <tunjian.tan@embeddedllm.com>
This commit is contained in:
TJian
2025-10-02 22:34:53 -07:00
committed by GitHub
parent 27edd2aeb4
commit 9c5ee91b2a
9 changed files with 154 additions and 141 deletions

View File

@@ -14,7 +14,7 @@ from transformers import Siglip2VisionConfig
from transformers.configuration_utils import PretrainedConfig
from vllm.attention.backends.registry import _Backend
from vllm.attention.layer import check_upstream_fa_availability
from vllm.attention.layer import maybe_get_vit_flash_attn_backend
from vllm.distributed import divide, get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
@@ -240,11 +240,12 @@ class Siglip2Attention(nn.Module):
self.attn_backend = get_vit_attn_backend(
head_size=self.head_dim, dtype=torch.get_default_dtype())
self.use_upstream_fa = False
if self.attn_backend != _Backend.FLASH_ATTN and \
check_upstream_fa_availability(
torch.get_default_dtype()):
self.attn_backend = _Backend.FLASH_ATTN
self.use_upstream_fa = True
self.attn_backend, self.flash_attn_varlen_func \
= maybe_get_vit_flash_attn_backend(
self.attn_backend,
self.use_upstream_fa,
)
if self.attn_backend not in {
_Backend.FLASH_ATTN, _Backend.TORCH_SDPA,
@@ -286,14 +287,7 @@ class Siglip2Attention(nn.Module):
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
if self.is_flash_attn_backend:
if self.attn_backend == _Backend.ROCM_AITER_FA:
from aiter import flash_attn_varlen_func
else:
if self.use_upstream_fa:
from flash_attn import flash_attn_varlen_func
else:
from vllm.vllm_flash_attn import flash_attn_varlen_func
attn_output = flash_attn_varlen_func(
attn_output = self.flash_attn_varlen_func(
queries, keys, values, cu_seqlens, cu_seqlens, max_seqlen,
max_seqlen).reshape(seq_length, -1)
elif self.attn_backend == _Backend.TORCH_SDPA: