[Kernel] Move attn_type to Attention.__init__() (#11690)
Signed-off-by: Chen Zhang <zhangch99@outlook.com>
This commit is contained in:
@@ -13,6 +13,7 @@ from torch._prims_common import TensorLikeType
|
||||
|
||||
from vllm.attention import AttentionBackend, AttentionMetadata, AttentionType
|
||||
from vllm.model_executor.layers.activation import SiluAndMul
|
||||
from vllm.platforms.interface import _Backend
|
||||
from vllm.utils import (STR_BACKEND_ENV_VAR, STR_FLASH_ATTN_VAL,
|
||||
STR_XFORMERS_ATTN_VAL, make_tensor_with_pad)
|
||||
|
||||
@@ -790,7 +791,7 @@ def make_block_tables_slot_mapping(
|
||||
|
||||
|
||||
def make_test_metadata(
|
||||
attn_backend: AttentionBackend,
|
||||
attn_backend: _Backend,
|
||||
is_prompt: bool,
|
||||
seq_lens: Optional[List[int]],
|
||||
decoder_test_params: Optional[PhaseTestParameters],
|
||||
@@ -815,7 +816,7 @@ def make_test_metadata(
|
||||
|
||||
Arguments:
|
||||
|
||||
* attn_backend: Backend for sourcing attention kernels
|
||||
* attn_backend_name: Backend for sourcing attention kernels
|
||||
* is_prompt: prefill if True, o/w decode
|
||||
* seq_lens: list of token counts for each sequence
|
||||
* decoder_test_params: decoder self-attention test params;
|
||||
@@ -882,6 +883,8 @@ def make_test_metadata(
|
||||
# (kv_mmap)
|
||||
cross_kv_mmap = cross_test_params.kv_mmap
|
||||
|
||||
attn_backend_obj = make_backend(attn_backend.name)
|
||||
|
||||
if is_prompt:
|
||||
# Prefill-phase scenario
|
||||
|
||||
@@ -902,8 +905,7 @@ def make_test_metadata(
|
||||
context_lens,
|
||||
encoder_seq_lens,
|
||||
device=device)
|
||||
|
||||
return attn_backend.make_metadata(
|
||||
return attn_backend_obj.make_metadata(
|
||||
num_prefills=num_prefills,
|
||||
slot_mapping=(None if kv_mmap is None else kv_mmap.slot_mapping),
|
||||
multi_modal_placeholder_index_maps=None,
|
||||
@@ -952,7 +954,7 @@ def make_test_metadata(
|
||||
encoder_seq_lens,
|
||||
device=device)
|
||||
|
||||
return attn_backend.make_metadata(
|
||||
return attn_backend_obj.make_metadata(
|
||||
num_prefills=num_prefills,
|
||||
slot_mapping=kv_mmap.slot_mapping,
|
||||
multi_modal_placeholder_index_maps=None,
|
||||
|
||||
Reference in New Issue
Block a user