[Kernel] Move attn_type to Attention.__init__() (#11690)
Signed-off-by: Chen Zhang <zhangch99@outlook.com>
This commit is contained in:
@@ -13,8 +13,7 @@ import pytest
|
||||
import torch
|
||||
|
||||
from tests.kernels.utils import *
|
||||
from vllm.attention import (Attention, AttentionBackend, AttentionMetadata,
|
||||
AttentionType)
|
||||
from vllm.attention import Attention, AttentionMetadata, AttentionType
|
||||
from vllm.attention.backends.utils import STR_NOT_IMPL_ENC_DEC_ROCM_HIP
|
||||
from vllm.attention.selector import (_Backend, _cached_get_attn_backend,
|
||||
global_force_attn_backend_context_manager)
|
||||
@@ -64,6 +63,7 @@ class TestPoint(NamedTuple):
|
||||
max_dec_seq_len: int
|
||||
max_enc_seq_len: int
|
||||
num_blocks: int
|
||||
attn_type: AttentionType
|
||||
|
||||
|
||||
class TestResources(NamedTuple):
|
||||
@@ -96,7 +96,6 @@ class TestResources(NamedTuple):
|
||||
'''
|
||||
|
||||
scale: float
|
||||
attn_backend: AttentionBackend
|
||||
attn: Attention
|
||||
kv_cache: torch.Tensor
|
||||
|
||||
@@ -129,16 +128,17 @@ def _make_test_resources(test_pt: TestPoint, ) -> TestResources:
|
||||
'''
|
||||
|
||||
scale = float(1.0 / (test_pt.head_size**0.5))
|
||||
attn_backend = make_backend(test_pt.backend_name)
|
||||
attn = Attention(
|
||||
test_pt.num_heads,
|
||||
test_pt.head_size,
|
||||
scale=scale,
|
||||
prefix=f"{test_pt.attn_type}",
|
||||
attn_type=test_pt.attn_type,
|
||||
)
|
||||
if test_pt.num_blocks is None or test_pt.num_heads is None:
|
||||
# Caller does not require a KV cache
|
||||
return TestResources(
|
||||
scale, attn_backend, attn,
|
||||
scale, attn,
|
||||
torch.tensor([], dtype=torch.float32, device=CUDA_DEVICE))
|
||||
|
||||
# Construct KV cache
|
||||
@@ -148,7 +148,7 @@ def _make_test_resources(test_pt: TestPoint, ) -> TestResources:
|
||||
test_pt.block_size,
|
||||
device=CUDA_DEVICE,
|
||||
backend=test_pt.backend_name)
|
||||
return TestResources(scale, attn_backend, attn, kv_cache)
|
||||
return TestResources(scale, attn, kv_cache)
|
||||
|
||||
|
||||
def _encoder_attn_setup(
|
||||
@@ -193,6 +193,7 @@ def _encoder_attn_setup(
|
||||
_,
|
||||
max_q_seq_len,
|
||||
_,
|
||||
_,
|
||||
) = test_pt
|
||||
|
||||
scale = test_rsrcs.scale
|
||||
@@ -301,6 +302,7 @@ def _decoder_attn_setup(
|
||||
max_q_seq_len,
|
||||
_,
|
||||
_,
|
||||
_,
|
||||
) = test_pt
|
||||
|
||||
scale = test_rsrcs.scale
|
||||
@@ -488,6 +490,7 @@ def _enc_dec_cross_attn_setup_reuses_query(
|
||||
max_decoder_seq_len,
|
||||
max_encoder_seq_len,
|
||||
_,
|
||||
_,
|
||||
) = test_pt
|
||||
|
||||
scale = test_rsrcs.scale
|
||||
@@ -622,7 +625,6 @@ def _run_encoder_attention_test(
|
||||
& attn_metadata
|
||||
'''
|
||||
assert attn_metadata.num_decode_tokens == 0
|
||||
attn_type = AttentionType.ENCODER
|
||||
packed_qkv = encoder_test_params.packed_qkvo.packed_qkv
|
||||
assert packed_qkv is not None
|
||||
with set_forward_context(attn_metadata, vllm_config):
|
||||
@@ -635,14 +637,11 @@ def _run_encoder_attention_test(
|
||||
# is shaped as [num_tokens, hidden_size] and we can skip the reshape.
|
||||
reshaped_query = packed_qkv.query.view(
|
||||
-1, test_pt.num_heads * test_pt.head_size)
|
||||
return attn.forward(reshaped_query,
|
||||
packed_qkv.key,
|
||||
packed_qkv.value,
|
||||
torch.tensor([],
|
||||
dtype=torch.float32,
|
||||
device=packed_qkv.query.device),
|
||||
attn_metadata,
|
||||
attn_type=attn_type)
|
||||
return attn.forward(
|
||||
reshaped_query, packed_qkv.key, packed_qkv.value,
|
||||
torch.tensor([],
|
||||
dtype=torch.float32,
|
||||
device=packed_qkv.query.device), attn_metadata)
|
||||
|
||||
|
||||
def _run_decoder_self_attention_test(
|
||||
@@ -675,7 +674,6 @@ def _run_decoder_self_attention_test(
|
||||
* Attention.forward() applied to packed_{query,key,value}, kv_cache
|
||||
& attn_metadata
|
||||
'''
|
||||
attn_type = AttentionType.DECODER
|
||||
attn = test_rsrcs.attn
|
||||
kv_cache = test_rsrcs.kv_cache
|
||||
packed_qkv = decoder_test_params.packed_qkvo.packed_qkv
|
||||
@@ -690,12 +688,8 @@ def _run_decoder_self_attention_test(
|
||||
# is shaped as [num_tokens, hidden_size] and we can skip the reshape.
|
||||
reshaped_query = packed_qkv.query.view(
|
||||
-1, test_pt.num_heads * test_pt.head_size)
|
||||
return attn.forward(reshaped_query,
|
||||
packed_qkv.key,
|
||||
packed_qkv.value,
|
||||
kv_cache,
|
||||
attn_metadata,
|
||||
attn_type=attn_type)
|
||||
return attn.forward(reshaped_query, packed_qkv.key, packed_qkv.value,
|
||||
kv_cache, attn_metadata)
|
||||
|
||||
|
||||
def _run_encoder_decoder_cross_attention_test(
|
||||
@@ -742,7 +736,6 @@ def _run_encoder_decoder_cross_attention_test(
|
||||
'''
|
||||
assert decoder_test_params.packed_qkvo.packed_qkv is not None
|
||||
|
||||
attn_type = AttentionType.ENCODER_DECODER
|
||||
attn = test_rsrcs.attn
|
||||
kv_cache = test_rsrcs.kv_cache
|
||||
if cross_test_params is None:
|
||||
@@ -762,12 +755,8 @@ def _run_encoder_decoder_cross_attention_test(
|
||||
# is shaped as [num_tokens, hidden_size] and we can skip the reshape.
|
||||
reshaped_query = decoder_test_params.packed_qkvo.packed_qkv.query.view(
|
||||
-1, test_pt.num_heads * test_pt.head_size)
|
||||
return attn.forward(reshaped_query,
|
||||
key,
|
||||
value,
|
||||
kv_cache,
|
||||
attn_metadata,
|
||||
attn_type=attn_type)
|
||||
return attn.forward(reshaped_query, key, value, kv_cache,
|
||||
attn_metadata)
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
@@ -839,7 +828,7 @@ def test_encoder_only(
|
||||
# is not part of this test
|
||||
test_pt = TestPoint(num_heads, head_size, attn_backend.name,
|
||||
batch_size, block_size, max_dec_seq_len,
|
||||
max_enc_seq_len, 4096)
|
||||
max_enc_seq_len, 4096, AttentionType.ENCODER)
|
||||
|
||||
# Attention scale factor, attention backend instance, attention wrapper
|
||||
# instance, KV cache init
|
||||
@@ -855,7 +844,7 @@ def test_encoder_only(
|
||||
# Shared prefill metadata structure
|
||||
|
||||
prephase_attn_metadata: AttentionMetadata = make_test_metadata(
|
||||
test_rsrcs.attn_backend,
|
||||
attn_backend,
|
||||
True,
|
||||
None,
|
||||
decoder_test_params=None,
|
||||
@@ -961,20 +950,29 @@ def test_e2e_enc_dec_attn(
|
||||
# Note: KV cache size of 4096 is arbitrary & chosen intentionally
|
||||
# to be more than necessary, since exceeding the kv cache size
|
||||
# is not part of this test
|
||||
test_pt = TestPoint(num_heads, head_size, attn_backend.name,
|
||||
batch_size, block_size, max_dec_seq_len,
|
||||
max_enc_seq_len, 4096)
|
||||
enc_test_pt = TestPoint(num_heads, head_size, attn_backend.name,
|
||||
batch_size, block_size, max_dec_seq_len,
|
||||
max_enc_seq_len, 4096, AttentionType.ENCODER)
|
||||
enc_dec_test_pt = TestPoint(num_heads, head_size, attn_backend.name,
|
||||
batch_size, block_size, max_dec_seq_len,
|
||||
max_enc_seq_len, 4096,
|
||||
AttentionType.ENCODER_DECODER)
|
||||
dec_test_pt = TestPoint(num_heads, head_size, attn_backend.name,
|
||||
batch_size, block_size, max_dec_seq_len,
|
||||
max_enc_seq_len, 4096, AttentionType.DECODER)
|
||||
|
||||
# Attention scale factor, attention backend instance, attention wrapper
|
||||
# instance, KV cache init
|
||||
vllm_config = VllmConfig()
|
||||
with set_current_vllm_config(vllm_config):
|
||||
test_rsrcs = _make_test_resources(test_pt)
|
||||
enc_test_rsrcs = _make_test_resources(enc_test_pt)
|
||||
enc_dec_test_rsrcs = _make_test_resources(enc_dec_test_pt)
|
||||
dec_test_rsrcs = _make_test_resources(dec_test_pt)
|
||||
|
||||
# Construct encoder attention test params (only used
|
||||
# during prefill)
|
||||
|
||||
enc_test_params = _encoder_attn_setup(test_pt, test_rsrcs)
|
||||
enc_test_params = _encoder_attn_setup(enc_test_pt, enc_test_rsrcs)
|
||||
|
||||
# Construct Decoder self-attention prefill-phase & decode-phase
|
||||
# test params, including query/key/value tensors, decoder self-attention
|
||||
@@ -987,7 +985,7 @@ def test_e2e_enc_dec_attn(
|
||||
prephase_dec_test_params,
|
||||
decphase_dec_test_params,
|
||||
cross_block_base_addr,
|
||||
) = _decoder_attn_setup(test_pt, test_rsrcs)
|
||||
) = _decoder_attn_setup(dec_test_pt, dec_test_rsrcs)
|
||||
|
||||
# Construct encoder/decoder cross-attention prefill-phase
|
||||
# & decode-phase test params, including key/value tensors,
|
||||
@@ -1000,14 +998,14 @@ def test_e2e_enc_dec_attn(
|
||||
dec_qkv,
|
||||
enc_test_params,
|
||||
prephase_dec_test_params,
|
||||
test_pt,
|
||||
test_rsrcs,
|
||||
enc_dec_test_pt,
|
||||
enc_dec_test_rsrcs,
|
||||
block_base_addr=cross_block_base_addr)
|
||||
|
||||
# Shared prefill metadata structure
|
||||
assert prephase_dec_test_params.packed_qkvo.packed_qkv is not None
|
||||
prephase_attn_metadata: AttentionMetadata = make_test_metadata(
|
||||
test_rsrcs.attn_backend,
|
||||
attn_backend,
|
||||
True,
|
||||
prephase_dec_test_params.packed_qkvo.packed_qkv.q_seq_lens,
|
||||
decoder_test_params=prephase_dec_test_params,
|
||||
@@ -1017,10 +1015,10 @@ def test_e2e_enc_dec_attn(
|
||||
|
||||
# PREFILL: encoder attention
|
||||
|
||||
enc_pckd_act_out = _run_encoder_attention_test(test_rsrcs.attn,
|
||||
enc_pckd_act_out = _run_encoder_attention_test(enc_test_rsrcs.attn,
|
||||
enc_test_params,
|
||||
prephase_attn_metadata,
|
||||
test_pt=test_pt,
|
||||
test_pt=enc_test_pt,
|
||||
vllm_config=vllm_config)
|
||||
|
||||
# - Is encoder attention result correct?
|
||||
@@ -1030,10 +1028,10 @@ def test_e2e_enc_dec_attn(
|
||||
# PREFILL: decoder self-attention test
|
||||
|
||||
prephase_dec_pckd_act_out = _run_decoder_self_attention_test(
|
||||
test_rsrcs,
|
||||
dec_test_rsrcs,
|
||||
prephase_dec_test_params,
|
||||
prephase_attn_metadata,
|
||||
test_pt=test_pt,
|
||||
test_pt=dec_test_pt,
|
||||
vllm_config=vllm_config)
|
||||
|
||||
# - Is prefill decoder self-attention correct?
|
||||
@@ -1044,11 +1042,11 @@ def test_e2e_enc_dec_attn(
|
||||
# PREFILL: encoder/decoder cross-attention test
|
||||
|
||||
prephase_cross_pckd_act_out = _run_encoder_decoder_cross_attention_test(
|
||||
test_rsrcs,
|
||||
enc_dec_test_rsrcs,
|
||||
prephase_dec_test_params,
|
||||
prephase_cross_test_params,
|
||||
prephase_attn_metadata,
|
||||
test_pt=test_pt,
|
||||
test_pt=enc_dec_test_pt,
|
||||
vllm_config=vllm_config)
|
||||
|
||||
# - Is prefill encoder/decoder cross-attention correct?
|
||||
@@ -1059,7 +1057,7 @@ def test_e2e_enc_dec_attn(
|
||||
# DECODE: build decode-phase attention metadata
|
||||
|
||||
decphase_attn_metadata: AttentionMetadata = make_test_metadata(
|
||||
test_rsrcs.attn_backend,
|
||||
attn_backend,
|
||||
False,
|
||||
dec_qkv.q_seq_lens,
|
||||
decoder_test_params=decphase_dec_test_params,
|
||||
@@ -1070,10 +1068,10 @@ def test_e2e_enc_dec_attn(
|
||||
# DECODE: decoder self-attention test
|
||||
|
||||
decphase_dec_pckd_act_out = _run_decoder_self_attention_test(
|
||||
test_rsrcs,
|
||||
dec_test_rsrcs,
|
||||
decphase_dec_test_params,
|
||||
decphase_attn_metadata,
|
||||
test_pt=test_pt,
|
||||
test_pt=dec_test_pt,
|
||||
vllm_config=vllm_config)
|
||||
|
||||
# - Is decode-phase decoder self-attention correct?
|
||||
@@ -1084,11 +1082,11 @@ def test_e2e_enc_dec_attn(
|
||||
# DECODE: encoder/decoder cross-attention test
|
||||
|
||||
decphase_cross_pckd_act_out = _run_encoder_decoder_cross_attention_test(
|
||||
test_rsrcs,
|
||||
enc_dec_test_rsrcs,
|
||||
decphase_dec_test_params,
|
||||
None,
|
||||
decphase_attn_metadata,
|
||||
test_pt=test_pt,
|
||||
test_pt=enc_dec_test_pt,
|
||||
vllm_config=vllm_config)
|
||||
|
||||
# - Is decode-phase encoder/decoder cross-attention correct?
|
||||
|
||||
Reference in New Issue
Block a user