[Encoder Decoder] Add flash_attn kernel support for encoder-decoder models (#9559)

This commit is contained in:
sroy745
2024-11-01 23:22:49 -07:00
committed by GitHub
parent d522034c85
commit a78dd3303e
11 changed files with 715 additions and 316 deletions

View File

@@ -13,8 +13,8 @@ from torch._prims_common import TensorLikeType
from vllm.attention import AttentionBackend, AttentionMetadata, AttentionType
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.utils import (STR_BACKEND_ENV_VAR, STR_XFORMERS_ATTN_VAL,
make_tensor_with_pad)
from vllm.utils import (STR_BACKEND_ENV_VAR, STR_FLASH_ATTN_VAL,
STR_XFORMERS_ATTN_VAL, make_tensor_with_pad)
# For now, disable "test_aot_dispatch_dynamic" since there are some
# bugs related to this test in PyTorch 2.4.
@@ -525,17 +525,22 @@ def make_backend(backend_name: str) -> AttentionBackend:
if backend_name == STR_XFORMERS_ATTN_VAL:
# NOTE: xFormers backend cannot be imported for CPU and AMD GPUs.
from vllm.attention.backends.xformers import XFormersBackend
return XFormersBackend()
elif backend_name == STR_FLASH_ATTN_VAL:
from vllm.attention.backends.flash_attn import FlashAttentionBackend
return FlashAttentionBackend()
raise AssertionError(
f"Unrecognized backend_name {backend_name} for unit test")
def _make_metadata_tensors(
seq_lens: Optional[List[int]], context_lens: Optional[List[int]],
encoder_seq_lens: Optional[List[int]], device: Union[torch.device, str]
) -> Tuple[torch.Tensor, torch.Tensor, Any, Any, Optional[List[int]],
torch.Tensor, Optional[int]]:
seq_lens: Optional[List[int]],
context_lens: Optional[List[int]],
encoder_seq_lens: Optional[List[int]],
device: Union[torch.device, str],
) -> Tuple[torch.Tensor, torch.Tensor, Any, Any, Optional[torch.Tensor],
torch.Tensor, torch.Tensor, Optional[int]]:
'''
Build scalar & tensor values required to build attention metadata structure.
@@ -553,6 +558,8 @@ def _make_metadata_tensors(
* max_context_len: max(context_lens)
* max_seq_len: max(seq_lens)
* seq_start_loc: start idx of each sequence
* encoder_seq_lens_tensor: encoder seq_lens list, as tensor
* encoder_seq_start_loc: start idx of each encoder sequence
* max_encoder_seq_len: encoder seq_lens list, as tensor
'''
seq_lens_tensor = maybe_make_int_tensor(seq_lens, device)
@@ -566,8 +573,26 @@ def _make_metadata_tensors(
seq_start_loc = None
if seq_lens_tensor is not None:
seq_start_loc = torch.zeros(seq_lens_tensor.shape[0] + 1,
dtype=torch.int32,
device=seq_lens_tensor.device)
torch.cumsum(seq_lens_tensor,
dim=0,
dtype=seq_start_loc.dtype,
out=seq_start_loc[1:])
encoder_seq_start_loc = torch.zeros(encoder_seq_lens_tensor.shape[0] + 1,
dtype=torch.int32,
device=encoder_seq_lens_tensor.device)
torch.cumsum(encoder_seq_lens_tensor,
dim=0,
dtype=encoder_seq_start_loc.dtype,
out=encoder_seq_start_loc[1:])
return (seq_lens_tensor, context_lens_tensor, max_context_len, max_seq_len,
seq_start_loc, encoder_seq_lens_tensor, max_encoder_seq_len)
seq_start_loc, encoder_seq_lens_tensor, encoder_seq_start_loc,
max_encoder_seq_len)
def make_kv_cache(num_blocks: int,
@@ -575,6 +600,7 @@ def make_kv_cache(num_blocks: int,
head_size: int,
block_size: int,
device: Union[torch.device, str],
backend: str,
default_val: float = 0.0) -> torch.Tensor:
'''
Create a fake KV cache.
@@ -591,10 +617,20 @@ def make_kv_cache(num_blocks: int,
Returns:
* kv_cache: 2 x num_blocks x (block_size * num_heads * head_size)
* for backend 'XFORMERS'
* kv_cache: 2 x num_blocks x block_size x num_heads x head_size
* for backend 'FLASH_ATTN'
'''
kv_cache = torch.rand(
(2, num_blocks, block_size * num_heads * head_size)).to(device)
if backend == 'XFORMERS':
kv_cache = torch.rand(
(2, num_blocks, block_size * num_heads * head_size)).to(device)
elif backend == 'FLASH_ATTN':
kv_cache = torch.rand(
(2, num_blocks, block_size, num_heads, head_size)).to(device)
else:
raise ValueError(
f"Unknown backend value: '{backend}'. Expected 'XFORMERS' or "
f"'FLASH_ATTN'.")
if default_val is not None:
kv_cache[:, :, :] = default_val
return kv_cache
@@ -858,8 +894,9 @@ def make_test_metadata(
context_lens_tensor,
_,
_,
_,
seq_start_loc,
encoder_seq_lens_tensor,
encoder_seq_start_loc,
max_encoder_seq_len,
) = _make_metadata_tensors(seq_lens,
context_lens,
@@ -874,6 +911,7 @@ def make_test_metadata(
num_decode_tokens=num_decode_tokens,
seq_lens=seq_lens,
seq_lens_tensor=seq_lens_tensor,
seq_start_loc=seq_start_loc,
max_prefill_seq_len=None if seq_lens is None else max(seq_lens),
max_decode_seq_len=0,
context_lens_tensor=context_lens_tensor,
@@ -882,6 +920,7 @@ def make_test_metadata(
num_encoder_tokens=num_encoder_tokens,
encoder_seq_lens=encoder_seq_lens,
encoder_seq_lens_tensor=encoder_seq_lens_tensor,
encoder_seq_start_loc=encoder_seq_start_loc,
max_encoder_seq_len=max_encoder_seq_len,
cross_slot_mapping=(None if cross_kv_mmap is None else
cross_kv_mmap.slot_mapping),
@@ -904,8 +943,9 @@ def make_test_metadata(
context_lens_tensor,
_,
_,
_,
seq_start_loc,
encoder_seq_lens_tensor,
encoder_seq_start_loc,
max_encoder_seq_len,
) = _make_metadata_tensors(seq_lens,
context_lens,
@@ -920,14 +960,17 @@ def make_test_metadata(
num_decode_tokens=num_decode_tokens,
seq_lens=seq_lens,
seq_lens_tensor=seq_lens_tensor,
seq_start_loc=seq_start_loc,
max_prefill_seq_len=0,
max_decode_seq_len=max(seq_lens),
max_decode_query_len=1,
context_lens_tensor=context_lens_tensor,
block_tables=kv_mmap.block_tables,
use_cuda_graph=False,
num_encoder_tokens=num_encoder_tokens,
encoder_seq_lens=encoder_seq_lens,
encoder_seq_lens_tensor=encoder_seq_lens_tensor,
encoder_seq_start_loc=encoder_seq_start_loc,
max_encoder_seq_len=max_encoder_seq_len,
cross_slot_mapping=(None if cross_kv_mmap is None else
cross_kv_mmap.slot_mapping),
@@ -936,7 +979,8 @@ def make_test_metadata(
def assert_actual_matches_ideal(test_params: PhaseTestParameters,
output_under_test: torch.Tensor) -> None:
output_under_test: torch.Tensor,
backend: str) -> None:
'''
Assert that observed output matches the ideal output
contained in the test parameters data structure.
@@ -947,8 +991,22 @@ def assert_actual_matches_ideal(test_params: PhaseTestParameters,
* output_under_test: actually observed output value
'''
ideal_output = test_params.packed_qkvo.ideal_output
torch.testing.assert_close(ideal_output,
output_under_test.view_as(ideal_output))
if backend == 'XFORMERS':
torch.testing.assert_close(ideal_output,
output_under_test.view_as(ideal_output))
elif backend == 'FLASH_ATTN':
# For FlashAttention override the accuracy thresholds to non default
# values since we notice a higher difference between the ideal and
# actual output.
torch.testing.assert_close(ideal_output,
output_under_test.view_as(ideal_output),
atol=0.01,
rtol=0.016)
else:
raise ValueError(
f"Unknown backend value: '{backend}'. Expected 'XFORMERS' or "
f"'FLASH_ATTN'.")
# Copied/modified from torch._refs.__init__.py