[Encoder Decoder] Add flash_attn kernel support for encoder-decoder models (#9559)

This commit is contained in:
sroy745
2024-11-01 23:22:49 -07:00
committed by GitHub
parent d522034c85
commit a78dd3303e
11 changed files with 715 additions and 316 deletions

View File

@@ -16,13 +16,13 @@ from tests.kernels.utils import *
from vllm.attention import (Attention, AttentionBackend, AttentionMetadata,
AttentionType)
from vllm.attention.backends.utils import STR_NOT_IMPL_ENC_DEC_ROCM_HIP
from vllm.attention.selector import (_Backend,
from vllm.attention.selector import (_Backend, get_attn_backend,
global_force_attn_backend_context_manager)
from vllm.forward_context import set_forward_context
from vllm.platforms import current_platform
# List of support backends for encoder/decoder models
LIST_ENC_DEC_SUPPORTED_BACKENDS = [_Backend.XFORMERS]
LIST_ENC_DEC_SUPPORTED_BACKENDS = [_Backend.XFORMERS, _Backend.FLASH_ATTN]
HEAD_SIZES = [64, 256]
NUM_HEADS = [1, 16]
@@ -145,7 +145,8 @@ def _make_test_resources(test_pt: TestPoint, ) -> TestResources:
test_pt.num_heads,
test_pt.head_size,
test_pt.block_size,
device=CUDA_DEVICE)
device=CUDA_DEVICE,
backend=test_pt.backend_name)
return TestResources(scale, attn_backend, attn, kv_cache)
@@ -592,6 +593,7 @@ def _run_encoder_attention_test(
attn: Attention,
encoder_test_params: PhaseTestParameters,
attn_metadata: AttentionMetadata,
test_pt: TestPoint,
) -> torch.Tensor:
'''
Run encoder attention.
@@ -610,6 +612,8 @@ def _run_encoder_attention_test(
(number_of_tokens x num_heads x head_size)
query/key/value fields
* attn_metadata: attention metadata for encoder/decoder-self attention
* test_pt: The TestPoint object containing test details like number of
model heads, head size, name of the backend being used etc.
Returns:
* Attention.forward() applied to packed {query,key,value} and
@@ -619,20 +623,31 @@ def _run_encoder_attention_test(
attn_type = AttentionType.ENCODER
packed_qkv = encoder_test_params.packed_qkvo.packed_qkv
assert packed_qkv is not None
return attn.forward(packed_qkv.query,
packed_qkv.key,
packed_qkv.value,
torch.tensor([],
dtype=torch.float32,
device=packed_qkv.query.device),
attn_metadata,
attn_type=attn_type)
with set_forward_context(attn_metadata):
# In the test setup the shape of the query is
# [batch_size, seq_len, num_heads, head_size]. However
# the attention backend expect the shape to be
# [num_tokens, hidden_size]. Hence reshape the query before
# invoking the forward method.
# TODO - Update the way we construct the query so that it
# is shaped as [num_tokens, hidden_size] and we can skip the reshape.
reshaped_query = packed_qkv.query.view(
-1, test_pt.num_heads * test_pt.head_size)
return attn.forward(reshaped_query,
packed_qkv.key,
packed_qkv.value,
torch.tensor([],
dtype=torch.float32,
device=packed_qkv.query.device),
attn_metadata,
attn_type=attn_type)
def _run_decoder_self_attention_test(
test_rsrcs: TestResources,
decoder_test_params: PhaseTestParameters,
attn_metadata: AttentionMetadata,
test_pt: TestPoint,
) -> torch.Tensor:
'''
Run decoder self-attention test.
@@ -650,6 +665,8 @@ def _run_decoder_self_attention_test(
query/key/value fields
* attn_metadata: attention metadata for decoder-self attention
(contains KV cache memory-mapping)
* test_pt: The TestPoint object containing test details like number of
model heads, head size, name of the backend being used etc.
Returns:
* Attention.forward() applied to packed_{query,key,value}, kv_cache
@@ -660,12 +677,22 @@ def _run_decoder_self_attention_test(
kv_cache = test_rsrcs.kv_cache
packed_qkv = decoder_test_params.packed_qkvo.packed_qkv
assert packed_qkv is not None
return attn.forward(packed_qkv.query,
packed_qkv.key,
packed_qkv.value,
kv_cache,
attn_metadata,
attn_type=attn_type)
with set_forward_context(attn_metadata):
# In the test setup the shape of the query is
# [batch_size, seq_len, num_heads, head_size]. However
# the attention backend expect the shape to be
# [num_tokens, hidden_size]. Hence reshape the query before
# invoking the forward method.
# TODO - Update the way we construct the query so that it
# is shaped as [num_tokens, hidden_size] and we can skip the reshape.
reshaped_query = packed_qkv.query.view(
-1, test_pt.num_heads * test_pt.head_size)
return attn.forward(reshaped_query,
packed_qkv.key,
packed_qkv.value,
kv_cache,
attn_metadata,
attn_type=attn_type)
def _run_encoder_decoder_cross_attention_test(
@@ -673,6 +700,7 @@ def _run_encoder_decoder_cross_attention_test(
decoder_test_params: PhaseTestParameters,
cross_test_params: Optional[PhaseTestParameters],
attn_metadata: AttentionMetadata,
test_pt: TestPoint,
) -> torch.Tensor:
'''
Run encoder/decoder cross-attention test.
@@ -701,6 +729,8 @@ def _run_encoder_decoder_cross_attention_test(
(number_of_tokens x num_heads x head_size)
key/value fields
* attn_metadata: attention metadata for encoder/decoder-self attention
* test_pt: The TestPoint object containing test details like number of
model heads, head size, name of the backend being used etc.
Returns:
* Attention.forward() applied to packed_{query,key,value}, kv_cache
@@ -718,12 +748,37 @@ def _run_encoder_decoder_cross_attention_test(
cross_pckd_qkv = cross_test_params.packed_qkvo.packed_qkv
key = (None if cross_pckd_qkv is None else cross_pckd_qkv.key)
value = (None if cross_pckd_qkv is None else cross_pckd_qkv.value)
return attn.forward(decoder_test_params.packed_qkvo.packed_qkv.query,
key,
value,
kv_cache,
attn_metadata,
attn_type=attn_type)
with set_forward_context(attn_metadata):
# In the test setup the shape of the query is
# [batch_size, seq_len, num_heads, head_size]. However
# the attention backend expect the shape to be
# [num_tokens, hidden_size]. Hence reshape the query before
# invoking the forward method.
# TODO - Update the way we construct the query so that it
# is shaped as [num_tokens, hidden_size] and we can skip the reshape.
reshaped_query = decoder_test_params.packed_qkvo.packed_qkv.query.view(
-1, test_pt.num_heads * test_pt.head_size)
return attn.forward(reshaped_query,
key,
value,
kv_cache,
attn_metadata,
attn_type=attn_type)
@pytest.fixture(autouse=True)
def set_reset_environment(attn_backend):
# Set the default torch datatype to bfloat16 to enable
# testing of the Flash Attention backend. Also clear the
# cached value of the backend.
default_dtype = torch.get_default_dtype()
if attn_backend.name == 'FLASH_ATTN':
torch.set_default_dtype(torch.bfloat16)
get_attn_backend.cache_clear()
yield
# Reset the torch datatype to what it was before the test
# so as not to impact the remaining tests.
torch.set_default_dtype(default_dtype)
@pytest.mark.skipif(current_platform.is_rocm(),
@@ -773,10 +828,8 @@ def test_encoder_only(
* max_dec_seq_len: max length of decoder input sequences
* max_enc_seq_len: max length of encoder input sequences
'''
# Force Attention wrapper backend
with global_force_attn_backend_context_manager(attn_backend):
# Note: KV cache size of 4096 is arbitrary & chosen intentionally
# to be more than necessary, since exceeding the kv cache size
# is not part of this test
@@ -807,10 +860,14 @@ def test_encoder_only(
# PREFILL: encoder attention
enc_pckd_act_out: torch.Tensor = (_run_encoder_attention_test(
test_rsrcs.attn, enc_test_params, prephase_attn_metadata))
test_rsrcs.attn,
enc_test_params,
prephase_attn_metadata,
test_pt=test_pt))
# - Is encoder attention result correct?
assert_actual_matches_ideal(enc_test_params, enc_pckd_act_out)
assert_actual_matches_ideal(enc_test_params, enc_pckd_act_out,
attn_backend.name)
@pytest.mark.skipif(current_platform.is_rocm(),
@@ -892,10 +949,8 @@ def test_e2e_enc_dec_attn(
* max_dec_seq_len: max length of decoder input sequences
* max_enc_seq_len: max length of encoder input sequences
'''
# Force Attention wrapper backend
with global_force_attn_backend_context_manager(attn_backend):
# Note: KV cache size of 4096 is arbitrary & chosen intentionally
# to be more than necessary, since exceeding the kv cache size
# is not part of this test
@@ -955,29 +1010,39 @@ def test_e2e_enc_dec_attn(
enc_pckd_act_out = _run_encoder_attention_test(test_rsrcs.attn,
enc_test_params,
prephase_attn_metadata)
prephase_attn_metadata,
test_pt=test_pt)
# - Is encoder attention result correct?
assert_actual_matches_ideal(enc_test_params, enc_pckd_act_out)
assert_actual_matches_ideal(enc_test_params, enc_pckd_act_out,
attn_backend.name)
# PREFILL: decoder self-attention test
prephase_dec_pckd_act_out = _run_decoder_self_attention_test(
test_rsrcs, prephase_dec_test_params, prephase_attn_metadata)
test_rsrcs,
prephase_dec_test_params,
prephase_attn_metadata,
test_pt=test_pt)
# - Is prefill decoder self-attention correct?
assert_actual_matches_ideal(prephase_dec_test_params,
prephase_dec_pckd_act_out)
prephase_dec_pckd_act_out,
attn_backend.name)
# PREFILL: encoder/decoder cross-attention test
prephase_cross_pckd_act_out = _run_encoder_decoder_cross_attention_test(
test_rsrcs, prephase_dec_test_params, prephase_cross_test_params,
prephase_attn_metadata)
test_rsrcs,
prephase_dec_test_params,
prephase_cross_test_params,
prephase_attn_metadata,
test_pt=test_pt)
# - Is prefill encoder/decoder cross-attention correct?
assert_actual_matches_ideal(prephase_cross_test_params,
prephase_cross_pckd_act_out)
prephase_cross_pckd_act_out,
attn_backend.name)
# DECODE: build decode-phase attention metadata
@@ -993,17 +1058,26 @@ def test_e2e_enc_dec_attn(
# DECODE: decoder self-attention test
decphase_dec_pckd_act_out = _run_decoder_self_attention_test(
test_rsrcs, decphase_dec_test_params, decphase_attn_metadata)
test_rsrcs,
decphase_dec_test_params,
decphase_attn_metadata,
test_pt=test_pt)
# - Is decode-phase decoder self-attention correct?
assert_actual_matches_ideal(decphase_dec_test_params,
decphase_dec_pckd_act_out)
decphase_dec_pckd_act_out,
attn_backend.name)
# DECODE: encoder/decoder cross-attention test
decphase_cross_pckd_act_out = _run_encoder_decoder_cross_attention_test(
test_rsrcs, decphase_dec_test_params, None, decphase_attn_metadata)
test_rsrcs,
decphase_dec_test_params,
None,
decphase_attn_metadata,
test_pt=test_pt)
# - Is decode-phase encoder/decoder cross-attention correct?
assert_actual_matches_ideal(decphase_cross_test_params,
decphase_cross_pckd_act_out)
decphase_cross_pckd_act_out,
attn_backend.name)