[ROCm] [VL] [Bugfix] Fix vit flash attn dispatcher logic for ROCm (#26104)
Signed-off-by: tjtanaa <tunjian.tan@embeddedllm.com>
This commit is contained in:
@@ -1,7 +1,7 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
"""Attention layer."""
|
"""Attention layer."""
|
||||||
from typing import List, Optional
|
from typing import Callable, List, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
@@ -68,9 +68,39 @@ def check_upstream_fa_availability(dtype: torch.dtype):
|
|||||||
) and current_platform.has_device_capability(80):
|
) and current_platform.has_device_capability(80):
|
||||||
from transformers.utils import is_flash_attn_2_available
|
from transformers.utils import is_flash_attn_2_available
|
||||||
return is_flash_attn_2_available()
|
return is_flash_attn_2_available()
|
||||||
|
if current_platform.is_rocm():
|
||||||
|
from importlib.util import find_spec
|
||||||
|
return find_spec("flash_attn") is not None
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def maybe_get_vit_flash_attn_backend(
|
||||||
|
attn_backend: _Backend,
|
||||||
|
use_upstream_fa: bool) -> tuple[_Backend, Callable]:
|
||||||
|
if attn_backend != _Backend.FLASH_ATTN and \
|
||||||
|
attn_backend != _Backend.ROCM_AITER_FA and \
|
||||||
|
check_upstream_fa_availability(torch.get_default_dtype()):
|
||||||
|
attn_backend = _Backend.FLASH_ATTN
|
||||||
|
use_upstream_fa = True
|
||||||
|
|
||||||
|
if current_platform.is_rocm() and \
|
||||||
|
attn_backend == _Backend.FLASH_ATTN:
|
||||||
|
use_upstream_fa = True
|
||||||
|
|
||||||
|
if (attn_backend in {_Backend.FLASH_ATTN, _Backend.ROCM_AITER_FA}):
|
||||||
|
if attn_backend == _Backend.ROCM_AITER_FA:
|
||||||
|
from aiter import flash_attn_varlen_func
|
||||||
|
else:
|
||||||
|
if use_upstream_fa:
|
||||||
|
from flash_attn import flash_attn_varlen_func
|
||||||
|
else:
|
||||||
|
from vllm.vllm_flash_attn import flash_attn_varlen_func
|
||||||
|
else:
|
||||||
|
flash_attn_varlen_func = None
|
||||||
|
|
||||||
|
return attn_backend, flash_attn_varlen_func
|
||||||
|
|
||||||
|
|
||||||
class Attention(nn.Module, AttentionLayerBase):
|
class Attention(nn.Module, AttentionLayerBase):
|
||||||
"""Attention layer.
|
"""Attention layer.
|
||||||
|
|
||||||
@@ -410,13 +440,9 @@ class MultiHeadAttention(nn.Module):
|
|||||||
# to upstream flash attention if available.
|
# to upstream flash attention if available.
|
||||||
# If vllm native fa is selected, we use it directly.
|
# If vllm native fa is selected, we use it directly.
|
||||||
use_upstream_fa = False
|
use_upstream_fa = False
|
||||||
if backend != _Backend.FLASH_ATTN and check_upstream_fa_availability(
|
|
||||||
dtype):
|
|
||||||
backend = _Backend.FLASH_ATTN
|
|
||||||
use_upstream_fa = True
|
|
||||||
|
|
||||||
if current_platform.is_rocm() or current_platform.is_xpu():
|
if current_platform.is_xpu():
|
||||||
# currently, only torch_sdpa is supported on rocm/xpu
|
# currently, only torch_sdpa is supported on xpu
|
||||||
self.attn_backend = _Backend.TORCH_SDPA
|
self.attn_backend = _Backend.TORCH_SDPA
|
||||||
else:
|
else:
|
||||||
|
|
||||||
@@ -428,17 +454,25 @@ class MultiHeadAttention(nn.Module):
|
|||||||
_Backend.FLASH_ATTN,
|
_Backend.FLASH_ATTN,
|
||||||
} else _Backend.TORCH_SDPA
|
} else _Backend.TORCH_SDPA
|
||||||
|
|
||||||
|
self.attn_backend, self._flash_attn_varlen_func \
|
||||||
|
= maybe_get_vit_flash_attn_backend(
|
||||||
|
self.attn_backend,
|
||||||
|
use_upstream_fa,
|
||||||
|
)
|
||||||
|
|
||||||
if (self.attn_backend == _Backend.XFORMERS
|
if (self.attn_backend == _Backend.XFORMERS
|
||||||
and not check_xformers_availability()):
|
and not check_xformers_availability()):
|
||||||
self.attn_backend = _Backend.TORCH_SDPA
|
self.attn_backend = _Backend.TORCH_SDPA
|
||||||
|
|
||||||
if self.attn_backend == _Backend.FLASH_ATTN:
|
self.is_flash_attn_backend = self.attn_backend in {
|
||||||
if use_upstream_fa:
|
_Backend.FLASH_ATTN, _Backend.ROCM_AITER_FA
|
||||||
from flash_attn import flash_attn_varlen_func
|
}
|
||||||
self._flash_attn_varlen_func = flash_attn_varlen_func
|
|
||||||
else:
|
# this condition is just to make sure that the
|
||||||
from vllm.vllm_flash_attn import flash_attn_varlen_func
|
# use_upstream_fa in the log is correct
|
||||||
self._flash_attn_varlen_func = flash_attn_varlen_func
|
if current_platform.is_rocm() \
|
||||||
|
and self.attn_backend == _Backend.FLASH_ATTN:
|
||||||
|
use_upstream_fa = True
|
||||||
|
|
||||||
logger.info_once(
|
logger.info_once(
|
||||||
f"MultiHeadAttention attn_backend: {self.attn_backend}, "
|
f"MultiHeadAttention attn_backend: {self.attn_backend}, "
|
||||||
@@ -466,7 +500,7 @@ class MultiHeadAttention(nn.Module):
|
|||||||
key = torch.repeat_interleave(key, num_repeat, dim=2)
|
key = torch.repeat_interleave(key, num_repeat, dim=2)
|
||||||
value = torch.repeat_interleave(value, num_repeat, dim=2)
|
value = torch.repeat_interleave(value, num_repeat, dim=2)
|
||||||
|
|
||||||
if self.attn_backend == _Backend.FLASH_ATTN:
|
if self.is_flash_attn_backend:
|
||||||
cu_seqlens_q = torch.arange(0, (bsz + 1) * q_len,
|
cu_seqlens_q = torch.arange(0, (bsz + 1) * q_len,
|
||||||
step=q_len,
|
step=q_len,
|
||||||
dtype=torch.int32,
|
dtype=torch.int32,
|
||||||
@@ -507,14 +541,6 @@ class MultiHeadAttention(nn.Module):
|
|||||||
from torch_xla.experimental.custom_kernel import flash_attention
|
from torch_xla.experimental.custom_kernel import flash_attention
|
||||||
out = flash_attention(query, key, value, sm_scale=self.scale)
|
out = flash_attention(query, key, value, sm_scale=self.scale)
|
||||||
out = out.transpose(1, 2)
|
out = out.transpose(1, 2)
|
||||||
elif self.attn_backend == _Backend.ROCM_AITER_FA:
|
|
||||||
from aiter import flash_attn_varlen_func
|
|
||||||
|
|
||||||
# ROCm Flash Attention expects (batch, seq, heads, head_dim)
|
|
||||||
out = flash_attn_varlen_func(query,
|
|
||||||
key,
|
|
||||||
value,
|
|
||||||
softmax_scale=self.scale)
|
|
||||||
else:
|
else:
|
||||||
# ViT attention hasn't supported this backend yet
|
# ViT attention hasn't supported this backend yet
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
|
|||||||
@@ -10,7 +10,8 @@ from torch.nn import LayerNorm
|
|||||||
from transformers.models.qwen2_vl import Qwen2VLProcessor
|
from transformers.models.qwen2_vl import Qwen2VLProcessor
|
||||||
|
|
||||||
from vllm.attention.backends.registry import _Backend
|
from vllm.attention.backends.registry import _Backend
|
||||||
from vllm.attention.layer import check_upstream_fa_availability
|
from vllm.attention.layer import (check_upstream_fa_availability,
|
||||||
|
maybe_get_vit_flash_attn_backend)
|
||||||
from vllm.config import VllmConfig
|
from vllm.config import VllmConfig
|
||||||
from vllm.distributed import utils as dist_utils
|
from vllm.distributed import utils as dist_utils
|
||||||
from vllm.distributed.parallel_state import (
|
from vllm.distributed.parallel_state import (
|
||||||
@@ -267,10 +268,12 @@ class DotsVisionAttention(nn.Module):
|
|||||||
self.attn_backend = get_vit_attn_backend(
|
self.attn_backend = get_vit_attn_backend(
|
||||||
self.hidden_size_per_attention_head, torch.get_default_dtype())
|
self.hidden_size_per_attention_head, torch.get_default_dtype())
|
||||||
self.use_upstream_fa = False
|
self.use_upstream_fa = False
|
||||||
if self.attn_backend != _Backend.FLASH_ATTN and \
|
|
||||||
check_upstream_fa_availability(torch.get_default_dtype()):
|
self.attn_backend, self.flash_attn_varlen_func \
|
||||||
self.attn_backend = _Backend.FLASH_ATTN
|
= maybe_get_vit_flash_attn_backend(
|
||||||
self.use_upstream_fa = True
|
self.attn_backend,
|
||||||
|
self.use_upstream_fa,
|
||||||
|
)
|
||||||
if self.attn_backend not in {
|
if self.attn_backend not in {
|
||||||
_Backend.FLASH_ATTN, _Backend.TORCH_SDPA, _Backend.XFORMERS,
|
_Backend.FLASH_ATTN, _Backend.TORCH_SDPA, _Backend.XFORMERS,
|
||||||
_Backend.ROCM_AITER_FA
|
_Backend.ROCM_AITER_FA
|
||||||
@@ -306,25 +309,18 @@ class DotsVisionAttention(nn.Module):
|
|||||||
q, k = torch.chunk(qk_rotated, 2, dim=0)
|
q, k = torch.chunk(qk_rotated, 2, dim=0)
|
||||||
|
|
||||||
if self.is_flash_attn_backend:
|
if self.is_flash_attn_backend:
|
||||||
if self.attn_backend == _Backend.ROCM_AITER_FA:
|
|
||||||
from aiter import flash_attn_varlen_func
|
|
||||||
else:
|
|
||||||
if self.use_upstream_fa:
|
|
||||||
from flash_attn import flash_attn_varlen_func
|
|
||||||
else:
|
|
||||||
from vllm.vllm_flash_attn import flash_attn_varlen_func
|
|
||||||
q_ = q.reshape(bs * q.shape[1], q.shape[2], q.shape[3])
|
q_ = q.reshape(bs * q.shape[1], q.shape[2], q.shape[3])
|
||||||
k_ = k.reshape(bs * k.shape[1], k.shape[2], k.shape[3])
|
k_ = k.reshape(bs * k.shape[1], k.shape[2], k.shape[3])
|
||||||
v_ = v.reshape(bs * v.shape[1], v.shape[2], v.shape[3])
|
v_ = v.reshape(bs * v.shape[1], v.shape[2], v.shape[3])
|
||||||
output = flash_attn_varlen_func(q_,
|
output = self.flash_attn_varlen_func(q_,
|
||||||
k_,
|
k_,
|
||||||
v_,
|
v_,
|
||||||
cu_seqlens_q=cu_seqlens,
|
cu_seqlens_q=cu_seqlens,
|
||||||
cu_seqlens_k=cu_seqlens,
|
cu_seqlens_k=cu_seqlens,
|
||||||
max_seqlen_q=max_seqlen,
|
max_seqlen_q=max_seqlen,
|
||||||
max_seqlen_k=max_seqlen,
|
max_seqlen_k=max_seqlen,
|
||||||
dropout_p=0.0,
|
dropout_p=0.0,
|
||||||
causal=False)
|
causal=False)
|
||||||
context_layer = output.view(bs, -1,
|
context_layer = output.view(bs, -1,
|
||||||
self.num_attention_heads_per_partition,
|
self.num_attention_heads_per_partition,
|
||||||
self.hidden_size_per_attention_head)
|
self.hidden_size_per_attention_head)
|
||||||
@@ -611,7 +607,8 @@ class DotsVisionTransformer(nn.Module):
|
|||||||
self, cu_seqlens: torch.Tensor
|
self, cu_seqlens: torch.Tensor
|
||||||
) -> tuple[Optional[int], Optional[list[int]]]:
|
) -> tuple[Optional[int], Optional[list[int]]]:
|
||||||
max_seqlen, seqlens = None, None
|
max_seqlen, seqlens = None, None
|
||||||
if self.attn_backend == _Backend.FLASH_ATTN:
|
if (self.attn_backend == _Backend.FLASH_ATTN
|
||||||
|
or self.attn_backend == _Backend.ROCM_AITER_FA):
|
||||||
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
|
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
|
||||||
elif self.attn_backend == _Backend.XFORMERS:
|
elif self.attn_backend == _Backend.XFORMERS:
|
||||||
seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()
|
seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()
|
||||||
|
|||||||
@@ -35,7 +35,8 @@ from einops import rearrange, repeat
|
|||||||
from transformers import BatchFeature
|
from transformers import BatchFeature
|
||||||
|
|
||||||
from vllm.attention.backends.registry import _Backend
|
from vllm.attention.backends.registry import _Backend
|
||||||
from vllm.attention.layer import check_upstream_fa_availability
|
from vllm.attention.layer import (check_upstream_fa_availability,
|
||||||
|
maybe_get_vit_flash_attn_backend)
|
||||||
from vllm.config import VllmConfig
|
from vllm.config import VllmConfig
|
||||||
from vllm.distributed import parallel_state
|
from vllm.distributed import parallel_state
|
||||||
from vllm.distributed import utils as dist_utils
|
from vllm.distributed import utils as dist_utils
|
||||||
@@ -176,14 +177,18 @@ class Ernie4_5_VisionAttention(nn.Module):
|
|||||||
dtype=torch.get_default_dtype())
|
dtype=torch.get_default_dtype())
|
||||||
|
|
||||||
self.use_upstream_fa = False
|
self.use_upstream_fa = False
|
||||||
if self.attn_backend != _Backend.FLASH_ATTN and \
|
|
||||||
check_upstream_fa_availability(torch.get_default_dtype()):
|
self.attn_backend, self.flash_attn_varlen_func \
|
||||||
self.attn_backend = _Backend.FLASH_ATTN
|
= maybe_get_vit_flash_attn_backend(
|
||||||
self.use_upstream_fa = True
|
self.attn_backend,
|
||||||
|
self.use_upstream_fa,
|
||||||
|
)
|
||||||
|
|
||||||
if self.attn_backend not in {
|
if self.attn_backend not in {
|
||||||
_Backend.FLASH_ATTN, _Backend.TORCH_SDPA, _Backend.XFORMERS,
|
_Backend.FLASH_ATTN,
|
||||||
_Backend.ROCM_AITER_FA
|
_Backend.TORCH_SDPA,
|
||||||
|
_Backend.XFORMERS,
|
||||||
|
_Backend.ROCM_AITER_FA,
|
||||||
}:
|
}:
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
f"Ernie45-VL does not support {self.attn_backend} backend now."
|
f"Ernie45-VL does not support {self.attn_backend} backend now."
|
||||||
@@ -239,27 +244,18 @@ class Ernie4_5_VisionAttention(nn.Module):
|
|||||||
q, k = torch.chunk(qk_rotated, 2, dim=0)
|
q, k = torch.chunk(qk_rotated, 2, dim=0)
|
||||||
|
|
||||||
if self.is_flash_attn_backend:
|
if self.is_flash_attn_backend:
|
||||||
# from vllm_flash_attn.flash_attn_interface import (
|
|
||||||
# flash_attn_varlen_func)
|
|
||||||
if self.attn_backend == _Backend.ROCM_AITER_FA:
|
|
||||||
from aiter import flash_attn_varlen_func
|
|
||||||
else:
|
|
||||||
if self.use_upstream_fa:
|
|
||||||
from flash_attn import flash_attn_varlen_func
|
|
||||||
else:
|
|
||||||
from vllm.vllm_flash_attn import flash_attn_varlen_func
|
|
||||||
|
|
||||||
q, k, v = (rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v])
|
q, k, v = (rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v])
|
||||||
|
|
||||||
output = flash_attn_varlen_func(q,
|
output = self.flash_attn_varlen_func(q,
|
||||||
k,
|
k,
|
||||||
v,
|
v,
|
||||||
cu_seqlens_q=cu_seqlens,
|
cu_seqlens_q=cu_seqlens,
|
||||||
cu_seqlens_k=cu_seqlens,
|
cu_seqlens_k=cu_seqlens,
|
||||||
max_seqlen_q=max_seqlen,
|
max_seqlen_q=max_seqlen,
|
||||||
max_seqlen_k=max_seqlen,
|
max_seqlen_k=max_seqlen,
|
||||||
dropout_p=0.0,
|
dropout_p=0.0,
|
||||||
causal=False)
|
causal=False)
|
||||||
|
|
||||||
context_layer = rearrange(output,
|
context_layer = rearrange(output,
|
||||||
"(b s) h d -> s b (h d)",
|
"(b s) h d -> s b (h d)",
|
||||||
@@ -516,7 +512,8 @@ class Ernie4_5_VisionTransformer(nn.Module):
|
|||||||
self, cu_seqlens: torch.Tensor
|
self, cu_seqlens: torch.Tensor
|
||||||
) -> tuple[Optional[int], Optional[list[int]]]:
|
) -> tuple[Optional[int], Optional[list[int]]]:
|
||||||
max_seqlen, seqlens = None, None
|
max_seqlen, seqlens = None, None
|
||||||
if self.attn_backend == _Backend.FLASH_ATTN:
|
if (self.attn_backend == _Backend.FLASH_ATTN
|
||||||
|
or self.attn_backend == _Backend.ROCM_AITER_FA):
|
||||||
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
|
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
|
||||||
elif self.attn_backend == _Backend.XFORMERS:
|
elif self.attn_backend == _Backend.XFORMERS:
|
||||||
seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()
|
seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()
|
||||||
|
|||||||
@@ -47,7 +47,8 @@ from transformers.models.glm4v.video_processing_glm4v import (
|
|||||||
from transformers.video_utils import VideoMetadata
|
from transformers.video_utils import VideoMetadata
|
||||||
|
|
||||||
from vllm.attention.backends.registry import _Backend
|
from vllm.attention.backends.registry import _Backend
|
||||||
from vllm.attention.layer import check_upstream_fa_availability
|
from vllm.attention.layer import (check_upstream_fa_availability,
|
||||||
|
maybe_get_vit_flash_attn_backend)
|
||||||
from vllm.config import VllmConfig
|
from vllm.config import VllmConfig
|
||||||
from vllm.distributed import (get_tensor_model_parallel_world_size,
|
from vllm.distributed import (get_tensor_model_parallel_world_size,
|
||||||
parallel_state)
|
parallel_state)
|
||||||
@@ -263,19 +264,26 @@ class Glm4vVisionAttention(nn.Module):
|
|||||||
head_size=self.hidden_size_per_attention_head,
|
head_size=self.hidden_size_per_attention_head,
|
||||||
dtype=torch.get_default_dtype())
|
dtype=torch.get_default_dtype())
|
||||||
self.use_upstream_fa = False
|
self.use_upstream_fa = False
|
||||||
if self.attn_backend != _Backend.FLASH_ATTN and \
|
|
||||||
check_upstream_fa_availability(torch.get_default_dtype()):
|
self.attn_backend, self.flash_attn_varlen_func \
|
||||||
self.attn_backend = _Backend.FLASH_ATTN
|
= maybe_get_vit_flash_attn_backend(
|
||||||
self.use_upstream_fa = True
|
self.attn_backend,
|
||||||
|
self.use_upstream_fa,
|
||||||
|
)
|
||||||
|
|
||||||
if self.attn_backend not in {
|
if self.attn_backend not in {
|
||||||
_Backend.FLASH_ATTN,
|
_Backend.FLASH_ATTN,
|
||||||
_Backend.TORCH_SDPA,
|
_Backend.TORCH_SDPA,
|
||||||
_Backend.XFORMERS,
|
_Backend.XFORMERS,
|
||||||
|
_Backend.ROCM_AITER_FA,
|
||||||
}:
|
}:
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
f"GLM-4V does not support {self.attn_backend} backend now.")
|
f"GLM-4V does not support {self.attn_backend} backend now.")
|
||||||
|
|
||||||
|
self.is_flash_attn_backend = self.attn_backend in {
|
||||||
|
_Backend.FLASH_ATTN, _Backend.ROCM_AITER_FA
|
||||||
|
}
|
||||||
|
|
||||||
def split_qkv(self, qkv: torch.Tensor) -> tuple[torch.Tensor, ...]:
|
def split_qkv(self, qkv: torch.Tensor) -> tuple[torch.Tensor, ...]:
|
||||||
# [s, b, 3 * head * head_dim]
|
# [s, b, 3 * head * head_dim]
|
||||||
seq_len, bs, _ = qkv.shape
|
seq_len, bs, _ = qkv.shape
|
||||||
@@ -316,17 +324,11 @@ class Glm4vVisionAttention(nn.Module):
|
|||||||
qk_rotated = apply_rotary_pos_emb_vision(qk_concat, rotary_pos_emb)
|
qk_rotated = apply_rotary_pos_emb_vision(qk_concat, rotary_pos_emb)
|
||||||
q, k = torch.chunk(qk_rotated, 2, dim=0)
|
q, k = torch.chunk(qk_rotated, 2, dim=0)
|
||||||
|
|
||||||
if self.attn_backend == _Backend.FLASH_ATTN:
|
if self.is_flash_attn_backend:
|
||||||
# from vllm_flash_attn.flash_attn_interface import (
|
|
||||||
# flash_attn_varlen_func)
|
|
||||||
if self.use_upstream_fa:
|
|
||||||
from flash_attn import flash_attn_varlen_func
|
|
||||||
else:
|
|
||||||
from vllm.vllm_flash_attn import flash_attn_varlen_func
|
|
||||||
|
|
||||||
q, k, v = (rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v])
|
q, k, v = (rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v])
|
||||||
|
|
||||||
output = flash_attn_varlen_func(
|
output = self.flash_attn_varlen_func(
|
||||||
q,
|
q,
|
||||||
k,
|
k,
|
||||||
v,
|
v,
|
||||||
@@ -774,7 +776,8 @@ class Glm4vVisionTransformer(nn.Module):
|
|||||||
) -> tuple[Optional[int], Optional[list[int]]]:
|
) -> tuple[Optional[int], Optional[list[int]]]:
|
||||||
max_seqlen, seqlens = None, None
|
max_seqlen, seqlens = None, None
|
||||||
seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()
|
seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()
|
||||||
if self.attn_backend == _Backend.FLASH_ATTN:
|
if (self.attn_backend == _Backend.FLASH_ATTN
|
||||||
|
or self.attn_backend == _Backend.ROCM_AITER_FA):
|
||||||
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
|
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
|
||||||
return max_seqlen, seqlens
|
return max_seqlen, seqlens
|
||||||
|
|
||||||
|
|||||||
@@ -39,7 +39,8 @@ from transformers.models.qwen2_5_vl.configuration_qwen2_5_vl import (
|
|||||||
Qwen2_5_VLConfig, Qwen2_5_VLVisionConfig)
|
Qwen2_5_VLConfig, Qwen2_5_VLVisionConfig)
|
||||||
|
|
||||||
from vllm.attention.backends.registry import _Backend
|
from vllm.attention.backends.registry import _Backend
|
||||||
from vllm.attention.layer import check_upstream_fa_availability
|
from vllm.attention.layer import (check_upstream_fa_availability,
|
||||||
|
maybe_get_vit_flash_attn_backend)
|
||||||
from vllm.config import VllmConfig
|
from vllm.config import VllmConfig
|
||||||
from vllm.distributed import parallel_state
|
from vllm.distributed import parallel_state
|
||||||
from vllm.distributed import utils as dist_utils
|
from vllm.distributed import utils as dist_utils
|
||||||
@@ -302,6 +303,11 @@ class Qwen2_5_VisionAttention(nn.Module):
|
|||||||
disable_tp=use_data_parallel)
|
disable_tp=use_data_parallel)
|
||||||
self.attn_backend = attn_backend
|
self.attn_backend = attn_backend
|
||||||
self.use_upstream_fa = use_upstream_fa
|
self.use_upstream_fa = use_upstream_fa
|
||||||
|
self.attn_backend, self.flash_attn_varlen_func \
|
||||||
|
= maybe_get_vit_flash_attn_backend(
|
||||||
|
self.attn_backend,
|
||||||
|
self.use_upstream_fa,
|
||||||
|
)
|
||||||
self.is_flash_attn_backend = self.attn_backend in {
|
self.is_flash_attn_backend = self.attn_backend in {
|
||||||
_Backend.FLASH_ATTN, _Backend.ROCM_AITER_FA
|
_Backend.FLASH_ATTN, _Backend.ROCM_AITER_FA
|
||||||
}
|
}
|
||||||
@@ -354,25 +360,18 @@ class Qwen2_5_VisionAttention(nn.Module):
|
|||||||
q, k = torch.chunk(qk_rotated, 2, dim=0)
|
q, k = torch.chunk(qk_rotated, 2, dim=0)
|
||||||
|
|
||||||
if self.is_flash_attn_backend:
|
if self.is_flash_attn_backend:
|
||||||
if self.attn_backend == _Backend.ROCM_AITER_FA:
|
|
||||||
from aiter import flash_attn_varlen_func
|
|
||||||
else:
|
|
||||||
if self.use_upstream_fa:
|
|
||||||
from flash_attn import flash_attn_varlen_func
|
|
||||||
else:
|
|
||||||
from vllm.vllm_flash_attn import flash_attn_varlen_func
|
|
||||||
|
|
||||||
q, k, v = (rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v])
|
q, k, v = (rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v])
|
||||||
|
|
||||||
output = flash_attn_varlen_func(q,
|
output = self.flash_attn_varlen_func(q,
|
||||||
k,
|
k,
|
||||||
v,
|
v,
|
||||||
cu_seqlens_q=cu_seqlens,
|
cu_seqlens_q=cu_seqlens,
|
||||||
cu_seqlens_k=cu_seqlens,
|
cu_seqlens_k=cu_seqlens,
|
||||||
max_seqlen_q=max_seqlen,
|
max_seqlen_q=max_seqlen,
|
||||||
max_seqlen_k=max_seqlen,
|
max_seqlen_k=max_seqlen,
|
||||||
dropout_p=0.0,
|
dropout_p=0.0,
|
||||||
causal=False)
|
causal=False)
|
||||||
|
|
||||||
context_layer = rearrange(output,
|
context_layer = rearrange(output,
|
||||||
"(b s) h d -> s b (h d)",
|
"(b s) h d -> s b (h d)",
|
||||||
@@ -618,6 +617,7 @@ class Qwen2_5_VisionTransformer(nn.Module):
|
|||||||
self.attn_backend = get_vit_attn_backend(
|
self.attn_backend = get_vit_attn_backend(
|
||||||
head_size=head_dim, dtype=torch.get_default_dtype())
|
head_size=head_dim, dtype=torch.get_default_dtype())
|
||||||
if self.attn_backend != _Backend.FLASH_ATTN and \
|
if self.attn_backend != _Backend.FLASH_ATTN and \
|
||||||
|
self.attn_backend != _Backend.ROCM_AITER_FA and \
|
||||||
check_upstream_fa_availability(
|
check_upstream_fa_availability(
|
||||||
torch.get_default_dtype()):
|
torch.get_default_dtype()):
|
||||||
self.attn_backend = _Backend.FLASH_ATTN
|
self.attn_backend = _Backend.FLASH_ATTN
|
||||||
|
|||||||
@@ -42,7 +42,8 @@ from transformers.models.qwen2_vl.video_processing_qwen2_vl import (
|
|||||||
Qwen2VLVideoProcessor)
|
Qwen2VLVideoProcessor)
|
||||||
|
|
||||||
from vllm.attention.backends.registry import _Backend
|
from vllm.attention.backends.registry import _Backend
|
||||||
from vllm.attention.layer import check_upstream_fa_availability
|
from vllm.attention.layer import (check_upstream_fa_availability,
|
||||||
|
maybe_get_vit_flash_attn_backend)
|
||||||
from vllm.config import VllmConfig
|
from vllm.config import VllmConfig
|
||||||
from vllm.distributed import parallel_state, tensor_model_parallel_all_gather
|
from vllm.distributed import parallel_state, tensor_model_parallel_all_gather
|
||||||
from vllm.distributed import utils as dist_utils
|
from vllm.distributed import utils as dist_utils
|
||||||
@@ -319,11 +320,12 @@ class Qwen2VisionAttention(nn.Module):
|
|||||||
head_size=self.hidden_size_per_attention_head,
|
head_size=self.hidden_size_per_attention_head,
|
||||||
dtype=torch.get_default_dtype())
|
dtype=torch.get_default_dtype())
|
||||||
self.use_upstream_fa = False
|
self.use_upstream_fa = False
|
||||||
if self.attn_backend != _Backend.FLASH_ATTN and \
|
|
||||||
check_upstream_fa_availability(
|
self.attn_backend, self.flash_attn_varlen_func \
|
||||||
torch.get_default_dtype()):
|
= maybe_get_vit_flash_attn_backend(
|
||||||
self.attn_backend = _Backend.FLASH_ATTN
|
self.attn_backend,
|
||||||
self.use_upstream_fa = True
|
self.use_upstream_fa,
|
||||||
|
)
|
||||||
|
|
||||||
if self.attn_backend not in {
|
if self.attn_backend not in {
|
||||||
_Backend.FLASH_ATTN, _Backend.TORCH_SDPA, _Backend.XFORMERS,
|
_Backend.FLASH_ATTN, _Backend.TORCH_SDPA, _Backend.XFORMERS,
|
||||||
@@ -331,6 +333,7 @@ class Qwen2VisionAttention(nn.Module):
|
|||||||
}:
|
}:
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
f"Qwen2-VL does not support {self.attn_backend} backend now.")
|
f"Qwen2-VL does not support {self.attn_backend} backend now.")
|
||||||
|
|
||||||
self.is_flash_attn_backend = self.attn_backend in {
|
self.is_flash_attn_backend = self.attn_backend in {
|
||||||
_Backend.FLASH_ATTN, _Backend.ROCM_AITER_FA
|
_Backend.FLASH_ATTN, _Backend.ROCM_AITER_FA
|
||||||
}
|
}
|
||||||
@@ -383,25 +386,18 @@ class Qwen2VisionAttention(nn.Module):
|
|||||||
q, k = torch.chunk(qk_rotated, 2, dim=0)
|
q, k = torch.chunk(qk_rotated, 2, dim=0)
|
||||||
|
|
||||||
if self.is_flash_attn_backend:
|
if self.is_flash_attn_backend:
|
||||||
if self.attn_backend == _Backend.ROCM_AITER_FA:
|
|
||||||
from aiter import flash_attn_varlen_func
|
|
||||||
else:
|
|
||||||
if self.use_upstream_fa:
|
|
||||||
from flash_attn import flash_attn_varlen_func
|
|
||||||
else:
|
|
||||||
from vllm.vllm_flash_attn import flash_attn_varlen_func
|
|
||||||
|
|
||||||
q, k, v = (rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v])
|
q, k, v = (rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v])
|
||||||
|
|
||||||
output = flash_attn_varlen_func(q,
|
output = self.flash_attn_varlen_func(q,
|
||||||
k,
|
k,
|
||||||
v,
|
v,
|
||||||
cu_seqlens_q=cu_seqlens,
|
cu_seqlens_q=cu_seqlens,
|
||||||
cu_seqlens_k=cu_seqlens,
|
cu_seqlens_k=cu_seqlens,
|
||||||
max_seqlen_q=max_seqlen,
|
max_seqlen_q=max_seqlen,
|
||||||
max_seqlen_k=max_seqlen,
|
max_seqlen_k=max_seqlen,
|
||||||
dropout_p=0.0,
|
dropout_p=0.0,
|
||||||
causal=False)
|
causal=False)
|
||||||
|
|
||||||
context_layer = rearrange(output,
|
context_layer = rearrange(output,
|
||||||
"(b s) h d -> s b (h d)",
|
"(b s) h d -> s b (h d)",
|
||||||
|
|||||||
@@ -323,6 +323,7 @@ class Qwen3_VisionTransformer(nn.Module):
|
|||||||
head_size=head_dim, dtype=torch.get_default_dtype())
|
head_size=head_dim, dtype=torch.get_default_dtype())
|
||||||
use_upstream_fa = False
|
use_upstream_fa = False
|
||||||
if self.attn_backend != _Backend.FLASH_ATTN and \
|
if self.attn_backend != _Backend.FLASH_ATTN and \
|
||||||
|
self.attn_backend != _Backend.ROCM_AITER_FA and \
|
||||||
check_upstream_fa_availability(
|
check_upstream_fa_availability(
|
||||||
torch.get_default_dtype()):
|
torch.get_default_dtype()):
|
||||||
self.attn_backend = _Backend.FLASH_ATTN
|
self.attn_backend = _Backend.FLASH_ATTN
|
||||||
@@ -476,7 +477,8 @@ class Qwen3_VisionTransformer(nn.Module):
|
|||||||
cu_seqlens: torch.Tensor,
|
cu_seqlens: torch.Tensor,
|
||||||
) -> tuple[Optional[int], Optional[list[int]]]:
|
) -> tuple[Optional[int], Optional[list[int]]]:
|
||||||
max_seqlen, seqlens = None, None
|
max_seqlen, seqlens = None, None
|
||||||
if self.attn_backend == _Backend.FLASH_ATTN:
|
if (self.attn_backend == _Backend.FLASH_ATTN
|
||||||
|
or self.attn_backend == _Backend.ROCM_AITER_FA):
|
||||||
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
|
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
|
||||||
elif self.attn_backend == _Backend.XFORMERS:
|
elif self.attn_backend == _Backend.XFORMERS:
|
||||||
seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()
|
seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()
|
||||||
|
|||||||
@@ -14,7 +14,7 @@ from transformers import Siglip2VisionConfig
|
|||||||
from transformers.configuration_utils import PretrainedConfig
|
from transformers.configuration_utils import PretrainedConfig
|
||||||
|
|
||||||
from vllm.attention.backends.registry import _Backend
|
from vllm.attention.backends.registry import _Backend
|
||||||
from vllm.attention.layer import check_upstream_fa_availability
|
from vllm.attention.layer import maybe_get_vit_flash_attn_backend
|
||||||
from vllm.distributed import divide, get_tensor_model_parallel_world_size
|
from vllm.distributed import divide, get_tensor_model_parallel_world_size
|
||||||
from vllm.model_executor.layers.activation import get_act_fn
|
from vllm.model_executor.layers.activation import get_act_fn
|
||||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||||
@@ -240,11 +240,12 @@ class Siglip2Attention(nn.Module):
|
|||||||
self.attn_backend = get_vit_attn_backend(
|
self.attn_backend = get_vit_attn_backend(
|
||||||
head_size=self.head_dim, dtype=torch.get_default_dtype())
|
head_size=self.head_dim, dtype=torch.get_default_dtype())
|
||||||
self.use_upstream_fa = False
|
self.use_upstream_fa = False
|
||||||
if self.attn_backend != _Backend.FLASH_ATTN and \
|
|
||||||
check_upstream_fa_availability(
|
self.attn_backend, self.flash_attn_varlen_func \
|
||||||
torch.get_default_dtype()):
|
= maybe_get_vit_flash_attn_backend(
|
||||||
self.attn_backend = _Backend.FLASH_ATTN
|
self.attn_backend,
|
||||||
self.use_upstream_fa = True
|
self.use_upstream_fa,
|
||||||
|
)
|
||||||
|
|
||||||
if self.attn_backend not in {
|
if self.attn_backend not in {
|
||||||
_Backend.FLASH_ATTN, _Backend.TORCH_SDPA,
|
_Backend.FLASH_ATTN, _Backend.TORCH_SDPA,
|
||||||
@@ -286,14 +287,7 @@ class Siglip2Attention(nn.Module):
|
|||||||
|
|
||||||
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
|
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
|
||||||
if self.is_flash_attn_backend:
|
if self.is_flash_attn_backend:
|
||||||
if self.attn_backend == _Backend.ROCM_AITER_FA:
|
attn_output = self.flash_attn_varlen_func(
|
||||||
from aiter import flash_attn_varlen_func
|
|
||||||
else:
|
|
||||||
if self.use_upstream_fa:
|
|
||||||
from flash_attn import flash_attn_varlen_func
|
|
||||||
else:
|
|
||||||
from vllm.vllm_flash_attn import flash_attn_varlen_func
|
|
||||||
attn_output = flash_attn_varlen_func(
|
|
||||||
queries, keys, values, cu_seqlens, cu_seqlens, max_seqlen,
|
queries, keys, values, cu_seqlens, cu_seqlens, max_seqlen,
|
||||||
max_seqlen).reshape(seq_length, -1)
|
max_seqlen).reshape(seq_length, -1)
|
||||||
elif self.attn_backend == _Backend.TORCH_SDPA:
|
elif self.attn_backend == _Backend.TORCH_SDPA:
|
||||||
|
|||||||
@@ -189,8 +189,6 @@ class RocmPlatform(Platform):
|
|||||||
from vllm.attention.backends.registry import _Backend
|
from vllm.attention.backends.registry import _Backend
|
||||||
if (envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_MHA
|
if (envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_MHA
|
||||||
and on_gfx9()):
|
and on_gfx9()):
|
||||||
# Note: AITER FA is only supported for Qwen-VL models.
|
|
||||||
# TODO: Add support for other VL models in their model class.
|
|
||||||
return _Backend.ROCM_AITER_FA
|
return _Backend.ROCM_AITER_FA
|
||||||
if on_gfx9():
|
if on_gfx9():
|
||||||
return _Backend.FLASH_ATTN
|
return _Backend.FLASH_ATTN
|
||||||
|
|||||||
Reference in New Issue
Block a user