[Kernel] Correctly invoke prefill & decode kernels for cross-attention (towards eventual encoder/decoder model support) (#4888)
Co-authored-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
953
tests/kernels/test_encoder_decoder_attn.py
Normal file
953
tests/kernels/test_encoder_decoder_attn.py
Normal file
@@ -0,0 +1,953 @@
|
||||
"""
|
||||
Tests:
|
||||
|
||||
* E2E test of Encoder attention + Decoder self-attention +
|
||||
Encoder/decoder cross-attention (collectively
|
||||
"encoder/decoder attention")
|
||||
* Confirm enc/dec models will fail for chunked prefill
|
||||
* Confirm enc/dec models will fail for prefix caching
|
||||
|
||||
"""
|
||||
|
||||
from typing import NamedTuple, Optional
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from tests.kernels.utils import *
|
||||
from tests.kernels.utils import make_causal_mask, maybe_make_long_tensor
|
||||
from vllm.attention import Attention, AttentionMetadata
|
||||
from vllm.attention.backends.abstract import AttentionBackend, AttentionType
|
||||
from vllm.attention.backends.utils import STR_NOT_IMPL_ENC_DEC_ROCM_HIP
|
||||
from vllm.utils import is_hip
|
||||
|
||||
HEAD_SIZES = [64, 256]
|
||||
|
||||
NUM_HEADS = [1, 16]
|
||||
|
||||
BATCH_SIZES = [1, 16]
|
||||
BLOCK_SIZES = [16]
|
||||
BACKEND_NAMES = [STR_XFORMERS_ATTN_VAL]
|
||||
CUDA_DEVICE = "cuda:0"
|
||||
|
||||
MAX_DEC_SEQ_LENS = [128]
|
||||
MAX_ENC_SEQ_LENS = [128]
|
||||
|
||||
# Narrow teest-cases for unsupported-scenario
|
||||
# tests
|
||||
HEAD_SIZES_FOR_UNSUPP = [HEAD_SIZES[0]]
|
||||
|
||||
|
||||
class TestPoint(NamedTuple):
|
||||
"""
|
||||
Encapsulates the attributes which define a single invocation
|
||||
of the test_e2e_enc_dec_attn() test
|
||||
|
||||
Attributes:
|
||||
num_heads: The number of heads in the model.
|
||||
head_size: Head dimension
|
||||
backend_name: Name of the backend framework used.
|
||||
batch_size: Number of samples per batch.
|
||||
block_size: Size of each block of data processed.
|
||||
max_dec_seq_len: Maximum sequence length for the decoder.
|
||||
max_enc_seq_len: Maximum sequence length for the encoder.
|
||||
num_blocks: Number of blocks in the model.
|
||||
"""
|
||||
|
||||
num_heads: int
|
||||
head_size: int
|
||||
backend_name: str
|
||||
batch_size: int
|
||||
block_size: int
|
||||
max_dec_seq_len: int
|
||||
max_enc_seq_len: int
|
||||
num_blocks: int
|
||||
|
||||
|
||||
class TestResources(NamedTuple):
|
||||
'''
|
||||
Encapsulates key components for performing an
|
||||
encoder/decoder attention test
|
||||
|
||||
Note that
|
||||
(1) attn automatically selects an attention backend
|
||||
based on platform info & a set of canned
|
||||
heuristics
|
||||
(2) attn_backend is thus *not the same backend
|
||||
instance* used by attn, but rather it is
|
||||
intended to be a
|
||||
*different instance* of the *same backend class*;
|
||||
it is assumed that the user of TestResources
|
||||
will leverage attn_backend for the purpose of
|
||||
constructing backend-compatible attention
|
||||
metadata instances
|
||||
|
||||
Attributes:
|
||||
|
||||
* scale: 1/sqrt(d) scale factor for attn
|
||||
* attn_backend: implementatino of abstraction
|
||||
attention interface using
|
||||
a particular kernel library
|
||||
i.e. XFormers
|
||||
* attn: Attention layer instance
|
||||
* kv_cache: shared key/value cache for all attention
|
||||
'''
|
||||
|
||||
scale: float
|
||||
attn_backend: AttentionBackend
|
||||
attn: Attention
|
||||
kv_cache: torch.Tensor
|
||||
|
||||
|
||||
def _make_test_resources(test_pt: TestPoint, ) -> TestResources:
|
||||
'''
|
||||
Build key components for performing encoder/decoder attention test.
|
||||
|
||||
Note that
|
||||
(1) The Attention instance constructed here, automatically selects
|
||||
an attention backend class based on platform info & a set of canned
|
||||
heuristics, so
|
||||
(2) The attention backend instance constructed here is thus *not
|
||||
the same backend instance* used by attn, but rather it is
|
||||
intended to be a *different instance* of the *same backend class*;
|
||||
therefore,
|
||||
(3) This function requires that test_pt.backend_name matches the backend
|
||||
class that Attention will automatically select when it is constructed.
|
||||
|
||||
|
||||
Arguments:
|
||||
|
||||
* test_pt: TestPoint data structure; this function relies on the
|
||||
following fields: num_heads, head_size, num_blocks,
|
||||
block_size, backend_name
|
||||
|
||||
Returns:
|
||||
|
||||
* TestResources data structure.
|
||||
'''
|
||||
|
||||
scale = float(1.0 / (test_pt.head_size**0.5))
|
||||
attn_backend = make_backend(test_pt.backend_name)
|
||||
attn = Attention(
|
||||
test_pt.num_heads,
|
||||
test_pt.head_size,
|
||||
scale=scale,
|
||||
)
|
||||
if test_pt.num_blocks is None or test_pt.num_heads is None:
|
||||
# Caller does not require a KV cache
|
||||
return TestResources(scale, attn_backend, attn, None)
|
||||
|
||||
# Construct KV cache
|
||||
kv_cache = make_kv_cache(test_pt.num_blocks,
|
||||
test_pt.num_heads,
|
||||
test_pt.head_size,
|
||||
test_pt.block_size,
|
||||
device=CUDA_DEVICE)
|
||||
return TestResources(scale, attn_backend, attn, kv_cache)
|
||||
|
||||
|
||||
def _encoder_attn_setup(
|
||||
test_pt: TestPoint,
|
||||
test_rsrcs: TestResources,
|
||||
) -> PhaseTestParameters:
|
||||
'''
|
||||
Set up test vectors & data structures for encoder attention test.
|
||||
|
||||
A triplet of synthetic query/key/value tensors are constructed.
|
||||
Given this is an encoder attention test, the key & value
|
||||
sequences will have the same length as the corresponding queries.
|
||||
|
||||
The query/key/value tensors are passed to an ideal reference
|
||||
self-attention implementation to generate an ideal output tensor.
|
||||
|
||||
Encoder inference does not populate the KV cache, therefore
|
||||
no KV cache memory mapping is constructed
|
||||
|
||||
Arguments:
|
||||
|
||||
* test_pt: TestPoint data structure; this function relies on the
|
||||
following fields: batch_size, num_heads, head_size,
|
||||
block_size, max_q_seq_len
|
||||
* test_rsrcs: TestResources data structure; this function relies on the
|
||||
scale field
|
||||
|
||||
|
||||
Returns:
|
||||
|
||||
* PhaseTestParameters data structure comprising (1) packed query/key/value
|
||||
tensors, (2) the ideal output of attention computed using a naive
|
||||
implementation, and (3) KVCache field set to None
|
||||
'''
|
||||
|
||||
(
|
||||
num_heads,
|
||||
head_size,
|
||||
_,
|
||||
batch_size,
|
||||
_,
|
||||
_,
|
||||
max_q_seq_len,
|
||||
_,
|
||||
) = test_pt
|
||||
|
||||
scale = test_rsrcs.scale
|
||||
|
||||
max_kv_seq_len = max_q_seq_len
|
||||
|
||||
# Make test tensors
|
||||
|
||||
qkv_in, _, _ = make_qkv(batch_size,
|
||||
max_q_seq_len,
|
||||
max_kv_seq_len,
|
||||
num_heads,
|
||||
head_size,
|
||||
attn_type=AttentionType.ENCODER,
|
||||
device=CUDA_DEVICE)
|
||||
|
||||
# Compute correct answer using naive non-causal attention
|
||||
# implementation
|
||||
|
||||
ideal_output = ref_masked_attention(qkv_in.query,
|
||||
qkv_in.key,
|
||||
qkv_in.value,
|
||||
scale=scale,
|
||||
q_seq_lens=qkv_in.q_seq_lens,
|
||||
kv_seq_lens=qkv_in.kv_seq_lens)
|
||||
|
||||
packed_ideal_output, _ = pack_tensor(ideal_output,
|
||||
qkv_in.q_seq_lens,
|
||||
device=CUDA_DEVICE)
|
||||
|
||||
packed_qkv = pack_qkv(qkv_in, device=CUDA_DEVICE)
|
||||
|
||||
return PhaseTestParameters(
|
||||
PackedQKVO(packed_qkv, packed_ideal_output),
|
||||
None # No KV cache
|
||||
)
|
||||
|
||||
|
||||
def _decoder_attn_setup(
|
||||
test_pt: TestPoint,
|
||||
test_rsrcs: TestResources,
|
||||
block_base_addr: int = 0,
|
||||
) -> Tuple[QKVInputs, PhaseTestParameters, PhaseTestParameters, int]:
|
||||
'''
|
||||
Set up test vectors & data structures for self-attention test.
|
||||
|
||||
A triplet of synthetic query/key/value tensors are constructed ("baseline"
|
||||
query/key/value). Given this is a self-attention test, the key & value
|
||||
sequences will have the same length as the corresponding queries.
|
||||
|
||||
"Prefill" query/key/value tensors are derived by masking out the last value
|
||||
in each baseline query/key/value. These tensors are used to test prefill &
|
||||
populate KV cache for a subsequent decode test.
|
||||
|
||||
"Decode" query/key/value tensors are derived by extracting *only* the last
|
||||
value from each baseline query/key/value (i.e. complement of the prefill
|
||||
tensors.) These tensors are used to test decode, conditional on the kv cache
|
||||
being populated during the prefill test.
|
||||
|
||||
The baseline query/key/value tensors are passed to an ideal reference
|
||||
self-attention implementation to generate a "Baseline" ideal output tensor.
|
||||
This tensor is split into the "Prefill" ideal output tensor (all but the
|
||||
last element of each output sequence) and the "Decode" ideal output tensor
|
||||
(*only* the last element of each output sequence); the "Prefill" and
|
||||
"Decode" ideal output tensors can be used to validate the prefill and decode
|
||||
test results, respectively.
|
||||
|
||||
This function also constructs the self-attention KV cache memory mapping
|
||||
(slot mapping and block table), ensuring that the block table starts at
|
||||
block_base_addr
|
||||
|
||||
Arguments:
|
||||
|
||||
* test_pt: TestPoint data structure; this function relies on the
|
||||
following fields: batch_size, num_heads, head_size,
|
||||
block_size, max_q_seq_len
|
||||
* test_rsrcs: TestResources data structure; this function relies on the
|
||||
scale field
|
||||
* block_base_addr: decoder self-attention block-table base address
|
||||
|
||||
Returns:
|
||||
* qkv: Unpacked (batch_size x padded_seq_len x num_heads x
|
||||
head_size) query/key/value tensors
|
||||
* Prefill-phase decoder self-attention PhaseTestParameters data structure,
|
||||
including (1) packed (number_of_tokens x num_heads x head_size)
|
||||
query/key/value tensors along with (2) ideal attention output
|
||||
computed using a naive implementation, and (3) memory-mapping data
|
||||
structures appropriate for prefill phase.
|
||||
* Decode-phase decoder self-attention PhaseTestParameters data structure,
|
||||
including (1) packed (number_of_tokens x num_heads x head_size)
|
||||
query/key/value tensors along with (2) ideal attention output
|
||||
computed using a naive implementation, and (3) memory-mapping data
|
||||
structures appropriate for decode phase.
|
||||
* max_block_idx: max physical address in decoder self-attention block-table
|
||||
(intended to be used as the base address for the encoder/
|
||||
decoder cross-attention block-table, which is not
|
||||
constructed in this function)
|
||||
'''
|
||||
|
||||
(
|
||||
num_heads,
|
||||
head_size,
|
||||
_,
|
||||
batch_size,
|
||||
block_size,
|
||||
max_q_seq_len,
|
||||
_,
|
||||
_,
|
||||
) = test_pt
|
||||
|
||||
scale = test_rsrcs.scale
|
||||
|
||||
max_kv_seq_len = max_q_seq_len
|
||||
|
||||
# Build test tensors
|
||||
|
||||
(
|
||||
qkv,
|
||||
prefill_qkv,
|
||||
decode_qkv,
|
||||
) = make_qkv(batch_size,
|
||||
max_q_seq_len,
|
||||
max_kv_seq_len,
|
||||
num_heads,
|
||||
head_size,
|
||||
attn_type=AttentionType.DECODER,
|
||||
device=CUDA_DEVICE)
|
||||
|
||||
# Compute correct answer using naive attention implementation
|
||||
# with causal attention mask
|
||||
|
||||
causal_mask = make_causal_mask(max_q_seq_len,
|
||||
max_kv_seq_len).to(CUDA_DEVICE)
|
||||
|
||||
ideal_output = ref_masked_attention(qkv.query,
|
||||
qkv.key,
|
||||
qkv.value,
|
||||
scale=scale,
|
||||
custom_mask=causal_mask,
|
||||
q_seq_lens=qkv.q_seq_lens,
|
||||
kv_seq_lens=qkv.kv_seq_lens)
|
||||
|
||||
# Split out the prefill- & decode-phase ideal answers & pack them
|
||||
|
||||
prefill_ideal_output = torch.zeros_like(ideal_output)
|
||||
decode_ideal_output = torch.zeros_like(ideal_output[:, 0:1])
|
||||
for bdx, prefill_q_seq_len in enumerate(prefill_qkv.q_seq_lens):
|
||||
prefill_ideal_output[bdx, :prefill_q_seq_len] = ideal_output[
|
||||
bdx, :prefill_q_seq_len]
|
||||
decode_ideal_output[bdx, :] = ideal_output[bdx, prefill_q_seq_len:(
|
||||
prefill_q_seq_len + 1)]
|
||||
|
||||
prefill_packed_ideal_output, _ = pack_tensor(prefill_ideal_output,
|
||||
prefill_qkv.q_seq_lens,
|
||||
device=CUDA_DEVICE)
|
||||
decode_packed_ideal_output, _ = pack_tensor(decode_ideal_output,
|
||||
[1 for _ in range(batch_size)],
|
||||
device=CUDA_DEVICE)
|
||||
|
||||
# Build prefill- & decode-phase data structures
|
||||
# for decoder self-attention. Block tables and
|
||||
# slot mapping must be in a format compatible
|
||||
# with KV caching & attention kernels
|
||||
#
|
||||
# Prefill-phase:
|
||||
#
|
||||
# * Empty block-tables tensor
|
||||
# * Slot-mapping with entries for prompt tokens
|
||||
#
|
||||
# Decode-phase:
|
||||
# * Block-tables tensor with minimum number of blocks
|
||||
# required by total num. tokens in the entirety of all sequences
|
||||
# (including both prefill & decode)
|
||||
# * Slot-mapping with entries for tokens that will be decoded in the
|
||||
# current decode iteration
|
||||
#
|
||||
# Note: the format described above is simply mirroring what ModelRunner
|
||||
# produces
|
||||
|
||||
prefill_block_tables = make_empty_block_tables_tensor(device=CUDA_DEVICE)
|
||||
|
||||
(
|
||||
decode_block_tables,
|
||||
slot_mapping_list,
|
||||
max_block_idx,
|
||||
) = make_block_tables_slot_mapping(block_size,
|
||||
qkv.q_seq_lens,
|
||||
device=CUDA_DEVICE,
|
||||
block_base_addr=block_base_addr)
|
||||
|
||||
(
|
||||
prefill_slot_mapping,
|
||||
decode_slot_mapping,
|
||||
) = split_slot_mapping(slot_mapping_list,
|
||||
qkv.q_seq_lens,
|
||||
device=CUDA_DEVICE)
|
||||
|
||||
prefill_pckd_qkv = pack_qkv(prefill_qkv, device=CUDA_DEVICE)
|
||||
|
||||
decode_pckd_qkv = pack_qkv(decode_qkv, device=CUDA_DEVICE)
|
||||
|
||||
return (
|
||||
qkv,
|
||||
PhaseTestParameters( # Prefill test params
|
||||
PackedQKVO(prefill_pckd_qkv, prefill_packed_ideal_output),
|
||||
KVMemoryMap(prefill_block_tables, prefill_slot_mapping)),
|
||||
PhaseTestParameters( # Decode test params
|
||||
PackedQKVO(decode_pckd_qkv, decode_packed_ideal_output),
|
||||
KVMemoryMap(decode_block_tables, decode_slot_mapping)),
|
||||
max_block_idx)
|
||||
|
||||
|
||||
def _enc_dec_cross_attn_setup_reuses_query(
|
||||
decoder_qkv: QKVInputs,
|
||||
encoder_test_params: PhaseTestParameters,
|
||||
prefill_decoder_phase_test_params: PhaseTestParameters,
|
||||
test_pt: TestPoint,
|
||||
test_rsrcs: TestResources,
|
||||
block_base_addr: int = 0,
|
||||
) -> Tuple[PhaseTestParameters, PhaseTestParameters]:
|
||||
'''
|
||||
Set up test vectors & data structures for cross-attention test.
|
||||
|
||||
A triplet of synthetic cross-attention key/value tensors are constructed
|
||||
("baseline" key/value). Given this is a cross-attention test, we assume
|
||||
query tensors were already synthesized for a prior self-attention test and
|
||||
will be reused for cross-attention. The key & value sequences generated here
|
||||
may have a different length than the corresponding queries (as is often
|
||||
the case for cross-attention between decoder and encoder sequences.)
|
||||
|
||||
Cross attention key & value tensors do not grow during autoregressive
|
||||
inference; thus this function obtains a single key/value pair suitable for
|
||||
both prefill and decode.
|
||||
|
||||
The "baseline" query tensor is received as an argument. The "baseline"
|
||||
query/key/value tensors are passed to an ideal reference cross-attention
|
||||
implementation to generate a "baseline" ideal output tensor. This tensor is
|
||||
split into the "Prefill" ideal output tensor (all but the last element of
|
||||
each output sequence) and the "Decode" ideal output tensor (*only* the last
|
||||
element of each output sequence); the "Prefill" and "Decode" ideal output
|
||||
tensors can be used to validate the prefill and decode test results,
|
||||
respectively.
|
||||
|
||||
This function also constructs the cross-attention KV cache memory mapping
|
||||
(slot mapping and block table), ensuring that the block table starts at
|
||||
block_base_addr.
|
||||
|
||||
Arguments:
|
||||
|
||||
* decoder_qkv: pre-existing unpacked (batch_size x padded_seq_len x
|
||||
num_heads x head_size) decoder self-attention inputs;
|
||||
this function relies on the query and q_seq_lens
|
||||
fields
|
||||
* encoder_test_params: PhaseTestParameters data structure which was
|
||||
used for encoder inference; KV cache field
|
||||
is not used by this function
|
||||
* prefill_decoder_phase_test_params: PhaseTestParameters data structure
|
||||
used for prefill-phase decoder
|
||||
self-attention; all fields
|
||||
including KV cache required
|
||||
* test_pt: TestPoint data structure; this function relies on the
|
||||
following fields: batch_size, num_heads, head_size,
|
||||
block_size, max_q_seq_len
|
||||
* test_rsrcs: TestResources data structure; this function relies on the
|
||||
scale field
|
||||
* block_base_addr: decoder self-attention block-table base address
|
||||
|
||||
Returns:
|
||||
|
||||
* Prefill-phase encoder/decoder cross-attention PhaseTestParameters data
|
||||
structure, including (1) packed
|
||||
(number_of_tokens x num_heads x head_size) query/key/value tensors
|
||||
along with (2) ideal attention output computed using a
|
||||
naive implementation, and (3) memory-mapping data structures appropriate
|
||||
for prefill phase.
|
||||
* Decode-phase encoder/decoder cross-attention PhaseTestParameters data
|
||||
structure, including (1) packed
|
||||
(number_of_tokens x num_heads x head_size) query/key/value tensors
|
||||
along with (2) ideal attention output computed using a
|
||||
naive implementation, and (3) memory-mapping data structures appropriate
|
||||
for decode phase.
|
||||
'''
|
||||
|
||||
assert encoder_test_params.packed_qkvo.packed_qkv is not None
|
||||
assert prefill_decoder_phase_test_params.packed_qkvo.packed_qkv is not None
|
||||
|
||||
(
|
||||
num_heads,
|
||||
head_size,
|
||||
_,
|
||||
batch_size,
|
||||
block_size,
|
||||
max_decoder_seq_len,
|
||||
max_encoder_seq_len,
|
||||
_,
|
||||
) = test_pt
|
||||
|
||||
scale = test_rsrcs.scale
|
||||
|
||||
decoder_query = decoder_qkv.query
|
||||
decoder_seq_lens = decoder_qkv.q_seq_lens
|
||||
encoder_seq_lens = encoder_test_params.packed_qkvo.packed_qkv.q_seq_lens
|
||||
prefill_q_seq_lens = (
|
||||
prefill_decoder_phase_test_params.packed_qkvo.packed_qkv.q_seq_lens)
|
||||
|
||||
assert prefill_q_seq_lens is not None
|
||||
|
||||
(
|
||||
cross_kv,
|
||||
_,
|
||||
_,
|
||||
) = make_qkv(batch_size,
|
||||
max_decoder_seq_len,
|
||||
max_encoder_seq_len,
|
||||
num_heads,
|
||||
head_size,
|
||||
force_kv_seq_lens=encoder_seq_lens,
|
||||
attn_type=AttentionType.ENCODER_DECODER,
|
||||
device=CUDA_DEVICE)
|
||||
|
||||
ideal_output = ref_masked_attention(decoder_query,
|
||||
cross_kv.key,
|
||||
cross_kv.value,
|
||||
scale=scale,
|
||||
q_seq_lens=decoder_seq_lens,
|
||||
kv_seq_lens=cross_kv.kv_seq_lens)
|
||||
|
||||
prefill_ideal_output = torch.zeros_like(ideal_output)
|
||||
decode_ideal_output = torch.zeros_like(ideal_output[:, 0:1])
|
||||
for bdx, prefill_q_seq_len in enumerate(prefill_q_seq_lens):
|
||||
prefill_ideal_output[bdx, :prefill_q_seq_len] = ideal_output[
|
||||
bdx, :prefill_q_seq_len]
|
||||
decode_ideal_output[bdx, :] = ideal_output[bdx, prefill_q_seq_len:(
|
||||
prefill_q_seq_len + 1)]
|
||||
|
||||
prefill_packed_ideal_output, _ = pack_tensor(prefill_ideal_output,
|
||||
prefill_q_seq_lens,
|
||||
device=CUDA_DEVICE)
|
||||
decode_packed_ideal_output, _ = pack_tensor(decode_ideal_output,
|
||||
[1 for _ in range(batch_size)],
|
||||
device=CUDA_DEVICE)
|
||||
|
||||
# Build prefill- & decode-phase data structures
|
||||
# for encoder/decoder cross-attention. Block tables and
|
||||
# slot mapping must be in a format compatible
|
||||
# with KV caching & attention kernels
|
||||
#
|
||||
# Whereas decoder self-attention extracts relationships between
|
||||
# equal-length Q/K/V sequences, which mutually grow in length
|
||||
# with each decoded token, cross-attention relates the Q sequence
|
||||
# - which grows with each new decoded token - to fixed-length
|
||||
# K and V sequences derived from the encoder hidden states.
|
||||
#
|
||||
# Prefill-phase:
|
||||
#
|
||||
# * Empty block-tables tensor
|
||||
# * Slot-mapping with as many entries as there are tokens in the encoder
|
||||
# prompt.
|
||||
#
|
||||
# Decode-phase:
|
||||
# * Block-tables tensor with minimum number of blocks to
|
||||
# accommodate K & V tensors which are equal in lnegth
|
||||
# to the encoder prompt length
|
||||
# * Empty slot-mapping tensor (since K & V are fixed in size,
|
||||
# new decoded tokens are not KV-cached and require no slot-
|
||||
# mapping)
|
||||
#
|
||||
# Note: the format above is simply an extension of what ModelRunner
|
||||
# produces for decoder-only models
|
||||
|
||||
prefill_block_tables = make_empty_block_tables_tensor(device=CUDA_DEVICE)
|
||||
decode_slot_mapping = make_empty_slot_mapping_tensor(device=CUDA_DEVICE)
|
||||
|
||||
(
|
||||
decode_block_tables,
|
||||
prefill_slot_mapping_list,
|
||||
_,
|
||||
) = make_block_tables_slot_mapping(block_size,
|
||||
cross_kv.kv_seq_lens,
|
||||
block_base_addr=block_base_addr,
|
||||
device=CUDA_DEVICE)
|
||||
|
||||
prefill_slot_mapping = maybe_make_long_tensor(prefill_slot_mapping_list,
|
||||
device=CUDA_DEVICE)
|
||||
|
||||
# Packed key/value (query is already provided)
|
||||
packed_cross_kv = pack_qkv(cross_kv, device=CUDA_DEVICE)
|
||||
|
||||
return (
|
||||
PhaseTestParameters( # Prefill-phase test params
|
||||
PackedQKVO(packed_cross_kv, prefill_packed_ideal_output),
|
||||
KVMemoryMap(prefill_block_tables, prefill_slot_mapping)),
|
||||
PhaseTestParameters( # Decode-phase test params
|
||||
PackedQKVO(None, decode_packed_ideal_output),
|
||||
KVMemoryMap(decode_block_tables, decode_slot_mapping)))
|
||||
|
||||
|
||||
def _run_encoder_attention_test(
|
||||
attn: Attention,
|
||||
encoder_test_params: PhaseTestParameters,
|
||||
attn_metadata: AttentionMetadata,
|
||||
) -> torch.Tensor:
|
||||
'''
|
||||
Run encoder attention.
|
||||
|
||||
attn.forward() is passed attn_type=AttentionType.ENCODER in order
|
||||
to configure the kernel invocation for encoder attention
|
||||
|
||||
Requires attn_metadata.num_decode_tokens == 0
|
||||
(There is no encoder execution in the decode-phase)
|
||||
|
||||
Arguments:
|
||||
|
||||
* attn: Attention wrapper instance
|
||||
* encoder_test_params: encoder PhaseTestParameters data structure;
|
||||
this function relies on the packed
|
||||
(number_of_tokens x num_heads x head_size)
|
||||
query/key/value fields
|
||||
* attn_metadata: attention metadata for encoder/decoder-self attention
|
||||
|
||||
Returns:
|
||||
* Attention.forward() applied to packed {query,key,value} and
|
||||
& attn_metadata
|
||||
'''
|
||||
assert attn_metadata.num_decode_tokens == 0
|
||||
attn_type = AttentionType.ENCODER
|
||||
packed_qkv = encoder_test_params.packed_qkvo.packed_qkv
|
||||
assert packed_qkv is not None
|
||||
return attn.forward(packed_qkv.query,
|
||||
packed_qkv.key,
|
||||
packed_qkv.value,
|
||||
None,
|
||||
attn_metadata,
|
||||
attn_type=attn_type)
|
||||
|
||||
|
||||
def _run_decoder_self_attention_test(
|
||||
test_rsrcs: TestResources,
|
||||
decoder_test_params: PhaseTestParameters,
|
||||
attn_metadata: AttentionMetadata,
|
||||
) -> torch.Tensor:
|
||||
'''
|
||||
Run decoder self-attention test.
|
||||
|
||||
attn.forward() is passed attn_type=AttentionType.DECODER
|
||||
in order to configure the kernel invocation for decoder self-attention.
|
||||
|
||||
Arguments:
|
||||
|
||||
* test_rsrcs: TestResources instance; this function relies on the kv_cache
|
||||
and attn (Attention wrapper instance) fields
|
||||
* decoder_test_params: decoder PhaseTestParameters data structure;
|
||||
this function relies on the packed
|
||||
(number_of_tokens x num_heads x head_size)
|
||||
query/key/value fields
|
||||
* attn_metadata: attention metadata for decoder-self attention
|
||||
(contains KV cache memory-mapping)
|
||||
|
||||
Returns:
|
||||
* Attention.forward() applied to packed_{query,key,value}, kv_cache
|
||||
& attn_metadata
|
||||
'''
|
||||
attn_type = AttentionType.DECODER
|
||||
attn = test_rsrcs.attn
|
||||
kv_cache = test_rsrcs.kv_cache
|
||||
packed_qkv = decoder_test_params.packed_qkvo.packed_qkv
|
||||
assert packed_qkv is not None
|
||||
return attn.forward(packed_qkv.query,
|
||||
packed_qkv.key,
|
||||
packed_qkv.value,
|
||||
kv_cache,
|
||||
attn_metadata,
|
||||
attn_type=attn_type)
|
||||
|
||||
|
||||
def _run_encoder_decoder_cross_attention_test(
|
||||
test_rsrcs: TestResources,
|
||||
decoder_test_params: PhaseTestParameters,
|
||||
cross_test_params: Optional[PhaseTestParameters],
|
||||
attn_metadata: AttentionMetadata,
|
||||
) -> torch.Tensor:
|
||||
'''
|
||||
Run encoder/decoder cross-attention test.
|
||||
|
||||
Via PhaseTestParameters data structures, consumes the same query utilized
|
||||
for decoder self-attention, plus a key/value specific to cross-attention.
|
||||
|
||||
if cross_test_params is None or cross_test_params.packed_qkvo.packed_qkv
|
||||
is None, this reflects that in decode-phase cross attention there
|
||||
is no growth in the key and value tensors.
|
||||
|
||||
attn.forward() is passed attn_type=AttentionType.ENCODER_DECODER
|
||||
in order to configure the kernel invocation for encoder/decoder cross-
|
||||
attention.
|
||||
|
||||
Arguments:
|
||||
|
||||
* test_rsrcs: TestResources instance; this function relies on the kv_cache
|
||||
and attn (Attention wrapper instance) fields
|
||||
* decoder_test_params: decoder PhaseTestParameters data structure;
|
||||
this function relies on the packed
|
||||
(number_of_tokens x num_heads x head_size)
|
||||
query field
|
||||
* cross_test_params: encoder/decoder PhaseTestParameters data structure;
|
||||
this function relies on the packed
|
||||
(number_of_tokens x num_heads x head_size)
|
||||
key/value fields
|
||||
* attn_metadata: attention metadata for encoder/decoder-self attention
|
||||
|
||||
Returns:
|
||||
* Attention.forward() applied to packed_{query,key,value}, kv_cache
|
||||
& attn_metadata
|
||||
'''
|
||||
assert decoder_test_params.packed_qkvo.packed_qkv is not None
|
||||
|
||||
attn_type = AttentionType.ENCODER_DECODER
|
||||
attn = test_rsrcs.attn
|
||||
kv_cache = test_rsrcs.kv_cache
|
||||
if cross_test_params is None:
|
||||
key = None
|
||||
value = None
|
||||
else:
|
||||
cross_pckd_qkv = cross_test_params.packed_qkvo.packed_qkv
|
||||
key = (None if cross_pckd_qkv is None else cross_pckd_qkv.key)
|
||||
value = (None if cross_pckd_qkv is None else cross_pckd_qkv.value)
|
||||
return attn.forward(decoder_test_params.packed_qkvo.packed_qkv.query,
|
||||
key,
|
||||
value,
|
||||
kv_cache,
|
||||
attn_metadata,
|
||||
attn_type=attn_type)
|
||||
|
||||
|
||||
@pytest.mark.skipif(is_hip(), reason=STR_NOT_IMPL_ENC_DEC_ROCM_HIP)
|
||||
@pytest.mark.parametrize("num_heads", NUM_HEADS)
|
||||
@pytest.mark.parametrize("head_size", HEAD_SIZES)
|
||||
@pytest.mark.parametrize("backend_name", BACKEND_NAMES)
|
||||
@pytest.mark.parametrize("batch_size", BATCH_SIZES)
|
||||
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
|
||||
@pytest.mark.parametrize("max_dec_seq_len", MAX_DEC_SEQ_LENS)
|
||||
@pytest.mark.parametrize("max_enc_seq_len", MAX_ENC_SEQ_LENS)
|
||||
def test_encoder_only(num_heads: int, head_size: int, backend_name: str,
|
||||
batch_size: int, block_size: int, max_dec_seq_len: int,
|
||||
max_enc_seq_len: int, monkeypatch):
|
||||
|
||||
# Force Attention wrapper backend
|
||||
override_backend_env_variable(monkeypatch, backend_name)
|
||||
|
||||
# Note: KV cache size of 4096 is arbitrary & chosen intentionally
|
||||
# to be more than necessary, since exceeding the kv cache size
|
||||
# is not part of this test
|
||||
test_pt = TestPoint(num_heads, head_size, backend_name, batch_size,
|
||||
block_size, max_dec_seq_len, max_enc_seq_len, 4096)
|
||||
|
||||
# Attention scale factor, attention backend instance, attention wrapper
|
||||
# instance, KV cache init
|
||||
test_rsrcs = _make_test_resources(test_pt)
|
||||
|
||||
# Construct encoder attention test params (only used
|
||||
# during prefill)
|
||||
|
||||
enc_test_params = _encoder_attn_setup(test_pt, test_rsrcs)
|
||||
|
||||
# Shared prefill metadata structure
|
||||
|
||||
prephase_attn_metadata: AttentionMetadata = make_test_metadata(
|
||||
test_rsrcs.attn_backend,
|
||||
True,
|
||||
None,
|
||||
decoder_test_params=None,
|
||||
encoder_test_params=enc_test_params,
|
||||
cross_test_params=None,
|
||||
device=CUDA_DEVICE)
|
||||
|
||||
# PREFILL: encoder attention
|
||||
|
||||
enc_pckd_act_out: torch.Tensor = (_run_encoder_attention_test(
|
||||
test_rsrcs.attn, enc_test_params, prephase_attn_metadata))
|
||||
|
||||
# - Is encoder attention result correct?
|
||||
assert_actual_matches_ideal(enc_test_params, enc_pckd_act_out)
|
||||
|
||||
|
||||
@pytest.mark.skipif(is_hip(), reason=STR_NOT_IMPL_ENC_DEC_ROCM_HIP)
|
||||
@pytest.mark.parametrize("num_heads", NUM_HEADS)
|
||||
@pytest.mark.parametrize("head_size", HEAD_SIZES)
|
||||
@pytest.mark.parametrize("backend_name", BACKEND_NAMES)
|
||||
@pytest.mark.parametrize("batch_size", BATCH_SIZES)
|
||||
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
|
||||
@pytest.mark.parametrize("max_dec_seq_len", MAX_DEC_SEQ_LENS)
|
||||
@pytest.mark.parametrize("max_enc_seq_len", MAX_ENC_SEQ_LENS)
|
||||
def test_e2e_enc_dec_attn(
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
backend_name: str,
|
||||
batch_size: int,
|
||||
block_size: int,
|
||||
max_dec_seq_len: int,
|
||||
max_enc_seq_len: int,
|
||||
monkeypatch,
|
||||
) -> None:
|
||||
'''
|
||||
End-to-end encoder/decoder test:
|
||||
|
||||
* Construct fake test vectors for (1) encoder attention,
|
||||
(2) decoder self-attention, and (3) encoder/decoder cross-attention
|
||||
* Construct (1) attention metadata structure with self- and cross-attention
|
||||
attributes for prefill-phase, and (2) an analogous attention metadata
|
||||
structure but for decode-phase
|
||||
* Test attention steps in the following order
|
||||
|
||||
* Encoder attention
|
||||
* Prefill self-attention
|
||||
* Prefill cross-attention
|
||||
* Decode self-attention
|
||||
* Decode cross-attention
|
||||
* Besides being reflective of realistic use-cases, this order would
|
||||
exacerbate any accidental overlap in the self-/cross-attention
|
||||
block tables, which one hopes to avoid
|
||||
|
||||
|
||||
* Validate output correctness against ideal reference attention
|
||||
implementation
|
||||
|
||||
Block tables are constructed such that cross-attention KV cache is in a
|
||||
higher, non-intersecting address-space than self-attention KV cache.
|
||||
|
||||
Self- and cross-attention share the same query tensor but not the K/V
|
||||
tensors. Self-attention K/Vs must have the same seq len as Q while
|
||||
cross-attention K/Vs are allowed to differ in seq len, as is often the case
|
||||
for cross-attention.
|
||||
|
||||
This test utilizes PyTest monkey patching to force the attention backend
|
||||
via an environment variable.
|
||||
|
||||
Note on ROCm/HIP: currently encoder/decoder models are not supported on
|
||||
AMD GPUs, therefore this test simply is skipped if is_hip().
|
||||
|
||||
Note on metadata: there is a single attention metadata structure shared by
|
||||
all prefill-phase attention operations (encoder, decoder, enc/dec cross),
|
||||
and a single one shared by all decode-phase attention operations
|
||||
(decoder & enc/dec cross.) This is intended to reflect the behavior
|
||||
of ModelRunner, which constructs a single attention metadata structure for
|
||||
each prefill or decode run. A realistic scenario would rely on the
|
||||
attention backend to utilize the appropriate attention metadata fields
|
||||
according to the value of attn_metadata.attention_type. Thus, this test is
|
||||
organized so as to confirm that the backend-under-test can handle a
|
||||
shared prefill attention metadata structure & a shared decode attention
|
||||
metadata structure.
|
||||
'''
|
||||
|
||||
# Force Attention wrapper backend
|
||||
override_backend_env_variable(monkeypatch, backend_name)
|
||||
|
||||
# Note: KV cache size of 4096 is arbitrary & chosen intentionally
|
||||
# to be more than necessary, since exceeding the kv cache size
|
||||
# is not part of this test
|
||||
test_pt = TestPoint(num_heads, head_size, backend_name, batch_size,
|
||||
block_size, max_dec_seq_len, max_enc_seq_len, 4096)
|
||||
|
||||
# Attention scale factor, attention backend instance, attention wrapper
|
||||
# instance, KV cache init
|
||||
test_rsrcs = _make_test_resources(test_pt)
|
||||
|
||||
# Construct encoder attention test params (only used
|
||||
# during prefill)
|
||||
|
||||
enc_test_params = _encoder_attn_setup(test_pt, test_rsrcs)
|
||||
|
||||
# Construct Decoder self-attention prefill-phase & decode-phase
|
||||
# test params, including query/key/value tensors, decoder self-attention
|
||||
# memory-mapping. cross_block_base_addr is the uppermost address in the
|
||||
# decoder self-attention block-table, i.e. a base address which the
|
||||
# encoder/decoder cross-attention block-table may build downward toward.
|
||||
|
||||
(
|
||||
dec_qkv,
|
||||
prephase_dec_test_params,
|
||||
decphase_dec_test_params,
|
||||
cross_block_base_addr,
|
||||
) = _decoder_attn_setup(test_pt, test_rsrcs)
|
||||
|
||||
# Construct encoder/decoder cross-attention prefill-phase & decode-phase
|
||||
# test params, including key/value tensors, cross-attention memory-mapping
|
||||
|
||||
(
|
||||
prephase_cross_test_params,
|
||||
decphase_cross_test_params,
|
||||
) = _enc_dec_cross_attn_setup_reuses_query(
|
||||
dec_qkv,
|
||||
enc_test_params,
|
||||
prephase_dec_test_params,
|
||||
test_pt,
|
||||
test_rsrcs,
|
||||
block_base_addr=cross_block_base_addr)
|
||||
|
||||
# Shared prefill metadata structure
|
||||
assert prephase_dec_test_params.packed_qkvo.packed_qkv is not None
|
||||
prephase_attn_metadata: AttentionMetadata = make_test_metadata(
|
||||
test_rsrcs.attn_backend,
|
||||
True,
|
||||
prephase_dec_test_params.packed_qkvo.packed_qkv.q_seq_lens,
|
||||
decoder_test_params=prephase_dec_test_params,
|
||||
encoder_test_params=enc_test_params,
|
||||
cross_test_params=prephase_cross_test_params,
|
||||
device=CUDA_DEVICE)
|
||||
|
||||
# PREFILL: encoder attention
|
||||
|
||||
enc_pckd_act_out = _run_encoder_attention_test(test_rsrcs.attn,
|
||||
enc_test_params,
|
||||
prephase_attn_metadata)
|
||||
|
||||
# - Is encoder attention result correct?
|
||||
assert_actual_matches_ideal(enc_test_params, enc_pckd_act_out)
|
||||
|
||||
# PREFILL: decoder self-attention test
|
||||
|
||||
prephase_dec_pckd_act_out = _run_decoder_self_attention_test(
|
||||
test_rsrcs, prephase_dec_test_params, prephase_attn_metadata)
|
||||
|
||||
# - Is prefill decoder self-attention correct?
|
||||
assert_actual_matches_ideal(prephase_dec_test_params,
|
||||
prephase_dec_pckd_act_out)
|
||||
|
||||
# PREFILL: encoder/decoder cross-attention test
|
||||
|
||||
prephase_cross_pckd_act_out = _run_encoder_decoder_cross_attention_test(
|
||||
test_rsrcs, prephase_dec_test_params, prephase_cross_test_params,
|
||||
prephase_attn_metadata)
|
||||
|
||||
# - Is prefill encoder/decoder cross-attention correct?
|
||||
assert_actual_matches_ideal(prephase_cross_test_params,
|
||||
prephase_cross_pckd_act_out)
|
||||
|
||||
# DECODE: build decode-phase attention metadata
|
||||
|
||||
decphase_attn_metadata: AttentionMetadata = make_test_metadata(
|
||||
test_rsrcs.attn_backend,
|
||||
False,
|
||||
dec_qkv.q_seq_lens,
|
||||
decoder_test_params=decphase_dec_test_params,
|
||||
encoder_test_params=enc_test_params,
|
||||
cross_test_params=decphase_cross_test_params,
|
||||
device=CUDA_DEVICE)
|
||||
|
||||
# DECODE: decoder self-attention test
|
||||
|
||||
decphase_dec_pckd_act_out = _run_decoder_self_attention_test(
|
||||
test_rsrcs, decphase_dec_test_params, decphase_attn_metadata)
|
||||
|
||||
# - Is decode-phase decoder self-attention correct?
|
||||
assert_actual_matches_ideal(decphase_dec_test_params,
|
||||
decphase_dec_pckd_act_out)
|
||||
|
||||
# DECODE: encoder/decoder cross-attention test
|
||||
|
||||
decphase_cross_pckd_act_out = _run_encoder_decoder_cross_attention_test(
|
||||
test_rsrcs, decphase_dec_test_params, None, decphase_attn_metadata)
|
||||
|
||||
# - Is decode-phase encoder/decoder cross-attention correct?
|
||||
assert_actual_matches_ideal(decphase_cross_test_params,
|
||||
decphase_cross_pckd_act_out)
|
||||
Reference in New Issue
Block a user