[Core] Subclass ModelRunner to support cross-attention & encoder sequences (towards eventual encoder/decoder model support) (#4942)

Co-authored-by: Andrew Feldman <afeld2012@gmail.com>
Co-authored-by: Nick Hill <nickhill@us.ibm.com>
This commit is contained in:
afeldman-nm
2024-08-06 16:51:47 -04:00
committed by GitHub
parent 660470e5a3
commit fd95e026e0
33 changed files with 3957 additions and 333 deletions

View File

@@ -4,8 +4,6 @@ Tests:
* E2E test of Encoder attention + Decoder self-attention +
Encoder/decoder cross-attention (collectively
"encoder/decoder attention")
* Confirm enc/dec models will fail for chunked prefill
* Confirm enc/dec models will fail for prefix caching
"""
@@ -15,19 +13,22 @@ import pytest
import torch
from tests.kernels.utils import *
from tests.kernels.utils import make_causal_mask, maybe_make_long_tensor
from vllm.attention import Attention, AttentionMetadata
from vllm.attention.backends.abstract import AttentionBackend, AttentionType
from vllm.attention import (Attention, AttentionBackend, AttentionMetadata,
AttentionType)
from vllm.attention.backends.utils import STR_NOT_IMPL_ENC_DEC_ROCM_HIP
from vllm.attention.selector import (_Backend,
global_force_attn_backend_context_manager)
from vllm.utils import is_hip
# List of support backends for encoder/decoder models
LIST_ENC_DEC_SUPPORTED_BACKENDS = [_Backend.XFORMERS]
HEAD_SIZES = [64, 256]
NUM_HEADS = [1, 16]
BATCH_SIZES = [1, 16]
BLOCK_SIZES = [16]
BACKEND_NAMES = [STR_XFORMERS_ATTN_VAL]
CUDA_DEVICE = "cuda:0"
MAX_DEC_SEQ_LENS = [128]
@@ -724,57 +725,92 @@ def _run_encoder_decoder_cross_attention_test(
@pytest.mark.skipif(is_hip(), reason=STR_NOT_IMPL_ENC_DEC_ROCM_HIP)
@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("backend_name", BACKEND_NAMES)
@pytest.mark.parametrize("attn_backend", LIST_ENC_DEC_SUPPORTED_BACKENDS)
@pytest.mark.parametrize("batch_size", BATCH_SIZES)
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
@pytest.mark.parametrize("max_dec_seq_len", MAX_DEC_SEQ_LENS)
@pytest.mark.parametrize("max_enc_seq_len", MAX_ENC_SEQ_LENS)
def test_encoder_only(num_heads: int, head_size: int, backend_name: str,
batch_size: int, block_size: int, max_dec_seq_len: int,
max_enc_seq_len: int, monkeypatch):
def test_encoder_only(
num_heads: int,
head_size: int,
attn_backend: _Backend,
batch_size: int,
block_size: int,
max_dec_seq_len: int,
max_enc_seq_len: int,
):
'''
End-to-end encoder-only attention test:
* Construct fake test vectors for (1) encoder attention
* Construct (1) attention metadata structure with prefill-phase
encoder attention, and (2) an analogous attention metadata
structure but for decode-phase
* Test & validate encoder attention against ideal output
No KV cache is required for encoder-only attention.
Note on ROCm/HIP: currently encoder/decoder models are not supported on
AMD GPUs, therefore this test simply is skipped if is_hip().
This test globally forces an override of the usual backend
auto-selection process, forcing the specific backend-under-test
to be utilized.
Arguments:
* num_heads
* head_size,
* attn_backend: The attention backend to employ for testing
* batch_size
* block_size: KV cache block size
* max_dec_seq_len: max length of decoder input sequences
* max_enc_seq_len: max length of encoder input sequences
'''
# Force Attention wrapper backend
override_backend_env_variable(monkeypatch, backend_name)
with global_force_attn_backend_context_manager(attn_backend):
# Note: KV cache size of 4096 is arbitrary & chosen intentionally
# to be more than necessary, since exceeding the kv cache size
# is not part of this test
test_pt = TestPoint(num_heads, head_size, backend_name, batch_size,
block_size, max_dec_seq_len, max_enc_seq_len, 4096)
# Note: KV cache size of 4096 is arbitrary & chosen intentionally
# to be more than necessary, since exceeding the kv cache size
# is not part of this test
test_pt = TestPoint(num_heads, head_size, attn_backend.name,
batch_size, block_size, max_dec_seq_len,
max_enc_seq_len, 4096)
# Attention scale factor, attention backend instance, attention wrapper
# instance, KV cache init
test_rsrcs = _make_test_resources(test_pt)
# Attention scale factor, attention backend instance, attention wrapper
# instance, KV cache init
test_rsrcs = _make_test_resources(test_pt)
# Construct encoder attention test params (only used
# during prefill)
# Construct encoder attention test params (only used
# during prefill)
enc_test_params = _encoder_attn_setup(test_pt, test_rsrcs)
enc_test_params = _encoder_attn_setup(test_pt, test_rsrcs)
# Shared prefill metadata structure
# Shared prefill metadata structure
prephase_attn_metadata: AttentionMetadata = make_test_metadata(
test_rsrcs.attn_backend,
True,
None,
decoder_test_params=None,
encoder_test_params=enc_test_params,
cross_test_params=None,
device=CUDA_DEVICE)
prephase_attn_metadata: AttentionMetadata = make_test_metadata(
test_rsrcs.attn_backend,
True,
None,
decoder_test_params=None,
encoder_test_params=enc_test_params,
cross_test_params=None,
device=CUDA_DEVICE)
# PREFILL: encoder attention
# PREFILL: encoder attention
enc_pckd_act_out: torch.Tensor = (_run_encoder_attention_test(
test_rsrcs.attn, enc_test_params, prephase_attn_metadata))
enc_pckd_act_out: torch.Tensor = (_run_encoder_attention_test(
test_rsrcs.attn, enc_test_params, prephase_attn_metadata))
# - Is encoder attention result correct?
assert_actual_matches_ideal(enc_test_params, enc_pckd_act_out)
# - Is encoder attention result correct?
assert_actual_matches_ideal(enc_test_params, enc_pckd_act_out)
@pytest.mark.skipif(is_hip(), reason=STR_NOT_IMPL_ENC_DEC_ROCM_HIP)
@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("backend_name", BACKEND_NAMES)
@pytest.mark.parametrize("attn_backend", LIST_ENC_DEC_SUPPORTED_BACKENDS)
@pytest.mark.parametrize("batch_size", BATCH_SIZES)
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
@pytest.mark.parametrize("max_dec_seq_len", MAX_DEC_SEQ_LENS)
@@ -782,12 +818,11 @@ def test_encoder_only(num_heads: int, head_size: int, backend_name: str,
def test_e2e_enc_dec_attn(
num_heads: int,
head_size: int,
backend_name: str,
attn_backend: _Backend,
batch_size: int,
block_size: int,
max_dec_seq_len: int,
max_enc_seq_len: int,
monkeypatch,
) -> None:
'''
End-to-end encoder/decoder test:
@@ -820,8 +855,9 @@ def test_e2e_enc_dec_attn(
cross-attention K/Vs are allowed to differ in seq len, as is often the case
for cross-attention.
This test utilizes PyTest monkey patching to force the attention backend
via an environment variable.
This test globally forces an override of the usual backend
auto-selection process, forcing the specific backend-under-test
to be utilized.
Note on ROCm/HIP: currently encoder/decoder models are not supported on
AMD GPUs, therefore this test simply is skipped if is_hip().
@@ -830,124 +866,136 @@ def test_e2e_enc_dec_attn(
all prefill-phase attention operations (encoder, decoder, enc/dec cross),
and a single one shared by all decode-phase attention operations
(decoder & enc/dec cross.) This is intended to reflect the behavior
of ModelRunner, which constructs a single attention metadata structure for
each prefill or decode run. A realistic scenario would rely on the
attention backend to utilize the appropriate attention metadata fields
according to the value of attn_metadata.attention_type. Thus, this test is
organized so as to confirm that the backend-under-test can handle a
shared prefill attention metadata structure & a shared decode attention
metadata structure.
of EncoderDecoderModelRunner, which constructs a single attention metadata
structure for each prefill or decode run. A realistic scenario would rely
on the attention backend to utilize the appropriate attention metadata
fields according to the value of attn_metadata.attention_type. Thus,
this test is organized so as to confirm that the backend-under-test can
handle a shared prefill attention metadata structure & a shared decode\
attention metadata structure.
Arguments:
* num_heads
* head_size,
* attn_backend: The attention backend to employ for testing
* batch_size
* block_size: KV cache block size
* max_dec_seq_len: max length of decoder input sequences
* max_enc_seq_len: max length of encoder input sequences
'''
# Force Attention wrapper backend
override_backend_env_variable(monkeypatch, backend_name)
with global_force_attn_backend_context_manager(attn_backend):
# Note: KV cache size of 4096 is arbitrary & chosen intentionally
# to be more than necessary, since exceeding the kv cache size
# is not part of this test
test_pt = TestPoint(num_heads, head_size, backend_name, batch_size,
block_size, max_dec_seq_len, max_enc_seq_len, 4096)
# Note: KV cache size of 4096 is arbitrary & chosen intentionally
# to be more than necessary, since exceeding the kv cache size
# is not part of this test
test_pt = TestPoint(num_heads, head_size, attn_backend.name,
batch_size, block_size, max_dec_seq_len,
max_enc_seq_len, 4096)
# Attention scale factor, attention backend instance, attention wrapper
# instance, KV cache init
test_rsrcs = _make_test_resources(test_pt)
# Attention scale factor, attention backend instance, attention wrapper
# instance, KV cache init
test_rsrcs = _make_test_resources(test_pt)
# Construct encoder attention test params (only used
# during prefill)
# Construct encoder attention test params (only used
# during prefill)
enc_test_params = _encoder_attn_setup(test_pt, test_rsrcs)
enc_test_params = _encoder_attn_setup(test_pt, test_rsrcs)
# Construct Decoder self-attention prefill-phase & decode-phase
# test params, including query/key/value tensors, decoder self-attention
# memory-mapping. cross_block_base_addr is the uppermost address in the
# decoder self-attention block-table, i.e. a base address which the
# encoder/decoder cross-attention block-table may build downward toward.
# Construct Decoder self-attention prefill-phase & decode-phase
# test params, including query/key/value tensors, decoder self-attention
# memory-mapping. cross_block_base_addr is the uppermost address in the
# decoder self-attention block-table, i.e. a base address which the
# encoder/decoder cross-attention block-table may build downward toward.
(
dec_qkv,
prephase_dec_test_params,
decphase_dec_test_params,
cross_block_base_addr,
) = _decoder_attn_setup(test_pt, test_rsrcs)
(
dec_qkv,
prephase_dec_test_params,
decphase_dec_test_params,
cross_block_base_addr,
) = _decoder_attn_setup(test_pt, test_rsrcs)
# Construct encoder/decoder cross-attention prefill-phase & decode-phase
# test params, including key/value tensors, cross-attention memory-mapping
# Construct encoder/decoder cross-attention prefill-phase
# & decode-phase test params, including key/value tensors,
# cross-attention memory-mapping
(
prephase_cross_test_params,
decphase_cross_test_params,
) = _enc_dec_cross_attn_setup_reuses_query(
dec_qkv,
enc_test_params,
prephase_dec_test_params,
test_pt,
test_rsrcs,
block_base_addr=cross_block_base_addr)
(
prephase_cross_test_params,
decphase_cross_test_params,
) = _enc_dec_cross_attn_setup_reuses_query(
dec_qkv,
enc_test_params,
prephase_dec_test_params,
test_pt,
test_rsrcs,
block_base_addr=cross_block_base_addr)
# Shared prefill metadata structure
assert prephase_dec_test_params.packed_qkvo.packed_qkv is not None
prephase_attn_metadata: AttentionMetadata = make_test_metadata(
test_rsrcs.attn_backend,
True,
prephase_dec_test_params.packed_qkvo.packed_qkv.q_seq_lens,
decoder_test_params=prephase_dec_test_params,
encoder_test_params=enc_test_params,
cross_test_params=prephase_cross_test_params,
device=CUDA_DEVICE)
# Shared prefill metadata structure
assert prephase_dec_test_params.packed_qkvo.packed_qkv is not None
prephase_attn_metadata: AttentionMetadata = make_test_metadata(
test_rsrcs.attn_backend,
True,
prephase_dec_test_params.packed_qkvo.packed_qkv.q_seq_lens,
decoder_test_params=prephase_dec_test_params,
encoder_test_params=enc_test_params,
cross_test_params=prephase_cross_test_params,
device=CUDA_DEVICE)
# PREFILL: encoder attention
# PREFILL: encoder attention
enc_pckd_act_out = _run_encoder_attention_test(test_rsrcs.attn,
enc_test_params,
prephase_attn_metadata)
enc_pckd_act_out = _run_encoder_attention_test(test_rsrcs.attn,
enc_test_params,
prephase_attn_metadata)
# - Is encoder attention result correct?
assert_actual_matches_ideal(enc_test_params, enc_pckd_act_out)
# - Is encoder attention result correct?
assert_actual_matches_ideal(enc_test_params, enc_pckd_act_out)
# PREFILL: decoder self-attention test
# PREFILL: decoder self-attention test
prephase_dec_pckd_act_out = _run_decoder_self_attention_test(
test_rsrcs, prephase_dec_test_params, prephase_attn_metadata)
prephase_dec_pckd_act_out = _run_decoder_self_attention_test(
test_rsrcs, prephase_dec_test_params, prephase_attn_metadata)
# - Is prefill decoder self-attention correct?
assert_actual_matches_ideal(prephase_dec_test_params,
prephase_dec_pckd_act_out)
# - Is prefill decoder self-attention correct?
assert_actual_matches_ideal(prephase_dec_test_params,
prephase_dec_pckd_act_out)
# PREFILL: encoder/decoder cross-attention test
# PREFILL: encoder/decoder cross-attention test
prephase_cross_pckd_act_out = _run_encoder_decoder_cross_attention_test(
test_rsrcs, prephase_dec_test_params, prephase_cross_test_params,
prephase_attn_metadata)
prephase_cross_pckd_act_out = _run_encoder_decoder_cross_attention_test(
test_rsrcs, prephase_dec_test_params, prephase_cross_test_params,
prephase_attn_metadata)
# - Is prefill encoder/decoder cross-attention correct?
assert_actual_matches_ideal(prephase_cross_test_params,
prephase_cross_pckd_act_out)
# - Is prefill encoder/decoder cross-attention correct?
assert_actual_matches_ideal(prephase_cross_test_params,
prephase_cross_pckd_act_out)
# DECODE: build decode-phase attention metadata
# DECODE: build decode-phase attention metadata
decphase_attn_metadata: AttentionMetadata = make_test_metadata(
test_rsrcs.attn_backend,
False,
dec_qkv.q_seq_lens,
decoder_test_params=decphase_dec_test_params,
encoder_test_params=enc_test_params,
cross_test_params=decphase_cross_test_params,
device=CUDA_DEVICE)
decphase_attn_metadata: AttentionMetadata = make_test_metadata(
test_rsrcs.attn_backend,
False,
dec_qkv.q_seq_lens,
decoder_test_params=decphase_dec_test_params,
encoder_test_params=enc_test_params,
cross_test_params=decphase_cross_test_params,
device=CUDA_DEVICE)
# DECODE: decoder self-attention test
# DECODE: decoder self-attention test
decphase_dec_pckd_act_out = _run_decoder_self_attention_test(
test_rsrcs, decphase_dec_test_params, decphase_attn_metadata)
decphase_dec_pckd_act_out = _run_decoder_self_attention_test(
test_rsrcs, decphase_dec_test_params, decphase_attn_metadata)
# - Is decode-phase decoder self-attention correct?
assert_actual_matches_ideal(decphase_dec_test_params,
decphase_dec_pckd_act_out)
# - Is decode-phase decoder self-attention correct?
assert_actual_matches_ideal(decphase_dec_test_params,
decphase_dec_pckd_act_out)
# DECODE: encoder/decoder cross-attention test
# DECODE: encoder/decoder cross-attention test
decphase_cross_pckd_act_out = _run_encoder_decoder_cross_attention_test(
test_rsrcs, decphase_dec_test_params, None, decphase_attn_metadata)
decphase_cross_pckd_act_out = _run_encoder_decoder_cross_attention_test(
test_rsrcs, decphase_dec_test_params, None, decphase_attn_metadata)
# - Is decode-phase encoder/decoder cross-attention correct?
assert_actual_matches_ideal(decphase_cross_test_params,
decphase_cross_pckd_act_out)
# - Is decode-phase encoder/decoder cross-attention correct?
assert_actual_matches_ideal(decphase_cross_test_params,
decphase_cross_pckd_act_out)