[Encoder Decoder] Add flash_attn kernel support for encoder-decoder models (#9559)
This commit is contained in:
@@ -13,8 +13,8 @@ from torch._prims_common import TensorLikeType
|
||||
|
||||
from vllm.attention import AttentionBackend, AttentionMetadata, AttentionType
|
||||
from vllm.model_executor.layers.activation import SiluAndMul
|
||||
from vllm.utils import (STR_BACKEND_ENV_VAR, STR_XFORMERS_ATTN_VAL,
|
||||
make_tensor_with_pad)
|
||||
from vllm.utils import (STR_BACKEND_ENV_VAR, STR_FLASH_ATTN_VAL,
|
||||
STR_XFORMERS_ATTN_VAL, make_tensor_with_pad)
|
||||
|
||||
# For now, disable "test_aot_dispatch_dynamic" since there are some
|
||||
# bugs related to this test in PyTorch 2.4.
|
||||
@@ -525,17 +525,22 @@ def make_backend(backend_name: str) -> AttentionBackend:
|
||||
if backend_name == STR_XFORMERS_ATTN_VAL:
|
||||
# NOTE: xFormers backend cannot be imported for CPU and AMD GPUs.
|
||||
from vllm.attention.backends.xformers import XFormersBackend
|
||||
|
||||
return XFormersBackend()
|
||||
elif backend_name == STR_FLASH_ATTN_VAL:
|
||||
from vllm.attention.backends.flash_attn import FlashAttentionBackend
|
||||
return FlashAttentionBackend()
|
||||
|
||||
raise AssertionError(
|
||||
f"Unrecognized backend_name {backend_name} for unit test")
|
||||
|
||||
|
||||
def _make_metadata_tensors(
|
||||
seq_lens: Optional[List[int]], context_lens: Optional[List[int]],
|
||||
encoder_seq_lens: Optional[List[int]], device: Union[torch.device, str]
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, Any, Any, Optional[List[int]],
|
||||
torch.Tensor, Optional[int]]:
|
||||
seq_lens: Optional[List[int]],
|
||||
context_lens: Optional[List[int]],
|
||||
encoder_seq_lens: Optional[List[int]],
|
||||
device: Union[torch.device, str],
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, Any, Any, Optional[torch.Tensor],
|
||||
torch.Tensor, torch.Tensor, Optional[int]]:
|
||||
'''
|
||||
Build scalar & tensor values required to build attention metadata structure.
|
||||
|
||||
@@ -553,6 +558,8 @@ def _make_metadata_tensors(
|
||||
* max_context_len: max(context_lens)
|
||||
* max_seq_len: max(seq_lens)
|
||||
* seq_start_loc: start idx of each sequence
|
||||
* encoder_seq_lens_tensor: encoder seq_lens list, as tensor
|
||||
* encoder_seq_start_loc: start idx of each encoder sequence
|
||||
* max_encoder_seq_len: encoder seq_lens list, as tensor
|
||||
'''
|
||||
seq_lens_tensor = maybe_make_int_tensor(seq_lens, device)
|
||||
@@ -566,8 +573,26 @@ def _make_metadata_tensors(
|
||||
|
||||
seq_start_loc = None
|
||||
|
||||
if seq_lens_tensor is not None:
|
||||
seq_start_loc = torch.zeros(seq_lens_tensor.shape[0] + 1,
|
||||
dtype=torch.int32,
|
||||
device=seq_lens_tensor.device)
|
||||
torch.cumsum(seq_lens_tensor,
|
||||
dim=0,
|
||||
dtype=seq_start_loc.dtype,
|
||||
out=seq_start_loc[1:])
|
||||
|
||||
encoder_seq_start_loc = torch.zeros(encoder_seq_lens_tensor.shape[0] + 1,
|
||||
dtype=torch.int32,
|
||||
device=encoder_seq_lens_tensor.device)
|
||||
torch.cumsum(encoder_seq_lens_tensor,
|
||||
dim=0,
|
||||
dtype=encoder_seq_start_loc.dtype,
|
||||
out=encoder_seq_start_loc[1:])
|
||||
|
||||
return (seq_lens_tensor, context_lens_tensor, max_context_len, max_seq_len,
|
||||
seq_start_loc, encoder_seq_lens_tensor, max_encoder_seq_len)
|
||||
seq_start_loc, encoder_seq_lens_tensor, encoder_seq_start_loc,
|
||||
max_encoder_seq_len)
|
||||
|
||||
|
||||
def make_kv_cache(num_blocks: int,
|
||||
@@ -575,6 +600,7 @@ def make_kv_cache(num_blocks: int,
|
||||
head_size: int,
|
||||
block_size: int,
|
||||
device: Union[torch.device, str],
|
||||
backend: str,
|
||||
default_val: float = 0.0) -> torch.Tensor:
|
||||
'''
|
||||
Create a fake KV cache.
|
||||
@@ -591,10 +617,20 @@ def make_kv_cache(num_blocks: int,
|
||||
Returns:
|
||||
|
||||
* kv_cache: 2 x num_blocks x (block_size * num_heads * head_size)
|
||||
* for backend 'XFORMERS'
|
||||
* kv_cache: 2 x num_blocks x block_size x num_heads x head_size
|
||||
* for backend 'FLASH_ATTN'
|
||||
'''
|
||||
|
||||
kv_cache = torch.rand(
|
||||
(2, num_blocks, block_size * num_heads * head_size)).to(device)
|
||||
if backend == 'XFORMERS':
|
||||
kv_cache = torch.rand(
|
||||
(2, num_blocks, block_size * num_heads * head_size)).to(device)
|
||||
elif backend == 'FLASH_ATTN':
|
||||
kv_cache = torch.rand(
|
||||
(2, num_blocks, block_size, num_heads, head_size)).to(device)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unknown backend value: '{backend}'. Expected 'XFORMERS' or "
|
||||
f"'FLASH_ATTN'.")
|
||||
if default_val is not None:
|
||||
kv_cache[:, :, :] = default_val
|
||||
return kv_cache
|
||||
@@ -858,8 +894,9 @@ def make_test_metadata(
|
||||
context_lens_tensor,
|
||||
_,
|
||||
_,
|
||||
_,
|
||||
seq_start_loc,
|
||||
encoder_seq_lens_tensor,
|
||||
encoder_seq_start_loc,
|
||||
max_encoder_seq_len,
|
||||
) = _make_metadata_tensors(seq_lens,
|
||||
context_lens,
|
||||
@@ -874,6 +911,7 @@ def make_test_metadata(
|
||||
num_decode_tokens=num_decode_tokens,
|
||||
seq_lens=seq_lens,
|
||||
seq_lens_tensor=seq_lens_tensor,
|
||||
seq_start_loc=seq_start_loc,
|
||||
max_prefill_seq_len=None if seq_lens is None else max(seq_lens),
|
||||
max_decode_seq_len=0,
|
||||
context_lens_tensor=context_lens_tensor,
|
||||
@@ -882,6 +920,7 @@ def make_test_metadata(
|
||||
num_encoder_tokens=num_encoder_tokens,
|
||||
encoder_seq_lens=encoder_seq_lens,
|
||||
encoder_seq_lens_tensor=encoder_seq_lens_tensor,
|
||||
encoder_seq_start_loc=encoder_seq_start_loc,
|
||||
max_encoder_seq_len=max_encoder_seq_len,
|
||||
cross_slot_mapping=(None if cross_kv_mmap is None else
|
||||
cross_kv_mmap.slot_mapping),
|
||||
@@ -904,8 +943,9 @@ def make_test_metadata(
|
||||
context_lens_tensor,
|
||||
_,
|
||||
_,
|
||||
_,
|
||||
seq_start_loc,
|
||||
encoder_seq_lens_tensor,
|
||||
encoder_seq_start_loc,
|
||||
max_encoder_seq_len,
|
||||
) = _make_metadata_tensors(seq_lens,
|
||||
context_lens,
|
||||
@@ -920,14 +960,17 @@ def make_test_metadata(
|
||||
num_decode_tokens=num_decode_tokens,
|
||||
seq_lens=seq_lens,
|
||||
seq_lens_tensor=seq_lens_tensor,
|
||||
seq_start_loc=seq_start_loc,
|
||||
max_prefill_seq_len=0,
|
||||
max_decode_seq_len=max(seq_lens),
|
||||
max_decode_query_len=1,
|
||||
context_lens_tensor=context_lens_tensor,
|
||||
block_tables=kv_mmap.block_tables,
|
||||
use_cuda_graph=False,
|
||||
num_encoder_tokens=num_encoder_tokens,
|
||||
encoder_seq_lens=encoder_seq_lens,
|
||||
encoder_seq_lens_tensor=encoder_seq_lens_tensor,
|
||||
encoder_seq_start_loc=encoder_seq_start_loc,
|
||||
max_encoder_seq_len=max_encoder_seq_len,
|
||||
cross_slot_mapping=(None if cross_kv_mmap is None else
|
||||
cross_kv_mmap.slot_mapping),
|
||||
@@ -936,7 +979,8 @@ def make_test_metadata(
|
||||
|
||||
|
||||
def assert_actual_matches_ideal(test_params: PhaseTestParameters,
|
||||
output_under_test: torch.Tensor) -> None:
|
||||
output_under_test: torch.Tensor,
|
||||
backend: str) -> None:
|
||||
'''
|
||||
Assert that observed output matches the ideal output
|
||||
contained in the test parameters data structure.
|
||||
@@ -947,8 +991,22 @@ def assert_actual_matches_ideal(test_params: PhaseTestParameters,
|
||||
* output_under_test: actually observed output value
|
||||
'''
|
||||
ideal_output = test_params.packed_qkvo.ideal_output
|
||||
torch.testing.assert_close(ideal_output,
|
||||
output_under_test.view_as(ideal_output))
|
||||
if backend == 'XFORMERS':
|
||||
torch.testing.assert_close(ideal_output,
|
||||
output_under_test.view_as(ideal_output))
|
||||
|
||||
elif backend == 'FLASH_ATTN':
|
||||
# For FlashAttention override the accuracy thresholds to non default
|
||||
# values since we notice a higher difference between the ideal and
|
||||
# actual output.
|
||||
torch.testing.assert_close(ideal_output,
|
||||
output_under_test.view_as(ideal_output),
|
||||
atol=0.01,
|
||||
rtol=0.016)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unknown backend value: '{backend}'. Expected 'XFORMERS' or "
|
||||
f"'FLASH_ATTN'.")
|
||||
|
||||
|
||||
# Copied/modified from torch._refs.__init__.py
|
||||
|
||||
Reference in New Issue
Block a user