[torch.compile] support all attention backends (#10558)
Signed-off-by: youkaichao <youkaichao@gmail.com>
This commit is contained in:
@@ -18,8 +18,10 @@ from vllm.attention import (Attention, AttentionBackend, AttentionMetadata,
|
||||
from vllm.attention.backends.utils import STR_NOT_IMPL_ENC_DEC_ROCM_HIP
|
||||
from vllm.attention.selector import (_Backend, _cached_get_attn_backend,
|
||||
global_force_attn_backend_context_manager)
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.forward_context import set_forward_context
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.plugins import set_current_vllm_config
|
||||
|
||||
# List of support backends for encoder/decoder models
|
||||
LIST_ENC_DEC_SUPPORTED_BACKENDS = [_Backend.XFORMERS, _Backend.FLASH_ATTN]
|
||||
@@ -594,6 +596,7 @@ def _run_encoder_attention_test(
|
||||
encoder_test_params: PhaseTestParameters,
|
||||
attn_metadata: AttentionMetadata,
|
||||
test_pt: TestPoint,
|
||||
vllm_config: VllmConfig,
|
||||
) -> torch.Tensor:
|
||||
'''
|
||||
Run encoder attention.
|
||||
@@ -623,7 +626,7 @@ def _run_encoder_attention_test(
|
||||
attn_type = AttentionType.ENCODER
|
||||
packed_qkv = encoder_test_params.packed_qkvo.packed_qkv
|
||||
assert packed_qkv is not None
|
||||
with set_forward_context(attn_metadata):
|
||||
with set_forward_context(attn_metadata, vllm_config):
|
||||
# In the test setup the shape of the query is
|
||||
# [batch_size, seq_len, num_heads, head_size]. However
|
||||
# the attention backend expect the shape to be
|
||||
@@ -648,6 +651,7 @@ def _run_decoder_self_attention_test(
|
||||
decoder_test_params: PhaseTestParameters,
|
||||
attn_metadata: AttentionMetadata,
|
||||
test_pt: TestPoint,
|
||||
vllm_config: VllmConfig,
|
||||
) -> torch.Tensor:
|
||||
'''
|
||||
Run decoder self-attention test.
|
||||
@@ -677,7 +681,7 @@ def _run_decoder_self_attention_test(
|
||||
kv_cache = test_rsrcs.kv_cache
|
||||
packed_qkv = decoder_test_params.packed_qkvo.packed_qkv
|
||||
assert packed_qkv is not None
|
||||
with set_forward_context(attn_metadata):
|
||||
with set_forward_context(attn_metadata, vllm_config):
|
||||
# In the test setup the shape of the query is
|
||||
# [batch_size, seq_len, num_heads, head_size]. However
|
||||
# the attention backend expect the shape to be
|
||||
@@ -701,6 +705,7 @@ def _run_encoder_decoder_cross_attention_test(
|
||||
cross_test_params: Optional[PhaseTestParameters],
|
||||
attn_metadata: AttentionMetadata,
|
||||
test_pt: TestPoint,
|
||||
vllm_config: VllmConfig,
|
||||
) -> torch.Tensor:
|
||||
'''
|
||||
Run encoder/decoder cross-attention test.
|
||||
@@ -748,7 +753,7 @@ def _run_encoder_decoder_cross_attention_test(
|
||||
cross_pckd_qkv = cross_test_params.packed_qkvo.packed_qkv
|
||||
key = (None if cross_pckd_qkv is None else cross_pckd_qkv.key)
|
||||
value = (None if cross_pckd_qkv is None else cross_pckd_qkv.value)
|
||||
with set_forward_context(attn_metadata):
|
||||
with set_forward_context(attn_metadata, vllm_config):
|
||||
# In the test setup the shape of the query is
|
||||
# [batch_size, seq_len, num_heads, head_size]. However
|
||||
# the attention backend expect the shape to be
|
||||
@@ -839,7 +844,9 @@ def test_encoder_only(
|
||||
|
||||
# Attention scale factor, attention backend instance, attention wrapper
|
||||
# instance, KV cache init
|
||||
test_rsrcs = _make_test_resources(test_pt)
|
||||
vllm_config = VllmConfig()
|
||||
with set_current_vllm_config(vllm_config):
|
||||
test_rsrcs = _make_test_resources(test_pt)
|
||||
|
||||
# Construct encoder attention test params (only used
|
||||
# during prefill)
|
||||
@@ -863,7 +870,8 @@ def test_encoder_only(
|
||||
test_rsrcs.attn,
|
||||
enc_test_params,
|
||||
prephase_attn_metadata,
|
||||
test_pt=test_pt))
|
||||
test_pt=test_pt,
|
||||
vllm_config=vllm_config))
|
||||
|
||||
# - Is encoder attention result correct?
|
||||
assert_actual_matches_ideal(enc_test_params, enc_pckd_act_out,
|
||||
@@ -960,7 +968,9 @@ def test_e2e_enc_dec_attn(
|
||||
|
||||
# Attention scale factor, attention backend instance, attention wrapper
|
||||
# instance, KV cache init
|
||||
test_rsrcs = _make_test_resources(test_pt)
|
||||
vllm_config = VllmConfig()
|
||||
with set_current_vllm_config(vllm_config):
|
||||
test_rsrcs = _make_test_resources(test_pt)
|
||||
|
||||
# Construct encoder attention test params (only used
|
||||
# during prefill)
|
||||
@@ -1011,7 +1021,8 @@ def test_e2e_enc_dec_attn(
|
||||
enc_pckd_act_out = _run_encoder_attention_test(test_rsrcs.attn,
|
||||
enc_test_params,
|
||||
prephase_attn_metadata,
|
||||
test_pt=test_pt)
|
||||
test_pt=test_pt,
|
||||
vllm_config=vllm_config)
|
||||
|
||||
# - Is encoder attention result correct?
|
||||
assert_actual_matches_ideal(enc_test_params, enc_pckd_act_out,
|
||||
@@ -1023,7 +1034,8 @@ def test_e2e_enc_dec_attn(
|
||||
test_rsrcs,
|
||||
prephase_dec_test_params,
|
||||
prephase_attn_metadata,
|
||||
test_pt=test_pt)
|
||||
test_pt=test_pt,
|
||||
vllm_config=vllm_config)
|
||||
|
||||
# - Is prefill decoder self-attention correct?
|
||||
assert_actual_matches_ideal(prephase_dec_test_params,
|
||||
@@ -1037,7 +1049,8 @@ def test_e2e_enc_dec_attn(
|
||||
prephase_dec_test_params,
|
||||
prephase_cross_test_params,
|
||||
prephase_attn_metadata,
|
||||
test_pt=test_pt)
|
||||
test_pt=test_pt,
|
||||
vllm_config=vllm_config)
|
||||
|
||||
# - Is prefill encoder/decoder cross-attention correct?
|
||||
assert_actual_matches_ideal(prephase_cross_test_params,
|
||||
@@ -1061,7 +1074,8 @@ def test_e2e_enc_dec_attn(
|
||||
test_rsrcs,
|
||||
decphase_dec_test_params,
|
||||
decphase_attn_metadata,
|
||||
test_pt=test_pt)
|
||||
test_pt=test_pt,
|
||||
vllm_config=vllm_config)
|
||||
|
||||
# - Is decode-phase decoder self-attention correct?
|
||||
assert_actual_matches_ideal(decphase_dec_test_params,
|
||||
@@ -1075,7 +1089,8 @@ def test_e2e_enc_dec_attn(
|
||||
decphase_dec_test_params,
|
||||
None,
|
||||
decphase_attn_metadata,
|
||||
test_pt=test_pt)
|
||||
test_pt=test_pt,
|
||||
vllm_config=vllm_config)
|
||||
|
||||
# - Is decode-phase encoder/decoder cross-attention correct?
|
||||
assert_actual_matches_ideal(decphase_cross_test_params,
|
||||
|
||||
Reference in New Issue
Block a user