[Core] Subclass ModelRunner to support cross-attention & encoder sequences (towards eventual encoder/decoder model support) (#4942)
Co-authored-by: Andrew Feldman <afeld2012@gmail.com> Co-authored-by: Nick Hill <nickhill@us.ibm.com>
This commit is contained in:
@@ -4,8 +4,6 @@ Tests:
|
||||
* E2E test of Encoder attention + Decoder self-attention +
|
||||
Encoder/decoder cross-attention (collectively
|
||||
"encoder/decoder attention")
|
||||
* Confirm enc/dec models will fail for chunked prefill
|
||||
* Confirm enc/dec models will fail for prefix caching
|
||||
|
||||
"""
|
||||
|
||||
@@ -15,19 +13,22 @@ import pytest
|
||||
import torch
|
||||
|
||||
from tests.kernels.utils import *
|
||||
from tests.kernels.utils import make_causal_mask, maybe_make_long_tensor
|
||||
from vllm.attention import Attention, AttentionMetadata
|
||||
from vllm.attention.backends.abstract import AttentionBackend, AttentionType
|
||||
from vllm.attention import (Attention, AttentionBackend, AttentionMetadata,
|
||||
AttentionType)
|
||||
from vllm.attention.backends.utils import STR_NOT_IMPL_ENC_DEC_ROCM_HIP
|
||||
from vllm.attention.selector import (_Backend,
|
||||
global_force_attn_backend_context_manager)
|
||||
from vllm.utils import is_hip
|
||||
|
||||
# List of support backends for encoder/decoder models
|
||||
LIST_ENC_DEC_SUPPORTED_BACKENDS = [_Backend.XFORMERS]
|
||||
|
||||
HEAD_SIZES = [64, 256]
|
||||
|
||||
NUM_HEADS = [1, 16]
|
||||
|
||||
BATCH_SIZES = [1, 16]
|
||||
BLOCK_SIZES = [16]
|
||||
BACKEND_NAMES = [STR_XFORMERS_ATTN_VAL]
|
||||
CUDA_DEVICE = "cuda:0"
|
||||
|
||||
MAX_DEC_SEQ_LENS = [128]
|
||||
@@ -724,57 +725,92 @@ def _run_encoder_decoder_cross_attention_test(
|
||||
@pytest.mark.skipif(is_hip(), reason=STR_NOT_IMPL_ENC_DEC_ROCM_HIP)
|
||||
@pytest.mark.parametrize("num_heads", NUM_HEADS)
|
||||
@pytest.mark.parametrize("head_size", HEAD_SIZES)
|
||||
@pytest.mark.parametrize("backend_name", BACKEND_NAMES)
|
||||
@pytest.mark.parametrize("attn_backend", LIST_ENC_DEC_SUPPORTED_BACKENDS)
|
||||
@pytest.mark.parametrize("batch_size", BATCH_SIZES)
|
||||
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
|
||||
@pytest.mark.parametrize("max_dec_seq_len", MAX_DEC_SEQ_LENS)
|
||||
@pytest.mark.parametrize("max_enc_seq_len", MAX_ENC_SEQ_LENS)
|
||||
def test_encoder_only(num_heads: int, head_size: int, backend_name: str,
|
||||
batch_size: int, block_size: int, max_dec_seq_len: int,
|
||||
max_enc_seq_len: int, monkeypatch):
|
||||
def test_encoder_only(
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
attn_backend: _Backend,
|
||||
batch_size: int,
|
||||
block_size: int,
|
||||
max_dec_seq_len: int,
|
||||
max_enc_seq_len: int,
|
||||
):
|
||||
'''
|
||||
End-to-end encoder-only attention test:
|
||||
|
||||
* Construct fake test vectors for (1) encoder attention
|
||||
* Construct (1) attention metadata structure with prefill-phase
|
||||
encoder attention, and (2) an analogous attention metadata
|
||||
structure but for decode-phase
|
||||
* Test & validate encoder attention against ideal output
|
||||
|
||||
No KV cache is required for encoder-only attention.
|
||||
|
||||
Note on ROCm/HIP: currently encoder/decoder models are not supported on
|
||||
AMD GPUs, therefore this test simply is skipped if is_hip().
|
||||
|
||||
This test globally forces an override of the usual backend
|
||||
auto-selection process, forcing the specific backend-under-test
|
||||
to be utilized.
|
||||
|
||||
Arguments:
|
||||
|
||||
* num_heads
|
||||
* head_size,
|
||||
* attn_backend: The attention backend to employ for testing
|
||||
* batch_size
|
||||
* block_size: KV cache block size
|
||||
* max_dec_seq_len: max length of decoder input sequences
|
||||
* max_enc_seq_len: max length of encoder input sequences
|
||||
'''
|
||||
|
||||
# Force Attention wrapper backend
|
||||
override_backend_env_variable(monkeypatch, backend_name)
|
||||
with global_force_attn_backend_context_manager(attn_backend):
|
||||
|
||||
# Note: KV cache size of 4096 is arbitrary & chosen intentionally
|
||||
# to be more than necessary, since exceeding the kv cache size
|
||||
# is not part of this test
|
||||
test_pt = TestPoint(num_heads, head_size, backend_name, batch_size,
|
||||
block_size, max_dec_seq_len, max_enc_seq_len, 4096)
|
||||
# Note: KV cache size of 4096 is arbitrary & chosen intentionally
|
||||
# to be more than necessary, since exceeding the kv cache size
|
||||
# is not part of this test
|
||||
test_pt = TestPoint(num_heads, head_size, attn_backend.name,
|
||||
batch_size, block_size, max_dec_seq_len,
|
||||
max_enc_seq_len, 4096)
|
||||
|
||||
# Attention scale factor, attention backend instance, attention wrapper
|
||||
# instance, KV cache init
|
||||
test_rsrcs = _make_test_resources(test_pt)
|
||||
# Attention scale factor, attention backend instance, attention wrapper
|
||||
# instance, KV cache init
|
||||
test_rsrcs = _make_test_resources(test_pt)
|
||||
|
||||
# Construct encoder attention test params (only used
|
||||
# during prefill)
|
||||
# Construct encoder attention test params (only used
|
||||
# during prefill)
|
||||
|
||||
enc_test_params = _encoder_attn_setup(test_pt, test_rsrcs)
|
||||
enc_test_params = _encoder_attn_setup(test_pt, test_rsrcs)
|
||||
|
||||
# Shared prefill metadata structure
|
||||
# Shared prefill metadata structure
|
||||
|
||||
prephase_attn_metadata: AttentionMetadata = make_test_metadata(
|
||||
test_rsrcs.attn_backend,
|
||||
True,
|
||||
None,
|
||||
decoder_test_params=None,
|
||||
encoder_test_params=enc_test_params,
|
||||
cross_test_params=None,
|
||||
device=CUDA_DEVICE)
|
||||
prephase_attn_metadata: AttentionMetadata = make_test_metadata(
|
||||
test_rsrcs.attn_backend,
|
||||
True,
|
||||
None,
|
||||
decoder_test_params=None,
|
||||
encoder_test_params=enc_test_params,
|
||||
cross_test_params=None,
|
||||
device=CUDA_DEVICE)
|
||||
|
||||
# PREFILL: encoder attention
|
||||
# PREFILL: encoder attention
|
||||
|
||||
enc_pckd_act_out: torch.Tensor = (_run_encoder_attention_test(
|
||||
test_rsrcs.attn, enc_test_params, prephase_attn_metadata))
|
||||
enc_pckd_act_out: torch.Tensor = (_run_encoder_attention_test(
|
||||
test_rsrcs.attn, enc_test_params, prephase_attn_metadata))
|
||||
|
||||
# - Is encoder attention result correct?
|
||||
assert_actual_matches_ideal(enc_test_params, enc_pckd_act_out)
|
||||
# - Is encoder attention result correct?
|
||||
assert_actual_matches_ideal(enc_test_params, enc_pckd_act_out)
|
||||
|
||||
|
||||
@pytest.mark.skipif(is_hip(), reason=STR_NOT_IMPL_ENC_DEC_ROCM_HIP)
|
||||
@pytest.mark.parametrize("num_heads", NUM_HEADS)
|
||||
@pytest.mark.parametrize("head_size", HEAD_SIZES)
|
||||
@pytest.mark.parametrize("backend_name", BACKEND_NAMES)
|
||||
@pytest.mark.parametrize("attn_backend", LIST_ENC_DEC_SUPPORTED_BACKENDS)
|
||||
@pytest.mark.parametrize("batch_size", BATCH_SIZES)
|
||||
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
|
||||
@pytest.mark.parametrize("max_dec_seq_len", MAX_DEC_SEQ_LENS)
|
||||
@@ -782,12 +818,11 @@ def test_encoder_only(num_heads: int, head_size: int, backend_name: str,
|
||||
def test_e2e_enc_dec_attn(
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
backend_name: str,
|
||||
attn_backend: _Backend,
|
||||
batch_size: int,
|
||||
block_size: int,
|
||||
max_dec_seq_len: int,
|
||||
max_enc_seq_len: int,
|
||||
monkeypatch,
|
||||
) -> None:
|
||||
'''
|
||||
End-to-end encoder/decoder test:
|
||||
@@ -820,8 +855,9 @@ def test_e2e_enc_dec_attn(
|
||||
cross-attention K/Vs are allowed to differ in seq len, as is often the case
|
||||
for cross-attention.
|
||||
|
||||
This test utilizes PyTest monkey patching to force the attention backend
|
||||
via an environment variable.
|
||||
This test globally forces an override of the usual backend
|
||||
auto-selection process, forcing the specific backend-under-test
|
||||
to be utilized.
|
||||
|
||||
Note on ROCm/HIP: currently encoder/decoder models are not supported on
|
||||
AMD GPUs, therefore this test simply is skipped if is_hip().
|
||||
@@ -830,124 +866,136 @@ def test_e2e_enc_dec_attn(
|
||||
all prefill-phase attention operations (encoder, decoder, enc/dec cross),
|
||||
and a single one shared by all decode-phase attention operations
|
||||
(decoder & enc/dec cross.) This is intended to reflect the behavior
|
||||
of ModelRunner, which constructs a single attention metadata structure for
|
||||
each prefill or decode run. A realistic scenario would rely on the
|
||||
attention backend to utilize the appropriate attention metadata fields
|
||||
according to the value of attn_metadata.attention_type. Thus, this test is
|
||||
organized so as to confirm that the backend-under-test can handle a
|
||||
shared prefill attention metadata structure & a shared decode attention
|
||||
metadata structure.
|
||||
of EncoderDecoderModelRunner, which constructs a single attention metadata
|
||||
structure for each prefill or decode run. A realistic scenario would rely
|
||||
on the attention backend to utilize the appropriate attention metadata
|
||||
fields according to the value of attn_metadata.attention_type. Thus,
|
||||
this test is organized so as to confirm that the backend-under-test can
|
||||
handle a shared prefill attention metadata structure & a shared decode\
|
||||
attention metadata structure.
|
||||
|
||||
Arguments:
|
||||
|
||||
* num_heads
|
||||
* head_size,
|
||||
* attn_backend: The attention backend to employ for testing
|
||||
* batch_size
|
||||
* block_size: KV cache block size
|
||||
* max_dec_seq_len: max length of decoder input sequences
|
||||
* max_enc_seq_len: max length of encoder input sequences
|
||||
'''
|
||||
|
||||
# Force Attention wrapper backend
|
||||
override_backend_env_variable(monkeypatch, backend_name)
|
||||
with global_force_attn_backend_context_manager(attn_backend):
|
||||
|
||||
# Note: KV cache size of 4096 is arbitrary & chosen intentionally
|
||||
# to be more than necessary, since exceeding the kv cache size
|
||||
# is not part of this test
|
||||
test_pt = TestPoint(num_heads, head_size, backend_name, batch_size,
|
||||
block_size, max_dec_seq_len, max_enc_seq_len, 4096)
|
||||
# Note: KV cache size of 4096 is arbitrary & chosen intentionally
|
||||
# to be more than necessary, since exceeding the kv cache size
|
||||
# is not part of this test
|
||||
test_pt = TestPoint(num_heads, head_size, attn_backend.name,
|
||||
batch_size, block_size, max_dec_seq_len,
|
||||
max_enc_seq_len, 4096)
|
||||
|
||||
# Attention scale factor, attention backend instance, attention wrapper
|
||||
# instance, KV cache init
|
||||
test_rsrcs = _make_test_resources(test_pt)
|
||||
# Attention scale factor, attention backend instance, attention wrapper
|
||||
# instance, KV cache init
|
||||
test_rsrcs = _make_test_resources(test_pt)
|
||||
|
||||
# Construct encoder attention test params (only used
|
||||
# during prefill)
|
||||
# Construct encoder attention test params (only used
|
||||
# during prefill)
|
||||
|
||||
enc_test_params = _encoder_attn_setup(test_pt, test_rsrcs)
|
||||
enc_test_params = _encoder_attn_setup(test_pt, test_rsrcs)
|
||||
|
||||
# Construct Decoder self-attention prefill-phase & decode-phase
|
||||
# test params, including query/key/value tensors, decoder self-attention
|
||||
# memory-mapping. cross_block_base_addr is the uppermost address in the
|
||||
# decoder self-attention block-table, i.e. a base address which the
|
||||
# encoder/decoder cross-attention block-table may build downward toward.
|
||||
# Construct Decoder self-attention prefill-phase & decode-phase
|
||||
# test params, including query/key/value tensors, decoder self-attention
|
||||
# memory-mapping. cross_block_base_addr is the uppermost address in the
|
||||
# decoder self-attention block-table, i.e. a base address which the
|
||||
# encoder/decoder cross-attention block-table may build downward toward.
|
||||
|
||||
(
|
||||
dec_qkv,
|
||||
prephase_dec_test_params,
|
||||
decphase_dec_test_params,
|
||||
cross_block_base_addr,
|
||||
) = _decoder_attn_setup(test_pt, test_rsrcs)
|
||||
(
|
||||
dec_qkv,
|
||||
prephase_dec_test_params,
|
||||
decphase_dec_test_params,
|
||||
cross_block_base_addr,
|
||||
) = _decoder_attn_setup(test_pt, test_rsrcs)
|
||||
|
||||
# Construct encoder/decoder cross-attention prefill-phase & decode-phase
|
||||
# test params, including key/value tensors, cross-attention memory-mapping
|
||||
# Construct encoder/decoder cross-attention prefill-phase
|
||||
# & decode-phase test params, including key/value tensors,
|
||||
# cross-attention memory-mapping
|
||||
|
||||
(
|
||||
prephase_cross_test_params,
|
||||
decphase_cross_test_params,
|
||||
) = _enc_dec_cross_attn_setup_reuses_query(
|
||||
dec_qkv,
|
||||
enc_test_params,
|
||||
prephase_dec_test_params,
|
||||
test_pt,
|
||||
test_rsrcs,
|
||||
block_base_addr=cross_block_base_addr)
|
||||
(
|
||||
prephase_cross_test_params,
|
||||
decphase_cross_test_params,
|
||||
) = _enc_dec_cross_attn_setup_reuses_query(
|
||||
dec_qkv,
|
||||
enc_test_params,
|
||||
prephase_dec_test_params,
|
||||
test_pt,
|
||||
test_rsrcs,
|
||||
block_base_addr=cross_block_base_addr)
|
||||
|
||||
# Shared prefill metadata structure
|
||||
assert prephase_dec_test_params.packed_qkvo.packed_qkv is not None
|
||||
prephase_attn_metadata: AttentionMetadata = make_test_metadata(
|
||||
test_rsrcs.attn_backend,
|
||||
True,
|
||||
prephase_dec_test_params.packed_qkvo.packed_qkv.q_seq_lens,
|
||||
decoder_test_params=prephase_dec_test_params,
|
||||
encoder_test_params=enc_test_params,
|
||||
cross_test_params=prephase_cross_test_params,
|
||||
device=CUDA_DEVICE)
|
||||
# Shared prefill metadata structure
|
||||
assert prephase_dec_test_params.packed_qkvo.packed_qkv is not None
|
||||
prephase_attn_metadata: AttentionMetadata = make_test_metadata(
|
||||
test_rsrcs.attn_backend,
|
||||
True,
|
||||
prephase_dec_test_params.packed_qkvo.packed_qkv.q_seq_lens,
|
||||
decoder_test_params=prephase_dec_test_params,
|
||||
encoder_test_params=enc_test_params,
|
||||
cross_test_params=prephase_cross_test_params,
|
||||
device=CUDA_DEVICE)
|
||||
|
||||
# PREFILL: encoder attention
|
||||
# PREFILL: encoder attention
|
||||
|
||||
enc_pckd_act_out = _run_encoder_attention_test(test_rsrcs.attn,
|
||||
enc_test_params,
|
||||
prephase_attn_metadata)
|
||||
enc_pckd_act_out = _run_encoder_attention_test(test_rsrcs.attn,
|
||||
enc_test_params,
|
||||
prephase_attn_metadata)
|
||||
|
||||
# - Is encoder attention result correct?
|
||||
assert_actual_matches_ideal(enc_test_params, enc_pckd_act_out)
|
||||
# - Is encoder attention result correct?
|
||||
assert_actual_matches_ideal(enc_test_params, enc_pckd_act_out)
|
||||
|
||||
# PREFILL: decoder self-attention test
|
||||
# PREFILL: decoder self-attention test
|
||||
|
||||
prephase_dec_pckd_act_out = _run_decoder_self_attention_test(
|
||||
test_rsrcs, prephase_dec_test_params, prephase_attn_metadata)
|
||||
prephase_dec_pckd_act_out = _run_decoder_self_attention_test(
|
||||
test_rsrcs, prephase_dec_test_params, prephase_attn_metadata)
|
||||
|
||||
# - Is prefill decoder self-attention correct?
|
||||
assert_actual_matches_ideal(prephase_dec_test_params,
|
||||
prephase_dec_pckd_act_out)
|
||||
# - Is prefill decoder self-attention correct?
|
||||
assert_actual_matches_ideal(prephase_dec_test_params,
|
||||
prephase_dec_pckd_act_out)
|
||||
|
||||
# PREFILL: encoder/decoder cross-attention test
|
||||
# PREFILL: encoder/decoder cross-attention test
|
||||
|
||||
prephase_cross_pckd_act_out = _run_encoder_decoder_cross_attention_test(
|
||||
test_rsrcs, prephase_dec_test_params, prephase_cross_test_params,
|
||||
prephase_attn_metadata)
|
||||
prephase_cross_pckd_act_out = _run_encoder_decoder_cross_attention_test(
|
||||
test_rsrcs, prephase_dec_test_params, prephase_cross_test_params,
|
||||
prephase_attn_metadata)
|
||||
|
||||
# - Is prefill encoder/decoder cross-attention correct?
|
||||
assert_actual_matches_ideal(prephase_cross_test_params,
|
||||
prephase_cross_pckd_act_out)
|
||||
# - Is prefill encoder/decoder cross-attention correct?
|
||||
assert_actual_matches_ideal(prephase_cross_test_params,
|
||||
prephase_cross_pckd_act_out)
|
||||
|
||||
# DECODE: build decode-phase attention metadata
|
||||
# DECODE: build decode-phase attention metadata
|
||||
|
||||
decphase_attn_metadata: AttentionMetadata = make_test_metadata(
|
||||
test_rsrcs.attn_backend,
|
||||
False,
|
||||
dec_qkv.q_seq_lens,
|
||||
decoder_test_params=decphase_dec_test_params,
|
||||
encoder_test_params=enc_test_params,
|
||||
cross_test_params=decphase_cross_test_params,
|
||||
device=CUDA_DEVICE)
|
||||
decphase_attn_metadata: AttentionMetadata = make_test_metadata(
|
||||
test_rsrcs.attn_backend,
|
||||
False,
|
||||
dec_qkv.q_seq_lens,
|
||||
decoder_test_params=decphase_dec_test_params,
|
||||
encoder_test_params=enc_test_params,
|
||||
cross_test_params=decphase_cross_test_params,
|
||||
device=CUDA_DEVICE)
|
||||
|
||||
# DECODE: decoder self-attention test
|
||||
# DECODE: decoder self-attention test
|
||||
|
||||
decphase_dec_pckd_act_out = _run_decoder_self_attention_test(
|
||||
test_rsrcs, decphase_dec_test_params, decphase_attn_metadata)
|
||||
decphase_dec_pckd_act_out = _run_decoder_self_attention_test(
|
||||
test_rsrcs, decphase_dec_test_params, decphase_attn_metadata)
|
||||
|
||||
# - Is decode-phase decoder self-attention correct?
|
||||
assert_actual_matches_ideal(decphase_dec_test_params,
|
||||
decphase_dec_pckd_act_out)
|
||||
# - Is decode-phase decoder self-attention correct?
|
||||
assert_actual_matches_ideal(decphase_dec_test_params,
|
||||
decphase_dec_pckd_act_out)
|
||||
|
||||
# DECODE: encoder/decoder cross-attention test
|
||||
# DECODE: encoder/decoder cross-attention test
|
||||
|
||||
decphase_cross_pckd_act_out = _run_encoder_decoder_cross_attention_test(
|
||||
test_rsrcs, decphase_dec_test_params, None, decphase_attn_metadata)
|
||||
decphase_cross_pckd_act_out = _run_encoder_decoder_cross_attention_test(
|
||||
test_rsrcs, decphase_dec_test_params, None, decphase_attn_metadata)
|
||||
|
||||
# - Is decode-phase encoder/decoder cross-attention correct?
|
||||
assert_actual_matches_ideal(decphase_cross_test_params,
|
||||
decphase_cross_pckd_act_out)
|
||||
# - Is decode-phase encoder/decoder cross-attention correct?
|
||||
assert_actual_matches_ideal(decphase_cross_test_params,
|
||||
decphase_cross_pckd_act_out)
|
||||
|
||||
Reference in New Issue
Block a user