[Encoder Decoder] Add flash_attn kernel support for encoder-decoder models (#9559)
This commit is contained in:
@@ -16,13 +16,13 @@ from tests.kernels.utils import *
|
||||
from vllm.attention import (Attention, AttentionBackend, AttentionMetadata,
|
||||
AttentionType)
|
||||
from vllm.attention.backends.utils import STR_NOT_IMPL_ENC_DEC_ROCM_HIP
|
||||
from vllm.attention.selector import (_Backend,
|
||||
from vllm.attention.selector import (_Backend, get_attn_backend,
|
||||
global_force_attn_backend_context_manager)
|
||||
from vllm.forward_context import set_forward_context
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
# List of support backends for encoder/decoder models
|
||||
LIST_ENC_DEC_SUPPORTED_BACKENDS = [_Backend.XFORMERS]
|
||||
|
||||
LIST_ENC_DEC_SUPPORTED_BACKENDS = [_Backend.XFORMERS, _Backend.FLASH_ATTN]
|
||||
HEAD_SIZES = [64, 256]
|
||||
|
||||
NUM_HEADS = [1, 16]
|
||||
@@ -145,7 +145,8 @@ def _make_test_resources(test_pt: TestPoint, ) -> TestResources:
|
||||
test_pt.num_heads,
|
||||
test_pt.head_size,
|
||||
test_pt.block_size,
|
||||
device=CUDA_DEVICE)
|
||||
device=CUDA_DEVICE,
|
||||
backend=test_pt.backend_name)
|
||||
return TestResources(scale, attn_backend, attn, kv_cache)
|
||||
|
||||
|
||||
@@ -592,6 +593,7 @@ def _run_encoder_attention_test(
|
||||
attn: Attention,
|
||||
encoder_test_params: PhaseTestParameters,
|
||||
attn_metadata: AttentionMetadata,
|
||||
test_pt: TestPoint,
|
||||
) -> torch.Tensor:
|
||||
'''
|
||||
Run encoder attention.
|
||||
@@ -610,6 +612,8 @@ def _run_encoder_attention_test(
|
||||
(number_of_tokens x num_heads x head_size)
|
||||
query/key/value fields
|
||||
* attn_metadata: attention metadata for encoder/decoder-self attention
|
||||
* test_pt: The TestPoint object containing test details like number of
|
||||
model heads, head size, name of the backend being used etc.
|
||||
|
||||
Returns:
|
||||
* Attention.forward() applied to packed {query,key,value} and
|
||||
@@ -619,20 +623,31 @@ def _run_encoder_attention_test(
|
||||
attn_type = AttentionType.ENCODER
|
||||
packed_qkv = encoder_test_params.packed_qkvo.packed_qkv
|
||||
assert packed_qkv is not None
|
||||
return attn.forward(packed_qkv.query,
|
||||
packed_qkv.key,
|
||||
packed_qkv.value,
|
||||
torch.tensor([],
|
||||
dtype=torch.float32,
|
||||
device=packed_qkv.query.device),
|
||||
attn_metadata,
|
||||
attn_type=attn_type)
|
||||
with set_forward_context(attn_metadata):
|
||||
# In the test setup the shape of the query is
|
||||
# [batch_size, seq_len, num_heads, head_size]. However
|
||||
# the attention backend expect the shape to be
|
||||
# [num_tokens, hidden_size]. Hence reshape the query before
|
||||
# invoking the forward method.
|
||||
# TODO - Update the way we construct the query so that it
|
||||
# is shaped as [num_tokens, hidden_size] and we can skip the reshape.
|
||||
reshaped_query = packed_qkv.query.view(
|
||||
-1, test_pt.num_heads * test_pt.head_size)
|
||||
return attn.forward(reshaped_query,
|
||||
packed_qkv.key,
|
||||
packed_qkv.value,
|
||||
torch.tensor([],
|
||||
dtype=torch.float32,
|
||||
device=packed_qkv.query.device),
|
||||
attn_metadata,
|
||||
attn_type=attn_type)
|
||||
|
||||
|
||||
def _run_decoder_self_attention_test(
|
||||
test_rsrcs: TestResources,
|
||||
decoder_test_params: PhaseTestParameters,
|
||||
attn_metadata: AttentionMetadata,
|
||||
test_pt: TestPoint,
|
||||
) -> torch.Tensor:
|
||||
'''
|
||||
Run decoder self-attention test.
|
||||
@@ -650,6 +665,8 @@ def _run_decoder_self_attention_test(
|
||||
query/key/value fields
|
||||
* attn_metadata: attention metadata for decoder-self attention
|
||||
(contains KV cache memory-mapping)
|
||||
* test_pt: The TestPoint object containing test details like number of
|
||||
model heads, head size, name of the backend being used etc.
|
||||
|
||||
Returns:
|
||||
* Attention.forward() applied to packed_{query,key,value}, kv_cache
|
||||
@@ -660,12 +677,22 @@ def _run_decoder_self_attention_test(
|
||||
kv_cache = test_rsrcs.kv_cache
|
||||
packed_qkv = decoder_test_params.packed_qkvo.packed_qkv
|
||||
assert packed_qkv is not None
|
||||
return attn.forward(packed_qkv.query,
|
||||
packed_qkv.key,
|
||||
packed_qkv.value,
|
||||
kv_cache,
|
||||
attn_metadata,
|
||||
attn_type=attn_type)
|
||||
with set_forward_context(attn_metadata):
|
||||
# In the test setup the shape of the query is
|
||||
# [batch_size, seq_len, num_heads, head_size]. However
|
||||
# the attention backend expect the shape to be
|
||||
# [num_tokens, hidden_size]. Hence reshape the query before
|
||||
# invoking the forward method.
|
||||
# TODO - Update the way we construct the query so that it
|
||||
# is shaped as [num_tokens, hidden_size] and we can skip the reshape.
|
||||
reshaped_query = packed_qkv.query.view(
|
||||
-1, test_pt.num_heads * test_pt.head_size)
|
||||
return attn.forward(reshaped_query,
|
||||
packed_qkv.key,
|
||||
packed_qkv.value,
|
||||
kv_cache,
|
||||
attn_metadata,
|
||||
attn_type=attn_type)
|
||||
|
||||
|
||||
def _run_encoder_decoder_cross_attention_test(
|
||||
@@ -673,6 +700,7 @@ def _run_encoder_decoder_cross_attention_test(
|
||||
decoder_test_params: PhaseTestParameters,
|
||||
cross_test_params: Optional[PhaseTestParameters],
|
||||
attn_metadata: AttentionMetadata,
|
||||
test_pt: TestPoint,
|
||||
) -> torch.Tensor:
|
||||
'''
|
||||
Run encoder/decoder cross-attention test.
|
||||
@@ -701,6 +729,8 @@ def _run_encoder_decoder_cross_attention_test(
|
||||
(number_of_tokens x num_heads x head_size)
|
||||
key/value fields
|
||||
* attn_metadata: attention metadata for encoder/decoder-self attention
|
||||
* test_pt: The TestPoint object containing test details like number of
|
||||
model heads, head size, name of the backend being used etc.
|
||||
|
||||
Returns:
|
||||
* Attention.forward() applied to packed_{query,key,value}, kv_cache
|
||||
@@ -718,12 +748,37 @@ def _run_encoder_decoder_cross_attention_test(
|
||||
cross_pckd_qkv = cross_test_params.packed_qkvo.packed_qkv
|
||||
key = (None if cross_pckd_qkv is None else cross_pckd_qkv.key)
|
||||
value = (None if cross_pckd_qkv is None else cross_pckd_qkv.value)
|
||||
return attn.forward(decoder_test_params.packed_qkvo.packed_qkv.query,
|
||||
key,
|
||||
value,
|
||||
kv_cache,
|
||||
attn_metadata,
|
||||
attn_type=attn_type)
|
||||
with set_forward_context(attn_metadata):
|
||||
# In the test setup the shape of the query is
|
||||
# [batch_size, seq_len, num_heads, head_size]. However
|
||||
# the attention backend expect the shape to be
|
||||
# [num_tokens, hidden_size]. Hence reshape the query before
|
||||
# invoking the forward method.
|
||||
# TODO - Update the way we construct the query so that it
|
||||
# is shaped as [num_tokens, hidden_size] and we can skip the reshape.
|
||||
reshaped_query = decoder_test_params.packed_qkvo.packed_qkv.query.view(
|
||||
-1, test_pt.num_heads * test_pt.head_size)
|
||||
return attn.forward(reshaped_query,
|
||||
key,
|
||||
value,
|
||||
kv_cache,
|
||||
attn_metadata,
|
||||
attn_type=attn_type)
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def set_reset_environment(attn_backend):
|
||||
# Set the default torch datatype to bfloat16 to enable
|
||||
# testing of the Flash Attention backend. Also clear the
|
||||
# cached value of the backend.
|
||||
default_dtype = torch.get_default_dtype()
|
||||
if attn_backend.name == 'FLASH_ATTN':
|
||||
torch.set_default_dtype(torch.bfloat16)
|
||||
get_attn_backend.cache_clear()
|
||||
yield
|
||||
# Reset the torch datatype to what it was before the test
|
||||
# so as not to impact the remaining tests.
|
||||
torch.set_default_dtype(default_dtype)
|
||||
|
||||
|
||||
@pytest.mark.skipif(current_platform.is_rocm(),
|
||||
@@ -773,10 +828,8 @@ def test_encoder_only(
|
||||
* max_dec_seq_len: max length of decoder input sequences
|
||||
* max_enc_seq_len: max length of encoder input sequences
|
||||
'''
|
||||
|
||||
# Force Attention wrapper backend
|
||||
with global_force_attn_backend_context_manager(attn_backend):
|
||||
|
||||
# Note: KV cache size of 4096 is arbitrary & chosen intentionally
|
||||
# to be more than necessary, since exceeding the kv cache size
|
||||
# is not part of this test
|
||||
@@ -807,10 +860,14 @@ def test_encoder_only(
|
||||
# PREFILL: encoder attention
|
||||
|
||||
enc_pckd_act_out: torch.Tensor = (_run_encoder_attention_test(
|
||||
test_rsrcs.attn, enc_test_params, prephase_attn_metadata))
|
||||
test_rsrcs.attn,
|
||||
enc_test_params,
|
||||
prephase_attn_metadata,
|
||||
test_pt=test_pt))
|
||||
|
||||
# - Is encoder attention result correct?
|
||||
assert_actual_matches_ideal(enc_test_params, enc_pckd_act_out)
|
||||
assert_actual_matches_ideal(enc_test_params, enc_pckd_act_out,
|
||||
attn_backend.name)
|
||||
|
||||
|
||||
@pytest.mark.skipif(current_platform.is_rocm(),
|
||||
@@ -892,10 +949,8 @@ def test_e2e_enc_dec_attn(
|
||||
* max_dec_seq_len: max length of decoder input sequences
|
||||
* max_enc_seq_len: max length of encoder input sequences
|
||||
'''
|
||||
|
||||
# Force Attention wrapper backend
|
||||
with global_force_attn_backend_context_manager(attn_backend):
|
||||
|
||||
# Note: KV cache size of 4096 is arbitrary & chosen intentionally
|
||||
# to be more than necessary, since exceeding the kv cache size
|
||||
# is not part of this test
|
||||
@@ -955,29 +1010,39 @@ def test_e2e_enc_dec_attn(
|
||||
|
||||
enc_pckd_act_out = _run_encoder_attention_test(test_rsrcs.attn,
|
||||
enc_test_params,
|
||||
prephase_attn_metadata)
|
||||
prephase_attn_metadata,
|
||||
test_pt=test_pt)
|
||||
|
||||
# - Is encoder attention result correct?
|
||||
assert_actual_matches_ideal(enc_test_params, enc_pckd_act_out)
|
||||
assert_actual_matches_ideal(enc_test_params, enc_pckd_act_out,
|
||||
attn_backend.name)
|
||||
|
||||
# PREFILL: decoder self-attention test
|
||||
|
||||
prephase_dec_pckd_act_out = _run_decoder_self_attention_test(
|
||||
test_rsrcs, prephase_dec_test_params, prephase_attn_metadata)
|
||||
test_rsrcs,
|
||||
prephase_dec_test_params,
|
||||
prephase_attn_metadata,
|
||||
test_pt=test_pt)
|
||||
|
||||
# - Is prefill decoder self-attention correct?
|
||||
assert_actual_matches_ideal(prephase_dec_test_params,
|
||||
prephase_dec_pckd_act_out)
|
||||
prephase_dec_pckd_act_out,
|
||||
attn_backend.name)
|
||||
|
||||
# PREFILL: encoder/decoder cross-attention test
|
||||
|
||||
prephase_cross_pckd_act_out = _run_encoder_decoder_cross_attention_test(
|
||||
test_rsrcs, prephase_dec_test_params, prephase_cross_test_params,
|
||||
prephase_attn_metadata)
|
||||
test_rsrcs,
|
||||
prephase_dec_test_params,
|
||||
prephase_cross_test_params,
|
||||
prephase_attn_metadata,
|
||||
test_pt=test_pt)
|
||||
|
||||
# - Is prefill encoder/decoder cross-attention correct?
|
||||
assert_actual_matches_ideal(prephase_cross_test_params,
|
||||
prephase_cross_pckd_act_out)
|
||||
prephase_cross_pckd_act_out,
|
||||
attn_backend.name)
|
||||
|
||||
# DECODE: build decode-phase attention metadata
|
||||
|
||||
@@ -993,17 +1058,26 @@ def test_e2e_enc_dec_attn(
|
||||
# DECODE: decoder self-attention test
|
||||
|
||||
decphase_dec_pckd_act_out = _run_decoder_self_attention_test(
|
||||
test_rsrcs, decphase_dec_test_params, decphase_attn_metadata)
|
||||
test_rsrcs,
|
||||
decphase_dec_test_params,
|
||||
decphase_attn_metadata,
|
||||
test_pt=test_pt)
|
||||
|
||||
# - Is decode-phase decoder self-attention correct?
|
||||
assert_actual_matches_ideal(decphase_dec_test_params,
|
||||
decphase_dec_pckd_act_out)
|
||||
decphase_dec_pckd_act_out,
|
||||
attn_backend.name)
|
||||
|
||||
# DECODE: encoder/decoder cross-attention test
|
||||
|
||||
decphase_cross_pckd_act_out = _run_encoder_decoder_cross_attention_test(
|
||||
test_rsrcs, decphase_dec_test_params, None, decphase_attn_metadata)
|
||||
test_rsrcs,
|
||||
decphase_dec_test_params,
|
||||
None,
|
||||
decphase_attn_metadata,
|
||||
test_pt=test_pt)
|
||||
|
||||
# - Is decode-phase encoder/decoder cross-attention correct?
|
||||
assert_actual_matches_ideal(decphase_cross_test_params,
|
||||
decphase_cross_pckd_act_out)
|
||||
decphase_cross_pckd_act_out,
|
||||
attn_backend.name)
|
||||
|
||||
Reference in New Issue
Block a user